diff --git a/.gitignore b/.gitignore index c6e843c4d..40ce4a299 100644 --- a/.gitignore +++ b/.gitignore @@ -160,5 +160,8 @@ docs/conf.py /temp /temp*.py +# Serena +/.serena + # determinism test snapshots /test/resources/determinism/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 7f65ba55e..4f2b7c8ef 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,265 @@ -# Changelog +# Change Log + +## Upcoming Release 2.0.0 + +This major release of Tianshou is a big step towards cleaner design and improved usability. + +Given the large extent of the changes, it was not possible to maintain compatibility with the previous version. + * Persisted agents that were created with earlier versions cannot be loaded in v2. + * Source code from v1 can, however, be migrated to v2 with minimal effort. + See migration information below. For concrete examples, you may use git to diff individual + example scripts with the corresponding ones in `v1.2.0`. + +This release is brought to you by [Applied AI Institute gGmbH](https://www.appliedai-institute.de). + +Developers: + * Dr. Dominik Jain (@opcode81) + * Michael Panchenko (@MischaPanch) + +### Trainer Abstraction + +* The trainer logic and configuration is now properly separated between the three cases of on-policy, off-policy + and offline learning: The base class is no longer a "God" class (formerly `BaseTrainer`) which does it all; logic and functionality has moved + to the respective subclasses (`OnPolicyTrainer`, `OffPolicyTrainer` and `OfflineTrainer`, with `OnlineTrainer` + being introduced as a base class for the two former specialisations). + +* The trainers now use configuration objects with central documentation (which has been greatly improved to enhance + clarity and usability in general); every type of trainer now has a dedicated configuration class which provides + precisely the options that are applicable. + +* The interface has been streamlined with improved naming of functions/parameters and limiting the public interface to purely + the methods and attributes a user should reasonably access. + +* Further changes potentially affecting usage: + * We dropped the iterator semantics: Method `__next__` has been replaced by `execute_epoch`. #913 + * We no longer report outdated statistics (e.g. on rewards/returns when a training step does not collect any full + episodes) + * See also "Issues resolved" below (as issue resolution can result in usage changes) + * The default value for `test_in_train` was changed from True to False (updating all usage sites to explicitly + set the parameter), because False is the more natural default, which does not make assumptions about + returns/score values computed for the data from a collection step being at all meaningful for early stopping + * The management of episolon-greedy exploration for discrete Q-learning algorithms has been simplified: + * All respective Policy implementations (e.g. `DQNPolicy`, `C51Policy`, etc.) now accept two parameters + `eps_training` and `eps_inference`, which allows the training and test collection cases to be sufficiently + differentiated and makes the use of callback functions (`train_fn`, `test_fn`) unnecessary if only + constants are to be set. + * The setter method `set_eps` has been replaced with `set_eps_training` and `set_eps_inference` accordingly. + +* Further internal changes unlikely to affect usage: + * Module `trainer.utils` was removed and the functions therein where moved to class `Trainer` + * The two places that collected and evaluated test episodes (`_test_in_train` and `_reset`) in addition to + `_test_step` were unified to use `_test_step` (with some minor parametrisation) and now log the results + of the test step accordingly. + +* Issues resolved: + * Methods `run` and `reset`: Parameter `reset_prior_to_run` of `run` was never respected if it was set to `False`, + because the implementation of `__iter__` (now removed) would call `reset` regardless - and calling `reset` + is indeed necessary, because it initializes the training. The parameter was removed and replaced by + `reset_collectors` (such that `run` now replicates the parameters of `reset`). + * Inconsistent configuration options now raise exceptions rather than silently ignoring the issue in the + hope that default behaviour will achieve what the user intended. + One condition where `test_in_train` was silently set to `False` was removed and replaced by a warning. + * The stop criterion `stop_fn` did not consider scores as computed by `compute_score_fn` but instead always used + mean returns (i.e. it was assumed that the default implementation of `compute_score_fn` applies). + This is an inconsistency which has been resolved. + * The `gradient_step` counter was flawed (as it made assumptions about the underlying algorithms, which were + not valid). It has been replaced with an update step counter. + Members of `InfoStats` and parameters of `Logger` (and subclasses) were changed accordingly. + +* Migration information at a glance: + * Training parameters are now passed via instances of configuration objects instead of directly as keyword arguments: + `OnPolicyTrainerParams`, `OffPolicyTrainerParams`, `OfflineTrainerParams`. + * Changed parameter default: Default for `test_in_train` was changed from True to False. + * Changed parameter names to improve clarity: + * `max_epoch` (`num_epochs` in high-level API) -> `max_epochs` + * `step_per_epoch` -> `epoch_num_steps` + * `episode_per_test` (`num_test_episodes` in high-level API) -> `test_step_num_episodes` + * `step_per_collect` -> `collection_step_num_env_steps` + * `episode_per_collect` -> collection_step_num_episodes` + * `update_per_step` -> `update_step_num_gradient_steps_per_sample` + * `repeat_per_collect` -> `update_step_num_repetitions` + * Trainer classes have been renamed: + * `OnpolicyTrainer` -> `OnPolicyTrainer` + * `OffpolicyTrainer` -> `OffPolicyTrainer` + * Method `run`: The parameter `reset_prior_to_run` was removed and replaced by `reset_collectors` (see above). + * Methods `run` and `reset`: The parameter `reset_buffer` was renamed to `reset_collector_buffers` for clarity + * Trainers are no longer iterators; manual usage (not using `run`) should simply call `reset` followed by + calls of `execute_epoch`. + +### Algorithms and Policies + +* We now conceptually differentiate between the learning algorithm and the policy being optimised: + + * The abstraction `BasePolicy` is thus replaced by `Algorithm` and `Policy`, and the package was renamed + from `tianshou.policy` to `tianshou.algorithm`. + + * Migration information: The instantiation of a policy is replaced by the instantiation of an `Algorithm`, + which is passed a `Policy`. In most cases, the former policy class name `Policy` is replaced by algorithm + class ``; exceptions are noted below. + + * `ImitationPolicy` -> `OffPolicyImitationLearning`, `OfflineImitationLearning` + * `PGPolicy` -> `Reinforce` + * `MultiAgentPolicyManager` -> `MultiAgentOnPolicyAlgorithm`, `MultiAgentOffPolicyAlgorithm` + * `MARLRandomPolicy` -> `MARLRandomDiscreteMaskedOffPolicyAlgorithm` + + For the respective subtype of `Policy` to use, see the respective algorithm class' constructor. + +* Interface changes/improvements: + * Core methods have been renamed (and removed from the public interface; #898): + * `process_fn` -> `_preprocess_batch` + * `post_process_fn` -> `_postprocess_batch` + * `learn` -> `_update_with_batch` + * The updating interface has been cleaned up (#949): + * Functions `update` and `_update_with_batch` (formerly `learn`) no longer have `*args` and `**kwargs`. + * Instead, the interfaces for the offline, off-policy and on-policy cases are properly differentiated. + * New method `run_training`: The `Algorithm` abstraction can now directly initiate the learning process via this method. + * `Algorithms` no longer require `torch.optim.Optimizer` instances and instead require `OptimizerFactory` + instances, which create the actual optimizers internally. #959 + The new `OptimizerFactory` abstraction simultaneously handles the creation of learning rate schedulers + for the optimizers created (via method `with_lr_scheduler_factory` and accompanying factory abstraction + `LRSchedulerFactory`). + The parameter `lr_scheduler` has thus been removed from all algorithm constructors. + * The flag `updating` has been removed (no internal usage, general usefulness questionable). + * Removed `max_action_num`, instead read it off from `action_space` + * Parameter changes: + * `actor_step_size` -> `trust_region_size` in NP + * `discount_factor` -> `gamma` (was already used internally almost everywhere) + * `reward_normalization` -> `return_standardization` or `return_scaling` (more precise naming) or removed (was actually unsupported by Q-learning algorithms) + * `return_standardization` in `Reinforce` and `DiscreteCRR` (as it applies standardization of returns) + * `return_scaling` in actor-critic on-policy algorithms (A2C, PPO, GAIL, NPG, TRPO) + * removed from Q-learning algorithms, where it was actually unsupported (DQN, C561, etc.) + * `clip_grad` -> `max_grad_norm` (for consistency) + * `clip_loss_grad` -> `huber_loss_delta` (allowing to control not only the use of the Huber loss but also its essential parameter) + * `estimation_step` -> `n_step_return_horizon` (more precise naming) + +* Internal design improvements: + + * Introduced an abstraction for the alpha parameter (coefficient of the entropy term) + in `SAC`, `DiscreteSAC` and other algorithms. + * Class hierarchy: + * Abstract base class `Alpha` base class with value property and update method + * `FixedAlpha` for constant entropy coefficients + * `AutoAlpha` for automatic entropy tuning (replaces the old tuple-based representation) + * The (auto-)updating logic is now completely encapsulated, reducing the complexity of the algorithms. + * Implementations for continuous and discrete cases now share the same abstraction, + making the codebase more consistent while preserving the original functionality. + + * Introduced a policy base class `ContinuousPolicyWithExplorationNoise` which encapsulates noise generation + for continuous action spaces (e.g. relevant to `DDPG`, `SAC` and `REDQ`). + + * Multi-agent RL methods are now differentiated by the type of the sub-algorithms being employed + (`MultiAgentOnPolicyAlgorithm`, `MultiAgentOffPolicyAlgorithm`), which renders all interfaces clean. + Helper class `MARLDispatcher` has been factored out to manage the dispatching of data to the respective agents. + + * Algorithms now internally use a wrapper (`Algorithm.Optimizer`) around the optimizers; creation is handled + by method `_create_optimizer`. + * This facilitates backpropagation steps with gradient clipping. + * The optimizers of an Algorithm instance are now centrally tracked, such that we can ensure that the + optimizers' states are handled alongside the model parameters when calling `state_dict` or `load_state_dict` + on the `Algorithm` instance. + Special handling of the restoration of optimizers' state dicts was thus removed from examples and tests. + + * Lagged networks (target networks) are now conveniently handled via the new algorithm mixins + `LaggedNetworkPolyakUpdateAlgorithmMixin` and `LaggedNetworkFullUpdateAlgorithmMixin`. + Using these mixins, + + * a lagged network can simply be added by calling `_add_lagged_network` + * the torch method `train` must no longer be overridden to ensure that the target networks + are never set to train mode/remain in eval mode (which was prone to errors), + * a method which updates all target networks with their source networks is automatically + provided and does not need to be implemented specifically for every algorithm + (`_update_lagged_network_weights`). + + All classes which make use of lagged networks were updated to use these mixins, simplifying + the implementations and reducing the potential for implementation errors. + (In the BCQ implementation, the VAE network was not correctly handled, but due to the way + in which examples were structured, it did not result in an error.) + +* Fixed issues in the class hierarchy (particularly critical violations of the Liskov substitution principle): + * Introduced base classes (to retain factorization without abusive inheritance): + * `ActorCriticOnPolicyAlgorithm` + * `ActorCriticOffPolicyAlgorithm` + * `ActorDualCriticsOffPolicyAlgorithm` (extends `ActorCriticOffPolicyAlgorithm`) + * `QLearningOffPolicyAlgorithm` + * `A2C`: Inherit from `ActorCriticOnPolicyAlgorithm` instead of `Reinforce` + * `BDQN`: + * Inherit from `QLearningOffPolicyAlgorithm` instead of `DQN` + * Remove parameter `clip_loss_grad` (unused; only passed on to former base class) + * Remove parameter `estimation_step`, for which only one option was valid + * `C51`: + * Inherit from `QLearningOffPolicyAlgorithm` instead of `DQN` + * Remove parameters `clip_loss_grad` and `is_double` (unused; only passed on to former base class) + * `CQL`: + * Inherit directly from `OfflineAlgorithm` instead of `SAC` (off-policy). + * Remove parameter `estimation_step` (now `n_step_return_horizon`), which was not actually used (it was only passed it on to its + superclass). + * `DiscreteBCQ`: + * Inherit directly from `OfflineAlgorithm` instead of `DQN` + * Remove unused parameters `clip_loss_grad` and `is_double`, which were only passed on to + former the base class but actually unused. + * `DiscreteCQL`: Remove unused parameters `clip_loss_grad` and `is_double`, which were only passed on to + base class `QRDQN` (and unused by it). + * `DiscreteCRR`: Inherit directly from `OfflineAlgorithm` instead of `Reinforce` (on-policy) + * `FQF`: Remove unused parameters `clip_loss_grad` and `is_double`, which were only passed on to + base class `QRDQN` (and unused by it). + * `IQN`: Remove unused parameters `clip_loss_grad` and `is_double`, which were only passed on to + base class `QRDQN` (and unused by it). + * `NPG`: Inherit from `ActorCriticOnPolicyAlgorithm` instead of `A2C` + * `QRDQN`: + * Inherit from `QLearningOffPolicyAlgorithm` instead of `DQN` + * Remove parameters `clip_loss_grad` and `is_double` (unused; only passed on to former base class) + * `REDQ`: Inherit from `ActorCriticOffPolicyAlgorithm` instead of `DDPG` + * `SAC`: Inherit from `ActorDualCriticsOffPolicyAlgorithm` instead of `DDPG` + * `TD3`: Inherit from `ActorDualCriticsOffPolicyAlgorithm` instead of `DDPG` + +### High-Level API + +* Detailed optimizer configuration (analogous to the procedural API) is now possible: + * All optimizers can be configured in the respective algorithm-specific `Params` object by using + `OptimizerFactoryFactory` instances as parameter values (e.g. `optim`, `actor_optim`, `critic_optim`, etc.). + * Learning rate schedulers remain separate parameters and now use `LRSchedulerFactoryFactory` + instances. The respective parameter names now use the suffix `lr_scheduler` instead of `lr_scheduler_factory` + (as the precise nature need not be reflected in the name; brevity is preferable). + +* `SamplingConfig` is replaced by `TrainingConfig` and subclasses differentiating off-policy and on-policy cases + appropriately (`OnPolicyTrainingConfig`, `OffPolicyTrainingConfig`). + * The `test_in_train` parameter is now exposed (default False). + * Inapplicable arguments can no longer be set in the respective subclass (e.g. `OffPolicyTrainingConfig` does not + contain parameter `repeat_per_collect`). + * All parameter names have been aligned with the new names used by `TrainerParams` (see above). + +### Peripheral Changes + +* The `Actor` classes have been renamed for clarity (#1091): + * `BaseActor` -> `Actor` + * `continuous.ActorProb` -> `ContinuousActorProbabilistic` + * `coninuous.Actor` -> `ContinuousActorDeterministic` + * `discrete.Actor` -> `DiscreteActor` +* The `Critic` classes have been renamed for clarity (#1091): + * `continuous.Critic` -> `ContinuousCritic` + * `discrete.Critic` -> `DiscreteCritic` +* Moved Atari helper modules `atari_network` and `atari_wrapper` to the library under `tianshou.env.atari`. +* Fix issues pertaining to the torch device assignment of network components (#810): + * Remove 'device' member (and the corresponding constructor argument) from the following classes: + `BranchingNet`, `C51Net`, `ContinuousActorDeterministic`, `ContinuousActorProbabilistic`, `ContinuousCritic`, + `DiscreteActor`, `DiscreteCritic`, `DQNet`, `FullQuantileFunction`, `ImplicitQuantileNetwork`, + `IntrinsicCuriosityModule`, `MLPActor`, `MLP`, `Perturbation`, `QRDQNet`, `Rainbow`, `Recurrent`, + `RecurrentActorProb`, `RecurrentCritic`, `VAE` + * (Peripheral change:) Require the use of keyword arguments for the constructors of all of these classes +* Clean up handling of modules that define attribute `output_dim`, introducing the explicit base class + `ModuleWithVectorOutput` + * Interfaces where one could specify either a module with `output_dim` or additionally provide the output + dimension as an argument were changed to use `ModuleWithVectorOutput`. + * The high-level API class `IntermediateModule` can now provide a `ModuleWithVectorOutput` instance + (via adaptation if necessary). +* The class hierarchy of supporting `nn.Module` implementations was cleaned up (#1091): + * With the fundamental base classes `ActionReprNet` and `ActionReprNetWithVectorOutput`, we etablished a + well-defined interface for the most commonly used `forward` interface in Tianshou's algorithms & policies. #948 + * Some network classes were renamed: + * `ScaledObsInputModule` -> `ScaledObsInputActionReprNet` + * `Rainbow` -> `RainbowNet` +* All modules containing base classes were renamed from `base` to a more descriptive name, rendering + file names unique. ## Release 1.2.0 @@ -222,7 +483,7 @@ A detailed list of changes can be found below. distribution type. #1032 - Exception no longer raised on `len` of empty `Batch`. #1084 - tests and examples are covered by `mypy`. #1077 -- `NetBase` is more used, stricter typing by making it generic. #1077 +- `Actor` is more used, stricter typing by making it generic. #1077 - Use explicit multiprocessing context for creating `Pipe` in `subproc.py`. #1102 diff --git a/README.md b/README.md index ee69f6ae6..4eb9d2154 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,6 @@ 1. Convenient high-level interfaces for applications of RL (training an implemented algorithm on a custom environment). 1. Large scope: online (on- and off-policy) and offline RL, experimental support for multi-agent RL (MARL), experimental support for model-based RL, and more - Unlike other reinforcement learning libraries, which may have complex codebases, unfriendly high-level APIs, or are not optimized for speed, Tianshou provides a high-performance, modularized framework and user-friendly interfaces for building deep reinforcement learning agents. One more aspect that sets Tianshou apart is its @@ -149,9 +148,11 @@ If no errors are reported, you have successfully installed Tianshou. ## Documentation -Tutorials and API documentation are hosted on [tianshou.readthedocs.io](https://tianshou.readthedocs.io/). +Find example scripts in the [test/]( https://github.com/thu-ml/tianshou/blob/master/test) and [examples/](https://github.com/thu-ml/tianshou/blob/master/examples) folders. -Find example scripts in the [test/](https://github.com/thu-ml/tianshou/blob/master/test) and [examples/](https://github.com/thu-ml/tianshou/blob/master/examples) folders. +Tutorials and API documentation are hosted on [tianshou.readthedocs.io](https://tianshou.readthedocs.io/). +**Important**: The documentation is currently being updated to reflect the changes in Tianshou v2.0.0. Not all features are documented yet, and some parts are outdated (they are marked as such). The documentation will be fully updated when +the v2.0.0 release is finalized. ## Why Tianshou? @@ -180,20 +181,23 @@ Check out the [GitHub Actions](https://github.com/thu-ml/tianshou/actions) page Atari and MuJoCo benchmark results can be found in the [examples/atari/](examples/atari/) and [examples/mujoco/](examples/mujoco/) folders respectively. **Our MuJoCo results reach or exceed the level of performance of most existing benchmarks.** -### Policy Interface +### Algorithm Abstraction + +Reinforcement learning algorithms are build on abstractions for + +- on-policy algorithms (`OnPolicyAlgorithm`), +- off-policy algorithms (`OffPolicyAlgorithm`), and +- offline algorithms (`OfflineAlgorithm`), + +all of which clearly separate the core algorithm from the training process and the respective environment interactions. -All algorithms implement the following, highly general API: +In each case, the implementation of an algorithm necessarily involves only the implementation of methods for -- `__init__`: initialize the policy; -- `forward`: compute actions based on given observations; -- `process_buffer`: process initial buffer, which is useful for some offline learning algorithms -- `process_fn`: preprocess data from the replay buffer (since we have reformulated _all_ algorithms to replay buffer-based algorithms); -- `learn`: learn from a given batch of data; -- `post_process_fn`: update the replay buffer from the learning process (e.g., prioritized replay buffer needs to update the weight); -- `update`: the main interface for training, i.e., `process_fn -> learn -> post_process_fn`. +- pre-processing a batch of data, augmenting it with necessary information/sufficient statistics for learning (`_preprocess_batch`), +- updating model parameters based on an augmented batch of data (`_update_with_batch`). -The implementation of this API suffices for a new algorithm to be applicable within Tianshou, -making experimenation with new approaches particularly straightforward. +The implementation of these methods suffices for a new algorithm to be applicable within Tianshou, +making experimentation with new approaches particularly straightforward. ## Quick Start @@ -203,70 +207,68 @@ Tianshou provides two API levels: - the procedural interface, which provides a maximum of control, especially for very advanced users and developers of reinforcement learning algorithms. In the following, let us consider an example application using the _CartPole_ gymnasium environment. -We shall apply the deep Q network (DQN) learning algorithm using both APIs. +We shall apply the deep Q-network (DQN) learning algorithm using both APIs. ### High-Level API -To get started, we need some imports. - -```python -from tianshou.highlevel.config import SamplingConfig -from tianshou.highlevel.env import ( - EnvFactoryRegistered, - VectorEnvType, -) -from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig -from tianshou.highlevel.params.policy_params import DQNParams -from tianshou.highlevel.trainer import ( - EpochTestCallbackDQNSetEps, - EpochTrainCallbackDQNSetEps, - EpochStopCallbackRewardThreshold -) -``` - In the high-level API, the basis for an RL experiment is an `ExperimentBuilder` with which we can build the experiment we then seek to run. Since we want to use DQN, we use the specialization `DQNExperimentBuilder`. -The other imports serve to provide configuration options for our experiment. The high-level API provides largely declarative semantics, i.e. the code is almost exclusively concerned with configuration that controls what to do (rather than how to do it). ```python +from tianshou.highlevel.config import OffPolicyTrainingConfig +from tianshou.highlevel.env import ( + EnvFactoryRegistered, + VectorEnvType, +) +from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig +from tianshou.highlevel.params.algorithm_params import DQNParams +from tianshou.highlevel.trainer import ( + EpochStopCallbackRewardThreshold, +) + experiment = ( - DQNExperimentBuilder( - EnvFactoryRegistered(task="CartPole-v1", train_seed=0, test_seed=0, venv_type=VectorEnvType.DUMMY), - ExperimentConfig( - persistence_enabled=False, - watch=True, - watch_render=1 / 35, - watch_num_episodes=100, - ), - SamplingConfig( - num_epochs=10, - step_per_epoch=10000, - batch_size=64, - num_train_envs=10, - num_test_envs=100, - buffer_size=20000, - step_per_collect=10, - update_per_step=1 / 10, - ), - ) - .with_dqn_params( - DQNParams( - lr=1e-3, - discount_factor=0.9, - estimation_step=3, - target_update_freq=320, - ), - ) - .with_model_factory_default(hidden_sizes=(64, 64)) - .with_epoch_train_callback(EpochTrainCallbackDQNSetEps(0.3)) - .with_epoch_test_callback(EpochTestCallbackDQNSetEps(0.0)) - .with_epoch_stop_callback(EpochStopCallbackRewardThreshold(195)) - .build() + DQNExperimentBuilder( + EnvFactoryRegistered( + task="CartPole-v1", + venv_type=VectorEnvType.DUMMY, + train_seed=0, + test_seed=10, + ), + ExperimentConfig( + persistence_enabled=False, + watch=True, + watch_render=1 / 35, + watch_num_episodes=100, + ), + OffPolicyTrainingConfig( + max_epochs=10, + epoch_num_steps=10000, + batch_size=64, + num_train_envs=10, + num_test_envs=100, + buffer_size=20000, + collection_step_num_env_steps=10, + update_step_num_gradient_steps_per_sample=1 / 10, + ), + ) + .with_dqn_params( + DQNParams( + lr=1e-3, + gamma=0.9, + n_step_return_horizon=3, + target_update_freq=320, + eps_training=0.3, + eps_inference=0.0, + ), + ) + .with_model_factory_default(hidden_sizes=(64, 64)) + .with_epoch_stop_callback(EpochStopCallbackRewardThreshold(195)) + .build() ) experiment.run() ``` @@ -281,24 +283,25 @@ The experiment builder takes three arguments: episodes (`watch_num_episodes=100`). We have disabled persistence, because we do not want to save training logs, the agent or its configuration for future use. -- the sampling configuration, which controls fundamental training parameters, +- the training configuration, which controls fundamental training parameters, such as the total number of epochs we run the experiment for (`num_epochs=10`) and the number of environment steps each epoch shall consist of - (`step_per_epoch=10000`). + (`epoch_num_steps=10000`). Every epoch consists of a series of data collection (rollout) steps and training steps. - The parameter `step_per_collect` controls the amount of data that is + The parameter `collection_step_num_env_steps` controls the amount of data that is collected in each collection step and after each collection step, we perform a training step, applying a gradient-based update based on a sample of data (`batch_size=64`) taken from the buffer of data that has been - collected. For further details, see the documentation of `SamplingConfig`. + collected. For further details, see the documentation of configuration class. -We then proceed to configure some of the parameters of the DQN algorithm itself -and of the neural network model we want to use. -A DQN-specific detail is the use of callbacks to configure the algorithm's -epsilon parameter for exploration. We want to use random exploration during rollouts -(train callback), but we don't when evaluating the agent's performance in the test -environments (test callback). +We then proceed to configure some of the parameters of the DQN algorithm itself: +For instance, we control the epsilon parameter for exploration. +We want to use random exploration during rollouts for training (`eps_training`), +but we don't when evaluating the agent's performance in the test environments +(`eps_inference`). +Furthermore, we configure model parameters of the network for the Q function, +parametrising the number of hidden layers of the default MLP factory. Find the script in [examples/discrete/discrete_dqn_hl.py](examples/discrete/discrete_dqn_hl.py). Here's a run (with the training time cut short): @@ -309,7 +312,7 @@ Here's a run (with the training time cut short): Find many further applications of the high-level API in the `examples/` folder; look for scripts ending with `_hl.py`. -Note that most of these examples require the extra package `argparse` +Note that most of these examples require the extra `argparse` (install it by adding `--extras argparse` when invoking poetry). ### Procedural API @@ -317,7 +320,7 @@ Note that most of these examples require the extra package `argparse` Let us now consider an analogous example in the procedural API. Find the full script in [examples/discrete/discrete_dqn.py](https://github.com/thu-ml/tianshou/blob/master/examples/discrete/discrete_dqn.py). -First, import some relevant packages: +First, import the relevant packages: ```python import gymnasium as gym @@ -326,7 +329,7 @@ from torch.utils.tensorboard import SummaryWriter import tianshou as ts ``` -Define some hyper-parameters: +Define hyper-parameters: ```python task = 'CartPole-v1' @@ -335,14 +338,13 @@ train_num, test_num = 10, 100 gamma, n_step, target_freq = 0.9, 3, 320 buffer_size = 20000 eps_train, eps_test = 0.1, 0.05 -step_per_epoch, step_per_collect = 10000, 10 +epoch_num_steps, collection_step_num_env_steps = 10000, 10 ``` Initialize the logger: ```python logger = ts.utils.TensorboardLogger(SummaryWriter('log/dqn')) -# For other loggers, see https://tianshou.readthedocs.io/en/master/01_tutorials/05_logger.html ``` Make environments: @@ -353,53 +355,78 @@ train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(train_ test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(test_num)]) ``` -Create the network as well as its optimizer: +Create the network, policy, and algorithm: ```python from tianshou.utils.net.common import Net +from tianshou.algorithm import DQN +from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory # Note: You can easily define other networks. # See https://tianshou.readthedocs.io/en/master/01_tutorials/00_dqn.html#build-the-network env = gym.make(task, render_mode="human") state_shape = env.observation_space.shape or env.observation_space.n action_shape = env.action_space.shape or env.action_space.n -net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128]) -optim = torch.optim.Adam(net.parameters(), lr=lr) -``` - -Set up the policy and collectors: +net = Net( + state_shape=state_shape, action_shape=action_shape, + hidden_sizes=[128, 128, 128] +) -```python -policy = ts.policy.DQNPolicy( +policy = DiscreteQLearningPolicy( model=net, - optim=optim, - discount_factor=gamma, action_space=env.action_space, - estimation_step=n_step, + eps_training=eps_train, + eps_inference=eps_test +) + +# Create the algorithm with the policy and optimizer factory +algorithm = DQN( + policy=policy, + optim=AdamOptimizerFactory(lr=lr), + gamma=gamma, + n_step_return_horizon=n_step, target_update_freq=target_freq ) -train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(buffer_size, train_num), exploration_noise=True) -test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True) # because DQN uses epsilon-greedy method ``` -Let's train it: +Set up the collectors: ```python -result = ts.trainer.OffpolicyTrainer( - policy=policy, +train_collector = ts.data.Collector(policy, train_envs, + ts.data.VectorReplayBuffer(buffer_size, train_num), exploration_noise=True) +test_collector = ts.data.Collector(policy, test_envs, + exploration_noise=True) # because DQN uses epsilon-greedy method +``` + +Let's train it using the algorithm: + +```python +from tianshou.highlevel.config import OffPolicyTrainingConfig + +# Create training configuration +training_config = OffPolicyTrainingConfig( + max_epochs=epoch, + epoch_num_steps=epoch_num_steps, + batch_size=batch_size, + num_train_envs=train_num, + num_test_envs=test_num, + buffer_size=buffer_size, + collection_step_num_env_steps=collection_step_num_env_steps, + update_step_num_gradient_steps_per_sample=1 / collection_step_num_env_steps, + test_step_num_episodes=test_num, +) + +# Run training (trainer is created automatically by the algorithm) +result = algorithm.run_training( + training_config=training_config, train_collector=train_collector, test_collector=test_collector, - max_epoch=epoch, - step_per_epoch=step_per_epoch, - step_per_collect=step_per_collect, - episode_per_test=test_num, - batch_size=batch_size, - update_per_step=1 / step_per_collect, + logger=logger, train_fn=lambda epoch, env_step: policy.set_eps(eps_train), test_fn=lambda epoch, env_step: policy.set_eps(eps_test), stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold, - logger=logger, -).run() +) print(f"Finished training in {result.timing.total_time} seconds") ``` diff --git a/docs/01_tutorials/00_dqn.rst b/docs/01_tutorials/00_dqn.rst index 263ee3709..3c28e7163 100644 --- a/docs/01_tutorials/00_dqn.rst +++ b/docs/01_tutorials/00_dqn.rst @@ -112,7 +112,7 @@ Tianshou supports any user-defined PyTorch networks and optimizers. Yet, of cour import torch, numpy as np from torch import nn - class Net(nn.Module): + class MLPActor(nn.Module): def __init__(self, state_shape, action_shape): super().__init__() self.model = nn.Sequential( @@ -129,10 +129,15 @@ Tianshou supports any user-defined PyTorch networks and optimizers. Yet, of cour logits = self.model(obs.view(batch, -1)) return logits, state - state_shape = env.observation_space.shape or env.observation_space.n - action_shape = env.action_space.shape or env.action_space.n - net = Net(state_shape, action_shape) - optim = torch.optim.Adam(net.parameters(), lr=1e-3) + from tianshou.utils.net.common import Net + from tianshou.utils.space_info import SpaceInfo + from tianshou.algorithm.optim import AdamOptimizerFactory + + space_info = SpaceInfo.from_env(env) + state_shape = space_info.observation_info.obs_shape + action_shape = space_info.action_info.action_shape + net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128, 128]) + optim = AdamOptimizerFactory(lr=1e-3) You can also use pre-defined MLP networks in :mod:`~tianshou.utils.net.common`, :mod:`~tianshou.utils.net.discrete`, and :mod:`~tianshou.utils.net.continuous`. The rules of self-defined networks are: @@ -150,13 +155,22 @@ Setup Policy We use the defined ``net`` and ``optim`` above, with extra policy hyper-parameters, to define a policy. Here we define a DQN policy with a target network: :: - policy = ts.policy.DQNPolicy( + from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy + from tianshou.algorithm import DQN + + policy = DiscreteQLearningPolicy( model=net, - optim=optim, action_space=env.action_space, - discount_factor=0.9, - estimation_step=3, - target_update_freq=320 + observation_space=env.observation_space, + eps_training=0.1, + eps_inference=0.05, + ) + algorithm = DQN( + policy=policy, + optim=optim, + gamma=0.9, + n_step_return_horizon=3, + target_update_freq=320, ) @@ -170,8 +184,11 @@ The following code shows how to set up a collector in practice. It is worth noti :: - train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(20000, 10), exploration_noise=True) - test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True) + from tianshou.data import Collector, CollectStats, VectorReplayBuffer + + buf = VectorReplayBuffer(20000, buffer_num=len(train_envs)) + train_collector = Collector[CollectStats](algorithm, train_envs, buf, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) The main function of collector is the collect function, which can be summarized in the following lines: @@ -188,29 +205,41 @@ The main function of collector is the collect function, which can be summarized Train Policy with a Trainer --------------------------- -Tianshou provides :class:`~tianshou.trainer.OnpolicyTrainer`, :class:`~tianshou.trainer.OffpolicyTrainer`, +Tianshou provides :class:`~tianshou.trainer.OnPolicyTrainer`, :class:`~tianshou.trainer.OffpolicyTrainer`, and :class:`~tianshou.trainer.OfflineTrainer`. The trainer will automatically stop training when the policy reaches the stop condition ``stop_fn`` on test collector. Since DQN is an off-policy algorithm, we use the :class:`~tianshou.trainer.OffpolicyTrainer` as follows: :: - result = ts.trainer.OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=10, step_per_epoch=10000, step_per_collect=10, - update_per_step=0.1, episode_per_test=100, batch_size=64, - train_fn=lambda epoch, env_step: policy.set_eps(0.1), - test_fn=lambda epoch, env_step: policy.set_eps(0.05), - stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold - ).run() - print(f'Finished training! Use {result["duration"]}') + from tianshou.trainer import OffPolicyTrainerParams + + def train_fn(epoch, env_step): + policy.set_eps_training(0.1) + + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold + + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=10, + epoch_num_steps=10000, + collection_step_num_env_steps=10, + test_step_num_episodes=100, + batch_size=64, + update_step_num_gradient_steps_per_sample=0.1, + train_fn=train_fn, + stop_fn=stop_fn, + ) + ) + print(f'Finished training! Use {result.duration}') The meaning of each parameter is as follows (full description can be found at :class:`~tianshou.trainer.OffpolicyTrainer`): * ``max_epoch``: The maximum of epochs for training. The training process might be finished before reaching the ``max_epoch``; -* ``step_per_epoch``: The number of environment step (a.k.a. transition) collected per epoch; -* ``step_per_collect``: The number of transition the collector would collect before the network update. For example, the code above means "collect 10 transitions and do one policy network update"; +* ``epoch_num_steps``: The number of environment step (a.k.a. transition) collected per epoch; +* ``collection_step_num_env_steps``: The number of transition the collector would collect before the network update. For example, the code above means "collect 10 transitions and do one policy network update"; * ``episode_per_test``: The number of episodes for one policy evaluation. * ``batch_size``: The batch size of sample data, which is going to feed in the policy network. * ``train_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training". @@ -232,18 +261,10 @@ The returned result is a dictionary as follows: :: { - 'train_step': 9246, - 'train_episode': 504.0, - 'train_time/collector': '0.65s', - 'train_time/model': '1.97s', - 'train_speed': '3518.79 step/s', - 'test_step': 49112, - 'test_episode': 400.0, - 'test_time': '1.38s', - 'test_speed': '35600.52 step/s', - 'best_reward': 199.03, - 'duration': '4.01s' - } + TrainingResult object with attributes like: + best_reward: 199.03 + duration: 4.01s + And other training statistics It shows that within approximately 4 seconds, we finished training a DQN agent on CartPole. The mean returns over 100 consecutive episodes is 199.03. @@ -265,8 +286,8 @@ Watch the Agent's Performance :: policy.eval() - policy.set_eps(0.05) - collector = ts.data.Collector(policy, env, exploration_noise=True) + policy.set_eps_inference(0.05) + collector = ts.data.Collector(algorithm, env, exploration_noise=True) collector.collect(n_episode=1, render=1 / 35) If you'd like to manually see the action generated by a well-trained agent: @@ -289,24 +310,24 @@ Tianshou supports user-defined training code. Here is the code snippet: # pre-collect at least 5000 transitions with random action before training train_collector.collect(n_step=5000, random=True) - policy.set_eps(0.1) + policy.set_eps_training(0.1) for i in range(int(1e6)): # total step collect_result = train_collector.collect(n_step=10) # once if the collected episodes' mean returns reach the threshold, # or every 1000 steps, we test it on test_collector if collect_result['rews'].mean() >= env.spec.reward_threshold or i % 1000 == 0: - policy.set_eps(0.05) + policy.set_eps_inference(0.05) result = test_collector.collect(n_episode=100) if result['rews'].mean() >= env.spec.reward_threshold: print(f'Finished training! Test mean returns: {result["rews"].mean()}') break else: # back to training eps - policy.set_eps(0.1) + policy.set_eps_training(0.1) # train policy with a sampled batch data from buffer - losses = policy.update(64, train_collector.buffer) + losses = algorithm.update(64, train_collector.buffer) For further usage, you can refer to the :doc:`/01_tutorials/07_cheatsheet`. diff --git a/docs/01_tutorials/01_concepts.rst b/docs/01_tutorials/01_concepts.rst index 5107bd690..28b0dc276 100644 --- a/docs/01_tutorials/01_concepts.rst +++ b/docs/01_tutorials/01_concepts.rst @@ -1,18 +1,22 @@ Basic concepts in Tianshou ========================== -Tianshou splits a Reinforcement Learning agent training procedure into these parts: trainer, collector, policy, and data buffer. The general control flow can be described as: +Tianshou splits a Reinforcement Learning agent training procedure into these parts: algorithm, trainer, collector, policy, a data buffer and batches from the buffer. +The algorithm encapsulates the specific RL learning method (e.g., DQN, PPO), which contains a policy and defines how to update it. -.. image:: /_static/images/concepts_arch.png - :align: center - :height: 300 +.. + The general control flow can be described as: + .. image:: /_static/images/concepts_arch.png + :align: center + :height: 300 -Here is a more detailed description, where ``Env`` is the environment and ``Model`` is the neural network: -.. image:: /_static/images/concepts_arch2.png - :align: center - :height: 300 + Here is a more detailed description, where ``Env`` is the environment and ``Model`` is the neural network: + + .. image:: /_static/images/concepts_arch2.png + :align: center + :height: 300 Batch @@ -220,19 +224,28 @@ The following code snippet illustrates the usage, including: Tianshou provides other type of data buffer such as :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``) and :class:`~tianshou.data.VectorReplayBuffer` (add different episodes' data but without losing chronological order). Check out :class:`~tianshou.data.ReplayBuffer` for more detail. -Policy ------- +Algorithm and Policy +-------------------- + +Tianshou's RL framework is built around two key abstractions: :class:`~tianshou.algorithm.Algorithm` and :class:`~tianshou.algorithm.Policy`. + +**Algorithm**: The core abstraction that encapsulates a complete RL learning method (e.g., DQN, PPO, SAC). Each algorithm contains a policy and defines how to update it using training data. All algorithm classes inherit from :class:`~tianshou.algorithm.Algorithm`. + +An algorithm class typically has the following parts: + +* :meth:`~tianshou.algorithm.Algorithm.__init__`: initialize the algorithm with a policy and optimization configuration; +* :meth:`~tianshou.algorithm.Algorithm._preprocess_batch`: pre-process data from the replay buffer (e.g., compute n-step returns); +* :meth:`~tianshou.algorithm.Algorithm._update_with_batch`: the algorithm-specific network update logic; +* :meth:`~tianshou.algorithm.Algorithm._postprocess_batch`: post-process the batch data (e.g., update prioritized replay buffer weights); +* :meth:`~tianshou.algorithm.Algorithm.create_trainer`: create the appropriate trainer for this algorithm; -Tianshou aims to modularize RL algorithms. It comes into several classes of policies in Tianshou. All of the policy classes must inherit :class:`~tianshou.policy.BasePolicy`. +**Policy**: Represents the mapping from observations to actions. Policy classes inherit from :class:`~tianshou.algorithm.Policy`. -A policy class typically has the following parts: +A policy class typically provides: -* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including copying the target network and so on; -* :meth:`~tianshou.policy.BasePolicy.forward`: compute action with given observation; -* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the replay buffer; -* :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given batch of data. -* :meth:`~tianshou.policy.BasePolicy.post_process_fn`: update the buffer with a given batch of data. -* :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training. This function samples data from buffer, pre-process data (such as computing n-step return), learn with the data, and finally post-process the data (such as updating prioritized replay buffer); in short, ``process_fn -> learn -> post_process_fn``. +* :meth:`~tianshou.algorithm.Policy.forward`: compute action distribution or Q-values given observations; +* :meth:`~tianshou.algorithm.Policy.compute_action`: get concrete actions from observations for environment interaction; +* :meth:`~tianshou.algorithm.Policy.map_action`: transform raw network outputs to environment action space; .. _policy_state: @@ -245,22 +258,10 @@ During the training process, the policy has two main states: training state and The meaning of training and testing state is obvious: the agent interacts with environment, collects training data and performs update, that's training state; the testing state is to evaluate the performance of the current policy during training process. As for the collecting state, it is defined as interacting with environments and collecting training data into the buffer; -we define the updating state as performing a model update by :meth:`~tianshou.policy.BasePolicy.update` during training process. +we define the updating state as performing a model update by the algorithm's update methods during training process. - -In order to distinguish these states, you can check the policy state by ``policy.training`` and ``policy.updating``. The state setting is as follows: - -+-----------------------------------+-----------------+-----------------+ -| State for policy | policy.training | policy.updating | -+================+==================+=================+=================+ -| | Collecting state | True | False | -| Training state +------------------+-----------------+-----------------+ -| | Updating state | True | True | -+----------------+------------------+-----------------+-----------------+ -| Testing state | False | False | -+-----------------------------------+-----------------+-----------------+ - -``policy.updating`` is helpful to distinguish the different exploration state, for example, in DQN we don't have to use epsilon-greedy in a pure network update, so ``policy.updating`` is helpful for setting epsilon in this case. +The collection of data from the env may differ in training and in inference (for example, in training one may add exploration noise, or sample from the predicted action distribution instead of taking its mode). The switch between the different collection strategies in training and inference is controlled by ``policy.is_within_training_step``, see also the docstring of it +for more details. policy.forward @@ -270,7 +271,7 @@ The ``forward`` function computes the action over given observations. The input The input batch is the environment data (e.g., observation, reward, done flag and info). It comes from either :meth:`~tianshou.data.Collector.collect` or :meth:`~tianshou.data.ReplayBuffer.sample`. The first dimension of all variables in the input ``batch`` should be equal to the batch-size. -The output is also a ``Batch`` which must contain "act" (action) and may contain "state" (hidden state of policy), "policy" (the intermediate result of policy which needs to save into the buffer, see :meth:`~tianshou.policy.BasePolicy.forward`), and some other algorithm-specific keys. +The output is also a ``Batch`` which must contain "act" (action) and may contain "state" (hidden state of policy), "policy" (the intermediate result of policy which needs to save into the buffer, see :meth:`~tianshou.algorithm.BasePolicy.forward`), and some other algorithm-specific keys. For example, if you try to use your policy to evaluate one episode (and don't want to use :meth:`~tianshou.data.Collector.collect`), use the following code-snippet: :: @@ -282,15 +283,17 @@ For example, if you try to use your policy to evaluate one episode (and don't wa act = policy(batch).act[0] # policy.forward return a batch, use ".act" to extract the action obs, rew, done, info = env.step(act) +For inference, it is recommended to use the shortcut method :meth:`~tianshou.algorithm.Policy.compute_action` to compute the action directly from the observation. + Here, ``Batch(obs=[obs])`` will automatically create the 0-dimension to be the batch-size. Otherwise, the network cannot determine the batch-size. .. _process_fn: -policy.process_fn -^^^^^^^^^^^^^^^^^ +Algorithm Preprocessing and N-step Returns +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -The ``process_fn`` function computes some variables that depends on time-series. For example, compute the N-step or GAE returns. +The algorithm handles data preprocessing, including computing variables that depend on time-series such as N-step or GAE returns. This functionality is implemented in :meth:`~tianshou.algorithm.Algorithm._preprocess_batch` and the static methods :meth:`~tianshou.algorithm.Algorithm.compute_nstep_return` and :meth:`~tianshou.algorithm.Algorithm.compute_episodic_return`. Take 2-step return DQN as an example. The 2-step return DQN compute each transition's return as: @@ -304,42 +307,19 @@ where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`. Here is # pseudocode, cannot work obs = env.reset() buffer = Buffer(size=10000) - agent = DQN() + algorithm = DQN(...) for i in range(int(1e6)): - act = agent.compute_action(obs) + act = algorithm.policy.compute_action(obs) obs_next, rew, done, _ = env.step(act) buffer.store(obs, act, obs_next, rew, done) obs = obs_next if i % 1000 == 0: - b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) - # compute 2-step returns. How? - b_ret = compute_2_step_return(buffer, b_r, b_d, ...) - # update DQN policy - agent.update(b_s, b_a, b_s_, b_r, b_d, b_ret) - -Thus, we need a time-related interface for calculating the 2-step return. :meth:`~tianshou.policy.BasePolicy.process_fn` finishes this work by providing the replay buffer, the sample index, and the sample batch data. Since we store all the data in the order of time, you can simply compute the 2-step return as: -:: - - class DQN_2step(BasePolicy): - """some code""" + # algorithm handles sampling, preprocessing, and updating + algorithm.update(sample_size=64, buffer=buffer) - def process_fn(self, batch, buffer, indices): - buffer_len = len(buffer) - batch_2 = buffer[(indices + 2) % buffer_len] - # this will return a batch data where batch_2.obs is s_t+2 - # we can also get s_t+2 through: - # batch_2_obs = buffer.obs[(indices + 2) % buffer_len] - # in short, buffer.obs[i] is equal to buffer[i].obs, but the former is more effecient. - Q = self(batch_2, eps=0) # shape: [batchsize, action_shape] - maxQ = Q.max(dim=-1) - batch.returns = batch.rew \ - + self._gamma * buffer.rew[(indices + 1) % buffer_len] \ - + self._gamma ** 2 * maxQ - return batch +The algorithm's :meth:`~tianshou.algorithm.Algorithm._preprocess_batch` method automatically handles n-step return computation by calling :meth:`~tianshou.algorithm.Algorithm.compute_nstep_return`, which provides the replay buffer, sample indices, and batch data. Since we store all the data in the order of time, the n-step return can be computed efficiently using the buffer's temporal structure. -This code does not consider the done flag, so it may not work very well. It shows two ways to get :math:`s_{t + 2}` from the replay buffer easily in :meth:`~tianshou.policy.BasePolicy.process_fn`. - -For other method, you can check out :doc:`/03_api/policy/index`. We give the usage of policy class a high-level explanation in :ref:`pseudocode`. +For custom preprocessing logic, you can override :meth:`~tianshou.algorithm.Algorithm._preprocess_batch` in your algorithm subclass. The method receives the sampled batch, buffer, and indices, allowing you to add computed values like returns, advantages, or other algorithm-specific preprocessing steps. Collector @@ -352,7 +332,7 @@ The :class:`~tianshou.data.Collector` enables the policy to interact with differ The general explanation is listed in :ref:`pseudocode`. Other usages of collector are listed in :class:`~tianshou.data.Collector` documentation. Here are some example usages: :: - policy = PGPolicy(...) # or other policies if you wish + policy = DiscreteQLearningPolicy(...) # or other policies if you wish env = gym.make("CartPole-v1") replay_buffer = ReplayBuffer(size=10000) @@ -380,29 +360,33 @@ There is also another type of collector :class:`~tianshou.data.AsyncCollector` w Trainer ------- -Once you have a collector and a policy, you can start writing the training method for your RL agent. Trainer, to be honest, is a simple wrapper. It helps you save energy for writing the training loop. You can also construct your own trainer: :ref:`customized_trainer`. +Once you have an algorithm and a collector, you can start the training process. The trainer orchestrates the training loop and calls upon the algorithm's specific network updating logic. Each algorithm creates its appropriate trainer type through the :meth:`~tianshou.algorithm.Algorithm.create_trainer` method. -Tianshou has three types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` for on-policy algorithms such as Policy Gradient, :func:`~tianshou.trainer.offpolicy_trainer` for off-policy algorithms such as DQN, and :func:`~tianshou.trainer.offline_trainer` for offline algorithms such as BCQ. Please check out :doc:`/03_api/trainer/index` for the usage. +Tianshou has three main trainer classes: :class:`~tianshou.trainer.OnPolicyTrainer` for on-policy algorithms such as Policy Gradient, :class:`~tianshou.trainer.OffPolicyTrainer` for off-policy algorithms such as DQN, and :class:`~tianshou.trainer.OfflineTrainer` for offline algorithms such as BCQ. -We also provide the corresponding iterator-based trainer classes :class:`~tianshou.trainer.OnpolicyTrainer`, :class:`~tianshou.trainer.OffpolicyTrainer`, :class:`~tianshou.trainer.OfflineTrainer` to facilitate users writing more flexible training logic: +The typical workflow is: :: - trainer = OnpolicyTrainer(...) - for epoch, epoch_stat, info in trainer: - print(f"Epoch: {epoch}") - print(epoch_stat) - print(info) - do_something_with_policy() - query_something_about_policy() - make_a_plot_with(epoch_stat) - display(info) + # Create algorithm with policy + algorithm = DQN(policy=policy, optim=optimizer_factory, ...) + + # Create trainer parameters + params = OffPolicyTrainerParams( + max_epochs=100, + step_per_epoch=1000, + train_collector=train_collector, + test_collector=test_collector, + ... + ) + + # Run training (trainer is created automatically) + result = algorithm.run_training(params) - # or even iterate on several trainers at the same time +You can also create trainers manually for more control: +:: - trainer1 = OnpolicyTrainer(...) - trainer2 = OnpolicyTrainer(...) - for result1, result2, ... in zip(trainer1, trainer2, ...): - compare_results(result1, result2, ...) + trainer = algorithm.create_trainer(params) + result = trainer.run() .. _pseudocode: @@ -416,22 +400,31 @@ We give a high-level explanation through the pseudocode used in section :ref:`pr # pseudocode, cannot work # methods in tianshou obs = env.reset() buffer = Buffer(size=10000) # buffer = tianshou.data.ReplayBuffer(size=10000) - agent = DQN() # policy.__init__(...) + algorithm = DQN(policy=policy, ...) # algorithm.__init__(...) for i in range(int(1e6)): # done in trainer - act = agent.compute_action(obs) # act = policy(batch, ...).act + act = algorithm.policy.compute_action(obs) # act = policy.compute_action(obs) obs_next, rew, done, _ = env.step(act) # collector.collect(...) buffer.store(obs, act, obs_next, rew, done) # collector.collect(...) obs = obs_next # collector.collect(...) if i % 1000 == 0: # done in trainer - # the following is done in policy.update(batch_size, buffer) + # the following is done in algorithm.update(batch_size, buffer) b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # batch, indices = buffer.sample(batch_size) # compute 2-step returns. How? - b_ret = compute_2_step_return(buffer, b_r, b_d, ...) # policy.process_fn(batch, buffer, indices) + b_ret = compute_2_step_return(buffer, b_r, b_d, ...) # algorithm._preprocess_batch(batch, buffer, indices) # update DQN policy - agent.update(b_s, b_a, b_s_, b_r, b_d, b_ret) # policy.learn(batch, ...) + algorithm.update(b_s, b_a, b_s_, b_r, b_d, b_ret) # algorithm._update_with_batch(batch) Conclusion ---------- -So far, we go through the overall framework of Tianshou. Really simple, isn't it? +So far, we've covered the overall framework of Tianshou with its new architecture centered around the Algorithm abstraction. The key components are: + +- **Algorithm**: Encapsulates the complete RL learning method, containing a policy and defining how to update it +- **Policy**: Handles the mapping from observations to actions +- **Collector**: Manages environment interaction and data collection +- **Trainer**: Orchestrates the training loop and calls the algorithm's update logic +- **Buffer**: Stores and manages experience data +- **Batch**: A flexible data structure for passing data between components. Batches are collected to the buffer by the Collector and are sampled from the buffer by the `Algorithm` where they are used for learning. + +This modular design cleanly separates concerns while maintaining the flexibility to implement various RL algorithms. diff --git a/docs/01_tutorials/04_tictactoe.rst b/docs/01_tutorials/04_tictactoe.rst index 232f6c349..c5d6a87c0 100644 --- a/docs/01_tutorials/04_tictactoe.rst +++ b/docs/01_tutorials/04_tictactoe.rst @@ -122,13 +122,13 @@ Two Random Agents .. Figure:: ../_static/images/marl.png -Tianshou already provides some builtin classes for multi-agent learning. You can check out the API documentation for details. Here we use :class:`~tianshou.policy.MARLRandomPolicy` and :class:`~tianshou.policy.MultiAgentPolicyManager`. The figure on the right gives an intuitive explanation. +Tianshou already provides some builtin classes for multi-agent learning. You can check out the API documentation for details. Here we use :class:`~tianshou.algorithm.MARLRandomPolicy` and :class:`~tianshou.algorithm.MultiAgentPolicyManager`. The figure on the right gives an intuitive explanation. :: >>> from tianshou.data import Collector >>> from tianshou.env import DummyVectorEnv - >>> from tianshou.policy import RandomPolicy, MultiAgentPolicyManager + >>> from tianshou.algorithm import RandomPolicy, MultiAgentPolicyManager >>> >>> # agents should be wrapped into one policy, >>> # which is responsible for calling the acting agent correctly @@ -198,7 +198,7 @@ So let's start to train our Tic-Tac-Toe agent! First, import some required modul from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv - from tianshou.policy import ( + from tianshou.algorithm import ( BasePolicy, DQNPolicy, MultiAgentPolicyManager, @@ -206,7 +206,7 @@ So let's start to train our Tic-Tac-Toe agent! First, import some required modul ) from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger - from tianshou.utils.net.common import Net + from tianshou.utils.net.common import MLPActor The explanation of each Tianshou class/function will be deferred to their first usages. Here we define some arguments and hyperparameters of the experiment. The meaning of arguments is clear by just looking at their names. :: @@ -224,15 +224,15 @@ The explanation of each Tianshou class/function will be deferred to their first parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=50) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--epoch_num_steps', type=int, default=1000) + parser.add_argument('--collection_step_num_env_steps', type=int, default=10) parser.add_argument('--update-per-step', type=float, default=0.1) - parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--batch_size', type=int, default=64) parser.add_argument( '--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128] ) - parser.add_argument('--training-num', type=int, default=10) - parser.add_argument('--test-num', type=int, default=10) + parser.add_argument('--num_train_envs', type=int, default=10) + parser.add_argument('--num_test_envs', type=int, default=10) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.1) parser.add_argument( @@ -284,11 +284,11 @@ The explanation of each Tianshou class/function will be deferred to their first The following ``get_agents`` function returns agents and their optimizers from either constructing a new policy, or loading from disk, or using the pass-in arguments. For the models: -- The action model we use is an instance of :class:`~tianshou.utils.net.common.Net`, essentially a multi-layer perceptron with the ReLU activation function; -- The network model is passed to a :class:`~tianshou.policy.DQNPolicy`, where actions are selected according to both the action mask and their Q-values; -- The opponent can be either a random agent :class:`~tianshou.policy.MARLRandomPolicy` that randomly chooses an action from legal actions, or it can be a pre-trained :class:`~tianshou.policy.DQNPolicy` allowing learned agents to play with themselves. +- The action model we use is an instance of :class:`~tianshou.utils.net.common.MLPActor`, essentially a multi-layer perceptron with the ReLU activation function; +- The network model is passed to a :class:`~tianshou.algorithm.DQNPolicy`, where actions are selected according to both the action mask and their Q-values; +- The opponent can be either a random agent :class:`~tianshou.algorithm.MARLRandomPolicy` that randomly chooses an action from legal actions, or it can be a pre-trained :class:`~tianshou.algorithm.DQNPolicy` allowing learned agents to play with themselves. -Both agents are passed to :class:`~tianshou.policy.MultiAgentPolicyManager`, which is responsible to call the correct agent according to the ``agent_id`` in the observation. :class:`~tianshou.policy.MultiAgentPolicyManager` also dispatches data to each agent according to ``agent_id``, so that each agent seems to play with a virtual single-agent environment. +Both agents are passed to :class:`~tianshou.algorithm.MultiAgentPolicyManager`, which is responsible to call the correct agent according to the ``agent_id`` in the observation. :class:`~tianshou.algorithm.MultiAgentPolicyManager` also dispatches data to each agent according to ``agent_id``, so that each agent seems to play with a virtual single-agent environment. Here it is: :: @@ -307,7 +307,7 @@ Here it is: args.action_shape = env.action_space.shape or env.action_space.n if agent_learn is None: # model - net = Net( + net = MLPActor( args.state_shape, args.action_shape, hidden_sizes=args.hidden_sizes, @@ -356,8 +356,8 @@ With the above preparation, we are close to the first learned agent. The followi ) -> Tuple[dict, BasePolicy]: # ======== environment setup ========= - train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)]) - test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)]) + train_envs = DummyVectorEnv([get_env for _ in range(args.num_train_envs)]) + test_envs = DummyVectorEnv([get_env for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -378,7 +378,7 @@ With the above preparation, we are close to the first learned agent. The followi ) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # ======== tensorboard logging setup ========= log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn') @@ -416,9 +416,9 @@ With the above preparation, we are close to the first learned agent. The followi train_collector, test_collector, args.epoch, - args.step_per_epoch, - args.step_per_collect, - args.test_num, + args.epoch_num_steps, + args.collection_step_num_env_steps, + args.num_test_envs, args.batch_size, train_fn=train_fn, test_fn=test_fn, diff --git a/docs/01_tutorials/07_cheatsheet.rst b/docs/01_tutorials/07_cheatsheet.rst index 51fece131..fc747d66f 100644 --- a/docs/01_tutorials/07_cheatsheet.rst +++ b/docs/01_tutorials/07_cheatsheet.rst @@ -1,6 +1,8 @@ Cheat Sheet =========== +**IMPORTANT**: The content here has not yet been adjusted to the v2 version of Tianshou. It is partially outdated and will be updated soon. + This page shows some code snippets of how to use Tianshou to develop new algorithms / apply algorithms to new scenarios. @@ -23,7 +25,7 @@ See :ref:`build_the_network`. Build New Policy ---------------- -See :class:`~tianshou.policy.BasePolicy`. +See :class:`~tianshou.algorithm.BasePolicy`. .. _eval_policy: @@ -283,12 +285,12 @@ Multi-GPU Training To enable training an RL agent with multiple GPUs for a standard environment (i.e., without nested observation) with default networks provided by Tianshou: 1. Import :class:`~tianshou.utils.net.common.DataParallelNet` from ``tianshou.utils.net.common``; -2. Change the ``device`` argument to ``None`` in the existing networks such as ``Net``, ``Actor``, ``Critic``, ``ActorProb`` +2. Change the ``device`` argument to ``None`` in the existing networks such as ``MLPActor``, ``Actor``, ``Critic``, ``ActorProb`` 3. Apply ``DataParallelNet`` wrapper to these networks. :: - from tianshou.utils.net.common import Net, DataParallelNet + from tianshou.utils.net.common import MLPActor, DataParallelNet from tianshou.utils.net.discrete import Actor, Critic actor = DataParallelNet(Actor(net, args.action_shape, device=None).to(args.device)) diff --git a/docs/02_notebooks/0_intro.md b/docs/02_notebooks/0_intro.md index e4d839c2a..e68b36e63 100644 --- a/docs/02_notebooks/0_intro.md +++ b/docs/02_notebooks/0_intro.md @@ -5,3 +5,5 @@ directly in colab, or download them and run them locally. They will guide you step by step to show you how the most basic modules in Tianshou work and how they collaborate with each other to conduct a classic DRL experiment. + +**IMPORTANT**: The notebooks are not yet adjusted to the v2 version of Tianshou! Their content is partly outdated and will be updated soon. diff --git a/docs/02_notebooks/L0_overview.ipynb b/docs/02_notebooks/L0_overview.ipynb deleted file mode 100644 index a9bf617bc..000000000 --- a/docs/02_notebooks/L0_overview.ipynb +++ /dev/null @@ -1,250 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "editable": true, - "id": "r7aE6Rq3cAEE", - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "# Overview\n", - "To begin, ensure you have Tianshou and the Gym environment installed by executing the following commands. This tutorials will always keep up with the latest version of Tianshou since they also serve as a test for the latest version. For users on older versions of Tianshou, please consult the [documentation](https://tianshou.readthedocs.io/en/latest/) corresponding to your version..\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1_mLTSEIcY2c" - }, - "source": [ - "## Run the code" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "IcFNmCjYeIIU" - }, - "source": [ - "Below is a short script that use a certain DRL algorithm (PPO) to solve the classic CartPole-v1\n", - "problem in Gym. Simply run it and **don't worry** if you can't understand the code very well. That is\n", - "exactly what this tutorial is for.\n", - "\n", - "If the script ends normally, you will see the evaluation result printed out before the first\n", - "epoch is finished." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [ - "hide-cell", - "remove-output" - ] - }, - "outputs": [], - "source": [ - "%%capture\n", - "\n", - "import gymnasium as gym\n", - "import torch\n", - "\n", - "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", - "from tianshou.env import DummyVectorEnv\n", - "from tianshou.policy import PPOPolicy\n", - "from tianshou.trainer import OnpolicyTrainer\n", - "from tianshou.utils.net.common import ActorCritic, Net\n", - "from tianshou.utils.net.discrete import Actor, Critic\n", - "\n", - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [], - "source": [ - "# environments\n", - "env = gym.make(\"CartPole-v1\")\n", - "train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(20)])\n", - "test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(10)])\n", - "\n", - "# model & optimizer\n", - "assert env.observation_space.shape is not None # for mypy\n", - "net = Net(state_shape=env.observation_space.shape, hidden_sizes=[64, 64], device=device)\n", - "\n", - "assert isinstance(env.action_space, gym.spaces.Discrete) # for mypy\n", - "actor = Actor(preprocess_net=net, action_shape=env.action_space.n, device=device).to(device)\n", - "critic = Critic(preprocess_net=net, device=device).to(device)\n", - "actor_critic = ActorCritic(actor, critic)\n", - "optim = torch.optim.Adam(actor_critic.parameters(), lr=0.0003)\n", - "\n", - "# PPO policy\n", - "dist = torch.distributions.Categorical\n", - "policy: PPOPolicy = PPOPolicy(\n", - " actor=actor,\n", - " critic=critic,\n", - " optim=optim,\n", - " dist_fn=dist,\n", - " action_space=env.action_space,\n", - " action_scaling=False,\n", - ")\n", - "\n", - "# collector\n", - "train_collector = Collector[CollectStats](\n", - " policy,\n", - " train_envs,\n", - " VectorReplayBuffer(20000, len(train_envs)),\n", - ")\n", - "test_collector = Collector[CollectStats](policy, test_envs)\n", - "\n", - "# trainer\n", - "train_result = OnpolicyTrainer(\n", - " policy=policy,\n", - " batch_size=256,\n", - " train_collector=train_collector,\n", - " test_collector=test_collector,\n", - " max_epoch=10,\n", - " step_per_epoch=50000,\n", - " repeat_per_collect=10,\n", - " episode_per_test=10,\n", - " step_per_collect=2000,\n", - " stop_fn=lambda mean_reward: mean_reward >= 195,\n", - ").run()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "train_result.pprint_asdict()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "G9YEQptYvCgx", - "outputId": "2a9b5b22-be50-4bb7-ae93-af7e65e7442a" - }, - "outputs": [], - "source": [ - "# Let's watch its performance!\n", - "policy.eval()\n", - "eval_result = test_collector.collect(n_episode=3, render=False)\n", - "print(f\"Final reward: {eval_result.returns.mean()}, length: {eval_result.lens.mean()}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xFYlcPo8fpPU" - }, - "source": [ - "## Tutorial Introduction\n", - "\n", - "A common DRL experiment as is shown above may require many components to work together. The agent, the\n", - "environment (possibly parallelized ones), the replay buffer and the trainer all work together to complete a\n", - "training task.\n", - "\n", - "
\n", - "\n", - "\n", - "
\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kV_uOyimj-bk" - }, - "source": [ - "In Tianshou, all of these main components are factored out as different building blocks, which you\n", - "can use to create your own algorithm and finish your own experiment.\n", - "\n", - "Building blocks may include:\n", - "- Batch\n", - "- Replay Buffer\n", - "- Vectorized Environment Wrapper\n", - "- Policy (the agent and the training algorithm)\n", - "- Data Collector\n", - "- Trainer\n", - "- Logger\n", - "\n", - "\n", - "These notebooks tutorials will guide you through all the modules one by one." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "S0mNKwH9i6Ek" - }, - "source": [ - "## Further reading" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "M3NPSUnAov4L" - }, - "source": [ - "### What if I am not familiar with the PPO algorithm itself?\n", - "As for the DRL algorithms themselves, we will refer you to the [Spinning up documentation](https://spinningup.openai.com/en/latest/algorithms/ppo.html), where they provide\n", - "plenty of resources and guides if you want to study the DRL algorithms. In Tianshou's tutorials, we will\n", - "focus on the usages of different modules, but not the algorithms themselves." - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/docs/02_notebooks/L1_Batch.ipynb b/docs/02_notebooks/L1_Batch.ipynb index 4e56c4a1c..d40869287 100644 --- a/docs/02_notebooks/L1_Batch.ipynb +++ b/docs/02_notebooks/L1_Batch.ipynb @@ -31,8 +31,6 @@ }, "outputs": [], "source": [ - "%%capture\n", - "\n", "import pickle\n", "\n", "import numpy as np\n", @@ -401,7 +399,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.11.4" } }, "nbformat": 4, diff --git a/docs/02_notebooks/L2_Buffer.ipynb b/docs/02_notebooks/L2_Buffer.ipynb index 892aaff4f..4f51abca5 100644 --- a/docs/02_notebooks/L2_Buffer.ipynb +++ b/docs/02_notebooks/L2_Buffer.ipynb @@ -6,8 +6,6 @@ "metadata": {}, "outputs": [], "source": [ - "%%capture\n", - "\n", "import pickle\n", "\n", "import numpy as np\n", @@ -421,7 +419,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.11.4" } }, "nbformat": 4, diff --git a/docs/02_notebooks/L3_Vectorized__Environment.ipynb b/docs/02_notebooks/L3_Vectorized__Environment.ipynb index 374ee2ba8..19e5489a2 100644 --- a/docs/02_notebooks/L3_Vectorized__Environment.ipynb +++ b/docs/02_notebooks/L3_Vectorized__Environment.ipynb @@ -47,8 +47,6 @@ }, "outputs": [], "source": [ - "%%capture\n", - "\n", "import time\n", "\n", "import gymnasium as gym\n", @@ -223,7 +221,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.11.4" } }, "nbformat": 4, diff --git a/docs/02_notebooks/L4_GAE.ipynb b/docs/02_notebooks/L4_GAE.ipynb new file mode 100644 index 000000000..8393d6f92 --- /dev/null +++ b/docs/02_notebooks/L4_GAE.ipynb @@ -0,0 +1,265 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "QJ5krjrcbuiA" + }, + "source": [ + "# Notes on Generalized Advantage Estimation\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UPVl5LBEWJ0t" + }, + "source": [ + "## How to compute GAE on your own?\n", + "(Note that for this reading you need to understand the calculation of [GAE](https://arxiv.org/abs/1506.02438) advantage first)\n", + "\n", + "In terms of code implementation, perhaps the most difficult and annoying part is computing GAE advantage. Just now, we use the `self.compute_episodic_return()` method inherited from `BasePolicy` to save us from all those troubles. However, it is still important that we know the details behind this.\n", + "\n", + "To compute GAE advantage, the usage of `self.compute_episodic_return()` may go like:" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "D34GlVvPNz08", + "outputId": "43a4e5df-59b5-4e4a-c61c-e69090810215" + }, + "source": [ + "```python\n", + "batch, indices = dummy_buffer.sample(0) # 0 means sampling all the data from the buffer\n", + "returns, advantage = Algorithm.compute_episodic_return(\n", + " batch=batch,\n", + " buffer=dummy_buffer,\n", + " indices=indices,\n", + " v_s_=np.zeros(10),\n", + " v_s=np.zeros(10),\n", + " gamma=1.0,\n", + " gae_lambda=1.0,\n", + ")\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the code above, we sample all the 10 data in the buffer and try to compute the GAE advantage. However, the way the returns are computed here might be a bit misleading. In fact, the last episode is unfinished, but its last step saved in the batch is treated as a terminal state, since it assumes that there are no future rewards. The episode is not terminated yet, it is truncated, so the agent could still get rewards in the future. Terminated and truncated episodes should indeed be treated differently.\n", + "The return of a step is the (discounted) sum of the future rewards from that step until the end of the episode. \n", + "\\begin{equation}\n", + "R_{t}=\\sum_{t}^{T} \\gamma^{t} r_{t}\n", + "\\end{equation}\n", + "Thus, at the last step of a terminated episode the return is equal to the reward at that state, since there are no future states.\n", + "\\begin{equation}\n", + "R_{T,terminated}=r_{T}\n", + "\\end{equation}\n", + "\n", + "However, if the episode was truncated the return at the last step is usually better represented by the estimated value of that state, which is the expected return from that state onwards.\n", + "\\begin{align*}\n", + "R_{T,truncated}=V^{\\pi}\\left(s_{T}\\right) \\quad & \\text{or} \\quad R_{T,truncated}=Q^{\\pi}(s_{T},a_{T})\n", + "\\end{align*}\n", + "Moreover, if the next state was also observed (but not its reward), then an even better estimate would be the reward of the last step plus the discounted value of the next state.\n", + "\\begin{align*}\n", + "R_{T,truncated}=r_T+\\gamma V^{\\pi}\\left(s_{T+1}\\right)\n", + "\\end{align*}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "h_5Dt6XwQLXV" + }, + "source": [ + "\n", + "As we know, we need to estimate the value function of every observation to compute GAE advantage. So in `v_s` is the value of `batch.obs`, and in `v_s_` is the value of `batch.obs_next`. This is usually computed by:\n", + "\n", + "`v_s = critic(batch.obs)`,\n", + "\n", + "`v_s_ = critic(batch.obs_next)`,\n", + "\n", + "where both `v_s` and `v_s_` are 10 dimensional arrays and `critic` is usually a neural network.\n", + "\n", + "After we've got all those values, GAE can be computed following the equation below." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ooHNIICGUO19" + }, + "source": [ + "\\begin{aligned}\n", + "\\hat{A}_{t}^{\\mathrm{GAE}(\\gamma, \\lambda)}: =& \\sum_{l=0}^{\\infty}(\\gamma \\lambda)^{l} \\delta_{t+l}^{V}\n", + "\\end{aligned}\n", + "\n", + "where\n", + "\n", + "\\begin{equation}\n", + "\\delta_{t}^{V} \\quad=-V\\left(s_{t}\\right)+r_{t}+\\gamma V\\left(s_{t+1}\\right)\n", + "\\end{equation}\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eV6XZaouU7EV" + }, + "source": [ + "Unfortunately, if you follow this equation, which is taken from the paper, you probably will get a slightly lower performance than you expected. There are at least 3 \"bugs\" in this equation." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FCxD9gNNVYbd" + }, + "source": [ + "**First** is that Gym always returns you a `obs_next` even if this is already the last step. The value of this timestep is exactly 0 and you should not let the neural network estimate it." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "rNZNUNgQVvRJ", + "outputId": "44354595-c25a-4da8-b4d8-cffa31ac4b7d" + }, + "source": [ + "```python\n", + "# Assume v_s_ is got by calling critic(batch.obs_next)\n", + "v_s_ = np.ones(10)\n", + "v_s_ *= ~batch.done\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2EtMi18QWXTN" + }, + "source": [ + "After the fix above, we will perhaps get a more accurate estimate.\n", + "\n", + "**Secondly**, you must know when to stop bootstrapping. Usually we stop bootstrapping when we meet a `done` flag. However, in the buffer above, the last (10th) step is not marked by done=True, because the collecting has not finished. We must know all those unfinished steps so that we know when to stop bootstrapping.\n", + "\n", + "Luckily, this can be done under the assistance of buffer because buffers in Tianshou not only store data, but also help you manage data trajectories." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "saluvX4JU6bC", + "outputId": "2994d178-2f33-40a0-a6e4-067916b0b5c5" + }, + "source": [ + "```python\n", + "unfinished_indexes = dummy_buffer.unfinished_index()\n", + "done_indexes = np.where(batch.done)[0]\n", + "stop_bootstrap_ids = np.concatenate([unfinished_indexes, done_indexes])\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qp6vVE4dYWv1" + }, + "source": [ + "**Thirdly**, there are some special indexes which are marked by done flag, however its value for obs_next should not be zero. It is again because done does not differentiate between terminated and truncated. These steps are usually those at the last step of an episode, but this episode stops not because the agent can no longer get any rewards (value=0), but because the episode is too long so we have to truncate it. These kind of steps are always marked with `info['TimeLimit.truncated']=True` in Gym." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tWkqXRJfZTvV" + }, + "source": [ + "As a result, we need to rewrite the equation above\n", + "\n", + "`v_s_ *= ~batch.done`" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kms-QtxKZe-M" + }, + "source": [ + "to\n", + "\n", + "```\n", + "mask = batch.info['TimeLimit.truncated'] | (~batch.done)\n", + "v_s_ *= mask\n", + "\n", + "```\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "u_aPPoKraBu6" + }, + "source": [ + "## Summary\n", + "If you already felt bored by now, simply remember that Tianshou can help handle all these little details so that you can focus on the algorithm itself. Just call `Algorithm.compute_episodic_return()`.\n", + "\n", + "If you still feel interested, we would recommend you check Appendix C in this [paper](https://arxiv.org/abs/2107.14171v2) and implementation of `Algorithm.value_mask()` and `Algorithm.compute_episodic_return()` for details." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2cPnUXRBWKD9" + }, + "source": [ + "
\n", + "\n", + "
\n", + "
\n", + "\n", + "
" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/02_notebooks/L4_Policy.ipynb b/docs/02_notebooks/L4_Policy.ipynb deleted file mode 100644 index eed8ea344..000000000 --- a/docs/02_notebooks/L4_Policy.ipynb +++ /dev/null @@ -1,1009 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "PNM9wqstBSY_" - }, - "source": [ - "# Policy\n", - "In reinforcement learning, the agent interacts with environments to improve itself. In this tutorial we will concentrate on the agent part. In Tianshou, both the agent and the core DRL algorithm are implemented in the Policy module. Tianshou provides more than 20 Policy modules, each representing one DRL algorithm. See supported algorithms [here](https://tianshou.readthedocs.io/en/master/03_api/policy/index.html).\n", - "\n", - "
\n", - "\n", - "\n", - " The agents interacting with the environment \n", - "
\n", - "\n", - "All Policy modules inherit from a BasePolicy Class and share the same interface." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZqdHYdoJJS51" - }, - "source": [ - "## Creating your own Policy\n", - "We will use the simple PGPolicy, also called REINFORCE algorithm Policy, to show the implementation of a Policy Module. The Policy we implement here will be a scaled-down version of [PGPolicy](https://tianshou.readthedocs.io/en/master/03_api/policy/modelfree/pg.html) in Tianshou." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "editable": true, - "id": "cDlSjASbJmy-", - "slideshow": { - "slide_type": "" - }, - "tags": [ - "hide-cell", - "remove-output" - ] - }, - "outputs": [], - "source": [ - "%%capture\n", - "\n", - "from typing import Any, cast\n", - "\n", - "import gymnasium as gym\n", - "import numpy as np\n", - "import torch\n", - "\n", - "from tianshou.data import (\n", - " Batch,\n", - " ReplayBuffer,\n", - " SequenceSummaryStats,\n", - " to_torch,\n", - " to_torch_as,\n", - ")\n", - "from tianshou.data.batch import BatchProtocol\n", - "from tianshou.data.types import (\n", - " BatchWithReturnsProtocol,\n", - " DistBatchProtocol,\n", - " ObsBatchProtocol,\n", - " RolloutBatchProtocol,\n", - ")\n", - "from tianshou.policy import BasePolicy\n", - "from tianshou.policy.modelfree.pg import (\n", - " PGTrainingStats,\n", - " TDistFnDiscrOrCont,\n", - " TPGTrainingStats,\n", - ")\n", - "from tianshou.utils import RunningMeanStd\n", - "from tianshou.utils.net.common import Net\n", - "from tianshou.utils.net.discrete import Actor\n", - "from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Protocols\n", - "Note: as we learned in tutorial [L1_batch](https://tianshou.readthedocs.io/en/master/02_notebooks/L1_Batch.html#), Tianshou uses `Batch` to store data. `Batch` is a dataclass that can store any data you want. In order to have more control about what kind of batch data is expected and produced in each processing step we use protocols. \n", - "For example, `BatchWithReturnsProtocol` specifies that the batch should have fields `obs`, `act`, `rew`, `done`, `obs_next`, `info` and `returns`. This is not only for type checking, but also for IDE support.\n", - "To learn more about protocols, please refer to the official documentation ([PEP 544](https://www.python.org/dev/peps/pep-0544/)) or to mypy documentation ([Protocols](https://mypy.readthedocs.io/en/stable/protocols.html)).\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Initialization\n", - "Firstly we create the `PGPolicy` by inheriting from `BasePolicy` in Tianshou." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "```python\n", - "class PGPolicy(BasePolicy):\n", - " \"\"\"Implementation of REINFORCE algorithm.\"\"\"\n", - "\n", - " def __init__(self) -> None:\n", - " super().__init__(\n", - " action_space=action_space,\n", - " observation_space=observation_space,\n", - " )\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qc1RnIBbLCDN" - }, - "source": [ - "The Policy Module mainly does two things:\n", - "\n", - "1. `policy.forward()` receives observation and other information (stored in a Batch) from the environment and returns a new Batch containing the next action and other information.\n", - "2. `policy.update()` receives training data sampled from the replay buffer and updates the policy network. It returns a dataclass containing logging details.\n", - "\n", - "
\n", - "\n", - "\n", - " policy.forward() and policy.update() \n", - "
\n", - "\n", - "We also need to take care of the following things:\n", - "\n", - "1. Since Tianshou is a **Deep** RL libraries, there should be a policy network and a Torch optimizer in our Policy Module.\n", - "2. In Tianshou's BasePolicy, `Policy.update()` first calls `Policy.process_fn()` to \n", - "preprocess training data and computes quantities like episodic returns (gradient free), \n", - "then it will call `Policy.learn()` to perform the back-propagation.\n", - "3. Each Policy is accompanied by a dedicated implementation of `TrainingStats` , which store details of each training step.\n", - "\n", - "This is how we get the implementation below." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "```python\n", - "class PGPolicy(BasePolicy[TPGTrainingStats]):\n", - " \"\"\"Implementation of REINFORCE algorithm.\"\"\"\n", - "\n", - " def __init__(\n", - " self, \n", - " actor: torch.nn.Module, \n", - " optim: torch.optim.Optimizer, \n", - " action_space: gym.Space\n", - " ):\n", - " super().__init__(\n", - " action_space=action_space,\n", - " observation_space=observation_space\n", - " )\n", - " self.actor = model\n", - " self.optim = optim\n", - "\n", - " def process_fn(\n", - " self, \n", - " batch: RolloutBatchProtocol, \n", - " buffer: ReplayBuffer, \n", - " indices: np.ndarray\n", - " ) -> BatchWithReturnsProtocol:\n", - " \"\"\"Compute the discounted returns for each transition.\"\"\"\n", - " return batch\n", - "\n", - " def forward(\n", - " self, \n", - " batch: ObsBatchProtocol,\n", - " state: dict | BatchProtocol | np.ndarray | None = None,\n", - " **kwargs: Any\n", - " ) -> DistBatchProtocol:\n", - " \"\"\"Compute action over the given batch data.\"\"\"\n", - " act = None\n", - " return Batch(act=act)\n", - "\n", - " def learn(\n", - " self,\n", - " batch: BatchWithReturnsProtocol, \n", - " batch_size: int | None, \n", - " repeat: int,\n", - " *args: Any,\n", - " **kwargs: Any,\n", - " ) -> TPGTrainingStats:\n", - " \"\"\"Perform the back-propagation.\"\"\"\n", - " return PGTrainingStats(loss=loss_summary_stat)\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tjtqjt8WRY5e" - }, - "source": [ - "### Policy.forward()\n", - "According to the equation of REINFORCE algorithm in Spinning Up's [documentation](https://spinningup.openai.com/en/latest/algorithms/vpg.html), we need to map the observation to an action distribution in action space using the neural network (`self.actor`).\n", - "\n", - "
\n", - "\n", - "
\n", - "\n", - "Let us suppose the action space is discrete, and the distribution is a simple categorical distribution." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "```python\n", - "def forward(\n", - " self,\n", - " batch: ObsBatchProtocol,\n", - " state: dict | BatchProtocol | np.ndarray | None = None,\n", - " **kwargs: Any,\n", - ") -> DistBatchProtocol:\n", - " \"\"\"Compute action over the given batch data.\"\"\"\n", - " logits, hidden = self.actor(batch.obs, state=state)\n", - " dist = self.dist_fn(logits)\n", - " act = dist.sample()\n", - " result = Batch(logits=logits, act=act, state=hidden, dist=dist)\n", - " return result\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "CultfOeuTx2V" - }, - "source": [ - "### Policy.process_fn()\n", - "Now that we have defined our actor, if given training data we can set up a loss function and optimize our neural network. However, before that, we must first calculate episodic returns for every step in our training data to construct the REINFORCE loss function.\n", - "\n", - "Calculating episodic return is not hard, given `ReplayBuffer.next()` allows us to access every reward to go in an episode. A more convenient way would be to simply use the built-in method `BasePolicy.compute_episodic_return()` inherited from BasePolicy.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "```python\n", - "def process_fn(\n", - " self,\n", - " batch: RolloutBatchProtocol,\n", - " buffer: ReplayBuffer,\n", - " indices: np.ndarray,\n", - ") -> BatchWithReturnsProtocol:\n", - " \"\"\"Compute the discounted returns for each transition.\"\"\"\n", - " v_s_ = np.full(indices.shape, self.ret_rms.mean)\n", - " returns, _ = self.compute_episodic_return(batch, buffer, indices, v_s_=v_s_, gamma=0.99, gae_lambda=1.0)\n", - " batch.returns = returns\n", - " return batch\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XA8OF4GnWWr5" - }, - "source": [ - "`BasePolicy.compute_episodic_return()` could also be used to compute [GAE](https://arxiv.org/abs/1506.02438). Another similar method is `BasePolicy.compute_nstep_return()`. Check the [source code](https://github.com/thu-ml/tianshou/blob/6fc68578127387522424460790cbcb32a2bd43c4/tianshou/policy/base.py#L304) for more details." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7UsdzNaOXPpC" - }, - "source": [ - "### Policy.learn()\n", - "Data batch returned by `Policy.process_fn()` will flow into `Policy.learn()`. Finally,\n", - "we can construct our loss function and perform the back-propagation. The method \n", - "should look something like this:" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "```python\n", - "def learn(\n", - " self,\n", - " batch: BatchWithReturnsProtocol,\n", - " batch_size: int | None,\n", - " repeat: int,\n", - " *args: Any,\n", - " **kwargs: Any,\n", - ") -> TPGTrainingStats:\n", - " \"\"\"Perform the back-propagation.\"\"\"\n", - " losses = []\n", - " split_batch_size = batch_size or -1\n", - " for _ in range(repeat):\n", - " for minibatch in batch.split(split_batch_size, merge_last=True):\n", - " self.optim.zero_grad()\n", - " result = self(minibatch)\n", - " dist = result.dist\n", - " act = to_torch_as(minibatch.act, result.act)\n", - " ret = to_torch(minibatch.returns, torch.float, result.act.device)\n", - " log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)\n", - " loss = -(log_prob * ret).mean()\n", - " loss.backward()\n", - " self.optim.step()\n", - " losses.append(loss.item())\n", - " loss_summary_stat = SequenceSummaryStats.from_sequence(losses)\n", - "\n", - " return PGTrainingStats(loss=loss_summary_stat)\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1BtuV2W0YJTi" - }, - "source": [ - "## Implementation\n", - "Now we can assemble the methods and form a PGPolicy. The outputs of\n", - "`learn` will be collected to a dedicated dataclass." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class PGPolicy(BasePolicy[TPGTrainingStats]):\n", - " \"\"\"Implementation of REINFORCE algorithm.\"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " *,\n", - " actor: torch.nn.Module,\n", - " optim: torch.optim.Optimizer,\n", - " dist_fn: TDistFnDiscrOrCont,\n", - " action_space: gym.Space,\n", - " discount_factor: float = 0.99,\n", - " observation_space: gym.Space | None = None,\n", - " ) -> None:\n", - " super().__init__(\n", - " action_space=action_space,\n", - " observation_space=observation_space,\n", - " )\n", - " self.actor = actor\n", - " self.optim = optim\n", - " self.dist_fn = dist_fn\n", - " assert 0.0 <= discount_factor <= 1.0, \"discount factor should be in [0, 1]\"\n", - " self.gamma = discount_factor\n", - " self.ret_rms = RunningMeanStd()\n", - "\n", - " def process_fn(\n", - " self,\n", - " batch: RolloutBatchProtocol,\n", - " buffer: ReplayBuffer,\n", - " indices: np.ndarray,\n", - " ) -> BatchWithReturnsProtocol:\n", - " \"\"\"Compute the discounted returns (Monte Carlo estimates) for each transition.\n", - "\n", - " They are added to the batch under the field `returns`.\n", - " Note: this function will modify the input batch!\n", - " \"\"\"\n", - " v_s_ = np.full(indices.shape, self.ret_rms.mean)\n", - " # use a function inherited from BasePolicy to compute returns\n", - " # gae_lambda = 1.0 means we use Monte Carlo estimate\n", - " batch.returns, _ = self.compute_episodic_return(\n", - " batch,\n", - " buffer,\n", - " indices,\n", - " v_s_=v_s_,\n", - " gamma=self.gamma,\n", - " gae_lambda=1.0,\n", - " )\n", - " batch: BatchWithReturnsProtocol\n", - " return batch\n", - "\n", - " def forward(\n", - " self,\n", - " batch: ObsBatchProtocol,\n", - " state: dict | BatchProtocol | np.ndarray | None = None,\n", - " **kwargs: Any,\n", - " ) -> DistBatchProtocol:\n", - " \"\"\"Compute action over the given batch data by applying the actor.\n", - "\n", - " Will sample from the dist_fn, if appropriate.\n", - " Returns a new object representing the processed batch data\n", - " (contrary to other methods that modify the input batch inplace).\n", - " \"\"\"\n", - " logits, hidden = self.actor(batch.obs, state=state)\n", - "\n", - " if isinstance(logits, tuple):\n", - " dist = self.dist_fn(*logits)\n", - " else:\n", - " dist = self.dist_fn(logits)\n", - "\n", - " act = dist.sample()\n", - " return cast(DistBatchProtocol, Batch(logits=logits, act=act, state=hidden, dist=dist))\n", - "\n", - " def learn( # type: ignore\n", - " self,\n", - " batch: BatchWithReturnsProtocol,\n", - " batch_size: int | None,\n", - " repeat: int,\n", - " *args: Any,\n", - " **kwargs: Any,\n", - " ) -> TPGTrainingStats:\n", - " losses = []\n", - " split_batch_size = batch_size or -1\n", - " for _ in range(repeat):\n", - " for minibatch in batch.split(split_batch_size, merge_last=True):\n", - " self.optim.zero_grad()\n", - " result = self(minibatch)\n", - " dist = result.dist\n", - " act = to_torch_as(minibatch.act, result.act)\n", - " ret = to_torch(minibatch.returns, torch.float, result.act.device)\n", - " log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)\n", - " loss = -(log_prob * ret).mean()\n", - " loss.backward()\n", - " self.optim.step()\n", - " losses.append(loss.item())\n", - "\n", - " loss_summary_stat = SequenceSummaryStats.from_sequence(losses)\n", - "\n", - " return PGTrainingStats(loss=loss_summary_stat) # type: ignore[return-value]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xlPAbh0lKti8" - }, - "source": [ - "## Use the policy\n", - "Note that `BasePolicy` itself inherits from `torch.nn.Module`. As a result, you can consider all Policy modules as a Torch Module. They share similar APIs.\n", - "\n", - "Firstly we will initialize a new PGPolicy." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "JkLFA9Z1KjuX" - }, - "outputs": [], - "source": [ - "state_shape = 4\n", - "action_shape = 2\n", - "# Usually taken from an env by using env.action_space\n", - "action_space = gym.spaces.Box(low=-1, high=1, shape=(2,))\n", - "net = Net(state_shape, hidden_sizes=[16, 16], device=\"cpu\")\n", - "actor = Actor(net, action_shape, device=\"cpu\").to(\"cpu\")\n", - "optim = torch.optim.Adam(actor.parameters(), lr=0.0003)\n", - "dist_fn = torch.distributions.Categorical\n", - "\n", - "policy: BasePolicy\n", - "policy = PGPolicy(actor=actor, optim=optim, dist_fn=dist_fn, action_space=action_space)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "LAo_0t2fekUD" - }, - "source": [ - "PGPolicy shares same APIs with the Torch Module." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "UiuTc8RhJiEi", - "outputId": "9b5bc54c-6303-45f3-ba81-2216a44931e8" - }, - "outputs": [], - "source": [ - "print(policy)\n", - "print(\"========================================\")\n", - "for param in policy.parameters():\n", - " print(param.shape)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-RCrsttYgAG-" - }, - "source": [ - "### Making decision\n", - "Given a batch of observations, the policy can return a batch of actions and other data." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "0jkBb6AAgUla", - "outputId": "37948844-cdd8-4567-9481-89453c80a157" - }, - "outputs": [], - "source": [ - "obs_batch = Batch(obs=np.ones(shape=(256, 4)))\n", - "dist_batch = policy(obs_batch) # forward() method is called\n", - "print(\"Next action for each observation: \\n\", dist_batch.act)\n", - "print(\"Dsitribution: \\n\", dist_batch.dist)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "swikhnuDfKep" - }, - "source": [ - "### Save and Load models\n", - "Naturally, Tianshou Policy can be saved and loaded like a normal Torch Network." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "tYOoWM_OJRnA" - }, - "outputs": [], - "source": [ - "torch.save(policy.state_dict(), \"policy.pth\")\n", - "assert policy.load_state_dict(torch.load(\"policy.pth\"))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gp8PzOYsg5z-" - }, - "source": [ - "### Algorithm Updating\n", - "We have to collect some data and save them in the ReplayBuffer before updating our agent(policy). Typically we use collector to collect data, but we leave this part till later when we have learned the Collector in Tianshou. For now we generate some **fake** data." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XrrPxOUAYShR" - }, - "source": [ - "#### Generating fake data\n", - "Firstly, we need to \"pretend\" that we are using the \"Policy\" to collect data. We plan to collect 10 data so that we can update our algorithm." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "a14CmzSfYh5C", - "outputId": "aaf45a1f-5e21-4bc8-cbe3-8ce798258af0" - }, - "outputs": [], - "source": [ - "dummy_buffer = ReplayBuffer(size=10)\n", - "print(dummy_buffer)\n", - "print(f\"maxsize: {dummy_buffer.maxsize}, data length: {len(dummy_buffer)}\")\n", - "env = gym.make(\"CartPole-v1\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8S94cV7yZITR" - }, - "source": [ - "Now we are pretending to collect the first episode. The first episode ends at step 3 (perhaps because we are performing too badly)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "a_mtvbmBZbfs" - }, - "outputs": [], - "source": [ - "obs, info = env.reset()\n", - "for i in range(3):\n", - " act = policy(Batch(obs=obs[np.newaxis, :])).act.item()\n", - " obs_next, rew, _, truncated, info = env.step(act)\n", - " # pretend ending at step 3\n", - " terminated = i == 2\n", - " info[\"id\"] = i\n", - " dummy_buffer.add(\n", - " Batch(\n", - " obs=obs,\n", - " act=act,\n", - " rew=rew,\n", - " terminated=terminated,\n", - " truncated=truncated,\n", - " obs_next=obs_next,\n", - " info=info,\n", - " ),\n", - " )\n", - " obs = obs_next" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(dummy_buffer)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pkxq4gu9bGkt" - }, - "source": [ - "Now we are pretending to collect the second episode. At step 7 the second episode still doesn't end, but we are unwilling to wait, so we stop collecting to update the algorithm." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "pAoKe02ybG68" - }, - "outputs": [], - "source": [ - "obs, info = env.reset()\n", - "for i in range(3, 10):\n", - " # For retrieving actions to be used for training, we set the policy to training mode,\n", - " # but the wrapped torch module should be in eval mode.\n", - " with policy_within_training_step(policy), torch_train_mode(policy, enabled=False):\n", - " act = policy(Batch(obs=obs[np.newaxis, :])).act.item()\n", - " obs_next, rew, _, truncated, info = env.step(act)\n", - " # pretend this episode never end\n", - " terminated = False\n", - " info[\"id\"] = i\n", - " dummy_buffer.add(\n", - " Batch(\n", - " obs=obs,\n", - " act=act,\n", - " rew=rew,\n", - " terminated=terminated,\n", - " truncated=truncated,\n", - " obs_next=obs_next,\n", - " info=info,\n", - " ),\n", - " )\n", - " obs = obs_next" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MKM6aWMucv-M" - }, - "source": [ - "Our replay buffer looks like this now." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "CSJEEWOqXdTU", - "outputId": "2b3bb75c-f219-4e82-ca78-0ea6173a91f9" - }, - "outputs": [], - "source": [ - "print(dummy_buffer)\n", - "print(f\"maxsize: {dummy_buffer.maxsize}, data length: {len(dummy_buffer)}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "55VWhWpkdfEb" - }, - "source": [ - "#### Updates\n", - "Now we have got a replay buffer with 10 data steps in it. We can call `Policy.update()` to train.\n", - "\n", - "However, we need to manually set the torch module to training mode prior to that, \n", - "and also declare that we are within a training step. Tianshou Trainers will take care of that automatically,\n", - "but users need to consider it when calling `.update` outside of the trainer." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "i_O1lJDWdeoc", - "outputId": "b154741a-d6dc-46cb-898f-6e84fa14e5a7" - }, - "outputs": [], - "source": [ - "# 0 means sample all data from the buffer\n", - "\n", - "# For updating the policy, the policy should be in training mode\n", - "# and the wrapped torch module should also be in training mode (unlike when collecting data).\n", - "with policy_within_training_step(policy), torch_train_mode(policy):\n", - " policy.update(sample_size=0, buffer=dummy_buffer, batch_size=10, repeat=6).pprint_asdict()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QJ5krjrcbuiA" - }, - "source": [ - "## Further Reading\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pmWi3HuXWcV8" - }, - "source": [ - "### Pre-defined Networks\n", - "Tianshou provides numerous pre-defined networks usually used in DRL so that you don't have to bother yourself. Check this [documentation](https://tianshou.readthedocs.io/en/master/03_api/utils/net/index.html) for details." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "UPVl5LBEWJ0t" - }, - "source": [ - "### How to compute GAE on your own?\n", - "(Note that for this reading you need to understand the calculation of [GAE](https://arxiv.org/abs/1506.02438) advantage first)\n", - "\n", - "In terms of code implementation, perhaps the most difficult and annoying part is computing GAE advantage. Just now, we use the `self.compute_episodic_return()` method inherited from `BasePolicy` to save us from all those troubles. However, it is still important that we know the details behind this.\n", - "\n", - "To compute GAE advantage, the usage of `self.compute_episodic_return()` may go like:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "D34GlVvPNz08", - "outputId": "43a4e5df-59b5-4e4a-c61c-e69090810215" - }, - "outputs": [], - "source": [ - "batch, indices = dummy_buffer.sample(0) # 0 means sampling all the data from the buffer\n", - "returns, advantage = BasePolicy.compute_episodic_return(\n", - " batch=batch,\n", - " buffer=dummy_buffer,\n", - " indices=indices,\n", - " v_s_=np.zeros(10),\n", - " v_s=np.zeros(10),\n", - " gamma=1.0,\n", - " gae_lambda=1.0,\n", - ")\n", - "print(f\"{batch.rew=}\")\n", - "print(f\"{batch.done=}\")\n", - "print(f\"{returns=}\")\n", - "print(f\"{advantage=}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In the code above, we sample all the 10 data in the buffer and try to compute the GAE advantage. However, the way the returns are computed here might be a bit misleading. In fact, the last episode is unfinished, but its last step saved in the batch is treated as a terminal state, since it assumes that there are no future rewards. The episode is not terminated yet, it is truncated, so the agent could still get rewards in the future. Terminated and truncated episodes should indeed be treated differently.\n", - "The return of a step is the (discounted) sum of the future rewards from that step until the end of the episode. \n", - "\\begin{equation}\n", - "R_{t}=\\sum_{t}^{T} \\gamma^{t} r_{t}\n", - "\\end{equation}\n", - "Thus, at the last step of a terminated episode the return is equal to the reward at that state, since there are no future states.\n", - "\\begin{equation}\n", - "R_{T,terminated}=r_{T}\n", - "\\end{equation}\n", - "\n", - "However, if the episode was truncated the return at the last step is usually better represented by the estimated value of that state, which is the expected return from that state onwards.\n", - "\\begin{align*}\n", - "R_{T,truncated}=V^{\\pi}\\left(s_{T}\\right) \\quad & \\text{or} \\quad R_{T,truncated}=Q^{\\pi}(s_{T},a_{T})\n", - "\\end{align*}\n", - "Moreover, if the next state was also observed (but not its reward), then an even better estimate would be the reward of the last step plus the discounted value of the next state.\n", - "\\begin{align*}\n", - "R_{T,truncated}=r_T+\\gamma V^{\\pi}\\left(s_{T+1}\\right)\n", - "\\end{align*}" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "h_5Dt6XwQLXV" - }, - "source": [ - "\n", - "As we know, we need to estimate the value function of every observation to compute GAE advantage. So in `v_s` is the value of `batch.obs`, and in `v_s_` is the value of `batch.obs_next`. This is usually computed by:\n", - "\n", - "`v_s = critic(batch.obs)`,\n", - "\n", - "`v_s_ = critic(batch.obs_next)`,\n", - "\n", - "where both `v_s` and `v_s_` are 10 dimensional arrays and `critic` is usually a neural network.\n", - "\n", - "After we've got all those values, GAE can be computed following the equation below." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ooHNIICGUO19" - }, - "source": [ - "\\begin{aligned}\n", - "\\hat{A}_{t}^{\\mathrm{GAE}(\\gamma, \\lambda)}: =& \\sum_{l=0}^{\\infty}(\\gamma \\lambda)^{l} \\delta_{t+l}^{V}\n", - "\\end{aligned}\n", - "\n", - "where\n", - "\n", - "\\begin{equation}\n", - "\\delta_{t}^{V} \\quad=-V\\left(s_{t}\\right)+r_{t}+\\gamma V\\left(s_{t+1}\\right)\n", - "\\end{equation}\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "eV6XZaouU7EV" - }, - "source": [ - "Unfortunately, if you follow this equation, which is taken from the paper, you probably will get a slightly lower performance than you expected. There are at least 3 \"bugs\" in this equation." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FCxD9gNNVYbd" - }, - "source": [ - "**First** is that Gym always returns you a `obs_next` even if this is already the last step. The value of this timestep is exactly 0 and you should not let the neural network estimate it." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "rNZNUNgQVvRJ", - "outputId": "44354595-c25a-4da8-b4d8-cffa31ac4b7d" - }, - "outputs": [], - "source": [ - "# Assume v_s_ is got by calling critic(batch.obs_next)\n", - "v_s_ = np.ones(10)\n", - "v_s_ *= ~batch.done\n", - "print(f\"{v_s_=}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2EtMi18QWXTN" - }, - "source": [ - "After the fix above, we will perhaps get a more accurate estimate.\n", - "\n", - "**Secondly**, you must know when to stop bootstrapping. Usually we stop bootstrapping when we meet a `done` flag. However, in the buffer above, the last (10th) step is not marked by done=True, because the collecting has not finished. We must know all those unfinished steps so that we know when to stop bootstrapping.\n", - "\n", - "Luckily, this can be done under the assistance of buffer because buffers in Tianshou not only store data, but also help you manage data trajectories." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "saluvX4JU6bC", - "outputId": "2994d178-2f33-40a0-a6e4-067916b0b5c5" - }, - "outputs": [], - "source": [ - "unfinished_indexes = dummy_buffer.unfinished_index()\n", - "print(\"unfinished_indexes: \", unfinished_indexes)\n", - "done_indexes = np.where(batch.done)[0]\n", - "print(\"done_indexes: \", done_indexes)\n", - "stop_bootstrap_ids = np.concatenate([unfinished_indexes, done_indexes])\n", - "print(\"stop_bootstrap_ids: \", stop_bootstrap_ids)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qp6vVE4dYWv1" - }, - "source": [ - "**Thirdly**, there are some special indexes which are marked by done flag, however its value for obs_next should not be zero. It is again because done does not differentiate between terminated and truncated. These steps are usually those at the last step of an episode, but this episode stops not because the agent can no longer get any rewards (value=0), but because the episode is too long so we have to truncate it. These kind of steps are always marked with `info['TimeLimit.truncated']=True` in Gym." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tWkqXRJfZTvV" - }, - "source": [ - "As a result, we need to rewrite the equation above\n", - "\n", - "`v_s_ *= ~batch.done`" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kms-QtxKZe-M" - }, - "source": [ - "to\n", - "\n", - "```\n", - "mask = batch.info['TimeLimit.truncated'] | (~batch.done)\n", - "v_s_ *= mask\n", - "\n", - "```\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "u_aPPoKraBu6" - }, - "source": [ - "### Summary\n", - "If you already felt bored by now, simply remember that Tianshou can help handle all these little details so that you can focus on the algorithm itself. Just call `BasePolicy.compute_episodic_return()`.\n", - "\n", - "If you still feel interested, we would recommend you check Appendix C in this [paper](https://arxiv.org/abs/2107.14171v2) and implementation of `BasePolicy.value_mask()` and `BasePolicy.compute_episodic_return()` for details." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2cPnUXRBWKD9" - }, - "source": [ - "
\n", - "\n", - "
\n", - "
\n", - "\n", - "
" - ] - } - ], - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.7" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/docs/02_notebooks/L5_Collector.ipynb b/docs/02_notebooks/L5_Collector.ipynb index d10df1666..a52dd25eb 100644 --- a/docs/02_notebooks/L5_Collector.ipynb +++ b/docs/02_notebooks/L5_Collector.ipynb @@ -8,10 +8,6 @@ "source": [ "# Collector\n", "From its literal meaning, we can easily know that the Collector in Tianshou is used to collect training data. More specifically, the Collector controls the interaction between Policy (agent) and the environment. It also helps save the interaction data into the ReplayBuffer and returns episode statistics.\n", - "\n", - "
\n", - "\n", - "
\n", "\n" ] }, @@ -53,16 +49,10 @@ }, "outputs": [], "source": [ - "%%capture\n", - "\n", "import gymnasium as gym\n", "import torch\n", "\n", - "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", - "from tianshou.env import DummyVectorEnv\n", - "from tianshou.policy import PGPolicy\n", - "from tianshou.utils.net.common import Net\n", - "from tianshou.utils.net.discrete import Actor" + "from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy" ] }, { @@ -71,25 +61,28 @@ "metadata": {}, "outputs": [], "source": [ + "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", + "from tianshou.env import DummyVectorEnv\n", + "from tianshou.utils.net.common import Net\n", + "from tianshou.utils.net.discrete import DiscreteActor\n", + "\n", "env = gym.make(\"CartPole-v1\")\n", "test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(2)])\n", "\n", "# model\n", "assert env.observation_space.shape is not None # for mypy\n", - "net = Net(\n", - " env.observation_space.shape,\n", + "preprocess_net = Net(\n", + " state_shape=env.observation_space.shape,\n", " hidden_sizes=[\n", " 16,\n", " ],\n", ")\n", "\n", "assert isinstance(env.action_space, gym.spaces.Discrete) # for mypy\n", - "actor = Actor(net, env.action_space.n)\n", - "optim = torch.optim.Adam(actor.parameters(), lr=0.0003)\n", + "actor = DiscreteActor(preprocess_net=preprocess_net, action_shape=env.action_space.n)\n", "\n", - "policy: PGPolicy = PGPolicy(\n", + "policy = ProbabilisticActorPolicy(\n", " actor=actor,\n", - " optim=optim,\n", " dist_fn=torch.distributions.Categorical,\n", " action_space=env.action_space,\n", " action_scaling=False,\n", @@ -270,7 +263,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.11.4" } }, "nbformat": 4, diff --git a/docs/02_notebooks/L6_Trainer.ipynb b/docs/02_notebooks/L6_Trainer.ipynb deleted file mode 100644 index c10cfcbe2..000000000 --- a/docs/02_notebooks/L6_Trainer.ipynb +++ /dev/null @@ -1,283 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "S3-tJZy35Ck_" - }, - "source": [ - "# Trainer\n", - "Trainer is the highest-level encapsulation in Tianshou. It controls the training loop and the evaluation method. It also controls the interaction between the Collector and the Policy, with the ReplayBuffer serving as the media.\n", - "\n", - "
\n", - "\n", - "
\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ifsEQMzZ6mmz" - }, - "source": [ - "## Usages\n", - "In Tianshou v0.5.1, there are three types of Trainer. They are designed to be used in on-policy training, off-policy training and offline training respectively. We will use on-policy trainer as an example and leave the other two for further reading." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XfsuU2AAE52C" - }, - "source": [ - "### Pseudocode\n", - "
\n", - "\n", - "
\n", - "\n", - "For the on-policy trainer, the main difference is that we clear the buffer after Line 10." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Hcp_o0CCFz12" - }, - "source": [ - "### Training without trainer\n", - "As we have learned the usages of the Collector and the Policy, it's possible that we write our own training logic.\n", - "\n", - "First, let us create the instances of Environment, ReplayBuffer, Policy and Collector." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "editable": true, - "id": "do-xZ-8B7nVH", - "slideshow": { - "slide_type": "" - }, - "tags": [ - "hide-cell", - "remove-output" - ] - }, - "outputs": [], - "source": [ - "%%capture\n", - "\n", - "import gymnasium as gym\n", - "import torch\n", - "\n", - "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", - "from tianshou.env import DummyVectorEnv\n", - "from tianshou.policy import PGPolicy\n", - "from tianshou.trainer import OnpolicyTrainer\n", - "from tianshou.utils.net.common import Net\n", - "from tianshou.utils.net.discrete import Actor\n", - "from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "train_env_num = 4\n", - "# Since REINFORCE is an on-policy algorithm, we don't need a very large buffer size\n", - "buffer_size = 2000\n", - "\n", - "\n", - "# Create the environments, used for training and evaluation\n", - "env = gym.make(\"CartPole-v1\")\n", - "test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(2)])\n", - "train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(train_env_num)])\n", - "\n", - "# Create the Policy instance\n", - "assert env.observation_space.shape is not None\n", - "net = Net(\n", - " env.observation_space.shape,\n", - " hidden_sizes=[\n", - " 16,\n", - " ],\n", - ")\n", - "\n", - "assert isinstance(env.action_space, gym.spaces.Discrete)\n", - "actor = Actor(net, env.action_space.n)\n", - "optim = torch.optim.Adam(actor.parameters(), lr=0.001)\n", - "\n", - "# We choose to use REINFORCE algorithm, also known as Policy Gradient\n", - "policy: PGPolicy = PGPolicy(\n", - " actor=actor,\n", - " optim=optim,\n", - " dist_fn=torch.distributions.Categorical,\n", - " action_space=env.action_space,\n", - " action_scaling=False,\n", - ")\n", - "\n", - "# Create the replay buffer and the collector\n", - "replayBuffer = VectorReplayBuffer(buffer_size, train_env_num)\n", - "test_collector = Collector[CollectStats](policy, test_envs)\n", - "train_collector = Collector[CollectStats](policy, train_envs, replayBuffer)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "wiEGiBgQIiFM" - }, - "source": [ - "Now, we can try training our policy network. The logic is simple. We collect some data into the buffer and then we use the data to train our policy." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "JMUNPN5SI_kd", - "outputId": "7d68323c-0322-4b82-dafb-7c7f63e7a26d" - }, - "outputs": [], - "source": [ - "train_collector.reset()\n", - "train_envs.reset()\n", - "test_collector.reset()\n", - "test_envs.reset()\n", - "replayBuffer.reset()\n", - "\n", - "n_episode = 10\n", - "for _i in range(n_episode):\n", - " # for test collector, we set the wrapped torch module to evaluation mode\n", - " # by default, the policy object itself is not within the training step\n", - " with torch_train_mode(policy, enabled=False):\n", - " evaluation_result = test_collector.collect(n_episode=n_episode)\n", - " print(f\"Evaluation mean episodic reward is: {evaluation_result.returns.mean()}\")\n", - " # for collecting data for training, the policy object should be within the training step\n", - " # (affecting e.g. whether the policy is stochastic or deterministic)\n", - " with policy_within_training_step(policy):\n", - " train_collector.collect(n_step=2000)\n", - " # 0 means taking all data stored in train_collector.buffer\n", - " # for updating the policy, the wrapped torch module should be in training mode\n", - " with torch_train_mode(policy):\n", - " policy.update(sample_size=None, buffer=train_collector.buffer, batch_size=512, repeat=1)\n", - " train_collector.reset_buffer(keep_statistics=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QXBHIBckMs_2" - }, - "source": [ - "The evaluation reward doesn't seem to improve. That is simply because we haven't trained it for enough time. Plus, the network size is too small and REINFORCE algorithm is actually not very stable. Don't worry, we will solve this problem in the end. Still we get some idea on how to start a training loop." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "p-7U_cwgF5Ej" - }, - "source": [ - "### Training with trainer\n", - "The trainer does almost the same thing. The only difference is that it has considered many details and is more modular." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "vcvw9J8RNtFE", - "outputId": "b483fa8b-2a57-4051-a3d0-6d8162d948c5", - "tags": [ - "remove-output" - ] - }, - "outputs": [], - "source": [ - "train_collector.reset()\n", - "train_envs.reset()\n", - "test_collector.reset()\n", - "test_envs.reset()\n", - "replayBuffer.reset()\n", - "\n", - "result = OnpolicyTrainer(\n", - " policy=policy,\n", - " train_collector=train_collector,\n", - " test_collector=test_collector,\n", - " max_epoch=10,\n", - " step_per_epoch=1,\n", - " repeat_per_collect=1,\n", - " episode_per_test=10,\n", - " step_per_collect=2000,\n", - " batch_size=512,\n", - ").run()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "result.pprint_asdict()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_j3aUJZQ7nml" - }, - "source": [ - "## Further Reading\n", - "### Logger usages\n", - "Tianshou provides experiment loggers that are both tensorboard- and wandb-compatible. It also has a BaseLogger Class which allows you to self-define your own logger. Check the [documentation](https://tianshou.org/en/master/03_api/utils/logger/base.html#tianshou.utils.logger.base.BaseLogger) for details.\n", - "\n", - "### Learn more about the APIs of Trainers\n", - "[documentation](https://tianshou.org/en/master/03_api/trainer/index.html)" - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [ - "S3-tJZy35Ck_", - "XfsuU2AAE52C", - "p-7U_cwgF5Ej", - "_j3aUJZQ7nml" - ], - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.4" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/docs/02_notebooks/L7_Experiment.ipynb b/docs/02_notebooks/L7_Experiment.ipynb deleted file mode 100644 index 47e4cb0c9..000000000 --- a/docs/02_notebooks/L7_Experiment.ipynb +++ /dev/null @@ -1,341 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "_UaXOSRjDUF9" - }, - "source": [ - "# Experiment\n", - "Finally, we can assemble building blocks that we have came across in previous tutorials to conduct our first DRL experiment. In this experiment, we will use [PPO](https://arxiv.org/abs/1707.06347) algorithm to solve the classic CartPole task in Gym." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2QRbCJvDHNAd" - }, - "source": [ - "## Experiment\n", - "To conduct this experiment, we need the following building blocks.\n", - "\n", - "\n", - "* Two vectorized environments, one for training and one for evaluation\n", - "* A PPO agent\n", - "* A replay buffer to store transition data\n", - "* Two collectors to manage the data collecting process, one for training and one for evaluation\n", - "* A trainer to manage the training loop\n", - "\n", - "
\n", - "\n", - "\n", - "
\n", - "\n", - "Let us do this step by step." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-Hh4E6i0Hj0I" - }, - "source": [ - "## Preparation\n", - "Firstly, install Tianshou if you haven't installed it before." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7E4EhiBeHxD5" - }, - "source": [ - "Import libraries we might need later." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "editable": true, - "id": "ao9gWJDiHgG-", - "tags": [ - "hide-cell", - "remove-output" - ] - }, - "outputs": [], - "source": [ - "%%capture\n", - "\n", - "import gymnasium as gym\n", - "import torch\n", - "\n", - "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", - "from tianshou.env import DummyVectorEnv\n", - "from tianshou.policy import PPOPolicy\n", - "from tianshou.trainer import OnpolicyTrainer\n", - "from tianshou.utils.net.common import ActorCritic, Net\n", - "from tianshou.utils.net.discrete import Actor, Critic\n", - "\n", - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QnRg5y7THRYw" - }, - "source": [ - "## Environment" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "YZERKCGtH8W1" - }, - "source": [ - "We create two vectorized environments both for training and testing. Since the execution time of CartPole is extremely short, there is no need to use multi-process wrappers and we simply use DummyVectorEnv." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Mpuj5PFnDKVS" - }, - "outputs": [], - "source": [ - "env = gym.make(\"CartPole-v1\")\n", - "train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(20)])\n", - "test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(10)])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BJtt_Ya8DTAh" - }, - "source": [ - "## Policy\n", - "Next we need to initialize our PPO policy. PPO is an actor-critic-style on-policy algorithm, so we have to define the actor and the critic in PPO first.\n", - "\n", - "The actor is a neural network that shares the same network head with the critic. Both networks' input is the environment observation. The output of the actor is the action and the output of the critic is a single value, representing the value of the current policy.\n", - "\n", - "Luckily, Tianshou already provides basic network modules that we can use in this experiment." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "_Vy8uPWXP4m_" - }, - "outputs": [], - "source": [ - "# net is the shared head of the actor and the critic\n", - "assert env.observation_space.shape is not None # for mypy\n", - "assert isinstance(env.action_space, gym.spaces.Discrete) # for mypy\n", - "net = Net(state_shape=env.observation_space.shape, hidden_sizes=[64, 64], device=device)\n", - "actor = Actor(preprocess_net=net, action_shape=env.action_space.n, device=device).to(device)\n", - "critic = Critic(preprocess_net=net, device=device).to(device)\n", - "actor_critic = ActorCritic(actor=actor, critic=critic)\n", - "\n", - "# optimizer of the actor and the critic\n", - "optim = torch.optim.Adam(actor_critic.parameters(), lr=0.0003)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Lh2-hwE5Dn9I" - }, - "source": [ - "Once we have defined the actor, the critic and the optimizer, we can use them to construct our PPO agent. CartPole is a discrete action space problem, so the distribution of our action space can be a categorical distribution." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "OiJ2GkT0Qnbr" - }, - "outputs": [], - "source": [ - "dist = torch.distributions.Categorical\n", - "policy: PPOPolicy = PPOPolicy(\n", - " actor=actor,\n", - " critic=critic,\n", - " optim=optim,\n", - " dist_fn=dist,\n", - " action_space=env.action_space,\n", - " deterministic_eval=True,\n", - " action_scaling=False,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "okxfj6IEQ-r8" - }, - "source": [ - "`deterministic_eval=True` means that we want to sample actions during training but we would like to always use the best action in evaluation. No randomness included." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "n5XAAbuBZarO" - }, - "source": [ - "## Collector\n", - "We can set up the collectors now. Train collector is used to collect and store training data, so an additional replay buffer has to be passed in." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ezwz0qerZhQM" - }, - "outputs": [], - "source": [ - "train_collector = Collector[CollectStats](\n", - " policy=policy,\n", - " env=train_envs,\n", - " buffer=VectorReplayBuffer(20000, len(train_envs)),\n", - ")\n", - "test_collector = Collector[CollectStats](policy=policy, env=test_envs)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZaoPxOd2hm0b" - }, - "source": [ - "We use `VectorReplayBuffer` here because it's more efficient to collaborate with vectorized environments, you can simply consider `VectorReplayBuffer` as a a list of ordinary replay buffers." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qBoE9pLUiC-8" - }, - "source": [ - "## Trainer\n", - "Finally, we can use the trainer to help us set up the training loop." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "editable": true, - "id": "i45EDnpxQ8gu", - "outputId": "b1666b88-0bfa-4340-868e-58611872d988", - "tags": [ - "remove-output" - ] - }, - "outputs": [], - "source": [ - "result = OnpolicyTrainer(\n", - " policy=policy,\n", - " train_collector=train_collector,\n", - " test_collector=test_collector,\n", - " max_epoch=10,\n", - " step_per_epoch=50000,\n", - " repeat_per_collect=10,\n", - " episode_per_test=10,\n", - " batch_size=256,\n", - " step_per_collect=2000,\n", - " stop_fn=lambda mean_reward: mean_reward >= 195,\n", - ").run()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ckgINHE2iTFR" - }, - "source": [ - "## Results\n", - "Print the training result." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "tJCPgmiyiaaX", - "outputId": "40123ae3-3365-4782-9563-46c43812f10f", - "tags": [] - }, - "outputs": [], - "source": [ - "result.pprint_asdict()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "A-MJ9avMibxN" - }, - "source": [ - "We can also test our trained agent." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "mnMANFcciiAQ", - "outputId": "6febcc1e-7265-4a75-c9dd-34e29a3e5d21" - }, - "outputs": [], - "source": [ - "# Let's watch its performance!\n", - "policy.eval()\n", - "result = test_collector.collect(n_episode=1, render=False)\n", - "print(f\"Final episode reward: {result.returns.mean()}, length: {result.lens.mean()}\")" - ] - } - ], - "metadata": { - "colab": { - "provenance": [], - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/docs/_config.yml b/docs/_config.yml index a0bb290a2..fce609211 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -103,6 +103,9 @@ sphinx: config : # key-value pairs to directly over-ride the Sphinx configuration autodoc_typehints_format: "short" autodoc_member_order: "bysource" + autodoc_mock_imports: + # mock imports for optional dependencies (e.g. dependencies of atari/atari_wrapper) + - cv2 autoclass_content: "both" autodoc_default_options: show-inheritance: True diff --git a/docs/index.rst b/docs/index.rst index c7c217759..4bfbbdd17 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -9,37 +9,37 @@ Welcome to Tianshou! **Tianshou** (`天授 `_) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed framework and pythonic API for building the deep reinforcement learning agent. The supported interface algorithms include: -* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network `_ -* :class:`~tianshou.policy.DQNPolicy` `Double DQN `_ -* :class:`~tianshou.policy.DQNPolicy` `Dueling DQN `_ -* :class:`~tianshou.policy.BranchingDQNPolicy` `Branching DQN `_ -* :class:`~tianshou.policy.C51Policy` `Categorical DQN `_ -* :class:`~tianshou.policy.RainbowPolicy` `Rainbow DQN `_ -* :class:`~tianshou.policy.QRDQNPolicy` `Quantile Regression DQN `_ -* :class:`~tianshou.policy.IQNPolicy` `Implicit Quantile Network `_ -* :class:`~tianshou.policy.FQFPolicy` `Fully-parameterized Quantile Function `_ -* :class:`~tianshou.policy.PGPolicy` `Policy Gradient `_ -* :class:`~tianshou.policy.NPGPolicy` `Natural Policy Gradient `_ -* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic `_ -* :class:`~tianshou.policy.TRPOPolicy` `Trust Region Policy Optimization `_ -* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization `_ -* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient `_ -* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG `_ -* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic `_ -* :class:`~tianshou.policy.REDQPolicy` `Randomized Ensembled Double Q-Learning `_ -* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic `_ -* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning -* :class:`~tianshou.policy.BCQPolicy` `Batch-Constrained deep Q-Learning `_ -* :class:`~tianshou.policy.CQLPolicy` `Conservative Q-Learning `_ -* :class:`~tianshou.policy.TD3BCPolicy` `Twin Delayed DDPG with Behavior Cloning `_ -* :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning `_ -* :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning `_ -* :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression `_ -* :class:`~tianshou.policy.GAILPolicy` `Generative Adversarial Imitation Learning `_ -* :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning `_ -* :class:`~tianshou.policy.ICMPolicy` `Intrinsic Curiosity Module `_ +* :class:`~tianshou.algorithm.DQNPolicy` `Deep Q-Network `_ +* :class:`~tianshou.algorithm.DQNPolicy` `Double DQN `_ +* :class:`~tianshou.algorithm.DQNPolicy` `Dueling DQN `_ +* :class:`~tianshou.algorithm.BranchingDQNPolicy` `Branching DQN `_ +* :class:`~tianshou.algorithm.C51Policy` `Categorical DQN `_ +* :class:`~tianshou.algorithm.RainbowPolicy` `Rainbow DQN `_ +* :class:`~tianshou.algorithm.QRDQNPolicy` `Quantile Regression DQN `_ +* :class:`~tianshou.algorithm.IQNPolicy` `Implicit Quantile Network `_ +* :class:`~tianshou.algorithm.FQFPolicy` `Fully-parameterized Quantile Function `_ +* :class:`~tianshou.algorithm.PGPolicy` `Policy Gradient `_ +* :class:`~tianshou.algorithm.NPGPolicy` `Natural Policy Gradient `_ +* :class:`~tianshou.algorithm.A2CPolicy` `Advantage Actor-Critic `_ +* :class:`~tianshou.algorithm.TRPOPolicy` `Trust Region Policy Optimization `_ +* :class:`~tianshou.algorithm.PPOPolicy` `Proximal Policy Optimization `_ +* :class:`~tianshou.algorithm.DDPGPolicy` `Deep Deterministic Policy Gradient `_ +* :class:`~tianshou.algorithm.TD3Policy` `Twin Delayed DDPG `_ +* :class:`~tianshou.algorithm.SACPolicy` `Soft Actor-Critic `_ +* :class:`~tianshou.algorithm.REDQPolicy` `Randomized Ensembled Double Q-Learning `_ +* :class:`~tianshou.algorithm.DiscreteSACPolicy` `Discrete Soft Actor-Critic `_ +* :class:`~tianshou.algorithm.ImitationPolicy` Imitation Learning +* :class:`~tianshou.algorithm.BCQPolicy` `Batch-Constrained deep Q-Learning `_ +* :class:`~tianshou.algorithm.CQLPolicy` `Conservative Q-Learning `_ +* :class:`~tianshou.algorithm.TD3BCPolicy` `Twin Delayed DDPG with Behavior Cloning `_ +* :class:`~tianshou.algorithm.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning `_ +* :class:`~tianshou.algorithm.DiscreteCQLPolicy` `Discrete Conservative Q-Learning `_ +* :class:`~tianshou.algorithm.DiscreteCRRPolicy` `Critic Regularized Regression `_ +* :class:`~tianshou.algorithm.GAILPolicy` `Generative Adversarial Imitation Learning `_ +* :class:`~tianshou.algorithm.PSRLPolicy` `Posterior Sampling Reinforcement Learning `_ +* :class:`~tianshou.algorithm.ICMPolicy` `Intrinsic Curiosity Module `_ * :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay `_ -* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator `_ +* :meth:`~tianshou.algorithm.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator `_ * :class:`~tianshou.data.HERReplayBuffer` `Hindsight Experience Replay `_ Here is Tianshou's other features: @@ -51,7 +51,7 @@ Here is Tianshou's other features: * Support recurrent state representation in actor network and critic network (RNN-style training for POMDP): :ref:`rnn_training` * Support any type of environment state/action (e.g. a dict, a self-defined class, ...): :ref:`self_defined_env` * Support :ref:`customize_training` -* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay :class:`~tianshou.data.PrioritizedReplayBuffer` for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation +* Support n-step returns estimation :meth:`~tianshou.algorithm.BasePolicy.compute_nstep_return` and prioritized experience replay :class:`~tianshou.data.PrioritizedReplayBuffer` for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation * Support :doc:`/01_tutorials/04_tictactoe` * Support both `TensorBoard `_ and `W&B `_ log tools * Support multi-GPU training :ref:`multi_gpu` diff --git a/examples/atari/README.md b/examples/atari/README.md index 62e58487b..ae80141d8 100644 --- a/examples/atari/README.md +++ b/examples/atari/README.md @@ -24,13 +24,13 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. | task | best reward | reward curve | parameters | time cost | | --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | ------------------- | -| PongNoFrameskip-v4 | 20 | ![](results/dqn/Pong_rew.png) | `python3 atari_dqn.py --task "PongNoFrameskip-v4" --batch-size 64` | ~30 min (~15 epoch) | -| BreakoutNoFrameskip-v4 | 316 | ![](results/dqn/Breakout_rew.png) | `python3 atari_dqn.py --task "BreakoutNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) | -| EnduroNoFrameskip-v4 | 670 | ![](results/dqn/Enduro_rew.png) | `python3 atari_dqn.py --task "EnduroNoFrameskip-v4 " --test-num 100` | 3~4h (100 epoch) | -| QbertNoFrameskip-v4 | 7307 | ![](results/dqn/Qbert_rew.png) | `python3 atari_dqn.py --task "QbertNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) | -| MsPacmanNoFrameskip-v4 | 2107 | ![](results/dqn/MsPacman_rew.png) | `python3 atari_dqn.py --task "MsPacmanNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) | -| SeaquestNoFrameskip-v4 | 2088 | ![](results/dqn/Seaquest_rew.png) | `python3 atari_dqn.py --task "SeaquestNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) | -| SpaceInvadersNoFrameskip-v4 | 812.2 | ![](results/dqn/SpaceInvader_rew.png) | `python3 atari_dqn.py --task "SpaceInvadersNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) | +| PongNoFrameskip-v4 | 20 | ![](results/dqn/Pong_rew.png) | `python3 atari_dqn.py --task "PongNoFrameskip-v4" --batch_size 64` | ~30 min (~15 epoch) | +| BreakoutNoFrameskip-v4 | 316 | ![](results/dqn/Breakout_rew.png) | `python3 atari_dqn.py --task "BreakoutNoFrameskip-v4" --num_test_envs 100` | 3~4h (100 epoch) | +| EnduroNoFrameskip-v4 | 670 | ![](results/dqn/Enduro_rew.png) | `python3 atari_dqn.py --task "EnduroNoFrameskip-v4 " --num_test_envs 100` | 3~4h (100 epoch) | +| QbertNoFrameskip-v4 | 7307 | ![](results/dqn/Qbert_rew.png) | `python3 atari_dqn.py --task "QbertNoFrameskip-v4" --num_test_envs 100` | 3~4h (100 epoch) | +| MsPacmanNoFrameskip-v4 | 2107 | ![](results/dqn/MsPacman_rew.png) | `python3 atari_dqn.py --task "MsPacmanNoFrameskip-v4" --num_test_envs 100` | 3~4h (100 epoch) | +| SeaquestNoFrameskip-v4 | 2088 | ![](results/dqn/Seaquest_rew.png) | `python3 atari_dqn.py --task "SeaquestNoFrameskip-v4" --num_test_envs 100` | 3~4h (100 epoch) | +| SpaceInvadersNoFrameskip-v4 | 812.2 | ![](results/dqn/SpaceInvader_rew.png) | `python3 atari_dqn.py --task "SpaceInvadersNoFrameskip-v4" --num_test_envs 100` | 3~4h (100 epoch) | Note: The `eps_train_final` and `eps_test` in the original DQN paper is 0.1 and 0.01, but [some works](https://github.com/google/dopamine/tree/master/baselines) found that smaller eps helps improve the performance. Also, a large batchsize (say 64 instead of 32) will help faster convergence but will slow down the training speed. @@ -42,7 +42,7 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. | task | best reward | reward curve | parameters | | --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | -| PongNoFrameskip-v4 | 20 | ![](results/c51/Pong_rew.png) | `python3 atari_c51.py --task "PongNoFrameskip-v4" --batch-size 64` | +| PongNoFrameskip-v4 | 20 | ![](results/c51/Pong_rew.png) | `python3 atari_c51.py --task "PongNoFrameskip-v4" --batch_size 64` | | BreakoutNoFrameskip-v4 | 536.6 | ![](results/c51/Breakout_rew.png) | `python3 atari_c51.py --task "BreakoutNoFrameskip-v4" --n-step 1` | | EnduroNoFrameskip-v4 | 1032 | ![](results/c51/Enduro_rew.png) | `python3 atari_c51.py --task "EnduroNoFrameskip-v4 " ` | | QbertNoFrameskip-v4 | 16245 | ![](results/c51/Qbert_rew.png) | `python3 atari_c51.py --task "QbertNoFrameskip-v4"` | @@ -58,7 +58,7 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. | task | best reward | reward curve | parameters | | --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | -| PongNoFrameskip-v4 | 20 | ![](results/qrdqn/Pong_rew.png) | `python3 atari_qrdqn.py --task "PongNoFrameskip-v4" --batch-size 64` | +| PongNoFrameskip-v4 | 20 | ![](results/qrdqn/Pong_rew.png) | `python3 atari_qrdqn.py --task "PongNoFrameskip-v4" --batch_size 64` | | BreakoutNoFrameskip-v4 | 409.2 | ![](results/qrdqn/Breakout_rew.png) | `python3 atari_qrdqn.py --task "BreakoutNoFrameskip-v4" --n-step 1` | | EnduroNoFrameskip-v4 | 1055.9 | ![](results/qrdqn/Enduro_rew.png) | `python3 atari_qrdqn.py --task "EnduroNoFrameskip-v4"` | | QbertNoFrameskip-v4 | 14990 | ![](results/qrdqn/Qbert_rew.png) | `python3 atari_qrdqn.py --task "QbertNoFrameskip-v4"` | @@ -72,7 +72,7 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. | task | best reward | reward curve | parameters | | --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | -| PongNoFrameskip-v4 | 20.3 | ![](results/iqn/Pong_rew.png) | `python3 atari_iqn.py --task "PongNoFrameskip-v4" --batch-size 64` | +| PongNoFrameskip-v4 | 20.3 | ![](results/iqn/Pong_rew.png) | `python3 atari_iqn.py --task "PongNoFrameskip-v4" --batch_size 64` | | BreakoutNoFrameskip-v4 | 496.7 | ![](results/iqn/Breakout_rew.png) | `python3 atari_iqn.py --task "BreakoutNoFrameskip-v4" --n-step 1` | | EnduroNoFrameskip-v4 | 1545 | ![](results/iqn/Enduro_rew.png) | `python3 atari_iqn.py --task "EnduroNoFrameskip-v4"` | | QbertNoFrameskip-v4 | 15342.5 | ![](results/iqn/Qbert_rew.png) | `python3 atari_iqn.py --task "QbertNoFrameskip-v4"` | @@ -86,7 +86,7 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. | task | best reward | reward curve | parameters | | --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | -| PongNoFrameskip-v4 | 20.7 | ![](results/fqf/Pong_rew.png) | `python3 atari_fqf.py --task "PongNoFrameskip-v4" --batch-size 64` | +| PongNoFrameskip-v4 | 20.7 | ![](results/fqf/Pong_rew.png) | `python3 atari_fqf.py --task "PongNoFrameskip-v4" --batch_size 64` | | BreakoutNoFrameskip-v4 | 517.3 | ![](results/fqf/Breakout_rew.png) | `python3 atari_fqf.py --task "BreakoutNoFrameskip-v4" --n-step 1` | | EnduroNoFrameskip-v4 | 2240.5 | ![](results/fqf/Enduro_rew.png) | `python3 atari_fqf.py --task "EnduroNoFrameskip-v4"` | | QbertNoFrameskip-v4 | 16172.5 | ![](results/fqf/Qbert_rew.png) | `python3 atari_fqf.py --task "QbertNoFrameskip-v4"` | @@ -100,7 +100,7 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. | task | best reward | reward curve | parameters | | --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | -| PongNoFrameskip-v4 | 21 | ![](results/rainbow/Pong_rew.png) | `python3 atari_rainbow.py --task "PongNoFrameskip-v4" --batch-size 64` | +| PongNoFrameskip-v4 | 21 | ![](results/rainbow/Pong_rew.png) | `python3 atari_rainbow.py --task "PongNoFrameskip-v4" --batch_size 64` | | BreakoutNoFrameskip-v4 | 684.6 | ![](results/rainbow/Breakout_rew.png) | `python3 atari_rainbow.py --task "BreakoutNoFrameskip-v4" --n-step 1` | | EnduroNoFrameskip-v4 | 1625.9 | ![](results/rainbow/Enduro_rew.png) | `python3 atari_rainbow.py --task "EnduroNoFrameskip-v4"` | | QbertNoFrameskip-v4 | 16192.5 | ![](results/rainbow/Qbert_rew.png) | `python3 atari_rainbow.py --task "QbertNoFrameskip-v4"` | diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 3c757f9b6..481562855 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -6,39 +6,41 @@ import numpy as np import torch -from atari_network import C51 -from atari_wrapper import make_atari_env +from tianshou.algorithm import C51 +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.c51 import C51Policy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import C51Net +from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import C51Policy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--scale-obs", type=int, default=0) - parser.add_argument("--eps-test", type=float, default=0.005) - parser.add_argument("--eps-train", type=float, default=1.0) - parser.add_argument("--eps-train-final", type=float, default=0.05) - parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--scale_obs", type=int, default=0) + parser.add_argument("--eps_test", type=float, default=0.005) + parser.add_argument("--eps_train", type=float, default=1.0) + parser.add_argument("--eps_train_final", type=float, default=0.05) + parser.add_argument("--buffer_size", type=int, default=100000) parser.add_argument("--lr", type=float, default=0.0001) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--num-atoms", type=int, default=51) - parser.add_argument("--v-min", type=float, default=-10.0) - parser.add_argument("--v-max", type=float, default=10.0) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--num_atoms", type=int, default=51) + parser.add_argument("--v_min", type=float, default=-10.0) + parser.add_argument("--v_max", type=float, default=10.0) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=500) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=100000) - parser.add_argument("--step-per-collect", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=32) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=100000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) + parser.add_argument("--update_per_step", type=float, default=0.1) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -46,62 +48,72 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--frames-stack", type=int, default=4) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--frames_stack", type=int, default=4) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="atari.benchmark") + parser.add_argument("--wandb_project", type=str, default="atari.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) - parser.add_argument("--save-buffer-name", type=str, default=None) + parser.add_argument("--save_buffer_name", type=str, default=None) return parser.parse_args() -def test_c51(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_atari_env( args.task, args.seed, - args.training_num, - args.test_num, + args.num_train_envs, + args.num_test_envs, scale=args.scale_obs, frame_stack=args.frames_stack, ) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + args.state_shape = env.observation_space.shape or env.observation_space.n # type: ignore + args.action_shape = env.action_space.shape or env.action_space.n # type: ignore # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) + # define model - net = C51(*args.state_shape, args.action_shape, args.num_atoms, args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) - # define policy - policy: C51Policy = C51Policy( + c, h, w = args.state_shape + net = C51Net(c=c, h=h, w=w, action_shape=args.action_shape, num_atoms=args.num_atoms) + + # define policy and algorithm + optim = AdamOptimizerFactory(lr=args.lr) + policy = C51Policy( model=net, - optim=optim, - discount_factor=args.gamma, action_space=env.action_space, num_atoms=args.num_atoms, v_min=args.v_min, v_max=args.v_max, - estimation_step=args.n_step, + eps_training=args.eps_train, + eps_inference=args.eps_test, + ) + algorithm: C51 = C51( + policy=policy, + optim=optim, + gamma=args.gamma, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) - # load a previous policy + + # load a previous model if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) + # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( @@ -111,9 +123,10 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - # collector - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + + # collectors + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -136,12 +149,12 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: - if env.spec.reward_threshold: - return mean_rewards >= env.spec.reward_threshold + if env.spec.reward_threshold: # type: ignore + return mean_rewards >= env.spec.reward_threshold # type: ignore if "Pong" in args.task: return mean_rewards >= 20 return False @@ -152,17 +165,12 @@ def train_fn(epoch: int, env_step: int) -> None: eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final) else: eps = args.eps_train_final - policy.set_eps(eps) + policy.set_eps_training(eps) if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/eps": eps}) - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) - - # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") @@ -173,7 +181,9 @@ def watch() -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) + collector = Collector[CollectStats]( + algorithm, test_envs, buffer, exploration_noise=True + ) result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -181,7 +191,7 @@ def watch() -> None: else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: @@ -190,29 +200,29 @@ def watch() -> None: # test train_collector and start filling replay buffer train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - test_in_train=False, - ).run() + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + train_fn=train_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_step_num_gradient_steps_per_sample=args.update_per_step, + test_in_train=False, + ) + ) pprint.pprint(result) watch() if __name__ == "__main__": - test_c51(get_args()) + main(get_args()) diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 023a961f5..c516520ea 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -6,15 +6,17 @@ import numpy as np import torch -from atari_network import DQN -from atari_wrapper import make_atari_env +from tianshou.algorithm import DQN +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelbased.icm import ICMOffPolicyWrapper +from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import DQNet +from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import DQNPolicy -from tianshou.policy.base import BasePolicy -from tianshou.policy.modelbased.icm import ICMPolicy -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.discrete import IntrinsicCuriosityModule @@ -22,22 +24,22 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--scale-obs", type=int, default=0) - parser.add_argument("--eps-test", type=float, default=0.005) - parser.add_argument("--eps-train", type=float, default=1.0) - parser.add_argument("--eps-train-final", type=float, default=0.05) - parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--scale_obs", type=int, default=0) + parser.add_argument("--eps_test", type=float, default=0.005) + parser.add_argument("--eps_train", type=float, default=1.0) + parser.add_argument("--eps_train_final", type=float, default=0.05) + parser.add_argument("--buffer_size", type=int, default=100000) parser.add_argument("--lr", type=float, default=0.0001) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=500) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=100000) - parser.add_argument("--step-per-collect", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=32) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=100000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) + parser.add_argument("--update_per_step", type=float, default=0.1) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -45,37 +47,37 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--frames-stack", type=int, default=4) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--frames_stack", type=int, default=4) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="atari.benchmark") + parser.add_argument("--wandb_project", type=str, default="atari.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) - parser.add_argument("--save-buffer-name", type=str, default=None) + parser.add_argument("--save_buffer_name", type=str, default=None) parser.add_argument( - "--icm-lr-scale", + "--icm_lr_scale", type=float, default=0.0, help="use intrinsic curiosity module with this lr scale", ) parser.add_argument( - "--icm-reward-scale", + "--icm_reward_scale", type=float, default=0.01, help="scaling factor for intrinsic curiosity reward", ) parser.add_argument( - "--icm-forward-loss-weight", + "--icm_forward_loss_weight", type=float, default=0.2, help="weight for the forward model loss in ICM", @@ -87,57 +89,67 @@ def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_atari_env( args.task, args.seed, - args.training_num, - args.test_num, + args.num_train_envs, + args.num_test_envs, scale=args.scale_obs, frame_stack=args.frames_stack, ) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + args.state_shape = env.observation_space.shape or env.observation_space.n # type: ignore + args.action_shape = env.action_space.shape or env.action_space.n # type: ignore # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) + # define model - net = DQN(*args.state_shape, args.action_shape, args.device).to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) - # define policy - policy: DQNPolicy | ICMPolicy - policy = DQNPolicy( + c, h, w = args.state_shape + net = DQNet(c=c, h=h, w=w, action_shape=args.action_shape).to(args.device) + optim = AdamOptimizerFactory(lr=args.lr) + + # define policy and algorithm + policy = DiscreteQLearningPolicy( model=net, - optim=optim, action_space=env.action_space, - discount_factor=args.gamma, - estimation_step=args.n_step, + eps_training=args.eps_train, + eps_inference=args.eps_test, + ) + algorithm: DQN | ICMOffPolicyWrapper + algorithm = DQN( + policy=policy, + optim=optim, + gamma=args.gamma, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ) if args.icm_lr_scale > 0: - feature_net = DQN(*args.state_shape, args.action_shape, args.device, features_only=True) - action_dim = np.prod(args.action_shape) + c, h, w = args.state_shape + feature_net = DQNet(c=c, h=h, w=w, action_shape=args.action_shape, features_only=True) + action_dim = int(np.prod(args.action_shape)) feature_dim = feature_net.output_dim icm_net = IntrinsicCuriosityModule( - feature_net.net, - feature_dim, - action_dim, + feature_net=feature_net.net, + feature_dim=feature_dim, + action_dim=action_dim, hidden_sizes=[512], - device=args.device, ) - icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) - policy = ICMPolicy( - policy=policy, + icm_optim = AdamOptimizerFactory(lr=args.lr) + algorithm = ICMOffPolicyWrapper( + wrapped_algorithm=algorithm, model=icm_net, optim=icm_optim, - action_space=env.action_space, lr_scale=args.icm_lr_scale, reward_scale=args.icm_reward_scale, forward_loss_weight=args.icm_forward_loss_weight, ).to(args.device) - # load a previous policy + + # load a previous model if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) + # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( @@ -147,9 +159,10 @@ def main(args: argparse.Namespace = get_args()) -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - # collector - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + + # collectors + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -172,12 +185,12 @@ def main(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: - if env.spec.reward_threshold: - return mean_rewards >= env.spec.reward_threshold + if env.spec.reward_threshold: # type: ignore + return mean_rewards >= env.spec.reward_threshold # type: ignore if "Pong" in args.task: return mean_rewards >= 20 return False @@ -188,23 +201,18 @@ def train_fn(epoch: int, env_step: int) -> None: eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final) else: eps = args.eps_train_final - policy.set_eps(eps) + policy.set_eps_training(eps) if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/eps": eps}) - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) - def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") - torch.save({"model": policy.state_dict()}, ckpt_path) + torch.save({"model": algorithm.state_dict()}, ckpt_path) return ckpt_path - # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") @@ -215,7 +223,9 @@ def watch() -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) + collector = Collector[CollectStats]( + algorithm, test_envs, buffer, exploration_noise=True + ) result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -223,7 +233,7 @@ def watch() -> None: else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: @@ -232,27 +242,28 @@ def watch() -> None: # test train_collector and start filling replay buffer train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - test_in_train=False, - resume_from_log=args.resume_id is not None, - save_checkpoint_fn=save_checkpoint_fn, - ).run() + train_collector.collect(n_step=args.batch_size * args.num_train_envs) + + # train + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + train_fn=train_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_step_num_gradient_steps_per_sample=args.update_per_step, + test_in_train=False, + resume_from_log=args.resume_id is not None, + save_checkpoint_fn=save_checkpoint_fn, + ) + ) pprint.pprint(result) watch() diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index 3bcb0f6c3..4310deff7 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -5,19 +5,19 @@ from sensai.util import logging from sensai.util.logging import datetime_tag -from examples.atari.atari_network import ( +from tianshou.env.atari.atari_network import ( IntermediateModuleFactoryAtariDQN, IntermediateModuleFactoryAtariDQNFeatures, ) -from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback -from tianshou.highlevel.config import SamplingConfig +from tianshou.env.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback +from tianshou.highlevel.config import OffPolicyTrainingConfig from tianshou.highlevel.experiment import ( DQNExperimentBuilder, ExperimentConfig, ) -from tianshou.highlevel.params.policy_params import DQNParams -from tianshou.highlevel.params.policy_wrapper import ( - PolicyWrapperFactoryIntrinsicCuriosity, +from tianshou.highlevel.params.algorithm_params import DQNParams +from tianshou.highlevel.params.algorithm_wrapper import ( + AlgorithmWrapperFactoryIntrinsicCuriosity, ) from tianshou.highlevel.trainer import ( EpochTestCallbackDQNSetEps, @@ -38,30 +38,28 @@ def main( n_step: int = 3, target_update_freq: int = 500, epoch: int = 100, - step_per_epoch: int = 100000, - step_per_collect: int = 10, + epoch_num_steps: int = 100000, + collection_step_num_env_steps: int = 10, update_per_step: float = 0.1, batch_size: int = 32, - training_num: int = 10, - test_num: int = 10, + num_train_envs: int = 10, + num_test_envs: int = 10, frames_stack: int = 4, - save_buffer_name: str | None = None, # TODO support? icm_lr_scale: float = 0.0, icm_reward_scale: float = 0.01, icm_forward_loss_weight: float = 0.2, ) -> None: log_name = os.path.join(task, "dqn", str(experiment_config.seed), datetime_tag()) - sampling_config = SamplingConfig( - num_epochs=epoch, - step_per_epoch=step_per_epoch, + training_config = OffPolicyTrainingConfig( + max_epochs=epoch, + epoch_num_steps=epoch_num_steps, batch_size=batch_size, - num_train_envs=training_num, - num_test_envs=test_num, + num_train_envs=num_train_envs, + num_test_envs=num_test_envs, buffer_size=buffer_size, - step_per_collect=step_per_collect, - update_per_step=update_per_step, - repeat_per_collect=None, + collection_step_num_env_steps=collection_step_num_env_steps, + update_step_num_gradient_steps_per_sample=update_per_step, replay_buffer_stack_num=frames_stack, replay_buffer_ignore_obs_next=True, replay_buffer_save_only_last_obs=True, @@ -74,11 +72,11 @@ def main( ) builder = ( - DQNExperimentBuilder(env_factory, experiment_config, sampling_config) + DQNExperimentBuilder(env_factory, experiment_config, training_config) .with_dqn_params( DQNParams( - discount_factor=gamma, - estimation_step=n_step, + gamma=gamma, + n_step_return_horizon=n_step, lr=lr, target_update_freq=target_update_freq, ), @@ -91,8 +89,8 @@ def main( .with_epoch_stop_callback(AtariEpochStopCallback(task)) ) if icm_lr_scale > 0: - builder.with_policy_wrapper_factory( - PolicyWrapperFactoryIntrinsicCuriosity( + builder.with_algorithm_wrapper_factory( + AlgorithmWrapperFactoryIntrinsicCuriosity( feature_net_factory=IntermediateModuleFactoryAtariDQNFeatures(), hidden_sizes=[512], lr=lr, diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index c25002613..f6b1c1deb 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -6,14 +6,16 @@ import numpy as np import torch -from atari_network import DQN -from atari_wrapper import make_atari_env +from tianshou.algorithm import FQF +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.fqf import FQFPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory, RMSpropOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import DQNet +from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import FQFPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction @@ -21,27 +23,27 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=3128) - parser.add_argument("--scale-obs", type=int, default=0) - parser.add_argument("--eps-test", type=float, default=0.005) - parser.add_argument("--eps-train", type=float, default=1.0) - parser.add_argument("--eps-train-final", type=float, default=0.05) - parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--scale_obs", type=int, default=0) + parser.add_argument("--eps_test", type=float, default=0.005) + parser.add_argument("--eps_train", type=float, default=1.0) + parser.add_argument("--eps_train_final", type=float, default=0.05) + parser.add_argument("--buffer_size", type=int, default=100000) parser.add_argument("--lr", type=float, default=5e-5) - parser.add_argument("--fraction-lr", type=float, default=2.5e-9) + parser.add_argument("--fraction_lr", type=float, default=2.5e-9) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--num-fractions", type=int, default=32) - parser.add_argument("--num-cosines", type=int, default=64) - parser.add_argument("--ent-coef", type=float, default=10.0) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[512]) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--num_fractions", type=int, default=32) + parser.add_argument("--num_cosines", type=int, default=64) + parser.add_argument("--ent_coef", type=float, default=10.0) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[512]) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=500) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=100000) - parser.add_argument("--step-per-collect", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=32) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=100000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) + parser.add_argument("--update_per_step", type=float, default=0.1) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -49,72 +51,81 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--frames-stack", type=int, default=4) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--frames_stack", type=int, default=4) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="atari.benchmark") + parser.add_argument("--wandb_project", type=str, default="atari.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) - parser.add_argument("--save-buffer-name", type=str, default=None) + parser.add_argument("--save_buffer_name", type=str, default=None) return parser.parse_args() -def test_fqf(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_atari_env( args.task, args.seed, - args.training_num, - args.test_num, + args.num_train_envs, + args.num_test_envs, scale=args.scale_obs, frame_stack=args.frames_stack, ) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + args.state_shape = env.observation_space.shape or env.observation_space.n # type: ignore + args.action_shape = env.action_space.shape or env.action_space.n # type: ignore + # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) + # define model - feature_net = DQN(*args.state_shape, args.action_shape, args.device, features_only=True) + c, h, w = args.state_shape + feature_net = DQNet(c=c, h=h, w=w, action_shape=args.action_shape, features_only=True) net = FullQuantileFunction( - feature_net, - args.action_shape, - args.hidden_sizes, - args.num_cosines, - device=args.device, + preprocess_net=feature_net, + action_shape=args.action_shape, + hidden_sizes=args.hidden_sizes, + num_cosines=args.num_cosines, ).to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim) - fraction_optim = torch.optim.RMSprop(fraction_net.parameters(), lr=args.fraction_lr) - # define policy - policy: FQFPolicy = FQFPolicy( + fraction_optim = RMSpropOptimizerFactory(lr=args.fraction_lr) + + # define policy and algorithm + policy = FQFPolicy( model=net, - optim=optim, fraction_model=fraction_net, - fraction_optim=fraction_optim, action_space=env.action_space, - discount_factor=args.gamma, + ) + algorithm: FQF = FQF( + policy=policy, + optim=optim, + fraction_optim=fraction_optim, + gamma=args.gamma, num_fractions=args.num_fractions, ent_coef=args.ent_coef, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) + # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) + # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( @@ -124,9 +135,10 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - # collector - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + + # collectors + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -149,12 +161,12 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: - if env.spec.reward_threshold: - return mean_rewards >= env.spec.reward_threshold + if env.spec.reward_threshold: # type: ignore + return mean_rewards >= env.spec.reward_threshold # type: ignore if "Pong" in args.task: return mean_rewards >= 20 return False @@ -165,17 +177,16 @@ def train_fn(epoch: int, env_step: int) -> None: eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final) else: eps = args.eps_train_final - policy.set_eps(eps) + policy.set_eps_training(eps) if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/eps": eps}) def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) + policy.set_eps_training(args.eps_test) - # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.set_eps(args.eps_test) + policy.set_eps_training(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") @@ -186,7 +197,9 @@ def watch() -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) + collector = Collector[CollectStats]( + algorithm, test_envs, buffer, exploration_noise=True + ) result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -194,7 +207,7 @@ def watch() -> None: else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: @@ -203,29 +216,31 @@ def watch() -> None: # test train_collector and start filling replay buffer train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - test_in_train=False, - ).run() + train_collector.collect(n_step=args.batch_size * args.num_train_envs) + + # train + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_step_num_gradient_steps_per_sample=args.update_per_step, + test_in_train=False, + ) + ) pprint.pprint(result) watch() if __name__ == "__main__": - test_fqf(get_args()) + main(get_args()) diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index 8a30ca75d..b89738e53 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -6,14 +6,16 @@ import numpy as np import torch -from atari_network import DQN -from atari_wrapper import make_atari_env +from tianshou.algorithm import IQN +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.iqn import IQNPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import DQNet +from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import IQNPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.discrete import ImplicitQuantileNetwork @@ -21,27 +23,27 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=1234) - parser.add_argument("--scale-obs", type=int, default=0) - parser.add_argument("--eps-test", type=float, default=0.005) - parser.add_argument("--eps-train", type=float, default=1.0) - parser.add_argument("--eps-train-final", type=float, default=0.05) - parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--scale_obs", type=int, default=0) + parser.add_argument("--eps_test", type=float, default=0.005) + parser.add_argument("--eps_train", type=float, default=1.0) + parser.add_argument("--eps_train_final", type=float, default=0.05) + parser.add_argument("--buffer_size", type=int, default=100000) parser.add_argument("--lr", type=float, default=0.0001) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--sample-size", type=int, default=32) - parser.add_argument("--online-sample-size", type=int, default=8) - parser.add_argument("--target-sample-size", type=int, default=8) - parser.add_argument("--num-cosines", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[512]) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--sample_size", type=int, default=32) + parser.add_argument("--online_sample_size", type=int, default=8) + parser.add_argument("--target_sample_size", type=int, default=8) + parser.add_argument("--num_cosines", type=int, default=64) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[512]) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=500) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=100000) - parser.add_argument("--step-per-collect", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=32) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=100000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) + parser.add_argument("--update_per_step", type=float, default=0.1) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -49,69 +51,79 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--frames-stack", type=int, default=4) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--frames_stack", type=int, default=4) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="atari.benchmark") + parser.add_argument("--wandb_project", type=str, default="atari.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) - parser.add_argument("--save-buffer-name", type=str, default=None) + parser.add_argument("--save_buffer_name", type=str, default=None) return parser.parse_args() -def test_iqn(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_atari_env( args.task, args.seed, - args.training_num, - args.test_num, + args.num_train_envs, + args.num_test_envs, scale=args.scale_obs, frame_stack=args.frames_stack, ) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + args.state_shape = env.observation_space.shape or env.observation_space.n # type: ignore + args.action_shape = env.action_space.shape or env.action_space.n # type: ignore # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) + # define model - feature_net = DQN(*args.state_shape, args.action_shape, args.device, features_only=True) + c, h, w = args.state_shape + feature_net = DQNet(c=c, h=h, w=w, action_shape=args.action_shape, features_only=True) net = ImplicitQuantileNetwork( - feature_net, - args.action_shape, - args.hidden_sizes, + preprocess_net=feature_net, + action_shape=args.action_shape, + hidden_sizes=args.hidden_sizes, num_cosines=args.num_cosines, - device=args.device, ).to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) - # define policy - policy: IQNPolicy = IQNPolicy( + optim = AdamOptimizerFactory(lr=args.lr) + + # define policy and algorithm + policy = IQNPolicy( model=net, - optim=optim, action_space=env.action_space, - discount_factor=args.gamma, sample_size=args.sample_size, online_sample_size=args.online_sample_size, target_sample_size=args.target_sample_size, - estimation_step=args.n_step, + eps_training=args.eps_train, + eps_inference=args.eps_test, + ) + algorithm: IQN = IQN( + policy=policy, + optim=optim, + gamma=args.gamma, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) - # load a previous policy + + # load previous model if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) + # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( @@ -122,8 +134,8 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) # collector - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -146,12 +158,12 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: - if env.spec.reward_threshold: - return mean_rewards >= env.spec.reward_threshold + if env.spec.reward_threshold: # type: ignore + return mean_rewards >= env.spec.reward_threshold # type: ignore if "Pong" in args.task: return mean_rewards >= 20 return False @@ -162,17 +174,13 @@ def train_fn(epoch: int, env_step: int) -> None: eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final) else: eps = args.eps_train_final - policy.set_eps(eps) + policy.set_eps_training(eps) if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/eps": eps}) - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) - # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") @@ -183,7 +191,9 @@ def watch() -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) + collector = Collector[CollectStats]( + algorithm, test_envs, buffer, exploration_noise=True + ) result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -191,7 +201,7 @@ def watch() -> None: else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: @@ -200,30 +210,30 @@ def watch() -> None: # test train_collector and start filling replay buffer train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) - # trainer - - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - test_in_train=False, - ).run() + train_collector.collect(n_step=args.batch_size * args.num_train_envs) + + # train + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + train_fn=train_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_step_num_gradient_steps_per_sample=args.update_per_step, + test_in_train=False, + ) + ) pprint.pprint(result) watch() if __name__ == "__main__": - test_iqn(get_args()) + main(get_args()) diff --git a/examples/atari/atari_iqn_hl.py b/examples/atari/atari_iqn_hl.py index c644b2469..ec884ba9c 100644 --- a/examples/atari/atari_iqn_hl.py +++ b/examples/atari/atari_iqn_hl.py @@ -6,18 +6,17 @@ from sensai.util import logging from sensai.util.logging import datetime_tag -from examples.atari.atari_network import ( +from tianshou.env.atari.atari_network import ( IntermediateModuleFactoryAtariDQN, ) -from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback -from tianshou.highlevel.config import SamplingConfig +from tianshou.env.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback +from tianshou.highlevel.config import OffPolicyTrainingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, IQNExperimentBuilder, ) -from tianshou.highlevel.params.policy_params import IQNParams +from tianshou.highlevel.params.algorithm_params import IQNParams from tianshou.highlevel.trainer import ( - EpochTestCallbackDQNSetEps, EpochTrainCallbackDQNEpsLinearDecay, ) @@ -40,27 +39,25 @@ def main( n_step: int = 3, target_update_freq: int = 500, epoch: int = 100, - step_per_epoch: int = 100000, - step_per_collect: int = 10, + epoch_num_steps: int = 100000, + collection_step_num_env_steps: int = 10, update_per_step: float = 0.1, batch_size: int = 32, - training_num: int = 10, - test_num: int = 10, + num_train_envs: int = 10, + num_test_envs: int = 10, frames_stack: int = 4, - save_buffer_name: str | None = None, # TODO support? ) -> None: log_name = os.path.join(task, "iqn", str(experiment_config.seed), datetime_tag()) - sampling_config = SamplingConfig( - num_epochs=epoch, - step_per_epoch=step_per_epoch, + training_config = OffPolicyTrainingConfig( + max_epochs=epoch, + epoch_num_steps=epoch_num_steps, batch_size=batch_size, - num_train_envs=training_num, - num_test_envs=test_num, + num_train_envs=num_train_envs, + num_test_envs=num_test_envs, buffer_size=buffer_size, - step_per_collect=step_per_collect, - update_per_step=update_per_step, - repeat_per_collect=None, + collection_step_num_env_steps=collection_step_num_env_steps, + update_step_num_gradient_steps_per_sample=update_per_step, replay_buffer_stack_num=frames_stack, replay_buffer_ignore_obs_next=True, replay_buffer_save_only_last_obs=True, @@ -73,11 +70,11 @@ def main( ) experiment = ( - IQNExperimentBuilder(env_factory, experiment_config, sampling_config) + IQNExperimentBuilder(env_factory, experiment_config, training_config) .with_iqn_params( IQNParams( - discount_factor=gamma, - estimation_step=n_step, + gamma=gamma, + n_step_return_horizon=n_step, lr=lr, sample_size=sample_size, online_sample_size=online_sample_size, @@ -85,13 +82,14 @@ def main( target_sample_size=target_sample_size, hidden_sizes=hidden_sizes, num_cosines=num_cosines, + eps_training=eps_train, + eps_inference=eps_test, ), ) .with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(features_only=True)) .with_epoch_train_callback( EpochTrainCallbackDQNEpsLinearDecay(eps_train, eps_train_final), ) - .with_epoch_test_callback(EpochTestCallbackDQNSetEps(eps_test)) .with_epoch_stop_callback(AtariEpochStopCallback(task)) .build() ) diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 2f3832d23..16d72fb15 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -3,50 +3,60 @@ import os import pprint import sys +from collections.abc import Sequence +from typing import cast import numpy as np import torch -from atari_network import DQN, layer_init, scale_obs -from atari_wrapper import make_atari_env -from torch.distributions import Categorical -from torch.optim.lr_scheduler import LambdaLR +from tianshou.algorithm import PPO +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelbased.icm import ICMOnPolicyWrapper +from tianshou.algorithm.modelfree.reinforce import DiscreteActorPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import ( + DQNet, + ScaledObsInputActionReprNet, + layer_init, +) +from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import ICMPolicy, PPOPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OnpolicyTrainer -from tianshou.utils.net.common import ActorCritic -from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule +from tianshou.trainer import OnPolicyTrainerParams +from tianshou.utils.net.discrete import ( + DiscreteActor, + DiscreteCritic, + IntrinsicCuriosityModule, +) def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=4213) - parser.add_argument("--scale-obs", type=int, default=1) - parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--scale_obs", type=int, default=1) + parser.add_argument("--buffer_size", type=int, default=100000) parser.add_argument("--lr", type=float, default=2.5e-4) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=100000) - parser.add_argument("--step-per-collect", type=int, default=1000) - parser.add_argument("--repeat-per-collect", type=int, default=4) - parser.add_argument("--batch-size", type=int, default=256) - parser.add_argument("--hidden-size", type=int, default=512) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) - parser.add_argument("--rew-norm", type=int, default=False) - parser.add_argument("--vf-coef", type=float, default=0.25) - parser.add_argument("--ent-coef", type=float, default=0.01) - parser.add_argument("--gae-lambda", type=float, default=0.95) - parser.add_argument("--lr-decay", type=int, default=True) - parser.add_argument("--max-grad-norm", type=float, default=0.5) - parser.add_argument("--eps-clip", type=float, default=0.1) - parser.add_argument("--dual-clip", type=float, default=None) - parser.add_argument("--value-clip", type=int, default=1) - parser.add_argument("--norm-adv", type=int, default=1) - parser.add_argument("--recompute-adv", type=int, default=0) + parser.add_argument("--epoch_num_steps", type=int, default=100000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=1000) + parser.add_argument("--update_step_num_repetitions", type=int, default=4) + parser.add_argument("--batch_size", type=int, default=256) + parser.add_argument("--hidden_size", type=int, default=512) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) + parser.add_argument("--return_scaling", type=int, default=False) + parser.add_argument("--vf_coef", type=float, default=0.25) + parser.add_argument("--ent_coef", type=float, default=0.01) + parser.add_argument("--gae_lambda", type=float, default=0.95) + parser.add_argument("--lr_decay", type=int, default=True) + parser.add_argument("--max_grad_norm", type=float, default=0.5) + parser.add_argument("--eps_clip", type=float, default=0.1) + parser.add_argument("--dual_clip", type=float, default=None) + parser.add_argument("--value_clip", type=int, default=1) + parser.add_argument("--advantage_normalization", type=int, default=1) + parser.add_argument("--recompute_adv", type=int, default=0) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -54,37 +64,37 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--frames-stack", type=int, default=4) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--frames_stack", type=int, default=4) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="atari.benchmark") + parser.add_argument("--wandb_project", type=str, default="atari.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) - parser.add_argument("--save-buffer-name", type=str, default=None) + parser.add_argument("--save_buffer_name", type=str, default=None) parser.add_argument( - "--icm-lr-scale", + "--icm_lr_scale", type=float, default=0.0, help="use intrinsic curiosity module with this lr scale", ) parser.add_argument( - "--icm-reward-scale", + "--icm_reward_scale", type=float, default=0.01, help="scaling factor for intrinsic curiosity reward", ) parser.add_argument( - "--icm-forward-loss-weight", + "--icm_forward_loss_weight", type=float, default=0.2, help="weight for the forward model loss in ICM", @@ -92,17 +102,17 @@ def get_args() -> argparse.Namespace: return parser.parse_args() -def test_ppo(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_atari_env( args.task, args.seed, - args.training_num, - args.test_num, + args.num_train_envs, + args.num_test_envs, scale=0, frame_stack=args.frames_stack, ) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + args.state_shape = cast(tuple[int, ...], env.observation_space.shape) + args.action_shape = cast(Sequence[int] | int, env.action_space.shape or env.action_space.n) # type: ignore # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) @@ -110,75 +120,76 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # define model - net = DQN( - *args.state_shape, - args.action_shape, - device=args.device, + c, h, w = args.state_shape + net: ScaledObsInputActionReprNet | DQNet + net = DQNet( + c=c, + h=h, + w=w, + action_shape=args.action_shape, features_only=True, output_dim_added_layer=args.hidden_size, layer_init=layer_init, ) if args.scale_obs: - net = scale_obs(net) - actor = Actor(net, args.action_shape, device=args.device, softmax_output=False) - critic = Critic(net, device=args.device) - optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr, eps=1e-5) + net = ScaledObsInputActionReprNet(net) + actor = DiscreteActor(preprocess_net=net, action_shape=args.action_shape, softmax_output=False) + critic = DiscreteCritic(preprocess_net=net) + optim = AdamOptimizerFactory(lr=args.lr, eps=1e-5) - lr_scheduler = None if args.lr_decay: - # decay learning rate to 0 linearly - max_update_num = np.ceil(args.step_per_epoch / args.step_per_collect) * args.epoch - - lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) - - # define policy - def dist(logits: torch.Tensor) -> Categorical: - return Categorical(logits=logits) + optim.with_lr_scheduler_factory( + LRSchedulerFactoryLinear( + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + ) + ) - policy: PPOPolicy = PPOPolicy( + # define algorithm + policy = DiscreteActorPolicy( actor=actor, + action_space=env.action_space, + ) + algorithm: PPO = PPO( + policy=policy, critic=critic, optim=optim, - dist_fn=dist, - discount_factor=args.gamma, + gamma=args.gamma, gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm, vf_coef=args.vf_coef, ent_coef=args.ent_coef, - reward_normalization=args.rew_norm, - action_scaling=False, - lr_scheduler=lr_scheduler, - action_space=env.action_space, + return_scaling=args.return_scaling, eps_clip=args.eps_clip, value_clip=args.value_clip, dual_clip=args.dual_clip, - advantage_normalization=args.norm_adv, + advantage_normalization=args.advantage_normalization, recompute_advantage=args.recompute_adv, ).to(args.device) if args.icm_lr_scale > 0: - feature_net = DQN(*args.state_shape, args.action_shape, args.device, features_only=True) - action_dim = np.prod(args.action_shape) + c, h, w = args.state_shape + feature_net = DQNet(c=c, h=h, w=w, action_shape=args.action_shape, features_only=True) + action_dim = int(np.prod(args.action_shape)) feature_dim = feature_net.output_dim icm_net = IntrinsicCuriosityModule( - feature_net.net, - feature_dim, - action_dim, + feature_net=feature_net.net, + feature_dim=feature_dim, + action_dim=action_dim, hidden_sizes=[args.hidden_size], - device=args.device, ) - icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) - policy: ICMPolicy = ICMPolicy( # type: ignore[no-redef] - policy=policy, + icm_optim = AdamOptimizerFactory(lr=args.lr) + algorithm = ICMOnPolicyWrapper( # type: ignore[assignment] + wrapped_algorithm=algorithm, model=icm_net, optim=icm_optim, - action_space=env.action_space, lr_scale=args.icm_lr_scale, reward_scale=args.icm_reward_scale, forward_loss_weight=args.icm_forward_loss_weight, ).to(args.device) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM @@ -190,8 +201,8 @@ def dist(logits: torch.Tensor) -> Categorical: stack_num=args.frames_stack, ) # collector - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -214,12 +225,12 @@ def dist(logits: torch.Tensor) -> Categorical: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: - if env.spec.reward_threshold: - return mean_rewards >= env.spec.reward_threshold + if env.spec.reward_threshold: # type: ignore + return mean_rewards >= env.spec.reward_threshold # type: ignore if "Pong" in args.task: return mean_rewards >= 20 return False @@ -227,10 +238,9 @@ def stop_fn(mean_rewards: float) -> bool: def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") - torch.save({"model": policy.state_dict()}, ckpt_path) + torch.save({"model": algorithm.state_dict()}, ckpt_path) return ckpt_path - # watch agent's performance def watch() -> None: print("Setup test envs ...") test_envs.seed(args.seed) @@ -243,7 +253,9 @@ def watch() -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) + collector = Collector[CollectStats]( + algorithm, test_envs, buffer, exploration_noise=True + ) result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -251,7 +263,7 @@ def watch() -> None: else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: @@ -260,29 +272,31 @@ def watch() -> None: # test train_collector and start filling replay buffer train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) - # trainer - result = OnpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - step_per_collect=args.step_per_collect, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - test_in_train=False, - resume_from_log=args.resume_id is not None, - save_checkpoint_fn=save_checkpoint_fn, - ).run() + train_collector.collect(n_step=args.batch_size * args.num_train_envs) + + # train + result = algorithm.run_training( + OnPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + collection_step_num_env_steps=args.collection_step_num_env_steps, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=False, + resume_from_log=args.resume_id is not None, + save_checkpoint_fn=save_checkpoint_fn, + ) + ) pprint.pprint(result) watch() if __name__ == "__main__": - test_ppo(get_args()) + main(get_args()) diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index 983608293..393040e54 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -6,21 +6,21 @@ from sensai.util import logging from sensai.util.logging import datetime_tag -from examples.atari.atari_network import ( +from tianshou.env.atari.atari_network import ( ActorFactoryAtariDQN, IntermediateModuleFactoryAtariDQNFeatures, ) -from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback -from tianshou.highlevel.config import SamplingConfig +from tianshou.env.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback +from tianshou.highlevel.config import OnPolicyTrainingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, PPOExperimentBuilder, ) -from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear -from tianshou.highlevel.params.policy_params import PPOParams -from tianshou.highlevel.params.policy_wrapper import ( - PolicyWrapperFactoryIntrinsicCuriosity, +from tianshou.highlevel.params.algorithm_params import PPOParams +from tianshou.highlevel.params.algorithm_wrapper import ( + AlgorithmWrapperFactoryIntrinsicCuriosity, ) +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear def main( @@ -31,14 +31,14 @@ def main( lr: float = 2.5e-4, gamma: float = 0.99, epoch: int = 100, - step_per_epoch: int = 100000, - step_per_collect: int = 1000, - repeat_per_collect: int = 4, + epoch_num_steps: int = 100000, + collection_step_num_env_steps: int = 1000, + update_step_num_repetitions: int = 4, batch_size: int = 256, hidden_sizes: Sequence[int] = (512,), - training_num: int = 10, - test_num: int = 10, - rew_norm: bool = False, + num_train_envs: int = 10, + num_test_envs: int = 10, + return_scaling: bool = False, vf_coef: float = 0.25, ent_coef: float = 0.01, gae_lambda: float = 0.95, @@ -47,7 +47,7 @@ def main( eps_clip: float = 0.1, dual_clip: float | None = None, value_clip: bool = True, - norm_adv: bool = True, + advantage_normalization: bool = True, recompute_adv: bool = False, frames_stack: int = 4, save_buffer_name: str | None = None, # TODO add support in high-level API? @@ -57,15 +57,15 @@ def main( ) -> None: log_name = os.path.join(task, "ppo", str(experiment_config.seed), datetime_tag()) - sampling_config = SamplingConfig( - num_epochs=epoch, - step_per_epoch=step_per_epoch, + training_config = OnPolicyTrainingConfig( + max_epochs=epoch, + epoch_num_steps=epoch_num_steps, batch_size=batch_size, - num_train_envs=training_num, - num_test_envs=test_num, + num_train_envs=num_train_envs, + num_test_envs=num_test_envs, buffer_size=buffer_size, - step_per_collect=step_per_collect, - repeat_per_collect=repeat_per_collect, + collection_step_num_env_steps=collection_step_num_env_steps, + update_step_num_repetitions=update_step_num_repetitions, replay_buffer_stack_num=frames_stack, replay_buffer_ignore_obs_next=True, replay_buffer_save_only_last_obs=True, @@ -78,24 +78,22 @@ def main( ) builder = ( - PPOExperimentBuilder(env_factory, experiment_config, sampling_config) + PPOExperimentBuilder(env_factory, experiment_config, training_config) .with_ppo_params( PPOParams( - discount_factor=gamma, + gamma=gamma, gae_lambda=gae_lambda, - reward_normalization=rew_norm, + return_scaling=return_scaling, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, value_clip=value_clip, - advantage_normalization=norm_adv, + advantage_normalization=advantage_normalization, eps_clip=eps_clip, dual_clip=dual_clip, recompute_advantage=recompute_adv, lr=lr, - lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config) - if lr_decay - else None, + lr_scheduler=LRSchedulerFactoryFactoryLinear(training_config) if lr_decay else None, ), ) .with_actor_factory(ActorFactoryAtariDQN(scale_obs=scale_obs, features_only=True)) @@ -103,8 +101,8 @@ def main( .with_epoch_stop_callback(AtariEpochStopCallback(task)) ) if icm_lr_scale > 0: - builder.with_policy_wrapper_factory( - PolicyWrapperFactoryIntrinsicCuriosity( + builder.with_algorithm_wrapper_factory( + AlgorithmWrapperFactoryIntrinsicCuriosity( feature_net_factory=IntermediateModuleFactoryAtariDQNFeatures(), hidden_sizes=hidden_sizes, lr=lr, diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index e47c08d92..99706ed0e 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -6,37 +6,39 @@ import numpy as np import torch -from atari_network import QRDQN -from atari_wrapper import make_atari_env +from tianshou.algorithm import QRDQN +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import QRDQNet +from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import QRDQNPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--scale-obs", type=int, default=0) - parser.add_argument("--eps-test", type=float, default=0.005) - parser.add_argument("--eps-train", type=float, default=1.0) - parser.add_argument("--eps-train-final", type=float, default=0.05) - parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--scale_obs", type=int, default=0) + parser.add_argument("--eps_test", type=float, default=0.005) + parser.add_argument("--eps_train", type=float, default=1.0) + parser.add_argument("--eps_train_final", type=float, default=0.05) + parser.add_argument("--buffer_size", type=int, default=100000) parser.add_argument("--lr", type=float, default=0.0001) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--num-quantiles", type=int, default=200) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--num_quantiles", type=int, default=200) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=500) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=100000) - parser.add_argument("--step-per-collect", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=32) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=100000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) + parser.add_argument("--update_per_step", type=float, default=0.1) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -44,67 +46,76 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--frames-stack", type=int, default=4) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--frames_stack", type=int, default=4) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="atari.benchmark") + parser.add_argument("--wandb_project", type=str, default="atari.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) - parser.add_argument("--save-buffer-name", type=str, default=None) + parser.add_argument("--save_buffer_name", type=str, default=None) return parser.parse_args() -def test_qrdqn(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_atari_env( args.task, args.seed, - args.training_num, - args.test_num, + args.num_train_envs, + args.num_test_envs, scale=args.scale_obs, frame_stack=args.frames_stack, ) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + args.state_shape = env.observation_space.shape or env.observation_space.n # type: ignore + args.action_shape = env.action_space.shape or env.action_space.n # type: ignore + # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) + # define model c, h, w = args.state_shape - net = QRDQN( + net = QRDQNet( c=c, h=h, w=w, action_shape=args.action_shape, num_quantiles=args.num_quantiles, - device=args.device, ) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) - # define policy - policy: QRDQNPolicy = QRDQNPolicy( + + # define policy and algorithm + optim = AdamOptimizerFactory(lr=args.lr) + policy = QRDQNPolicy( model=net, - optim=optim, action_space=env.action_space, - discount_factor=args.gamma, + eps_training=args.eps_train, + eps_inference=args.eps_test, + ) + algorithm: QRDQN = QRDQN( + policy=policy, + optim=optim, + gamma=args.gamma, num_quantiles=args.num_quantiles, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) + # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM @@ -115,9 +126,10 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) + # collector - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -140,12 +152,12 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: - if env.spec.reward_threshold: - return mean_rewards >= env.spec.reward_threshold + if env.spec.reward_threshold: # type: ignore + return mean_rewards >= env.spec.reward_threshold # type: ignore if "Pong" in args.task: return mean_rewards >= 20 return False @@ -156,17 +168,13 @@ def train_fn(epoch: int, env_step: int) -> None: eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final) else: eps = args.eps_train_final - policy.set_eps(eps) + policy.set_eps_training(eps) if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/eps": eps}) - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) - # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") @@ -177,7 +185,9 @@ def watch() -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) + collector = Collector[CollectStats]( + algorithm, test_envs, buffer, exploration_noise=True + ) result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -185,7 +195,7 @@ def watch() -> None: else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: @@ -194,29 +204,30 @@ def watch() -> None: # test train_collector and start filling replay buffer train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - test_in_train=False, - ).run() + train_collector.collect(n_step=args.batch_size * args.num_train_envs) + + # train + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + train_fn=train_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_step_num_gradient_steps_per_sample=args.update_per_step, + test_in_train=False, + ) + ) pprint.pprint(result) watch() if __name__ == "__main__": - test_qrdqn(get_args()) + main(get_args()) diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 5373d0536..27f13c3d7 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -6,53 +6,55 @@ import numpy as np import torch -from atari_network import Rainbow -from atari_wrapper import make_atari_env +from tianshou.algorithm import C51, RainbowDQN +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.c51 import C51Policy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, PrioritizedVectorReplayBuffer, VectorReplayBuffer, ) +from tianshou.env.atari.atari_network import RainbowNet +from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import C51Policy, RainbowPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--scale-obs", type=int, default=0) - parser.add_argument("--eps-test", type=float, default=0.005) - parser.add_argument("--eps-train", type=float, default=1.0) - parser.add_argument("--eps-train-final", type=float, default=0.05) - parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--scale_obs", type=int, default=0) + parser.add_argument("--eps_test", type=float, default=0.005) + parser.add_argument("--eps_train", type=float, default=1.0) + parser.add_argument("--eps_train_final", type=float, default=0.05) + parser.add_argument("--buffer_size", type=int, default=100000) parser.add_argument("--lr", type=float, default=0.0000625) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--num-atoms", type=int, default=51) - parser.add_argument("--v-min", type=float, default=-10.0) - parser.add_argument("--v-max", type=float, default=10.0) - parser.add_argument("--noisy-std", type=float, default=0.1) - parser.add_argument("--no-dueling", action="store_true", default=False) - parser.add_argument("--no-noisy", action="store_true", default=False) - parser.add_argument("--no-priority", action="store_true", default=False) + parser.add_argument("--num_atoms", type=int, default=51) + parser.add_argument("--v_min", type=float, default=-10.0) + parser.add_argument("--v_max", type=float, default=10.0) + parser.add_argument("--noisy_std", type=float, default=0.1) + parser.add_argument("--no_dueling", action="store_true", default=False) + parser.add_argument("--no_noisy", action="store_true", default=False) + parser.add_argument("--no_priority", action="store_true", default=False) parser.add_argument("--alpha", type=float, default=0.5) parser.add_argument("--beta", type=float, default=0.4) - parser.add_argument("--beta-final", type=float, default=1.0) - parser.add_argument("--beta-anneal-step", type=int, default=5000000) - parser.add_argument("--no-weight-norm", action="store_true", default=False) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--beta_final", type=float, default=1.0) + parser.add_argument("--beta_anneal_step", type=int, default=5000000) + parser.add_argument("--no_weight_norm", action="store_true", default=False) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=500) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=100000) - parser.add_argument("--step-per-collect", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=32) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=100000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) + parser.add_argument("--update_per_step", type=float, default=0.1) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -60,23 +62,23 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--frames-stack", type=int, default=4) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--frames_stack", type=int, default=4) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="atari.benchmark") + parser.add_argument("--wandb_project", type=str, default="atari.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) - parser.add_argument("--save-buffer-name", type=str, default=None) + parser.add_argument("--save_buffer_name", type=str, default=None) return parser.parse_args() @@ -84,46 +86,58 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_atari_env( args.task, args.seed, - args.training_num, - args.test_num, + args.num_train_envs, + args.num_test_envs, scale=args.scale_obs, frame_stack=args.frames_stack, ) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + args.state_shape = env.observation_space.shape or env.observation_space.n # type: ignore + args.action_shape = env.action_space.shape or env.action_space.n # type: ignore # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) + # define model - net = Rainbow( - *args.state_shape, - args.action_shape, - args.num_atoms, - args.noisy_std, - args.device, + c, h, w = args.state_shape + net = RainbowNet( + c=c, + h=h, + w=w, + action_shape=args.action_shape, + num_atoms=args.num_atoms, + noisy_std=args.noisy_std, is_dueling=not args.no_dueling, is_noisy=not args.no_noisy, ) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) - # define policy - policy: C51Policy = RainbowPolicy( + + # define policy and algorithm + policy = C51Policy( model=net, - optim=optim, - discount_factor=args.gamma, action_space=env.action_space, num_atoms=args.num_atoms, v_min=args.v_min, v_max=args.v_max, - estimation_step=args.n_step, + eps_training=args.eps_train, + eps_inference=args.eps_test, + ) + optim = AdamOptimizerFactory(lr=args.lr) + algorithm: C51 = RainbowDQN( + policy=policy, + optim=optim, + gamma=args.gamma, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) + # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) + # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer @@ -146,9 +160,10 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: beta=args.beta, weight_norm=not args.no_weight_norm, ) + # collector - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -171,12 +186,12 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: - if env.spec.reward_threshold: - return mean_rewards >= env.spec.reward_threshold + if env.spec.reward_threshold: # type: ignore + return mean_rewards >= env.spec.reward_threshold # type: ignore if "Pong" in args.task: return mean_rewards >= 20 return False @@ -187,7 +202,7 @@ def train_fn(epoch: int, env_step: int) -> None: eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final) else: eps = args.eps_train_final - policy.set_eps(eps) + policy.set_eps_training(eps) if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/eps": eps}) if not args.no_priority: @@ -199,13 +214,8 @@ def train_fn(epoch: int, env_step: int) -> None: if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/beta": beta}) - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) - - # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") @@ -218,7 +228,9 @@ def watch() -> None: alpha=args.alpha, beta=args.beta, ) - collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) + collector = Collector[CollectStats]( + algorithm, test_envs, buffer, exploration_noise=True + ) result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -226,7 +238,7 @@ def watch() -> None: else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: @@ -235,25 +247,26 @@ def watch() -> None: # test train_collector and start filling replay buffer train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - test_in_train=False, - ).run() + train_collector.collect(n_step=args.batch_size * args.num_train_envs) + + # train + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + train_fn=train_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_step_num_gradient_steps_per_sample=args.update_per_step, + test_in_train=False, + ) + ) pprint.pprint(result) watch() diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index df43e49ac..047c6b435 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -6,40 +6,47 @@ import numpy as np import torch -from atari_network import DQN -from atari_wrapper import make_atari_env +from tianshou.algorithm import DiscreteSAC, ICMOffPolicyWrapper +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.discrete_sac import DiscreteSACPolicy +from tianshou.algorithm.modelfree.sac import AutoAlpha +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import DQNet +from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import DiscreteSACPolicy, ICMPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OffpolicyTrainer -from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule +from tianshou.trainer import OffPolicyTrainerParams +from tianshou.utils.net.discrete import ( + DiscreteActor, + DiscreteCritic, + IntrinsicCuriosityModule, +) def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=4213) - parser.add_argument("--scale-obs", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=100000) - parser.add_argument("--actor-lr", type=float, default=1e-5) - parser.add_argument("--critic-lr", type=float, default=1e-5) + parser.add_argument("--scale_obs", type=int, default=0) + parser.add_argument("--buffer_size", type=int, default=100000) + parser.add_argument("--actor_lr", type=float, default=1e-5) + parser.add_argument("--critic_lr", type=float, default=1e-5) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--n_step", type=int, default=3) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--alpha", type=float, default=0.05) - parser.add_argument("--auto-alpha", action="store_true", default=False) - parser.add_argument("--alpha-lr", type=float, default=3e-4) + parser.add_argument("--auto_alpha", action="store_true", default=False) + parser.add_argument("--alpha_lr", type=float, default=3e-4) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=100000) - parser.add_argument("--step-per-collect", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument("--hidden-size", type=int, default=512) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) - parser.add_argument("--rew-norm", type=int, default=False) + parser.add_argument("--epoch_num_steps", type=int, default=100000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) + parser.add_argument("--update_per_step", type=float, default=0.1) + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--hidden_size", type=int, default=512) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) + parser.add_argument("--return_scaling", type=int, default=False) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -47,37 +54,37 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--frames-stack", type=int, default=4) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--frames_stack", type=int, default=4) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="atari.benchmark") + parser.add_argument("--wandb_project", type=str, default="atari.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) - parser.add_argument("--save-buffer-name", type=str, default=None) + parser.add_argument("--save_buffer_name", type=str, default=None) parser.add_argument( - "--icm-lr-scale", + "--icm_lr_scale", type=float, default=0.0, help="use intrinsic curiosity module with this lr scale", ) parser.add_argument( - "--icm-reward-scale", + "--icm_reward_scale", type=float, default=0.01, help="scaling factor for intrinsic curiosity reward", ) parser.add_argument( - "--icm-forward-loss-weight", + "--icm_forward_loss_weight", type=float, default=0.2, help="weight for the forward model loss in ICM", @@ -89,80 +96,87 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_atari_env( args.task, args.seed, - args.training_num, - args.test_num, + args.num_train_envs, + args.num_test_envs, scale=args.scale_obs, frame_stack=args.frames_stack, ) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + c, h, w = env.observation_space.shape # type: ignore + args.action_shape = env.action_space.n # type: ignore + # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) + # define model - net = DQN( - *args.state_shape, - args.action_shape, - device=args.device, + net = DQNet( + c, + h, + w, + action_shape=args.action_shape, features_only=True, output_dim_added_layer=args.hidden_size, ) - actor = Actor(net, args.action_shape, device=args.device, softmax_output=False) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - critic1 = Critic(net, last_size=args.action_shape, device=args.device) - critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic(net, last_size=args.action_shape, device=args.device) - critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + actor = DiscreteActor(preprocess_net=net, action_shape=args.action_shape, softmax_output=False) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) + critic1 = DiscreteCritic(preprocess_net=net, last_size=args.action_shape) + critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) + critic2 = DiscreteCritic(preprocess_net=net, last_size=args.action_shape) + critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) - # define policy + # define policy and algorithm if args.auto_alpha: target_entropy = 0.98 * np.log(np.prod(args.action_shape)) - log_alpha = torch.zeros(1, requires_grad=True, device=args.device) - alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) - args.alpha = (target_entropy, log_alpha, alpha_optim) - - policy: DiscreteSACPolicy | ICMPolicy + log_alpha = 0.0 + alpha_optim = AdamOptimizerFactory(lr=args.alpha_lr) + args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim) + algorithm: DiscreteSAC | ICMOffPolicyWrapper policy = DiscreteSACPolicy( actor=actor, - actor_optim=actor_optim, + action_space=env.action_space, + ) + algorithm = DiscreteSAC( + policy=policy, + policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, critic2=critic2, critic2_optim=critic2_optim, - action_space=env.action_space, tau=args.tau, gamma=args.gamma, alpha=args.alpha, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, ).to(args.device) if args.icm_lr_scale > 0: - feature_net = DQN(*args.state_shape, args.action_shape, args.device, features_only=True) + c, h, w = args.state_shape + feature_net = DQNet(c=c, h=h, w=w, action_shape=args.action_shape, features_only=True) action_dim = np.prod(args.action_shape) feature_dim = feature_net.output_dim icm_net = IntrinsicCuriosityModule( - feature_net.net, - feature_dim, - action_dim, + feature_net=feature_net.net, + feature_dim=feature_dim, + action_dim=int(action_dim), hidden_sizes=[args.hidden_size], - device=args.device, ) - icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.actor_lr) - policy = ICMPolicy( - policy=policy, + icm_optim = AdamOptimizerFactory(lr=args.actor_lr) + algorithm = ICMOffPolicyWrapper( + wrapped_algorithm=algorithm, model=icm_net, optim=icm_optim, - action_space=env.action_space, lr_scale=args.icm_lr_scale, reward_scale=args.icm_reward_scale, forward_loss_weight=args.icm_forward_loss_weight, ).to(args.device) - # load a previous policy + + # load a previous model if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) + # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( @@ -173,8 +187,8 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) # collector - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -197,12 +211,12 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: - if env.spec.reward_threshold: - return mean_rewards >= env.spec.reward_threshold + if env.spec.reward_threshold: # type: ignore + return mean_rewards >= env.spec.reward_threshold # type: ignore if "Pong" in args.task: return mean_rewards >= 20 return False @@ -210,10 +224,9 @@ def stop_fn(mean_rewards: float) -> bool: def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html ckpt_path = os.path.join(log_path, "checkpoint.pth") - torch.save({"model": policy.state_dict()}, ckpt_path) + torch.save({"model": algorithm.state_dict()}, ckpt_path) return ckpt_path - # watch agent's performance def watch() -> None: print("Setup test envs ...") test_envs.seed(args.seed) @@ -226,7 +239,9 @@ def watch() -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) + collector = Collector[CollectStats]( + algorithm, test_envs, buffer, exploration_noise=True + ) result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -234,7 +249,7 @@ def watch() -> None: else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: @@ -243,25 +258,27 @@ def watch() -> None: # test train_collector and start filling replay buffer train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - test_in_train=False, - resume_from_log=args.resume_id is not None, - save_checkpoint_fn=save_checkpoint_fn, - ).run() + train_collector.collect(n_step=args.batch_size * args.num_train_envs) + + # train + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_step_num_gradient_steps_per_sample=args.update_per_step, + test_in_train=False, + resume_from_log=args.resume_id is not None, + save_checkpoint_fn=save_checkpoint_fn, + ) + ) pprint.pprint(result) watch() diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index 76f18f55f..b21ed5e44 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -6,21 +6,21 @@ from sensai.util import logging from sensai.util.logging import datetime_tag -from examples.atari.atari_network import ( +from tianshou.env.atari.atari_network import ( ActorFactoryAtariDQN, IntermediateModuleFactoryAtariDQNFeatures, ) -from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback -from tianshou.highlevel.config import SamplingConfig +from tianshou.env.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback +from tianshou.highlevel.config import OffPolicyTrainingConfig from tianshou.highlevel.experiment import ( DiscreteSACExperimentBuilder, ExperimentConfig, ) -from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault -from tianshou.highlevel.params.policy_params import DiscreteSACParams -from tianshou.highlevel.params.policy_wrapper import ( - PolicyWrapperFactoryIntrinsicCuriosity, +from tianshou.highlevel.params.algorithm_params import DiscreteSACParams +from tianshou.highlevel.params.algorithm_wrapper import ( + AlgorithmWrapperFactoryIntrinsicCuriosity, ) +from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault def main( @@ -37,13 +37,13 @@ def main( auto_alpha: bool = False, alpha_lr: float = 3e-4, epoch: int = 100, - step_per_epoch: int = 100000, - step_per_collect: int = 10, + epoch_num_steps: int = 100000, + collection_step_num_env_steps: int = 10, update_per_step: float = 0.1, batch_size: int = 64, hidden_sizes: Sequence[int] = (512,), - training_num: int = 10, - test_num: int = 10, + num_train_envs: int = 10, + num_test_envs: int = 10, frames_stack: int = 4, icm_lr_scale: float = 0.0, icm_reward_scale: float = 0.01, @@ -51,16 +51,15 @@ def main( ) -> None: log_name = os.path.join(task, "sac", str(experiment_config.seed), datetime_tag()) - sampling_config = SamplingConfig( - num_epochs=epoch, - step_per_epoch=step_per_epoch, - update_per_step=update_per_step, + training_config = OffPolicyTrainingConfig( + max_epochs=epoch, + epoch_num_steps=epoch_num_steps, + update_step_num_gradient_steps_per_sample=update_per_step, batch_size=batch_size, - num_train_envs=training_num, - num_test_envs=test_num, + num_train_envs=num_train_envs, + num_test_envs=num_test_envs, buffer_size=buffer_size, - step_per_collect=step_per_collect, - repeat_per_collect=None, + collection_step_num_env_steps=collection_step_num_env_steps, replay_buffer_stack_num=frames_stack, replay_buffer_ignore_obs_next=True, replay_buffer_save_only_last_obs=True, @@ -73,7 +72,7 @@ def main( ) builder = ( - DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config) + DiscreteSACExperimentBuilder(env_factory, experiment_config, training_config) .with_sac_params( DiscreteSACParams( actor_lr=actor_lr, @@ -81,10 +80,12 @@ def main( critic2_lr=critic_lr, gamma=gamma, tau=tau, - alpha=AutoAlphaFactoryDefault(lr=alpha_lr, target_entropy_coefficient=0.98) - if auto_alpha - else alpha, - estimation_step=n_step, + alpha=( + AutoAlphaFactoryDefault(lr=alpha_lr, target_entropy_coefficient=0.98) + if auto_alpha + else alpha + ), + n_step_return_horizon=n_step, ), ) .with_actor_factory(ActorFactoryAtariDQN(scale_obs=False, features_only=True)) @@ -92,8 +93,8 @@ def main( .with_epoch_stop_callback(AtariEpochStopCallback(task)) ) if icm_lr_scale > 0: - builder.with_policy_wrapper_factory( - PolicyWrapperFactoryIntrinsicCuriosity( + builder.with_algorithm_wrapper_factory( + AlgorithmWrapperFactoryIntrinsicCuriosity( feature_net_factory=IntermediateModuleFactoryAtariDQNFeatures(), hidden_sizes=hidden_sizes, lr=actor_lr, diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index b25f35c15..97e6c56d1 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -7,11 +7,13 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import DQN +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import DQNPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo @@ -21,23 +23,23 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Acrobot-v1") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--eps-test", type=float, default=0.05) - parser.add_argument("--eps-train", type=float, default=0.5) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--eps_test", type=float, default=0.05) + parser.add_argument("--eps_train", type=float, default=0.5) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.95) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=10) - parser.add_argument("--step-per-epoch", type=int, default=100000) - parser.add_argument("--step-per-collect", type=int, default=100) - parser.add_argument("--update-per-step", type=float, default=0.01) - parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128]) - parser.add_argument("--dueling-q-hidden-sizes", type=int, nargs="*", default=[128, 128]) - parser.add_argument("--dueling-v-hidden-sizes", type=int, nargs="*", default=[128, 128]) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=100000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=100) + parser.add_argument("--update_per_step", type=float, default=0.01) + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128]) + parser.add_argument("--dueling_q_hidden_sizes", type=int, nargs="*", default=[128, 128]) + parser.add_argument("--dueling_v_hidden_sizes", type=int, nargs="*", default=[128, 128]) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -56,9 +58,9 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: args.action_shape = space_info.action_info.action_shape # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -71,35 +73,38 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, dueling_param=(Q_param, V_param), - ).to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: DQNPolicy = DQNPolicy( + ) + optim = AdamOptimizerFactory(lr=args.lr) + policy = DiscreteQLearningPolicy( model=net, - optim=optim, action_space=env.action_space, - discount_factor=args.gamma, - estimation_step=args.n_step, - target_update_freq=args.target_update_freq, + eps_training=args.eps_train, + eps_inference=args.eps_test, ) + algorithm: DQN = DQN( + policy=policy, + optim=optim, + gamma=args.gamma, + n_step_return_horizon=args.n_step, + target_update_freq=args.target_update_freq, + ).to(args.device) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) - # policy.set_eps(1) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # log log_path = os.path.join(args.logdir, args.task, "dqn") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: @@ -112,42 +117,39 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: if env_step <= 100000: - policy.set_eps(args.eps_train) + policy.set_eps_training(args.eps_train) elif env_step <= 500000: eps = args.eps_train - (env_step - 100000) / 400000 * (0.5 * args.eps_train) - policy.set_eps(eps) + policy.set_eps_training(eps) else: - policy.set_eps(0.5 * args.eps_train) - - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) + policy.set_eps_training(0.5 * args.eps_train) - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - update_per_step=args.update_per_step, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + update_step_num_gradient_steps_per_sample=args.update_per_step, + train_fn=train_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=True, + ) + ) assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! - policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index d88379b23..41882525b 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -8,11 +8,13 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import BDQN +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.bdqn import BDQNPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import ContinuousToDiscrete, SubprocVectorEnv -from tianshou.policy import BranchingDQNPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import BranchingNet @@ -22,26 +24,26 @@ def get_args() -> argparse.Namespace: # task parser.add_argument("--task", type=str, default="BipedalWalker-v3") # network architecture - parser.add_argument("--common-hidden-sizes", type=int, nargs="*", default=[512, 256]) - parser.add_argument("--action-hidden-sizes", type=int, nargs="*", default=[128]) - parser.add_argument("--value-hidden-sizes", type=int, nargs="*", default=[128]) - parser.add_argument("--action-per-branch", type=int, default=25) + parser.add_argument("--common_hidden_sizes", type=int, nargs="*", default=[512, 256]) + parser.add_argument("--action_hidden_sizes", type=int, nargs="*", default=[128]) + parser.add_argument("--value_hidden_sizes", type=int, nargs="*", default=[128]) + parser.add_argument("--action_per_branch", type=int, default=25) # training hyperparameters parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--eps-test", type=float, default=0.0) - parser.add_argument("--eps-train", type=float, default=0.73) - parser.add_argument("--eps-decay", type=float, default=5e-6) - parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--eps_test", type=float, default=0.0) + parser.add_argument("--eps_train", type=float, default=0.73) + parser.add_argument("--eps_decay", type=float, default=5e-6) + parser.add_argument("--buffer_size", type=int, default=100000) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--target-update-freq", type=int, default=1000) + parser.add_argument("--target_update_freq", type=int, default=1000) parser.add_argument("--epoch", type=int, default=25) - parser.add_argument("--step-per-epoch", type=int, default=80000) - parser.add_argument("--step-per-collect", type=int, default=16) - parser.add_argument("--update-per-step", type=float, default=0.0625) - parser.add_argument("--batch-size", type=int, default=512) - parser.add_argument("--training-num", type=int, default=20) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=80000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=16) + parser.add_argument("--update_per_step", type=float, default=0.0625) + parser.add_argument("--batch_size", type=int, default=512) + parser.add_argument("--num_train_envs", type=int, default=20) + parser.add_argument("--num_test_envs", type=int, default=10) # other parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) @@ -53,7 +55,7 @@ def get_args() -> argparse.Namespace: return parser.parse_args() -def test_bdq(args: argparse.Namespace = get_args()) -> None: +def run_bdq(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) env = ContinuousToDiscrete(env, args.action_per_branch) @@ -75,14 +77,14 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: train_envs = SubprocVectorEnv( [ lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch) - for _ in range(args.training_num) + for _ in range(args.num_train_envs) ], ) # test_envs = ContinuousToDiscrete(gym.make(args.task), args.action_per_branch) test_envs = SubprocVectorEnv( [ lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch) - for _ in range(args.test_num) + for _ in range(args.num_test_envs) ], ) # seed @@ -92,40 +94,43 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model net = BranchingNet( - args.state_shape, - args.num_branches, - args.action_per_branch, - args.common_hidden_sizes, - args.value_hidden_sizes, - args.action_hidden_sizes, - device=args.device, + state_shape=args.state_shape, + num_branches=args.num_branches, + action_per_branch=args.action_per_branch, + common_hidden_sizes=args.common_hidden_sizes, + value_hidden_sizes=args.value_hidden_sizes, + action_hidden_sizes=args.action_hidden_sizes, ).to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: BranchingDQNPolicy = BranchingDQNPolicy( + optim = AdamOptimizerFactory(lr=args.lr) + policy = BDQNPolicy( model=net, - optim=optim, - discount_factor=args.gamma, action_space=env.action_space, # type: ignore[arg-type] # TODO: should `BranchingDQNPolicy` support also `MultiDiscrete` action spaces? + eps_training=args.eps_train, + eps_inference=args.eps_test, + ) + algorithm: BDQN = BDQN( + policy=policy, + optim=optim, + gamma=args.gamma, target_update_freq=args.target_update_freq, ) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=False) - # policy.set_eps(1) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=False) train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # log current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") log_path = os.path.join(args.logdir, "bdq", args.task, current_time) writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: @@ -135,39 +140,37 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: # exp decay eps = max(args.eps_train * (1 - args.eps_decay) ** env_step, args.eps_test) - policy.set_eps(eps) - - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) + policy.set_eps_training(eps) # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - update_per_step=args.update_per_step, - stop_fn=stop_fn, - train_fn=train_fn, - test_fn=test_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + update_step_num_gradient_steps_per_sample=args.update_per_step, + stop_fn=stop_fn, + train_fn=train_fn, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=True, + ) + ) assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! - policy.set_eps(args.eps_test) + policy.set_eps_training(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) if __name__ == "__main__": - test_bdq(get_args()) + run_bdq(get_args()) diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index b377d7bb1..f27d72017 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -9,14 +9,16 @@ from gymnasium.core import WrapperActType, WrapperObsType from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import SAC +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import SubprocVectorEnv -from tianshou.policy import SACPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -24,31 +26,31 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="BipedalWalkerHardcore-v3") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=1000000) - parser.add_argument("--actor-lr", type=float, default=3e-4) - parser.add_argument("--critic-lr", type=float, default=1e-3) + parser.add_argument("--buffer_size", type=int, default=1000000) + parser.add_argument("--actor_lr", type=float, default=3e-4) + parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--alpha", type=float, default=0.1) - parser.add_argument("--auto-alpha", type=int, default=1) - parser.add_argument("--alpha-lr", type=float, default=3e-4) + parser.add_argument("--auto_alpha", type=int, default=1) + parser.add_argument("--alpha_lr", type=float, default=3e-4) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=100000) - parser.add_argument("--step-per-collect", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=128) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128]) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=100000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) + parser.add_argument("--update_per_step", type=float, default=0.1) + parser.add_argument("--batch_size", type=int, default=128) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128]) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--n-step", type=int, default=4) + parser.add_argument("--n_step", type=int, default=4) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) return parser.parse_args() @@ -91,13 +93,13 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: args.action_shape = space_info.action_info.action_shape args.max_action = space_info.action_info.max_action train_envs = SubprocVectorEnv( - [lambda: Wrapper(gym.make(args.task)) for _ in range(args.training_num)], + [lambda: Wrapper(gym.make(args.task)) for _ in range(args.num_train_envs)], ) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv( [ lambda: Wrapper(gym.make(args.task), reward_scale=1, rm_done=False) - for _ in range(args.test_num) + for _ in range(args.num_test_envs) ], ) @@ -108,45 +110,46 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model - net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = ActorProb( + net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + actor = ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=args.action_shape, - device=args.device, unbounded=True, ).to(args.device) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic1 = Critic(net_c1, device=args.device).to(args.device) - critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) + critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic2 = Critic(net_c2, device=args.device).to(args.device) - critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + critic2 = ContinuousCritic(preprocess_net=net_c2).to(args.device) + critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) action_dim = space_info.action_info.action_dim if args.auto_alpha: target_entropy = -action_dim - log_alpha = torch.zeros(1, requires_grad=True, device=args.device) - alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) - args.alpha = (target_entropy, log_alpha, alpha_optim) + log_alpha = 0.0 + alpha_optim = AdamOptimizerFactory(lr=args.alpha_lr) + args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim).to(args.device) - policy: SACPolicy = SACPolicy( + policy = SACPolicy( actor=actor, - actor_optim=actor_optim, + action_space=env.action_space, + ) + algorithm: SAC = SAC( + policy=policy, + policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, critic2=critic2, @@ -154,29 +157,28 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: tau=args.tau, gamma=args.gamma, alpha=args.alpha, - estimation_step=args.n_step, - action_space=env.action_space, + n_step_return_horizon=args.n_step, ) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path)) + algorithm.load_state_dict(torch.load(args.resume_path)) print("Loaded agent from: ", args.resume_path) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # train_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, "sac") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: @@ -187,29 +189,30 @@ def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= env.spec.reward_threshold return False - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - update_per_step=args.update_per_step, - test_in_train=False, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + update_step_num_gradient_steps_per_sample=args.update_per_step, + test_in_train=False, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + ) + ) if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 347da2cf9..8ecbda311 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -7,11 +7,13 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import DQN +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv, SubprocVectorEnv -from tianshou.policy import DQNPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo @@ -22,23 +24,23 @@ def get_args() -> argparse.Namespace: # the parameters are found by Optuna parser.add_argument("--task", type=str, default="LunarLander-v2") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--eps-test", type=float, default=0.01) - parser.add_argument("--eps-train", type=float, default=0.73) - parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--eps_test", type=float, default=0.01) + parser.add_argument("--eps_train", type=float, default=0.73) + parser.add_argument("--buffer_size", type=int, default=100000) parser.add_argument("--lr", type=float, default=0.013) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--n-step", type=int, default=4) - parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--n_step", type=int, default=4) + parser.add_argument("--target_update_freq", type=int, default=500) parser.add_argument("--epoch", type=int, default=10) - parser.add_argument("--step-per-epoch", type=int, default=80000) - parser.add_argument("--step-per-collect", type=int, default=16) - parser.add_argument("--update-per-step", type=float, default=0.0625) - parser.add_argument("--batch-size", type=int, default=128) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128]) - parser.add_argument("--dueling-q-hidden-sizes", type=int, nargs="*", default=[128, 128]) - parser.add_argument("--dueling-v-hidden-sizes", type=int, nargs="*", default=[128, 128]) - parser.add_argument("--training-num", type=int, default=16) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=80000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=16) + parser.add_argument("--update_per_step", type=float, default=0.0625) + parser.add_argument("--batch_size", type=int, default=128) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128]) + parser.add_argument("--dueling_q_hidden_sizes", type=int, nargs="*", default=[128, 128]) + parser.add_argument("--dueling_v_hidden_sizes", type=int, nargs="*", default=[128, 128]) + parser.add_argument("--num_train_envs", type=int, default=16) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -58,9 +60,9 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: args.max_action = space_info.action_info.max_action # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) - test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -73,35 +75,38 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, dueling_param=(Q_param, V_param), ).to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: DQNPolicy = DQNPolicy( + optim = AdamOptimizerFactory(lr=args.lr) + policy = DiscreteQLearningPolicy( model=net, - optim=optim, action_space=env.action_space, - discount_factor=args.gamma, - estimation_step=args.n_step, + eps_training=args.eps_train, + eps_inference=args.eps_test, + ) + algorithm: DQN = DQN( + policy=policy, + optim=optim, + gamma=args.gamma, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) - # policy.set_eps(1) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # log log_path = os.path.join(args.logdir, args.task, "dqn") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: @@ -114,37 +119,34 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: # exp decay eps = max(args.eps_train * (1 - 5e-6) ** env_step, args.eps_test) - policy.set_eps(eps) + policy.set_eps_training(eps) - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) - - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - update_per_step=args.update_per_step, - stop_fn=stop_fn, - train_fn=train_fn, - test_fn=test_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + update_step_num_gradient_steps_per_sample=args.update_per_step, + stop_fn=stop_fn, + train_fn=train_fn, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=True, + ) + ) assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! - policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 452eb02d6..031b759fb 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -7,15 +7,17 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import SAC +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.exploration import OUNoise -from tianshou.policy import SACPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -23,23 +25,23 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="MountainCarContinuous-v0") parser.add_argument("--seed", type=int, default=1626) - parser.add_argument("--buffer-size", type=int, default=50000) - parser.add_argument("--actor-lr", type=float, default=3e-4) - parser.add_argument("--critic-lr", type=float, default=3e-4) - parser.add_argument("--alpha-lr", type=float, default=3e-4) + parser.add_argument("--buffer_size", type=int, default=50000) + parser.add_argument("--actor_lr", type=float, default=3e-4) + parser.add_argument("--critic_lr", type=float, default=3e-4) + parser.add_argument("--alpha_lr", type=float, default=3e-4) parser.add_argument("--noise_std", type=float, default=1.2) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--auto_alpha", type=int, default=1) parser.add_argument("--alpha", type=float, default=0.2) parser.add_argument("--epoch", type=int, default=20) - parser.add_argument("--step-per-epoch", type=int, default=12000) - parser.add_argument("--step-per-collect", type=int, default=5) - parser.add_argument("--update-per-step", type=float, default=0.2) - parser.add_argument("--batch-size", type=int, default=128) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128]) - parser.add_argument("--training-num", type=int, default=5) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=12000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=5) + parser.add_argument("--update_per_step", type=float, default=0.2) + parser.add_argument("--batch_size", type=int, default=128) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128]) + parser.add_argument("--num_train_envs", type=int, default=5) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -57,47 +59,52 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: args.action_shape = space_info.action_info.action_shape args.max_action = space_info.action_info.max_action # train_envs = gym.make(args.task) - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = ActorProb(net, args.action_shape, device=args.device, unbounded=True).to(args.device) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + actor = ContinuousActorProbabilistic( + preprocess_net=net, action_shape=args.action_shape, unbounded=True + ).to(args.device) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic1 = Critic(net_c1, device=args.device).to(args.device) - critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) + critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic2 = Critic(net_c2, device=args.device).to(args.device) - critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + critic2 = ContinuousCritic(preprocess_net=net_c2).to(args.device) + critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) action_dim = space_info.action_info.action_dim if args.auto_alpha: target_entropy = -action_dim - log_alpha = torch.zeros(1, requires_grad=True, device=args.device) - alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) - args.alpha = (target_entropy, log_alpha, alpha_optim) + log_alpha = 0.0 + alpha_optim = AdamOptimizerFactory(lr=args.alpha_lr) + args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim).to(args.device) - policy: SACPolicy = SACPolicy( + policy = SACPolicy( actor=actor, - actor_optim=actor_optim, + exploration_noise=OUNoise(0.0, args.noise_std), + action_space=env.action_space, + ) + algorithm: SAC = SAC( + policy=policy, + policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, critic2=critic2, @@ -105,24 +112,22 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: tau=args.tau, gamma=args.gamma, alpha=args.alpha, - exploration_noise=OUNoise(0.0, args.noise_std), - action_space=env.action_space, ) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # train_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, "sac") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: @@ -133,21 +138,23 @@ def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= env.spec.reward_threshold return False - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - update_per_step=args.update_per_step, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + update_step_num_gradient_steps_per_sample=args.update_per_step, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=True, + ) + ) assert stop_fn(result.best_reward) if __name__ == "__main__": @@ -155,7 +162,7 @@ def stop_fn(mean_rewards: float) -> bool: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) diff --git a/examples/discrete/discrete_dqn.py b/examples/discrete/discrete_dqn.py index bcf3e7f18..05604b626 100644 --- a/examples/discrete/discrete_dqn.py +++ b/examples/discrete/discrete_dqn.py @@ -1,27 +1,29 @@ import gymnasium as gym -import torch from torch.utils.tensorboard import SummaryWriter import tianshou as ts +from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import CollectStats +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.space_info import SpaceInfo def main() -> None: task = "CartPole-v1" lr, epoch, batch_size = 1e-3, 10, 64 - train_num, test_num = 10, 100 + num_train_envs, num_test_envs = 10, 100 gamma, n_step, target_freq = 0.9, 3, 320 buffer_size = 20000 eps_train, eps_test = 0.1, 0.05 - step_per_epoch, step_per_collect = 10000, 10 + epoch_num_steps, collection_step_num_env_steps = 10000, 10 logger = ts.utils.TensorboardLogger(SummaryWriter("log/dqn")) # TensorBoard is supported! # For other loggers, see https://tianshou.readthedocs.io/en/master/tutorials/logger.html # You can also try SubprocVectorEnv, which will use parallelization - train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(train_num)]) - test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(test_num)]) + train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(num_train_envs)]) + test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)]) from tianshou.utils.net.common import Net @@ -33,24 +35,26 @@ def main() -> None: state_shape = space_info.observation_info.obs_shape action_shape = space_info.action_info.action_shape net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128]) - optim = torch.optim.Adam(net.parameters(), lr=lr) + optim = AdamOptimizerFactory(lr=lr) - policy: ts.policy.DQNPolicy = ts.policy.DQNPolicy( - model=net, + policy = DiscreteQLearningPolicy( + model=net, action_space=env.action_space, eps_training=eps_train, eps_inference=eps_test + ) + algorithm = ts.algorithm.DQN( + policy=policy, optim=optim, - discount_factor=gamma, - action_space=env.action_space, - estimation_step=n_step, + gamma=gamma, + n_step_return_horizon=n_step, target_update_freq=target_freq, ) train_collector = ts.data.Collector[CollectStats]( - policy, + algorithm, train_envs, - ts.data.VectorReplayBuffer(buffer_size, train_num), + ts.data.VectorReplayBuffer(buffer_size, num_train_envs), exploration_noise=True, ) test_collector = ts.data.Collector[CollectStats]( - policy, + algorithm, test_envs, exploration_noise=True, ) # because DQN uses epsilon-greedy method @@ -63,27 +67,26 @@ def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= env.spec.reward_threshold return False - result = ts.trainer.OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=epoch, - step_per_epoch=step_per_epoch, - step_per_collect=step_per_collect, - episode_per_test=test_num, - batch_size=batch_size, - update_per_step=1 / step_per_collect, - train_fn=lambda epoch, env_step: policy.set_eps(eps_train), - test_fn=lambda epoch, env_step: policy.set_eps(eps_test), - stop_fn=stop_fn, - logger=logger, - ).run() + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=epoch, + epoch_num_steps=epoch_num_steps, + collection_step_num_env_steps=collection_step_num_env_steps, + test_step_num_episodes=num_test_envs, + batch_size=batch_size, + update_step_num_gradient_steps_per_sample=1 / collection_step_num_env_steps, + stop_fn=stop_fn, + logger=logger, + test_in_train=True, + ) + ) print(f"Finished training in {result.timing.total_time} seconds") # watch performance - policy.set_eps(eps_test) - collector = ts.data.Collector[CollectStats](policy, env, exploration_noise=True) - collector.collect(n_episode=100, render=1 / 35, reset_before_collect=True) + collector = ts.data.Collector[CollectStats](algorithm, env, exploration_noise=True) + collector.collect(n_episode=100, render=1 / 35) if __name__ == "__main__": diff --git a/examples/discrete/discrete_dqn_hl.py b/examples/discrete/discrete_dqn_hl.py index 464d06e71..0ba102f2b 100644 --- a/examples/discrete/discrete_dqn_hl.py +++ b/examples/discrete/discrete_dqn_hl.py @@ -1,16 +1,14 @@ from sensai.util.logging import run_main -from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.config import OffPolicyTrainingConfig from tianshou.highlevel.env import ( EnvFactoryRegistered, VectorEnvType, ) from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig -from tianshou.highlevel.params.policy_params import DQNParams +from tianshou.highlevel.params.algorithm_params import DQNParams from tianshou.highlevel.trainer import ( EpochStopCallbackRewardThreshold, - EpochTestCallbackDQNSetEps, - EpochTrainCallbackDQNSetEps, ) @@ -29,28 +27,28 @@ def main() -> None: watch_render=1 / 35, watch_num_episodes=100, ), - SamplingConfig( - num_epochs=10, - step_per_epoch=10000, + OffPolicyTrainingConfig( + max_epochs=10, + epoch_num_steps=10000, batch_size=64, num_train_envs=10, num_test_envs=100, buffer_size=20000, - step_per_collect=10, - update_per_step=1 / 10, + collection_step_num_env_steps=10, + update_step_num_gradient_steps_per_sample=1 / 10, ), ) .with_dqn_params( DQNParams( lr=1e-3, - discount_factor=0.9, - estimation_step=3, + gamma=0.9, + n_step_return_horizon=3, target_update_freq=320, + eps_training=0.3, + eps_inference=0.0, ), ) .with_model_factory_default(hidden_sizes=(64, 64)) - .with_epoch_train_callback(EpochTrainCallbackDQNSetEps(0.3)) - .with_epoch_test_callback(EpochTestCallbackDQNSetEps(0.0)) .with_epoch_stop_callback(EpochStopCallbackRewardThreshold(195)) .build() ) diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index 815060d1c..22d35c4b7 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -12,9 +12,12 @@ import torch from torch import nn from torch.distributions import Distribution, Independent, Normal -from torch.optim.lr_scheduler import LambdaLR from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import GAIL +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.data import ( Batch, Collector, @@ -24,12 +27,10 @@ ) from tianshou.data.types import RolloutBatchProtocol from tianshou.env import SubprocVectorEnv, VectorEnvNormObs -from tianshou.policy import GAILPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OnpolicyTrainer +from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import ActorCritic, Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -51,34 +52,34 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="HalfCheetah-v2") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--expert-data-task", type=str, default="halfcheetah-expert-v2") - parser.add_argument("--buffer-size", type=int, default=4096) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--expert_data_task", type=str, default="halfcheetah-expert-v2") + parser.add_argument("--buffer_size", type=int, default=4096) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--lr", type=float, default=3e-4) - parser.add_argument("--disc-lr", type=float, default=2.5e-5) + parser.add_argument("--disc_lr", type=float, default=2.5e-5) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=30000) - parser.add_argument("--step-per-collect", type=int, default=2048) - parser.add_argument("--repeat-per-collect", type=int, default=10) - parser.add_argument("--disc-update-num", type=int, default=2) - parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument("--training-num", type=int, default=64) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=30000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=2048) + parser.add_argument("--update_step_num_repetitions", type=int, default=10) + parser.add_argument("--disc_update_num", type=int, default=2) + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--num_train_envs", type=int, default=64) + parser.add_argument("--num_test_envs", type=int, default=10) # ppo special - parser.add_argument("--rew-norm", type=int, default=True) + parser.add_argument("--return_scaling", type=int, default=True) # In theory, `vf-coef` will not make any difference if using Adam optimizer. - parser.add_argument("--vf-coef", type=float, default=0.25) - parser.add_argument("--ent-coef", type=float, default=0.001) - parser.add_argument("--gae-lambda", type=float, default=0.95) - parser.add_argument("--bound-action-method", type=str, default="clip") - parser.add_argument("--lr-decay", type=int, default=True) - parser.add_argument("--max-grad-norm", type=float, default=0.5) - parser.add_argument("--eps-clip", type=float, default=0.2) - parser.add_argument("--dual-clip", type=float, default=None) - parser.add_argument("--value-clip", type=int, default=0) - parser.add_argument("--norm-adv", type=int, default=0) - parser.add_argument("--recompute-adv", type=int, default=1) + parser.add_argument("--vf_coef", type=float, default=0.25) + parser.add_argument("--ent_coef", type=float, default=0.001) + parser.add_argument("--gae_lambda", type=float, default=0.95) + parser.add_argument("--bound_action_method", type=str, default="clip") + parser.add_argument("--lr_decay", type=int, default=True) + parser.add_argument("--max_grad_norm", type=float, default=0.5) + parser.add_argument("--eps_clip", type=float, default=0.2) + parser.add_argument("--dual_clip", type=float, default=None) + parser.add_argument("--value_clip", type=int, default=0) + parser.add_argument("--advantage_normalization", type=int, default=0) + parser.add_argument("--recompute_adv", type=int, default=1) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -86,7 +87,7 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) parser.add_argument( "--watch", default=False, @@ -107,11 +108,11 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: print("Action range:", args.min_action, args.max_action) # train_envs = gym.make(args.task) train_envs = SubprocVectorEnv( - [lambda: NoRewardEnv(gym.make(args.task)) for _ in range(args.training_num)], + [lambda: NoRewardEnv(gym.make(args.task)) for _ in range(args.num_train_envs)], ) train_envs = VectorEnvNormObs(train_envs) # test_envs = gym.make(args.task) - test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False) test_envs.set_obs_rms(train_envs.get_obs_rms()) @@ -122,24 +123,21 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model net_a = Net( - args.state_shape, + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, - device=args.device, ) - actor = ActorProb( + actor = ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=args.action_shape, unbounded=True, - device=args.device, ).to(args.device) net_c = Net( - args.state_shape, + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, - device=args.device, ) - critic = Critic(net_c, device=args.device).to(args.device) + critic = ContinuousCritic(preprocess_net=net_c).to(args.device) torch.nn.init.constant_(actor.sigma_param, -0.5) for m in list(actor.modules()) + list(critic.modules()): if isinstance(m, torch.nn.Linear): @@ -154,30 +152,31 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: torch.nn.init.zeros_(m.bias) m.weight.data.copy_(0.01 * m.weight.data) - optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) # discriminator net_d = Net( - args.state_shape, + state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, - device=args.device, concat=True, ) - disc_net = Critic(net_d, device=args.device).to(args.device) + disc_net = ContinuousCritic(preprocess_net=net_d).to(args.device) for m in disc_net.modules(): if isinstance(m, torch.nn.Linear): # orthogonal initialization torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) torch.nn.init.zeros_(m.bias) - disc_optim = torch.optim.Adam(disc_net.parameters(), lr=args.disc_lr) + disc_optim = AdamOptimizerFactory(lr=args.disc_lr) - lr_scheduler = None if args.lr_decay: - # decay learning rate to 0 linearly - max_update_num = np.ceil(args.step_per_epoch / args.step_per_collect) * args.epoch - - lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + optim.with_lr_scheduler_factory( + LRSchedulerFactoryLinear( + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + ) + ) def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale @@ -205,45 +204,47 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: ) print("dataset loaded") - policy: GAILPolicy = GAILPolicy( + policy = ProbabilisticActorPolicy( actor=actor, + dist_fn=dist, + action_scaling=True, + action_bound_method=args.bound_action_method, + action_space=env.action_space, + ) + algorithm: GAIL = GAIL( + policy=policy, critic=critic, optim=optim, - dist_fn=dist, expert_buffer=expert_buffer, disc_net=disc_net, disc_optim=disc_optim, disc_update_num=args.disc_update_num, - discount_factor=args.gamma, + gamma=args.gamma, gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm, vf_coef=args.vf_coef, ent_coef=args.ent_coef, - reward_normalization=args.rew_norm, - action_scaling=True, - action_bound_method=args.bound_action_method, - lr_scheduler=lr_scheduler, - action_space=env.action_space, + return_scaling=args.return_scaling, eps_clip=args.eps_clip, value_clip=args.value_clip, dual_clip=args.dual_clip, - advantage_normalization=args.norm_adv, + advantage_normalization=args.advantage_normalization, recompute_advantage=args.recompute_adv, ) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector buffer: ReplayBuffer - if args.training_num > 1: + if args.num_train_envs > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs) # log t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_gail' @@ -252,31 +253,32 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: writer.add_text("args", str(args)) logger = TensorboardLogger(writer, update_interval=100, train_interval=100) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) if not args.watch: - # trainer - result = OnpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - step_per_collect=args.step_per_collect, - save_best_fn=save_best_fn, - logger=logger, - test_in_train=False, - ).run() + # train + result = algorithm.run_training( + OnPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + collection_step_num_env_steps=args.collection_step_num_env_steps, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=False, + ) + ) pprint.pprint(result) # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) diff --git a/examples/modelbased/README.md b/examples/modelbased/README.md index c3563f629..879847ac4 100644 --- a/examples/modelbased/README.md +++ b/examples/modelbased/README.md @@ -1,7 +1,7 @@ # PSRL -`NChain-v0`: `python3 psrl.py --task NChain-v0 --step-per-epoch 10 --rew-mean-prior 0 --rew-std-prior 1` +`NChain-v0`: `python3 psrl.py --task NChain-v0 --epoch_num_steps 10 --rew-mean-prior 0 --rew-std-prior 1` -`FrozenLake-v0`: `python3 psrl.py --task FrozenLake-v0 --step-per-epoch 1000 --rew-mean-prior 0 --rew-std-prior 1 --add-done-loop --epoch 20` +`FrozenLake-v0`: `python3 psrl.py --task FrozenLake-v0 --epoch_num_steps 1000 --rew-mean-prior 0 --rew-std-prior 1 --add-done-loop --epoch 20` -`Taxi-v3`: `python3 psrl.py --task Taxi-v3 --step-per-epoch 1000 --rew-mean-prior 0 --rew-std-prior 2 --epoch 20` +`Taxi-v3`: `python3 psrl.py --task Taxi-v3 --epoch_num_steps 1000 --rew-mean-prior 0 --rew-std-prior 2 --epoch 20` diff --git a/examples/mujoco/README.md b/examples/mujoco/README.md index f633273a7..3e6ceee29 100644 --- a/examples/mujoco/README.md +++ b/examples/mujoco/README.md @@ -3,15 +3,16 @@ We benchmarked Tianshou algorithm implementations in 9 out of 13 environments from the MuJoCo Gym task suite[[1]](#footnote1). For each supported algorithm and supported mujoco environments, we provide: + - Default hyperparameters used for benchmark and scripts to reproduce the benchmark; - A comparison of performance (or code level details) with other open source implementations or classic papers; - Graphs and raw data that can be used for research purposes[[2]](#footnote2); - Log details obtained during training[[2]](#footnote2); - Pretrained agents[[2]](#footnote2); - Some hints on how to tune the algorithm. - Supported algorithms are listed below: + - [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e605bdea942b408126ef4fbc740359773259c9ec) - [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e605bdea942b408126ef4fbc740359773259c9ec) - [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e605bdea942b408126ef4fbc740359773259c9ec) @@ -79,62 +80,64 @@ For pretrained agents, detailed graphs (single agent, single game) and log detai ### Notes -1. In offpolicy algorithms (DDPG, TD3, SAC), the shared hyperparameters are almost the same, and unless otherwise stated, hyperparameters are consistent with those used for benchmark in SpinningUp's implementations (e.g. we use batchsize 256 in DDPG/TD3/SAC while SpinningUp use 100. Minor difference also lies with `start-timesteps`, data loop method `step_per_collect`, method to deal with/bootstrap truncated steps because of timelimit and unfinished/collecting episodes (contribute to performance improvement), etc.). +1. In offpolicy algorithms (DDPG, TD3, SAC), the shared hyperparameters are almost the same, and unless otherwise stated, hyperparameters are consistent with those used for benchmark in SpinningUp's implementations (e.g. we use batchsize 256 in DDPG/TD3/SAC while SpinningUp use 100. Minor difference also lies with `start-timesteps`, data loop method `collection_step_num_env_steps`, method to deal with/bootstrap truncated steps because of timelimit and unfinished/collecting episodes (contribute to performance improvement), etc.). 2. By comparison to both classic literature and open source implementations (e.g., SpinningUp)[[1]](#footnote1)[[2]](#footnote2), Tianshou's implementations of DDPG, TD3, and SAC are roughly at-parity with or better than the best reported results for these algorithms, so you can definitely use Tianshou's benchmark for research purposes. 3. We didn't compare offpolicy algorithms to OpenAI baselines [benchmark](https://github.com/openai/baselines/blob/master/benchmarks_mujoco1M.htm), because for now it seems that they haven't provided benchmark for offpolicy algorithms, but in [SpinningUp docs](https://spinningup.openai.com/en/latest/spinningup/bench.html) they stated that "SpinningUp implementations of DDPG, TD3, and SAC are roughly at-parity with the best-reported results for these algorithms", so we think lack of comparisons with OpenAI baselines is okay. ### DDPG | Environment | Tianshou (1M) | [Spinning Up (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench.html) | [TD3 paper (DDPG)](https://arxiv.org/abs/1802.09477) | [TD3 paper (OurDDPG)](https://arxiv.org/abs/1802.09477) | -| :--------------------: | :---------------: | :----------------------------------------------------------: | :--------------------------------------------------: | :-----------------------------------------------------: | -| Ant | 990.4±4.3 | ~840 | **1005.3** | 888.8 | -| HalfCheetah | **11718.7±465.6** | ~11000 | 3305.6 | 8577.3 | -| Hopper | **2197.0±971.6** | ~1800 | **2020.5** | 1860.0 | -| Walker2d | 1400.6±905.0 | ~1950 | 1843.6 | **3098.1** | -| Swimmer | **144.1±6.5** | ~137 | N | N | -| Humanoid | **177.3±77.6** | N | N | N | -| Reacher | **-3.3±0.3** | N | -6.51 | -4.01 | -| InvertedPendulum | **1000.0±0.0** | N | **1000.0** | **1000.0** | -| InvertedDoublePendulum | 8364.3±2778.9 | N | **9355.5** | 8370.0 | +| :--------------------: | :---------------: | :------------------------------------------------------------------------------------: | :--------------------------------------------------: | :-----------------------------------------------------: | +| Ant | 990.4±4.3 | ~840 | **1005.3** | 888.8 | +| HalfCheetah | **11718.7±465.6** | ~11000 | 3305.6 | 8577.3 | +| Hopper | **2197.0±971.6** | ~1800 | **2020.5** | 1860.0 | +| Walker2d | 1400.6±905.0 | ~1950 | 1843.6 | **3098.1** | +| Swimmer | **144.1±6.5** | ~137 | N | N | +| Humanoid | **177.3±77.6** | N | N | N | +| Reacher | **-3.3±0.3** | N | -6.51 | -4.01 | +| InvertedPendulum | **1000.0±0.0** | N | **1000.0** | **1000.0** | +| InvertedDoublePendulum | 8364.3±2778.9 | N | **9355.5** | 8370.0 | \* details[[4]](#footnote4)[[5]](#footnote5)[[6]](#footnote6) ### TD3 | Environment | Tianshou (1M) | [Spinning Up (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench.html) | [TD3 paper](https://arxiv.org/abs/1802.09477) | -| :--------------------: | :---------------: | :----------------------------------------------------------: | :-------------------------------------------: | -| Ant | **5116.4±799.9** | ~3800 | 4372.4±1000.3 | -| HalfCheetah | **10201.2±772.8** | ~9750 | 9637.0±859.1 | -| Hopper | 3472.2±116.8 | ~2860 | **3564.1±114.7** | -| Walker2d | 3982.4±274.5 | ~4000 | **4682.8±539.6** | -| Swimmer | **104.2±34.2** | ~78 | N | -| Humanoid | **5189.5±178.5** | N | N | -| Reacher | **-2.7±0.2** | N | -3.6±0.6 | -| InvertedPendulum | **1000.0±0.0** | N | **1000.0±0.0** | -| InvertedDoublePendulum | **9349.2±14.3** | N | **9337.5±15.0** | +| :--------------------: | :---------------: | :------------------------------------------------------------------------------------: | :-------------------------------------------: | +| Ant | **5116.4±799.9** | ~3800 | 4372.4±1000.3 | +| HalfCheetah | **10201.2±772.8** | ~9750 | 9637.0±859.1 | +| Hopper | 3472.2±116.8 | ~2860 | **3564.1±114.7** | +| Walker2d | 3982.4±274.5 | ~4000 | **4682.8±539.6** | +| Swimmer | **104.2±34.2** | ~78 | N | +| Humanoid | **5189.5±178.5** | N | N | +| Reacher | **-2.7±0.2** | N | -3.6±0.6 | +| InvertedPendulum | **1000.0±0.0** | N | **1000.0±0.0** | +| InvertedDoublePendulum | **9349.2±14.3** | N | **9337.5±15.0** | \* details[[4]](#footnote4)[[5]](#footnote5)[[6]](#footnote6) #### Hints for TD3 + 1. TD3's learning rate is set to 3e-4 while it is 1e-3 for DDPG/SAC. However, there is NO enough evidence to support our choice of such hyperparameters (we simply choose them because SpinningUp do so) and you can try playing with those hyperparameters to see if you can improve performance. Do tell us if you can! ### SAC | Environment | Tianshou (1M) | [Spinning Up (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench.html) | [SAC paper](https://arxiv.org/abs/1801.01290) | -| :--------------------: | :----------------: | :----------------------------------------------------------: | :-------------------------------------------: | -| Ant | **5850.2±475.7** | ~3980 | ~3720 | -| HalfCheetah | **12138.8±1049.3** | ~11520 | ~10400 | -| Hopper | **3542.2±51.5** | ~3150 | ~3370 | -| Walker2d | **5007.0±251.5** | ~4250 | ~3740 | -| Swimmer | **44.4±0.5** | ~41.7 | N | -| Humanoid | **5488.5±81.2** | N | ~5200 | -| Reacher | **-2.6±0.2** | N | N | -| InvertedPendulum | **1000.0±0.0** | N | N | -| InvertedDoublePendulum | **9359.5±0.4** | N | N | +| :--------------------: | :----------------: | :------------------------------------------------------------------------------------: | :-------------------------------------------: | +| Ant | **5850.2±475.7** | ~3980 | ~3720 | +| HalfCheetah | **12138.8±1049.3** | ~11520 | ~10400 | +| Hopper | **3542.2±51.5** | ~3150 | ~3370 | +| Walker2d | **5007.0±251.5** | ~4250 | ~3740 | +| Swimmer | **44.4±0.5** | ~41.7 | N | +| Humanoid | **5488.5±81.2** | N | ~5200 | +| Reacher | **-2.6±0.2** | N | N | +| InvertedPendulum | **1000.0±0.0** | N | N | +| InvertedDoublePendulum | **9359.5±0.4** | N | N | \* details[[4]](#footnote4)[[5]](#footnote5) #### Hints for SAC + 1. SAC's start-timesteps is set to 10000 by default while it is 25000 is DDPG/TD3. However, there is NO enough evidence to support our choice of such hyperparameters (we simply choose them because SpinningUp do so) and you can try playing with those hyperparameters to see if you can improve performance. Do tell us if you can! 2. DO NOT share the same network with two critic networks. 3. The sigma (of the Gaussian policy) should be conditioned on input. @@ -143,6 +146,7 @@ For pretrained agents, detailed graphs (single agent, single game) and log detai ## Onpolicy Algorithms ### Notes + 1. In A2C and PPO, unless otherwise stated, most hyperparameters are consistent with those used for benchmark in [ikostrikov/pytorch-a2c-ppo-acktr-gail](https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail). 2. Gernally speaking, by comparison to both classic literature and open source implementations (e.g., OPENAI Baselines)[[1]](#footnote1)[[2]](#footnote2), Tianshou's implementations of REINFORCE, A2C, PPO are better than the best reported results for these algorithms, so you can definitely use Tianshou's benchmark for research purposes. @@ -160,18 +164,17 @@ For pretrained agents, detailed graphs (single agent, single game) and log detai | InvertedPendulum | **1000.0±0.0** | | InvertedDoublePendulum | **7726.2±1287.3** | - | Environment | Tianshou (3M) | [Spinning Up (VPG PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench_vpg.html)[[7]](#footnote7) | -| :--------------------: | :---------------: | :----------------------------------------------------------: | -| Ant | **474.9+-133.5** | ~5 | -| HalfCheetah | **884.0+-41.0** | ~600 | -| Hopper | 395.8+-64.5* | **~800** | -| Walker2d | 412.0+-52.4 | **~460** | -| Swimmer | 35.3+-1.4 | **~51** | -| Humanoid | **438.2+-47.8** | N | -| Reacher | **-10.5+-0.7** | N | -| InvertedPendulum | **999.2+-2.4** | N | -| InvertedDoublePendulum | **1059.7+-307.7** | N | +| :--------------------: | :---------------: | :------------------------------------------------------------------------------------------------------------------------: | +| Ant | **474.9+-133.5** | ~5 | +| HalfCheetah | **884.0+-41.0** | ~600 | +| Hopper | 395.8+-64.5\* | **~800** | +| Walker2d | 412.0+-52.4 | **~460** | +| Swimmer | 35.3+-1.4 | **~51** | +| Humanoid | **438.2+-47.8** | N | +| Reacher | **-10.5+-0.7** | N | +| InvertedPendulum | **999.2+-2.4** | N | +| InvertedDoublePendulum | **1059.7+-307.7** | N | \* details[[4]](#footnote4)[[5]](#footnote5) @@ -188,35 +191,35 @@ For pretrained agents, detailed graphs (single agent, single game) and log detai ### A2C | Environment | Tianshou (3M) | [Spinning Up (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench_vpg.html) | -| :--------------------: | :----------------: | :----------------------------------------------------------: | -| Ant | **5236.8+-236.7** | ~5 | -| HalfCheetah | **2377.3+-1363.7** | ~600 | -| Hopper | **1608.6+-529.5** | ~800 | -| Walker2d | **1805.4+-1055.9** | ~460 | -| Swimmer | 40.2+-1.8 | **~51** | -| Humanoid | **5316.6+-554.8** | N | -| Reacher | **-5.2+-0.5** | N | -| InvertedPendulum | **1000.0+-0.0** | N | -| InvertedDoublePendulum | **9351.3+-12.8** | N | +| :--------------------: | :----------------: | :----------------------------------------------------------------------------------------: | +| Ant | **5236.8+-236.7** | ~5 | +| HalfCheetah | **2377.3+-1363.7** | ~600 | +| Hopper | **1608.6+-529.5** | ~800 | +| Walker2d | **1805.4+-1055.9** | ~460 | +| Swimmer | 40.2+-1.8 | **~51** | +| Humanoid | **5316.6+-554.8** | N | +| Reacher | **-5.2+-0.5** | N | +| InvertedPendulum | **1000.0+-0.0** | N | +| InvertedDoublePendulum | **9351.3+-12.8** | N | | Environment | Tianshou (1M) | [PPO paper](https://arxiv.org/abs/1707.06347) A2C | [PPO paper](https://arxiv.org/abs/1707.06347) A2C + Trust Region | -| :--------------------: | :----------------: | :-----------------------------------------------: | :----------------------------------------------------------: | -| Ant | **3485.4+-433.1** | N | N | -| HalfCheetah | **1829.9+-1068.3** | ~1000 | ~930 | -| Hopper | **1253.2+-458.0** | ~900 | ~1220 | -| Walker2d | **1091.6+-709.2** | ~850 | ~700 | -| Swimmer | **36.6+-2.1** | ~31 | **~36** | -| Humanoid | **1726.0+-1070.1** | N | N | -| Reacher | **-6.7+-2.3** | ~-24 | ~-27 | -| InvertedPendulum | **1000.0+-0.0** | **~1000** | **~1000** | -| InvertedDoublePendulum | **9257.7+-277.4** | ~7100 | ~8100 | +| :--------------------: | :----------------: | :-----------------------------------------------: | :--------------------------------------------------------------: | +| Ant | **3485.4+-433.1** | N | N | +| HalfCheetah | **1829.9+-1068.3** | ~1000 | ~930 | +| Hopper | **1253.2+-458.0** | ~900 | ~1220 | +| Walker2d | **1091.6+-709.2** | ~850 | ~700 | +| Swimmer | **36.6+-2.1** | ~31 | **~36** | +| Humanoid | **1726.0+-1070.1** | N | N | +| Reacher | **-6.7+-2.3** | ~-24 | ~-27 | +| InvertedPendulum | **1000.0+-0.0** | **~1000** | **~1000** | +| InvertedDoublePendulum | **9257.7+-277.4** | ~7100 | ~8100 | \* details[[4]](#footnote4)[[5]](#footnote5) #### Hints for A2C 1. We choose `clip` action method in A2C instead of `tanh` option as used in REINFORCE simply to be consistent with original implementation. `tanh` may be better or equally well but we didn't have a try. -2. (Initial) learning rate, lr\_decay, `step-per-collect` and `training-num` affect the performance of A2C to a great extend. These 4 hyperparameters also affect each other and should be tuned together. We have done full scale ablation studies on these 4 hyperparameters (more than 800 agents have been trained). Below are our findings. +2. (Initial) learning rate, lr_decay, `step-per-collect` and `training-num` affect the performance of A2C to a great extend. These 4 hyperparameters also affect each other and should be tuned together. We have done full scale ablation studies on these 4 hyperparameters (more than 800 agents have been trained). Below are our findings. 3. `step-per-collect` / `training-num` are equal to `bootstrap-lenghth`, which is the max length of an "episode" used in GAE estimator and 80/16=5 in default settings. When `bootstrap-lenghth` is small, (maybe) because GAE can look forward at most 5 steps and use bootstrap strategy very often, the critic is less well-trained leading the actor to a not very high score. However, if we increase `step-per-collect` to increase `bootstrap-lenghth` (e.g. 256/16=16), actor/critic will be updated less often, resulting in low sample efficiency and slow training process. To conclude, If you don't restrict env timesteps, you can try using larger `bootstrap-lenghth` and train with more steps to get a better converged score. Train slower, achieve higher. 4. The learning rate 7e-4 with decay strategy is appropriate for `step-per-collect=80` and `training-num=16`. But if you use a larger `step-per-collect`(e.g. 256 - 2048), 7e-4 is a little bit small for `lr` because each update will have more data, less noise and thus smaller deviation in this case. So it is more appropriate to use a higher learning rate (e.g. 1e-3) to boost performance in this setting. If plotting results arise fast in early stages and become unstable later, consider lr decay first before decreasing lr. 5. `max-grad-norm` didn't really help in our experiments. We simply keep it for consistency with other open-source implementations (e.g. SB3). @@ -227,58 +230,60 @@ For pretrained agents, detailed graphs (single agent, single game) and log detai ### PPO | Environment | Tianshou (1M) | [ikostrikov/pytorch-a2c-ppo-acktr-gail](https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail) | [PPO paper](https://arxiv.org/pdf/1707.06347.pdf) | [OpenAI Baselines](https://github.com/openai/baselines/blob/master/benchmarks_mujoco1M.htm) | [Spinning Up (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench_ppo.html) | -| :--------------------: | :----------------: | :----------------------------------------------------------: | :-----------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | -| Ant | **3258.4+-1079.3** | N | N | N | ~650 | -| HalfCheetah | **5783.9+-1244.0** | ~3120 | ~1800 | ~1700 | ~1670 | -| Hopper | **2609.3+-700.8** | ~2300 | ~2330 | ~2400 | ~1850 | -| Walker2d | 3588.5+-756.6 | **~4000** | ~3460 | ~3510 | ~1230 | -| Swimmer | 66.7+-99.1 | N | ~108 | ~111 | **~120** | -| Humanoid | **787.1+-193.5** | N | N | N | N | -| Reacher | **-4.1+-0.3** | ~-5 | ~-7 | ~-6 | N | -| InvertedPendulum | **1000.0+-0.0** | N | **~1000** | ~940 | N | -| InvertedDoublePendulum | **9231.3+-270.4** | N | ~8000 | ~7350 | N | +| :--------------------: | :----------------: | :-----------------------------------------------------------------------------------------------: | :-----------------------------------------------: | :-----------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------: | +| Ant | **3258.4+-1079.3** | N | N | N | ~650 | +| HalfCheetah | **5783.9+-1244.0** | ~3120 | ~1800 | ~1700 | ~1670 | +| Hopper | **2609.3+-700.8** | ~2300 | ~2330 | ~2400 | ~1850 | +| Walker2d | 3588.5+-756.6 | **~4000** | ~3460 | ~3510 | ~1230 | +| Swimmer | 66.7+-99.1 | N | ~108 | ~111 | **~120** | +| Humanoid | **787.1+-193.5** | N | N | N | N | +| Reacher | **-4.1+-0.3** | ~-5 | ~-7 | ~-6 | N | +| InvertedPendulum | **1000.0+-0.0** | N | **~1000** | ~940 | N | +| InvertedDoublePendulum | **9231.3+-270.4** | N | ~8000 | ~7350 | N | | Environment | Tianshou (3M) | [Spinning Up (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench_ppo.html) | -| :--------------------: | :----------------: | :----------------------------------------------------------: | -| Ant | **4079.3+-880.2** | ~3000 | -| HalfCheetah | **7337.4+-1508.2** | ~3130 | -| Hopper | **3127.7+-413.0** | ~2460 | -| Walker2d | **4895.6+-704.3** | ~2600 | -| Swimmer | 81.4+-96.0 | **~120** | -| Humanoid | **1359.7+-572.7** | N | -| Reacher | **-3.7+-0.3** | N | -| InvertedPendulum | **1000.0+-0.0** | N | -| InvertedDoublePendulum | **9231.3+-270.4** | N | +| :--------------------: | :----------------: | :----------------------------------------------------------------------------------------: | +| Ant | **4079.3+-880.2** | ~3000 | +| HalfCheetah | **7337.4+-1508.2** | ~3130 | +| Hopper | **3127.7+-413.0** | ~2460 | +| Walker2d | **4895.6+-704.3** | ~2600 | +| Swimmer | 81.4+-96.0 | **~120** | +| Humanoid | **1359.7+-572.7** | N | +| Reacher | **-3.7+-0.3** | N | +| InvertedPendulum | **1000.0+-0.0** | N | +| InvertedDoublePendulum | **9231.3+-270.4** | N | \* details[[4]](#footnote4)[[5]](#footnote5) #### Hints for PPO + 1. Following [Andrychowicz, Marcin, et al](https://arxiv.org/abs/2006.05990) Sec 3.5, we use "recompute advantage" strategy, which contributes a lot to our SOTA benchmark. However, I personally don't quite agree with their explanation about why "recompute advantage" helps. They stated that it's because old strategy "makes it impossible to compute advantages as the temporal structure is broken", but PPO's update equation is designed to learn from slightly-outdated advantages. I think the only reason "recompute advantage" works is that it update the critic several times rather than just one time per update, which leads to a better value function estimation. 2. We have done full scale ablation studies of PPO algorithm's hyperparameters. Here are our findings: In Mujoco settings, `value-clip` and `norm-adv` may help a litte bit in some games (e.g. `norm-adv` helps stabilize training in InvertedPendulum-v2), but they make no difference to overall performance. So in our benchmark we do not use such tricks. We validate that setting `ent-coef` to 0.0 rather than 0.01 will increase overall performance in mujoco environments. `max-grad-norm` still offers no help for PPO algorithm, but we still keep it for consistency. 3. [Andrychowicz, Marcin, et al](https://arxiv.org/abs/2006.05990)'s work indicates that using `gae-lambda` 0.9 and changing policy network's width based on which game you play (e.g. use [16, 16] `hidden-sizes` for `actor` network in HalfCheetah and [256, 256] for Ant) may help boost performance. Our ablation studies say otherwise: both options may lead to equal or lower performance overall in our experiments. We are not confident about this claim because we didn't change learning rate and other maybe-correlated factors in our experiments. So if you want, you can still have a try. -4. `batch-size` 128 and 64 (default) work equally well. Changing `training-num` alone slightly (maybe in range [8, 128]) won't affect performance. For bound action method, both `clip` and `tanh` work quite well. +4. `batch-size` 128 and 64 (default) work equally well. Changing `training-num` alone slightly (maybe in range [8, 128]) won't affect performance. For bound action method, both `clip` and `tanh` work quite well. 5. In OPENAI implementations of PPO, they multiply value loss with a factor of 0.5 for no good reason (see this [issue](https://github.com/openai/baselines/issues/445#issuecomment-777988738)). We do not do so and therefore make our `vf-coef` 0.25 (half of standard 0.5). However, since value loss is only used to optimize `critic` network, setting different `vf-coef` should in theory make no difference if using Adam optimizer. - + ### TRPO | Environment | Tianshou (1M) | [ACKTR paper](https://arxiv.org/pdf/1708.05144.pdf) | [PPO paper](https://arxiv.org/pdf/1707.06347.pdf) | [OpenAI Baselines](https://github.com/openai/baselines/blob/master/benchmarks_mujoco1M.htm) | [Spinning Up (Tensorflow)](https://spinningup.openai.com/en/latest/spinningup/bench.html) | -| :--------------------: | :---------------: | :-------------------------------------------------: | :-----------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | -| Ant | **2866.7±707.9** | ~0 | N | N | ~150 | -| HalfCheetah | **4471.2±804.9** | ~400 | ~0 | ~1350 | ~850 | -| Hopper | 2046.0±1037.9 | ~1400 | ~2100 | **~2200** | ~1200 | -| Walker2d | **3826.7±782.7** | ~550 | ~1100 | ~2350 | ~600 | -| Swimmer | 40.9±19.6 | ~40 | **~121** | ~95 | ~85 | -| Humanoid | **810.1±126.1** | N | N | N | N | -| Reacher | **-5.1±0.8** | -8 | ~-115 | **~-5** | N | -| InvertedPendulum | **1000.0±0.0** | **~1000** | **~1000** | ~910 | N | -| InvertedDoublePendulum | **8435.2±1073.3** | ~800 | ~200 | ~7000 | N | +| :--------------------: | :---------------: | :-------------------------------------------------: | :-----------------------------------------------: | :-----------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------: | +| Ant | **2866.7±707.9** | ~0 | N | N | ~150 | +| HalfCheetah | **4471.2±804.9** | ~400 | ~0 | ~1350 | ~850 | +| Hopper | 2046.0±1037.9 | ~1400 | ~2100 | **~2200** | ~1200 | +| Walker2d | **3826.7±782.7** | ~550 | ~1100 | ~2350 | ~600 | +| Swimmer | 40.9±19.6 | ~40 | **~121** | ~95 | ~85 | +| Humanoid | **810.1±126.1** | N | N | N | N | +| Reacher | **-5.1±0.8** | -8 | ~-115 | **~-5** | N | +| InvertedPendulum | **1000.0±0.0** | **~1000** | **~1000** | ~910 | N | +| InvertedDoublePendulum | **8435.2±1073.3** | ~800 | ~200 | ~7000 | N | \* details[[4]](#footnote4)[[5]](#footnote5) #### Hints for TRPO + 1. We have tried `step-per-collect` in (80, 1024, 2048, 4096), and `training-num` in (4, 16, 32, 64), and found out 1024 for `step-per-collect` (same as OpenAI Baselines) and smaller `training-num` (below 16) are good choices. Set `training-num` to 4 is actually better but we still use 16 considering the boost of training speed. 2. Advantage normalization is a standard trick in TRPO, but we found it of minor help, just like in PPO. -3. Larger `optim-critic-iters` (than 5, as used in OpenAI Baselines) helps in most environments. Smaller lr and lr\_decay strategy also help a tiny little bit for performance. +3. Larger `optim-critic-iters` (than 5, as used in OpenAI Baselines) helps in most environments. Smaller lr and lr_decay strategy also help a tiny little bit for performance. 4. `gae-lambda` 0.98 and 0.95 work equally well. 5. We use GAE returns (GAE advantage + value) as the target of critic network when updating, while people usually tend to use reward to go (lambda = 0.) as target. We found that they work equally well although using GAE returns is a little bit inaccurate (biased) by math. 6. Empirically, Swimmer-v3 usually requires larger bootstrap lengths and learning rate. Humanoid-v3 and InvertedPendulum-v2, however, are on the opposite. @@ -302,33 +307,36 @@ For pretrained agents, detailed graphs (single agent, single game) and log detai \* details[[4]](#footnote4)[[5]](#footnote5) #### Hints for NPG + 1. All shared hyperparameters are exactly the same as TRPO, regarding how similar these two algorithms are. 2. We found different games in Mujoco may require quite different `actor-step-size`: Reacher/Swimmer are insensitive to step-size in range (0.1~1.0), while InvertedDoublePendulum / InvertedPendulum / Humanoid are quite sensitive to step size, and even 0.1 is too large. Other games may require `actor-step-size` in range (0.1~0.4), but aren't that sensitive in general. ## Others ### HER -| Environment | DDPG without HER | DDPG with HER | -| :--------------------: | :--------------: | :--------------: | -| FetchReach | -49.9±0.2. | **-17.6±21.7** | + +| Environment | DDPG without HER | DDPG with HER | +| :---------: | :--------------: | :------------: | +| FetchReach | -49.9±0.2. | **-17.6±21.7** | #### Hints for HER -1. The HER technique is proposed for solving task-based environments, so it cannot be compared with non-task-based mujoco benchmarks. The environment used in this evaluation is ``FetchReach-v3`` which requires an extra [installation](https://github.com/Farama-Foundation/Gymnasium-Robotics). -2. Simple hyperparameters optimizations are done for both settings, DDPG with and without HER. However, since *DDPG without HER* failed in every experiment, the best hyperparameters for *DDPG with HER* are used in the evaluation of both settings. -3. The scores are the mean reward ± 1 standard deviation of 16 seeds. The minimum reward for ``FetchReach-v3`` is -50 which we can imply that *DDPG without HER* performs as good as a random policy. *DDPG with HER* although has a better mean reward, the standard deviation is quite high. This is because in this setting, the agent will either fail completely (-50 reward) or successfully learn the task (close to 0 reward). This means that the agent successfully learned in about 70% of the 16 seeds. + +1. The HER technique is proposed for solving task-based environments, so it cannot be compared with non-task-based mujoco benchmarks. The environment used in this evaluation is `FetchReach-v3` which requires an extra [installation](https://github.com/Farama-Foundation/Gymnasium-Robotics). +2. Simple hyperparameters optimizations are done for both settings, DDPG with and without HER. However, since _DDPG without HER_ failed in every experiment, the best hyperparameters for _DDPG with HER_ are used in the evaluation of both settings. +3. The scores are the mean reward ± 1 standard deviation of 16 seeds. The minimum reward for `FetchReach-v3` is -50 which we can imply that _DDPG without HER_ performs as good as a random policy. _DDPG with HER_ although has a better mean reward, the standard deviation is quite high. This is because in this setting, the agent will either fail completely (-50 reward) or successfully learn the task (close to 0 reward). This means that the agent successfully learned in about 70% of the 16 seeds. ## Note -[1] Supported environments include HalfCheetah-v3, Hopper-v3, Swimmer-v3, Walker2d-v3, Ant-v3, Humanoid-v3, Reacher-v2, InvertedPendulum-v2 and InvertedDoublePendulum-v2. Pusher, Thrower, Striker and HumanoidStandup are not supported because they are not commonly seen in literatures. +[1] Supported environments include HalfCheetah-v3, Hopper-v3, Swimmer-v3, Walker2d-v3, Ant-v3, Humanoid-v3, Reacher-v2, InvertedPendulum-v2 and InvertedDoublePendulum-v2. Pusher, Thrower, Striker and HumanoidStandup are not supported because they are not commonly seen in literatures. -[2] Pretrained agents, detailed graphs (single agent, single game) and log details can all be found at [Google Drive](https://drive.google.com/drive/folders/1IycImzTmWcyEeD38viea5JHoboC4zmNP?usp=share_link). +[2] Pretrained agents, detailed graphs (single agent, single game) and log details can all be found at [Google Drive](https://drive.google.com/drive/folders/1IycImzTmWcyEeD38viea5JHoboC4zmNP?usp=share_link). -[3] We used the latest version of all mujoco environments in gym (0.17.3 with mujoco==2.0.2.13), but it's not often the case with other benchmarks. Please check for details yourself in the original paper. (Different version's outcomes are usually similar, though) +[3] We used the latest version of all mujoco environments in gym (0.17.3 with mujoco==2.0.2.13), but it's not often the case with other benchmarks. Please check for details yourself in the original paper. (Different version's outcomes are usually similar, though) -[4] ~ means the number is approximated from the graph because accurate numbers is not provided in the paper. N means graphs not provided. +[4] ~ means the number is approximated from the graph because accurate numbers is not provided in the paper. N means graphs not provided. -[5] Reward metric: The meaning of the table value is the max average return over 10 trails (different seeds) ± a single standard deviation over trails. Each trial is averaged on another 10 test seeds. Only the first 1M steps data will be considered, if not otherwise stated. The shaded region on the graph also represents a single standard deviation. It is the same as [TD3 evaluation method](https://github.com/sfujim/TD3/issues/34). +[5] Reward metric: The meaning of the table value is the max average return over 10 trails (different seeds) ± a single standard deviation over trails. Each trial is averaged on another 10 test seeds. Only the first 1M steps data will be considered, if not otherwise stated. The shaded region on the graph also represents a single standard deviation. It is the same as [TD3 evaluation method](https://github.com/sfujim/TD3/issues/34). -[6] In TD3 paper, shaded region represents only half of standard deviation. +[6] In TD3 paper, shaded region represents only half of standard deviation. -[7] Comparing Tianshou's REINFORCE algorithm with SpinningUp's VPG is quite unfair because SpinningUp's VPG uses a generative advantage estimator (GAE) which requires a dnn value predictor (critic network), which makes so called "VPG" more like A2C (advantage actor critic) algorithm. Even so, you can see that we are roughly at-parity with each other even if tianshou's REINFORCE do not use a critic or GAE. +[7] Comparing Tianshou's REINFORCE algorithm with SpinningUp's VPG is quite unfair because SpinningUp's VPG uses a generative advantage estimator (GAE) which requires a dnn value predictor (critic network), which makes so called "VPG" more like A2C (advantage actor critic) algorithm. Even so, you can see that we are roughly at-parity with each other even if tianshou's REINFORCE do not use a critic or GAE. diff --git a/examples/mujoco/analysis.py b/examples/mujoco/analysis.py index b881cdd34..3bd40f4ad 100755 --- a/examples/mujoco/analysis.py +++ b/examples/mujoco/analysis.py @@ -89,7 +89,7 @@ def numerical_analysis(root_dir: str | PathLike, xlim: float, norm: bool = False default=1000000, help="x-axis limitation (default: 1000000)", ) - parser.add_argument("--root-dir", type=str) + parser.add_argument("--root_dir", type=str) parser.add_argument( "--norm", action="store_true", diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index ee9b76e75..05fef72c6 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -10,7 +10,6 @@ import numpy as np import torch - from tianshou.data import ( Collector, CollectStats, @@ -19,15 +18,17 @@ ReplayBuffer, VectorReplayBuffer, ) -from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.env import ShmemVectorEnv, TruncatedAsTerminated +from tianshou.env.venvs import BaseVectorEnv from tianshou.exploration import GaussianNoise -from tianshou.policy import DDPGPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OffpolicyTrainer +from tianshou.highlevel.logger import LoggerFactoryDefault +from tianshou.algorithm import DDPG +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net, get_dict_state_decorator -from tianshou.utils.net.continuous import Actor, Critic -from tianshou.env.venvs import BaseVectorEnv +from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic from tianshou.utils.space_info import ActionSpaceInfo @@ -35,25 +36,25 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="FetchReach-v2") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=100000) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256]) - parser.add_argument("--actor-lr", type=float, default=1e-3) - parser.add_argument("--critic-lr", type=float, default=3e-3) + parser.add_argument("--buffer_size", type=int, default=100000) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[256, 256]) + parser.add_argument("--actor_lr", type=float, default=1e-3) + parser.add_argument("--critic_lr", type=float, default=3e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--tau", type=float, default=0.005) - parser.add_argument("--exploration-noise", type=float, default=0.1) - parser.add_argument("--start-timesteps", type=int, default=25000) + parser.add_argument("--exploration_noise", type=float, default=0.1) + parser.add_argument("--start_timesteps", type=int, default=25000) parser.add_argument("--epoch", type=int, default=10) - parser.add_argument("--step-per-epoch", type=int, default=5000) - parser.add_argument("--step-per-collect", type=int, default=1) - parser.add_argument("--update-per-step", type=int, default=1) - parser.add_argument("--n-step", type=int, default=1) - parser.add_argument("--batch-size", type=int, default=512) - parser.add_argument("--replay-buffer", type=str, default="her", choices=["normal", "her"]) - parser.add_argument("--her-horizon", type=int, default=50) - parser.add_argument("--her-future-k", type=int, default=8) - parser.add_argument("--training-num", type=int, default=1) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=5000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=1) + parser.add_argument("--update_per_step", type=int, default=1) + parser.add_argument("--n_step", type=int, default=1) + parser.add_argument("--batch_size", type=int, default=512) + parser.add_argument("--replay_buffer", type=str, default="her", choices=["normal", "her"]) + parser.add_argument("--her_horizon", type=int, default=50) + parser.add_argument("--her_future_k", type=int, default=8) + parser.add_argument("--num_train_envs", type=int, default=1) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -61,15 +62,15 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="HER-benchmark") + parser.add_argument("--wandb_project", type=str, default="HER-benchmark") parser.add_argument( "--watch", default=False, @@ -81,15 +82,15 @@ def get_args() -> argparse.Namespace: def make_fetch_env( task: str, - training_num: int, - test_num: int, + num_train_envs: int, + num_test_envs: int, ) -> tuple[gym.Env, BaseVectorEnv, BaseVectorEnv]: env = TruncatedAsTerminated(gym.make(task)) train_envs = ShmemVectorEnv( - [lambda: TruncatedAsTerminated(gym.make(task)) for _ in range(training_num)], + [lambda: TruncatedAsTerminated(gym.make(task)) for _ in range(num_train_envs)], ) test_envs = ShmemVectorEnv( - [lambda: TruncatedAsTerminated(gym.make(task)) for _ in range(test_num)], + [lambda: TruncatedAsTerminated(gym.make(task)) for _ in range(num_test_envs)], ) return env, train_envs, test_envs @@ -116,7 +117,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - env, train_envs, test_envs = make_fetch_env(args.task, args.training_num, args.test_num) + env, train_envs, test_envs = make_fetch_env(args.task, args.num_train_envs, args.num_test_envs) # The method HER works with goal-based environments if not isinstance(env.observation_space, gym.spaces.Dict): raise ValueError( @@ -153,13 +154,13 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: hidden_sizes=args.hidden_sizes, device=args.device, ) - actor = dict_state_dec(Actor)( + actor = dict_state_dec(ContinuousActorDeterministic)( net_a, args.action_shape, max_action=args.max_action, device=args.device, ).to(args.device) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c = dict_state_dec(Net)( flat_state_shape, action_shape=args.action_shape, @@ -167,23 +168,26 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: concat=True, device=args.device, ) - critic = dict_state_dec(Critic)(net_c, device=args.device).to(args.device) - critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) - policy: DDPGPolicy = DDPGPolicy( + critic = dict_state_dec(ContinuousCritic)(net_c, device=args.device).to(args.device) + critic_optim = AdamOptimizerFactory(lr=args.critic_lr) + policy = ContinuousDeterministicPolicy( actor=actor, - actor_optim=actor_optim, + exploration_noise=GaussianNoise(sigma=args.exploration_noise), + action_space=env.action_space, + ) + algorithm: DDPG = DDPG( + policy=policy, + policy_optim=actor_optim, critic=critic, critic_optim=critic_optim, tau=args.tau, gamma=args.gamma, - exploration_noise=GaussianNoise(sigma=args.exploration_noise), - estimation_step=args.n_step, - action_space=env.action_space, + n_step_return_horizon=args.n_step, ) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector @@ -192,12 +196,12 @@ def compute_reward_fn(ag: np.ndarray, g: np.ndarray) -> np.ndarray: buffer: VectorReplayBuffer | ReplayBuffer | HERReplayBuffer | HERVectorReplayBuffer if args.replay_buffer == "normal": - if args.training_num > 1: + if args.num_train_envs > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) else: - if args.training_num > 1: + if args.num_train_envs > 1: buffer = HERVectorReplayBuffer( args.buffer_size, len(train_envs), @@ -212,36 +216,37 @@ def compute_reward_fn(ag: np.ndarray, g: np.ndarray) -> np.ndarray: horizon=args.her_horizon, future_k=args.her_future_k, ) - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs) train_collector.reset() train_collector.collect(n_step=args.start_timesteps, random=True) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) if not args.watch: - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - test_in_train=False, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + save_best_fn=save_best_fn, + logger=logger, + update_step_num_gradient_steps_per_sample=args.update_per_step, + test_in_train=False, + ) + ) pprint.pprint(result) # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) collector_stats.pprint_asdict() diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index 817da079a..1abaa6123 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -10,41 +10,42 @@ from mujoco_env import make_mujoco_env from torch import nn from torch.distributions import Distribution, Independent, Normal -from torch.optim.lr_scheduler import LambdaLR +from tianshou.algorithm import A2C +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy +from tianshou.algorithm.optim import LRSchedulerFactoryLinear, RMSpropOptimizerFactory from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import A2CPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OnpolicyTrainer +from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import ActorCritic, Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Ant-v4") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=4096) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--buffer_size", type=int, default=4096) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--lr", type=float, default=7e-4) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=30000) - parser.add_argument("--step-per-collect", type=int, default=80) - parser.add_argument("--repeat-per-collect", type=int, default=1) + parser.add_argument("--epoch_num_steps", type=int, default=30000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=80) + parser.add_argument("--update_step_num_repetitions", type=int, default=1) # batch-size >> step-per-collect means calculating all data in one singe forward. - parser.add_argument("--batch-size", type=int, default=None) - parser.add_argument("--training-num", type=int, default=16) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=None) + parser.add_argument("--num_train_envs", type=int, default=16) + parser.add_argument("--num_test_envs", type=int, default=10) # a2c special - parser.add_argument("--rew-norm", type=int, default=True) - parser.add_argument("--vf-coef", type=float, default=0.5) - parser.add_argument("--ent-coef", type=float, default=0.01) - parser.add_argument("--gae-lambda", type=float, default=0.95) - parser.add_argument("--bound-action-method", type=str, default="clip") - parser.add_argument("--lr-decay", type=int, default=True) - parser.add_argument("--max-grad-norm", type=float, default=0.5) + parser.add_argument("--return_scaling", type=int, default=True) + parser.add_argument("--vf_coef", type=float, default=0.5) + parser.add_argument("--ent_coef", type=float, default=0.01) + parser.add_argument("--gae_lambda", type=float, default=0.95) + parser.add_argument("--bound_action_method", type=str, default="clip") + parser.add_argument("--lr_decay", type=int, default=True) + parser.add_argument("--max_grad_norm", type=float, default=0.5) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -52,15 +53,15 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="mujoco.benchmark") + parser.add_argument("--wandb_project", type=str, default="mujoco.benchmark") parser.add_argument( "--watch", default=False, @@ -70,12 +71,12 @@ def get_args() -> argparse.Namespace: return parser.parse_args() -def test_a2c(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, - args.training_num, - args.test_num, + args.num_train_envs, + args.num_test_envs, obs_norm=True, ) args.state_shape = env.observation_space.shape or env.observation_space.n @@ -89,24 +90,21 @@ def test_a2c(args: argparse.Namespace = get_args()) -> None: torch.manual_seed(args.seed) # model net_a = Net( - args.state_shape, + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, - device=args.device, ) - actor = ActorProb( - net_a, - args.action_shape, + actor = ContinuousActorProbabilistic( + preprocess_net=net_a, + action_shape=args.action_shape, unbounded=True, - device=args.device, ).to(args.device) net_c = Net( - args.state_shape, + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, - device=args.device, ) - critic = Critic(net_c, device=args.device).to(args.device) + critic = ContinuousCritic(preprocess_net=net_c).to(args.device) actor_critic = ActorCritic(actor, critic) torch.nn.init.constant_(actor.sigma_param, -0.5) @@ -123,57 +121,60 @@ def test_a2c(args: argparse.Namespace = get_args()) -> None: torch.nn.init.zeros_(m.bias) m.weight.data.copy_(0.01 * m.weight.data) - optim = torch.optim.RMSprop( - actor_critic.parameters(), + optim = RMSpropOptimizerFactory( lr=args.lr, eps=1e-5, alpha=0.99, ) - lr_scheduler = None if args.lr_decay: - # decay learning rate to 0 linearly - max_update_num = np.ceil(args.step_per_epoch / args.step_per_collect) * args.epoch - - lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + optim.with_lr_scheduler_factory( + LRSchedulerFactoryLinear( + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + ) + ) def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy: A2CPolicy = A2CPolicy( + policy = ProbabilisticActorPolicy( actor=actor, + dist_fn=dist, + action_scaling=True, + action_bound_method=args.bound_action_method, + action_space=env.action_space, + ) + algorithm: A2C = A2C( + policy=policy, critic=critic, optim=optim, - dist_fn=dist, - discount_factor=args.gamma, + gamma=args.gamma, gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm, vf_coef=args.vf_coef, ent_coef=args.ent_coef, - reward_normalization=args.rew_norm, - action_scaling=True, - action_bound_method=args.bound_action_method, - lr_scheduler=lr_scheduler, - action_space=env.action_space, + return_scaling=args.return_scaling, ) # load a previous policy if args.resume_path: ckpt = torch.load(args.resume_path, map_location=args.device) - policy.load_state_dict(ckpt["model"]) + algorithm.load_state_dict(ckpt["model"]) train_envs.set_obs_rms(ckpt["obs_rms"]) test_envs.set_obs_rms(ckpt["obs_rms"]) print("Loaded agent from: ", args.resume_path) # collector buffer: VectorReplayBuffer | ReplayBuffer - if args.training_num > 1: + if args.num_train_envs > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -196,34 +197,35 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: state = {"model": policy.state_dict(), "obs_rms": train_envs.get_obs_rms()} torch.save(state, os.path.join(log_path, "policy.pth")) if not args.watch: - # trainer - result = OnpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - step_per_collect=args.step_per_collect, - save_best_fn=save_best_fn, - logger=logger, - test_in_train=False, - ).run() + # train + result = algorithm.run_training( + OnPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + collection_step_num_env_steps=args.collection_step_num_env_steps, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=False, + ) + ) pprint.pprint(result) # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) if __name__ == "__main__": - test_a2c() + main() diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index bba7f9e76..6922a1209 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -9,14 +9,14 @@ from torch import nn from examples.mujoco.mujoco_env import MujocoEnvFactory -from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.config import OnPolicyTrainingConfig from tianshou.highlevel.experiment import ( A2CExperimentBuilder, ExperimentConfig, ) -from tianshou.highlevel.optim import OptimizerFactoryRMSprop -from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear -from tianshou.highlevel.params.policy_params import A2CParams +from tianshou.highlevel.params.algorithm_params import A2CParams +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear +from tianshou.highlevel.params.optim import OptimizerFactoryFactoryRMSprop def main( @@ -27,13 +27,13 @@ def main( lr: float = 7e-4, gamma: float = 0.99, epoch: int = 100, - step_per_epoch: int = 30000, - step_per_collect: int = 80, - repeat_per_collect: int = 1, + epoch_num_steps: int = 30000, + collection_step_num_env_steps: int = 80, + update_step_num_repetitions: int = 1, batch_size: int = 16, - training_num: int = 16, - test_num: int = 10, - rew_norm: bool = True, + num_train_envs: int = 16, + num_test_envs: int = 10, + return_scaling: bool = True, vf_coef: float = 0.5, ent_coef: float = 0.01, gae_lambda: float = 0.95, @@ -43,37 +43,35 @@ def main( ) -> None: log_name = os.path.join(task, "a2c", str(experiment_config.seed), datetime_tag()) - sampling_config = SamplingConfig( - num_epochs=epoch, - step_per_epoch=step_per_epoch, + training_config = OnPolicyTrainingConfig( + max_epochs=epoch, + epoch_num_steps=epoch_num_steps, batch_size=batch_size, - num_train_envs=training_num, - num_test_envs=test_num, + num_train_envs=num_train_envs, + num_test_envs=num_test_envs, buffer_size=buffer_size, - step_per_collect=step_per_collect, - repeat_per_collect=repeat_per_collect, + collection_step_num_env_steps=collection_step_num_env_steps, + update_step_num_repetitions=update_step_num_repetitions, ) env_factory = MujocoEnvFactory(task, obs_norm=True) experiment = ( - A2CExperimentBuilder(env_factory, experiment_config, sampling_config) + A2CExperimentBuilder(env_factory, experiment_config, training_config) .with_a2c_params( A2CParams( - discount_factor=gamma, + gamma=gamma, gae_lambda=gae_lambda, action_bound_method=bound_action_method, - reward_normalization=rew_norm, + return_scaling=return_scaling, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, + optim=OptimizerFactoryFactoryRMSprop(eps=1e-5, alpha=0.99), lr=lr, - lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config) - if lr_decay - else None, + lr_scheduler=LRSchedulerFactoryFactoryLinear(training_config) if lr_decay else None, ), ) - .with_optim_factory(OptimizerFactoryRMSprop(eps=1e-5, alpha=0.99)) .with_actor_factory_default(hidden_sizes, nn.Tanh, continuous_unbounded=True) .with_critic_factory_default(hidden_sizes, nn.Tanh) .build() diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index d85a14427..f3ae4e968 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -9,36 +9,38 @@ import torch from mujoco_env import make_mujoco_env +from tianshou.algorithm import DDPG +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.exploration import GaussianNoise from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import DDPGPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import Actor, Critic +from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Ant-v4") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=1000000) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256]) - parser.add_argument("--actor-lr", type=float, default=1e-3) - parser.add_argument("--critic-lr", type=float, default=1e-3) + parser.add_argument("--buffer_size", type=int, default=1000000) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[256, 256]) + parser.add_argument("--actor_lr", type=float, default=1e-3) + parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--tau", type=float, default=0.005) - parser.add_argument("--exploration-noise", type=float, default=0.1) - parser.add_argument("--start-timesteps", type=int, default=25000) + parser.add_argument("--exploration_noise", type=float, default=0.1) + parser.add_argument("--start_timesteps", type=int, default=25000) parser.add_argument("--epoch", type=int, default=200) - parser.add_argument("--step-per-epoch", type=int, default=5000) - parser.add_argument("--step-per-collect", type=int, default=1) - parser.add_argument("--update-per-step", type=int, default=1) - parser.add_argument("--n-step", type=int, default=1) - parser.add_argument("--batch-size", type=int, default=256) - parser.add_argument("--training-num", type=int, default=1) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=5000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=1) + parser.add_argument("--update_per_step", type=int, default=1) + parser.add_argument("--n_step", type=int, default=1) + parser.add_argument("--batch_size", type=int, default=256) + parser.add_argument("--num_train_envs", type=int, default=1) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -46,15 +48,15 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="mujoco.benchmark") + parser.add_argument("--wandb_project", type=str, default="mujoco.benchmark") parser.add_argument( "--watch", default=False, @@ -64,12 +66,12 @@ def get_args() -> argparse.Namespace: return parser.parse_args() -def test_ddpg(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, - args.training_num, - args.test_num, + args.num_train_envs, + args.num_test_envs, obs_norm=False, ) args.state_shape = env.observation_space.shape or env.observation_space.n @@ -83,45 +85,49 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = Actor(net_a, args.action_shape, max_action=args.max_action, device=args.device).to( + net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + actor = ContinuousActorDeterministic( + preprocess_net=net_a, action_shape=args.action_shape, max_action=args.max_action + ).to( args.device, ) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic = Critic(net_c, device=args.device).to(args.device) - critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) - policy: DDPGPolicy = DDPGPolicy( + critic = ContinuousCritic(preprocess_net=net_c).to(args.device) + critic_optim = AdamOptimizerFactory(lr=args.critic_lr) + policy = ContinuousDeterministicPolicy( actor=actor, - actor_optim=actor_optim, + exploration_noise=GaussianNoise(sigma=args.exploration_noise), + action_space=env.action_space, + ) + algorithm: DDPG = DDPG( + policy=policy, + policy_optim=actor_optim, critic=critic, critic_optim=critic_optim, tau=args.tau, gamma=args.gamma, - exploration_noise=GaussianNoise(sigma=args.exploration_noise), - estimation_step=args.n_step, - action_space=env.action_space, + n_step_return_horizon=args.n_step, ) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector buffer: VectorReplayBuffer | ReplayBuffer - if args.training_num > 1: + if args.num_train_envs > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs) train_collector.reset() train_collector.collect(n_step=args.start_timesteps, random=True) @@ -146,33 +152,34 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) if not args.watch: - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - test_in_train=False, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + save_best_fn=save_best_fn, + logger=logger, + update_step_num_gradient_steps_per_sample=args.update_per_step, + test_in_train=False, + ) + ) pprint.pprint(result) # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) if __name__ == "__main__": - test_ddpg() + main() diff --git a/examples/mujoco/mujoco_ddpg_hl.py b/examples/mujoco/mujoco_ddpg_hl.py index daa936533..414faa145 100644 --- a/examples/mujoco/mujoco_ddpg_hl.py +++ b/examples/mujoco/mujoco_ddpg_hl.py @@ -7,13 +7,13 @@ from sensai.util.logging import datetime_tag from examples.mujoco.mujoco_env import MujocoEnvFactory -from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.config import OffPolicyTrainingConfig from tianshou.highlevel.experiment import ( DDPGExperimentBuilder, ExperimentConfig, ) +from tianshou.highlevel.params.algorithm_params import DDPGParams from tianshou.highlevel.params.noise import MaxActionScaledGaussian -from tianshou.highlevel.params.policy_params import DDPGParams def main( @@ -28,26 +28,25 @@ def main( exploration_noise: float = 0.1, start_timesteps: int = 25000, epoch: int = 200, - step_per_epoch: int = 5000, - step_per_collect: int = 1, + epoch_num_steps: int = 5000, + collection_step_num_env_steps: int = 1, update_per_step: int = 1, n_step: int = 1, batch_size: int = 256, - training_num: int = 1, - test_num: int = 10, + num_train_envs: int = 1, + num_test_envs: int = 10, ) -> None: log_name = os.path.join(task, "ddpg", str(experiment_config.seed), datetime_tag()) - sampling_config = SamplingConfig( - num_epochs=epoch, - step_per_epoch=step_per_epoch, + training_config = OffPolicyTrainingConfig( + max_epochs=epoch, + epoch_num_steps=epoch_num_steps, batch_size=batch_size, - num_train_envs=training_num, - num_test_envs=test_num, + num_train_envs=num_train_envs, + num_test_envs=num_test_envs, buffer_size=buffer_size, - step_per_collect=step_per_collect, - update_per_step=update_per_step, - repeat_per_collect=None, + collection_step_num_env_steps=collection_step_num_env_steps, + update_step_num_gradient_steps_per_sample=update_per_step, start_timesteps=start_timesteps, start_timesteps_random=True, ) @@ -55,7 +54,7 @@ def main( env_factory = MujocoEnvFactory(task, obs_norm=False) experiment = ( - DDPGExperimentBuilder(env_factory, experiment_config, sampling_config) + DDPGExperimentBuilder(env_factory, experiment_config, training_config) .with_ddpg_params( DDPGParams( actor_lr=actor_lr, @@ -63,7 +62,7 @@ def main( gamma=gamma, tau=tau, exploration_noise=MaxActionScaledGaussian(exploration_noise), - estimation_step=n_step, + n_step_return_horizon=n_step, ), ) .with_actor_factory_default(hidden_sizes) diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index 9416376a1..6beac506d 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -10,24 +10,25 @@ from mujoco_env import make_mujoco_env from torch import nn from torch.distributions import Distribution, Independent, Normal -from torch.optim.lr_scheduler import LambdaLR +from tianshou.algorithm import NPG +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import NPGPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OnpolicyTrainer +from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Ant-v4") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=4096) + parser.add_argument("--buffer_size", type=int, default=4096) parser.add_argument( - "--hidden-sizes", + "--hidden_sizes", type=int, nargs="*", default=[64, 64], @@ -35,37 +36,37 @@ def get_args() -> argparse.Namespace: parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=30000) - parser.add_argument("--step-per-collect", type=int, default=1024) - parser.add_argument("--repeat-per-collect", type=int, default=1) + parser.add_argument("--epoch_num_steps", type=int, default=30000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=1024) + parser.add_argument("--update_step_num_repetitions", type=int, default=1) # batch-size >> step-per-collect means calculating all data in one singe forward. - parser.add_argument("--batch-size", type=int, default=None) - parser.add_argument("--training-num", type=int, default=16) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=None) + parser.add_argument("--num_train_envs", type=int, default=16) + parser.add_argument("--num_test_envs", type=int, default=10) # npg special - parser.add_argument("--rew-norm", type=int, default=True) - parser.add_argument("--gae-lambda", type=float, default=0.95) - parser.add_argument("--bound-action-method", type=str, default="clip") - parser.add_argument("--lr-decay", type=int, default=True) + parser.add_argument("--return_scaling", type=int, default=True) + parser.add_argument("--gae_lambda", type=float, default=0.95) + parser.add_argument("--bound_action_method", type=str, default="clip") + parser.add_argument("--lr_decay", type=int, default=True) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--norm-adv", type=int, default=1) - parser.add_argument("--optim-critic-iters", type=int, default=20) - parser.add_argument("--actor-step-size", type=float, default=0.1) + parser.add_argument("--advantage_normalization", type=int, default=1) + parser.add_argument("--optim_critic_iters", type=int, default=20) + parser.add_argument("--trust_region_size", type=float, default=0.1) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="mujoco.benchmark") + parser.add_argument("--wandb_project", type=str, default="mujoco.benchmark") parser.add_argument( "--watch", default=False, @@ -75,12 +76,12 @@ def get_args() -> argparse.Namespace: return parser.parse_args() -def test_npg(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, - args.training_num, - args.test_num, + args.num_train_envs, + args.num_test_envs, obs_norm=True, ) args.state_shape = env.observation_space.shape or env.observation_space.n @@ -94,24 +95,21 @@ def test_npg(args: argparse.Namespace = get_args()) -> None: torch.manual_seed(args.seed) # model net_a = Net( - args.state_shape, + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, - device=args.device, ) - actor = ActorProb( - net_a, - args.action_shape, + actor = ContinuousActorProbabilistic( + preprocess_net=net_a, + action_shape=args.action_shape, unbounded=True, - device=args.device, ).to(args.device) net_c = Net( - args.state_shape, + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, - device=args.device, ) - critic = Critic(net_c, device=args.device).to(args.device) + critic = ContinuousCritic(preprocess_net=net_c).to(args.device) torch.nn.init.constant_(actor.sigma_param, -0.5) for m in list(actor.modules()) + list(critic.modules()): if isinstance(m, torch.nn.Linear): @@ -126,51 +124,55 @@ def test_npg(args: argparse.Namespace = get_args()) -> None: torch.nn.init.zeros_(m.bias) m.weight.data.copy_(0.01 * m.weight.data) - optim = torch.optim.Adam(critic.parameters(), lr=args.lr) - lr_scheduler = None + optim = AdamOptimizerFactory(lr=args.lr) if args.lr_decay: - # decay learning rate to 0 linearly - max_update_num = np.ceil(args.step_per_epoch / args.step_per_collect) * args.epoch - - lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + optim.with_lr_scheduler_factory( + LRSchedulerFactoryLinear( + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + ) + ) def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy: NPGPolicy = NPGPolicy( + policy = ProbabilisticActorPolicy( actor=actor, - critic=critic, - optim=optim, dist_fn=dist, - discount_factor=args.gamma, - gae_lambda=args.gae_lambda, - reward_normalization=args.rew_norm, action_scaling=True, action_bound_method=args.bound_action_method, - lr_scheduler=lr_scheduler, action_space=env.action_space, - advantage_normalization=args.norm_adv, + ) + algorithm: NPG = NPG( + policy=policy, + critic=critic, + optim=optim, + gamma=args.gamma, + gae_lambda=args.gae_lambda, + return_scaling=args.return_scaling, + advantage_normalization=args.advantage_normalization, optim_critic_iters=args.optim_critic_iters, - actor_step_size=args.actor_step_size, + trust_region_size=args.trust_region_size, ) # load a previous policy if args.resume_path: ckpt = torch.load(args.resume_path, map_location=args.device) - policy.load_state_dict(ckpt["model"]) + algorithm.load_state_dict(ckpt["model"]) train_envs.set_obs_rms(ckpt["obs_rms"]) test_envs.set_obs_rms(ckpt["obs_rms"]) print("Loaded agent from: ", args.resume_path) # collector buffer: VectorReplayBuffer | ReplayBuffer - if args.training_num > 1: + if args.num_train_envs > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -193,34 +195,35 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: state = {"model": policy.state_dict(), "obs_rms": train_envs.get_obs_rms()} torch.save(state, os.path.join(log_path, "policy.pth")) if not args.watch: - # trainer - result = OnpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - step_per_collect=args.step_per_collect, - save_best_fn=save_best_fn, - logger=logger, - test_in_train=False, - ).run() + # train + result = algorithm.run_training( + OnPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + collection_step_num_env_steps=args.collection_step_num_env_steps, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=False, + ) + ) pprint.pprint(result) # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) if __name__ == "__main__": - test_npg() + main() diff --git a/examples/mujoco/mujoco_npg_hl.py b/examples/mujoco/mujoco_npg_hl.py index a231e1b21..93dcfcc42 100644 --- a/examples/mujoco/mujoco_npg_hl.py +++ b/examples/mujoco/mujoco_npg_hl.py @@ -9,13 +9,13 @@ from sensai.util.logging import datetime_tag from examples.mujoco.mujoco_env import MujocoEnvFactory -from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.config import OnPolicyTrainingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, NPGExperimentBuilder, ) -from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear -from tianshou.highlevel.params.policy_params import NPGParams +from tianshou.highlevel.params.algorithm_params import NPGParams +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear def main( @@ -26,50 +26,48 @@ def main( lr: float = 1e-3, gamma: float = 0.99, epoch: int = 100, - step_per_epoch: int = 30000, - step_per_collect: int = 1024, - repeat_per_collect: int = 1, + epoch_num_steps: int = 30000, + collection_step_num_env_steps: int = 1024, + update_step_num_repetitions: int = 1, batch_size: int = 16, - training_num: int = 16, - test_num: int = 10, - rew_norm: bool = True, + num_train_envs: int = 16, + num_test_envs: int = 10, + return_scaling: bool = True, gae_lambda: float = 0.95, bound_action_method: Literal["clip", "tanh"] = "clip", lr_decay: bool = True, - norm_adv: bool = True, + advantage_normalization: bool = True, optim_critic_iters: int = 20, - actor_step_size: float = 0.1, + trust_region_size: float = 0.1, ) -> None: log_name = os.path.join(task, "npg", str(experiment_config.seed), datetime_tag()) - sampling_config = SamplingConfig( - num_epochs=epoch, - step_per_epoch=step_per_epoch, + training_config = OnPolicyTrainingConfig( + max_epochs=epoch, + epoch_num_steps=epoch_num_steps, batch_size=batch_size, - num_train_envs=training_num, - num_test_envs=test_num, + num_train_envs=num_train_envs, + num_test_envs=num_test_envs, buffer_size=buffer_size, - step_per_collect=step_per_collect, - repeat_per_collect=repeat_per_collect, + collection_step_num_env_steps=collection_step_num_env_steps, + update_step_num_repetitions=update_step_num_repetitions, ) env_factory = MujocoEnvFactory(task, obs_norm=True) experiment = ( - NPGExperimentBuilder(env_factory, experiment_config, sampling_config) + NPGExperimentBuilder(env_factory, experiment_config, training_config) .with_npg_params( NPGParams( - discount_factor=gamma, + gamma=gamma, gae_lambda=gae_lambda, action_bound_method=bound_action_method, - reward_normalization=rew_norm, - advantage_normalization=norm_adv, + return_scaling=return_scaling, + advantage_normalization=advantage_normalization, optim_critic_iters=optim_critic_iters, - actor_step_size=actor_step_size, + trust_region_size=trust_region_size, lr=lr, - lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config) - if lr_decay - else None, + lr_scheduler=LRSchedulerFactoryFactoryLinear(training_config) if lr_decay else None, ), ) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index 965ec7739..06b0ac904 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -10,46 +10,47 @@ from mujoco_env import make_mujoco_env from torch import nn from torch.distributions import Distribution, Independent, Normal -from torch.optim.lr_scheduler import LambdaLR +from tianshou.algorithm import PPO +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import PPOPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OnpolicyTrainer +from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import ActorCritic, Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Ant-v4") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=4096) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--buffer_size", type=int, default=4096) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=30000) - parser.add_argument("--step-per-collect", type=int, default=2048) - parser.add_argument("--repeat-per-collect", type=int, default=10) - parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument("--training-num", type=int, default=8) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=30000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=2048) + parser.add_argument("--update_step_num_repetitions", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--num_train_envs", type=int, default=8) + parser.add_argument("--num_test_envs", type=int, default=10) # ppo special - parser.add_argument("--rew-norm", type=int, default=True) + parser.add_argument("--return_scaling", type=int, default=True) # In theory, `vf-coef` will not make any difference if using Adam optimizer. - parser.add_argument("--vf-coef", type=float, default=0.25) - parser.add_argument("--ent-coef", type=float, default=0.0) - parser.add_argument("--gae-lambda", type=float, default=0.95) - parser.add_argument("--bound-action-method", type=str, default="clip") - parser.add_argument("--lr-decay", type=int, default=True) - parser.add_argument("--max-grad-norm", type=float, default=0.5) - parser.add_argument("--eps-clip", type=float, default=0.2) - parser.add_argument("--dual-clip", type=float, default=None) - parser.add_argument("--value-clip", type=int, default=0) - parser.add_argument("--norm-adv", type=int, default=0) - parser.add_argument("--recompute-adv", type=int, default=1) + parser.add_argument("--vf_coef", type=float, default=0.25) + parser.add_argument("--ent_coef", type=float, default=0.0) + parser.add_argument("--gae_lambda", type=float, default=0.95) + parser.add_argument("--bound_action_method", type=str, default="clip") + parser.add_argument("--lr_decay", type=int, default=True) + parser.add_argument("--max_grad_norm", type=float, default=0.5) + parser.add_argument("--eps_clip", type=float, default=0.2) + parser.add_argument("--dual_clip", type=float, default=None) + parser.add_argument("--value_clip", type=int, default=0) + parser.add_argument("--advantage_normalization", type=int, default=0) + parser.add_argument("--recompute_adv", type=int, default=1) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -57,15 +58,15 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="mujoco.benchmark") + parser.add_argument("--wandb_project", type=str, default="mujoco.benchmark") parser.add_argument( "--watch", default=False, @@ -75,12 +76,12 @@ def get_args() -> argparse.Namespace: return parser.parse_args() -def test_ppo(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, - args.training_num, - args.test_num, + args.num_train_envs, + args.num_test_envs, obs_norm=True, ) args.state_shape = env.observation_space.shape or env.observation_space.n @@ -94,24 +95,21 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: torch.manual_seed(args.seed) # model net_a = Net( - args.state_shape, + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, - device=args.device, ) - actor = ActorProb( - net_a, - args.action_shape, + actor = ContinuousActorProbabilistic( + preprocess_net=net_a, + action_shape=args.action_shape, unbounded=True, - device=args.device, ).to(args.device) net_c = Net( - args.state_shape, + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, - device=args.device, ) - critic = Critic(net_c, device=args.device).to(args.device) + critic = ContinuousCritic(preprocess_net=net_c).to(args.device) actor_critic = ActorCritic(actor, critic) torch.nn.init.constant_(actor.sigma_param, -0.5) @@ -128,57 +126,61 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: torch.nn.init.zeros_(m.bias) m.weight.data.copy_(0.01 * m.weight.data) - optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) - lr_scheduler = None if args.lr_decay: - # decay learning rate to 0 linearly - max_update_num = np.ceil(args.step_per_epoch / args.step_per_collect) * args.epoch - - lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + optim.with_lr_scheduler_factory( + LRSchedulerFactoryLinear( + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + ) + ) def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy: PPOPolicy = PPOPolicy( + policy = ProbabilisticActorPolicy( actor=actor, + dist_fn=dist, + action_scaling=True, + action_bound_method=args.bound_action_method, + action_space=env.action_space, + ) + algorithm: PPO = PPO( + policy=policy, critic=critic, optim=optim, - dist_fn=dist, - discount_factor=args.gamma, + gamma=args.gamma, gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm, vf_coef=args.vf_coef, ent_coef=args.ent_coef, - reward_normalization=args.rew_norm, - action_scaling=True, - action_bound_method=args.bound_action_method, - lr_scheduler=lr_scheduler, - action_space=env.action_space, + return_scaling=args.return_scaling, eps_clip=args.eps_clip, value_clip=args.value_clip, dual_clip=args.dual_clip, - advantage_normalization=args.norm_adv, + advantage_normalization=args.advantage_normalization, recompute_advantage=args.recompute_adv, ) # load a previous policy if args.resume_path: ckpt = torch.load(args.resume_path, map_location=args.device) - policy.load_state_dict(ckpt["model"]) + algorithm.load_state_dict(ckpt["model"]) train_envs.set_obs_rms(ckpt["obs_rms"]) test_envs.set_obs_rms(ckpt["obs_rms"]) print("Loaded agent from: ", args.resume_path) # collector buffer: VectorReplayBuffer | ReplayBuffer - if args.training_num > 1: + if args.num_train_envs > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -201,34 +203,35 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: state = {"model": policy.state_dict(), "obs_rms": train_envs.get_obs_rms()} torch.save(state, os.path.join(log_path, "policy.pth")) if not args.watch: - # trainer - result = OnpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - step_per_collect=args.step_per_collect, - save_best_fn=save_best_fn, - logger=logger, - test_in_train=False, - ).run() + # train + result = algorithm.run_training( + OnPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + collection_step_num_env_steps=args.collection_step_num_env_steps, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=False, + ) + ) pprint.pprint(result) # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) if __name__ == "__main__": - test_ppo() + main() diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index 973a822a6..d402044cd 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -9,13 +9,13 @@ from sensai.util.logging import datetime_tag from examples.mujoco.mujoco_env import MujocoEnvFactory -from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.config import OnPolicyTrainingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, PPOExperimentBuilder, ) -from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear -from tianshou.highlevel.params.policy_params import PPOParams +from tianshou.highlevel.params.algorithm_params import PPOParams +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear def main( @@ -26,13 +26,13 @@ def main( lr: float = 3e-4, gamma: float = 0.99, epoch: int = 100, - step_per_epoch: int = 30000, - step_per_collect: int = 2048, - repeat_per_collect: int = 10, + epoch_num_steps: int = 30000, + collection_step_num_env_steps: int = 2048, + update_step_num_repetitions: int = 10, batch_size: int = 64, - training_num: int = 10, - test_num: int = 10, - rew_norm: bool = True, + num_train_envs: int = 10, + num_test_envs: int = 10, + return_scaling: bool = True, vf_coef: float = 0.25, ent_coef: float = 0.0, gae_lambda: float = 0.95, @@ -42,44 +42,42 @@ def main( eps_clip: float = 0.2, dual_clip: float | None = None, value_clip: bool = False, - norm_adv: bool = False, + advantage_normalization: bool = False, recompute_adv: bool = True, ) -> None: log_name = os.path.join(task, "ppo", str(experiment_config.seed), datetime_tag()) - sampling_config = SamplingConfig( - num_epochs=epoch, - step_per_epoch=step_per_epoch, + training_config = OnPolicyTrainingConfig( + max_epochs=epoch, + epoch_num_steps=epoch_num_steps, batch_size=batch_size, - num_train_envs=training_num, - num_test_envs=test_num, + num_train_envs=num_train_envs, + num_test_envs=num_test_envs, buffer_size=buffer_size, - step_per_collect=step_per_collect, - repeat_per_collect=repeat_per_collect, + collection_step_num_env_steps=collection_step_num_env_steps, + update_step_num_repetitions=update_step_num_repetitions, ) env_factory = MujocoEnvFactory(task, obs_norm=True) experiment = ( - PPOExperimentBuilder(env_factory, experiment_config, sampling_config) + PPOExperimentBuilder(env_factory, experiment_config, training_config) .with_ppo_params( PPOParams( - discount_factor=gamma, + gamma=gamma, gae_lambda=gae_lambda, action_bound_method=bound_action_method, - reward_normalization=rew_norm, + return_scaling=return_scaling, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, value_clip=value_clip, - advantage_normalization=norm_adv, + advantage_normalization=advantage_normalization, eps_clip=eps_clip, dual_clip=dual_clip, recompute_advantage=recompute_adv, lr=lr, - lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config) - if lr_decay - else None, + lr_scheduler=LRSchedulerFactoryFactoryLinear(training_config) if lr_decay else None, ), ) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) diff --git a/examples/mujoco/mujoco_ppo_hl_multi.py b/examples/mujoco/mujoco_ppo_hl_multi.py index 47ccc9ae2..2efc462cd 100644 --- a/examples/mujoco/mujoco_ppo_hl_multi.py +++ b/examples/mujoco/mujoco_ppo_hl_multi.py @@ -22,14 +22,14 @@ from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.evaluation.launcher import RegisteredExpLauncher from tianshou.evaluation.rliable_evaluation_hl import RLiableExperimentResult -from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.config import OnPolicyTrainingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, PPOExperimentBuilder, ) from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear -from tianshou.highlevel.params.policy_params import PPOParams +from tianshou.highlevel.params.algorithm_params import PPOParams +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear log = logging.getLogger(__name__) @@ -58,16 +58,16 @@ def main( experiment_config = ExperimentConfig(persistence_base_dir=persistence_dir, watch=False) - sampling_config = SamplingConfig( - num_epochs=1, - step_per_epoch=5000, + training_config = OnPolicyTrainingConfig( + max_epochs=1, + epoch_num_steps=5000, batch_size=64, num_train_envs=5, num_test_envs=5, - num_test_episodes=5, + test_step_num_episodes=5, buffer_size=4096, - step_per_collect=2048, - repeat_per_collect=1, + collection_step_num_env_steps=2048, + update_step_num_repetitions=1, ) env_factory = MujocoEnvFactory(task, obs_norm=True) @@ -90,13 +90,13 @@ def main( raise ValueError(f"Unknown logger type: {logger_type}") experiment_collection = ( - PPOExperimentBuilder(env_factory, experiment_config, sampling_config) + PPOExperimentBuilder(env_factory, experiment_config, training_config) .with_ppo_params( PPOParams( - discount_factor=0.99, + gamma=0.99, gae_lambda=0.95, action_bound_method="clip", - reward_normalization=True, + return_scaling=True, ent_coef=0.0, vf_coef=0.25, max_grad_norm=0.5, @@ -106,7 +106,7 @@ def main( dual_clip=None, recompute_advantage=True, lr=3e-4, - lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config), + lr_scheduler=LRSchedulerFactoryFactoryLinear(training_config), ), ) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index 61d85ae1c..91f1c56ae 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -9,40 +9,43 @@ import torch from mujoco_env import make_mujoco_env +from tianshou.algorithm import REDQ +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.redq import REDQPolicy +from tianshou.algorithm.modelfree.sac import AutoAlpha +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import REDQPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import EnsembleLinear, Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Ant-v4") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=1000000) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256]) - parser.add_argument("--ensemble-size", type=int, default=10) - parser.add_argument("--subset-size", type=int, default=2) - parser.add_argument("--actor-lr", type=float, default=1e-3) - parser.add_argument("--critic-lr", type=float, default=1e-3) + parser.add_argument("--buffer_size", type=int, default=1000000) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[256, 256]) + parser.add_argument("--ensemble_size", type=int, default=10) + parser.add_argument("--subset_size", type=int, default=2) + parser.add_argument("--actor_lr", type=float, default=1e-3) + parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--alpha", type=float, default=0.2) - parser.add_argument("--auto-alpha", default=False, action="store_true") - parser.add_argument("--alpha-lr", type=float, default=3e-4) - parser.add_argument("--start-timesteps", type=int, default=10000) + parser.add_argument("--auto_alpha", default=False, action="store_true") + parser.add_argument("--alpha_lr", type=float, default=3e-4) + parser.add_argument("--start_timesteps", type=int, default=10000) parser.add_argument("--epoch", type=int, default=200) - parser.add_argument("--step-per-epoch", type=int, default=5000) - parser.add_argument("--step-per-collect", type=int, default=1) - parser.add_argument("--update-per-step", type=int, default=20) - parser.add_argument("--n-step", type=int, default=1) - parser.add_argument("--batch-size", type=int, default=256) - parser.add_argument("--target-mode", type=str, choices=("min", "mean"), default="min") - parser.add_argument("--training-num", type=int, default=1) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=5000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=1) + parser.add_argument("--update_per_step", type=int, default=20) + parser.add_argument("--n_step", type=int, default=1) + parser.add_argument("--batch_size", type=int, default=256) + parser.add_argument("--target_mode", type=str, choices=("min", "mean"), default="min") + parser.add_argument("--num_train_envs", type=int, default=1) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -50,15 +53,15 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="mujoco.benchmark") + parser.add_argument("--wandb_project", type=str, default="mujoco.benchmark") parser.add_argument( "--watch", default=False, @@ -68,12 +71,12 @@ def get_args() -> argparse.Namespace: return parser.parse_args() -def test_redq(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, - args.training_num, - args.test_num, + args.num_train_envs, + args.num_test_envs, obs_norm=False, ) args.state_shape = env.observation_space.shape or env.observation_space.n @@ -86,15 +89,14 @@ def test_redq(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = ActorProb( - net_a, - args.action_shape, - device=args.device, + net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + actor = ContinuousActorProbabilistic( + preprocess_net=net_a, + action_shape=args.action_shape, unbounded=True, conditioned_sigma=True, ).to(args.device) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) def linear(x: int, y: int) -> EnsembleLinear: return EnsembleLinear(args.ensemble_size, x, y) @@ -104,26 +106,28 @@ def linear(x: int, y: int) -> EnsembleLinear: action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, linear_layer=linear, ) - critics = Critic( - net_c, - device=args.device, + critics = ContinuousCritic( + preprocess_net=net_c, linear_layer=linear, flatten_input=False, ).to(args.device) - critics_optim = torch.optim.Adam(critics.parameters(), lr=args.critic_lr) + critics_optim = AdamOptimizerFactory(lr=args.critic_lr) if args.auto_alpha: target_entropy = -np.prod(env.action_space.shape) - log_alpha = torch.zeros(1, requires_grad=True, device=args.device) - alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) - args.alpha = (target_entropy, log_alpha, alpha_optim) + log_alpha = 0.0 + alpha_optim = AdamOptimizerFactory(lr=args.alpha_lr) + args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim).to(args.device) - policy: REDQPolicy = REDQPolicy( + policy = REDQPolicy( actor=actor, - actor_optim=actor_optim, + action_space=env.action_space, + ) + algorithm: REDQ = REDQ( + policy=policy, + policy_optim=actor_optim, critic=critics, critic_optim=critics_optim, ensemble_size=args.ensemble_size, @@ -131,25 +135,24 @@ def linear(x: int, y: int) -> EnsembleLinear: tau=args.tau, gamma=args.gamma, alpha=args.alpha, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, actor_delay=args.update_per_step, target_mode=args.target_mode, - action_space=env.action_space, ) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector buffer: VectorReplayBuffer | ReplayBuffer - if args.training_num > 1: + if args.num_train_envs > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs) train_collector.reset() train_collector.collect(n_step=args.start_timesteps, random=True) @@ -174,33 +177,34 @@ def linear(x: int, y: int) -> EnsembleLinear: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) if not args.watch: - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - test_in_train=False, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + save_best_fn=save_best_fn, + logger=logger, + update_step_num_gradient_steps_per_sample=args.update_per_step, + test_in_train=False, + ) + ) pprint.pprint(result) # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) if __name__ == "__main__": - test_redq() + main() diff --git a/examples/mujoco/mujoco_redq_hl.py b/examples/mujoco/mujoco_redq_hl.py index 90f6ef318..deb7270da 100644 --- a/examples/mujoco/mujoco_redq_hl.py +++ b/examples/mujoco/mujoco_redq_hl.py @@ -8,13 +8,13 @@ from sensai.util.logging import datetime_tag from examples.mujoco.mujoco_env import MujocoEnvFactory -from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.config import OffPolicyTrainingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, REDQExperimentBuilder, ) +from tianshou.highlevel.params.algorithm_params import REDQParams from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault -from tianshou.highlevel.params.policy_params import REDQParams def main( @@ -33,27 +33,26 @@ def main( alpha_lr: float = 3e-4, start_timesteps: int = 10000, epoch: int = 200, - step_per_epoch: int = 5000, - step_per_collect: int = 1, + epoch_num_steps: int = 5000, + collection_step_num_env_steps: int = 1, update_per_step: int = 20, n_step: int = 1, batch_size: int = 256, target_mode: Literal["mean", "min"] = "min", - training_num: int = 1, - test_num: int = 10, + num_train_envs: int = 1, + num_test_envs: int = 10, ) -> None: log_name = os.path.join(task, "redq", str(experiment_config.seed), datetime_tag()) - sampling_config = SamplingConfig( - num_epochs=epoch, - step_per_epoch=step_per_epoch, + training_config = OffPolicyTrainingConfig( + max_epochs=epoch, + epoch_num_steps=epoch_num_steps, batch_size=batch_size, - num_train_envs=training_num, - num_test_envs=test_num, + num_train_envs=num_train_envs, + num_test_envs=num_test_envs, buffer_size=buffer_size, - step_per_collect=step_per_collect, - update_per_step=update_per_step, - repeat_per_collect=None, + collection_step_num_env_steps=collection_step_num_env_steps, + update_step_num_gradient_steps_per_sample=update_per_step, start_timesteps=start_timesteps, start_timesteps_random=True, ) @@ -61,7 +60,7 @@ def main( env_factory = MujocoEnvFactory(task, obs_norm=False) experiment = ( - REDQExperimentBuilder(env_factory, experiment_config, sampling_config) + REDQExperimentBuilder(env_factory, experiment_config, training_config) .with_redq_params( REDQParams( actor_lr=actor_lr, @@ -69,7 +68,7 @@ def main( gamma=gamma, tau=tau, alpha=AutoAlphaFactoryDefault(lr=alpha_lr) if auto_alpha else alpha, - estimation_step=n_step, + n_step_return_horizon=n_step, target_mode=target_mode, subset_size=subset_size, ensemble_size=ensemble_size, diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index 95391e1ea..81f3e527f 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -10,38 +10,39 @@ from mujoco_env import make_mujoco_env from torch import nn from torch.distributions import Distribution, Independent, Normal -from torch.optim.lr_scheduler import LambdaLR +from tianshou.algorithm import Reinforce +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import PGPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OnpolicyTrainer +from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorProb +from tianshou.utils.net.continuous import ContinuousActorProbabilistic def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Ant-v4") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=4096) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--buffer_size", type=int, default=4096) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=30000) - parser.add_argument("--step-per-collect", type=int, default=2048) - parser.add_argument("--repeat-per-collect", type=int, default=1) + parser.add_argument("--epoch_num_steps", type=int, default=30000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=2048) + parser.add_argument("--update_step_num_repetitions", type=int, default=1) # batch-size >> step-per-collect means calculating all data in one singe forward. - parser.add_argument("--batch-size", type=int, default=None) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=None) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) # reinforce special - parser.add_argument("--rew-norm", type=int, default=True) + parser.add_argument("--return_scaling", type=int, default=True) # "clip" option also works well. - parser.add_argument("--action-bound-method", type=str, default="tanh") - parser.add_argument("--lr-decay", type=int, default=True) + parser.add_argument("--action_bound_method", type=str, default="tanh") + parser.add_argument("--lr_decay", type=int, default=True) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -49,15 +50,15 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="mujoco.benchmark") + parser.add_argument("--wandb_project", type=str, default="mujoco.benchmark") parser.add_argument( "--watch", default=False, @@ -67,12 +68,12 @@ def get_args() -> argparse.Namespace: return parser.parse_args() -def test_reinforce(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, - args.training_num, - args.test_num, + args.num_train_envs, + args.num_test_envs, obs_norm=True, ) args.state_shape = env.observation_space.shape or env.observation_space.n @@ -86,16 +87,14 @@ def test_reinforce(args: argparse.Namespace = get_args()) -> None: torch.manual_seed(args.seed) # model net_a = Net( - args.state_shape, + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, - device=args.device, ) - actor = ActorProb( - net_a, - args.action_shape, + actor = ContinuousActorProbabilistic( + preprocess_net=net_a, + action_shape=args.action_shape, unbounded=True, - device=args.device, ).to(args.device) torch.nn.init.constant_(actor.sigma_param, -0.5) for m in actor.modules(): @@ -111,46 +110,50 @@ def test_reinforce(args: argparse.Namespace = get_args()) -> None: torch.nn.init.zeros_(m.bias) m.weight.data.copy_(0.01 * m.weight.data) - optim = torch.optim.Adam(actor.parameters(), lr=args.lr) - lr_scheduler = None + optim = AdamOptimizerFactory(lr=args.lr) if args.lr_decay: - # decay learning rate to 0 linearly - max_update_num = np.ceil(args.step_per_epoch / args.step_per_collect) * args.epoch - - lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + optim.with_lr_scheduler_factory( + LRSchedulerFactoryLinear( + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + ) + ) def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy: PGPolicy = PGPolicy( + policy = ProbabilisticActorPolicy( actor=actor, - optim=optim, dist_fn=dist, action_space=env.action_space, - discount_factor=args.gamma, - reward_normalization=args.rew_norm, action_scaling=True, action_bound_method=args.action_bound_method, - lr_scheduler=lr_scheduler, + ) + algorithm: Reinforce = Reinforce( + policy=policy, + optim=optim, + gamma=args.gamma, + return_standardization=args.return_scaling, ) # load a previous policy if args.resume_path: ckpt = torch.load(args.resume_path, map_location=args.device) - policy.load_state_dict(ckpt["model"]) + algorithm.load_state_dict(ckpt["model"]) train_envs.set_obs_rms(ckpt["obs_rms"]) test_envs.set_obs_rms(ckpt["obs_rms"]) print("Loaded agent from: ", args.resume_path) # collector buffer: VectorReplayBuffer | ReplayBuffer - if args.training_num > 1: + if args.num_train_envs > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -173,34 +176,35 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: state = {"model": policy.state_dict(), "obs_rms": train_envs.get_obs_rms()} torch.save(state, os.path.join(log_path, "policy.pth")) if not args.watch: - # trainer - result = OnpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - step_per_collect=args.step_per_collect, - save_best_fn=save_best_fn, - logger=logger, - test_in_train=False, - ).run() + # train + result = algorithm.run_training( + OnPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + collection_step_num_env_steps=args.collection_step_num_env_steps, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=False, + ) + ) pprint.pprint(result) # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) if __name__ == "__main__": - test_reinforce() + main() diff --git a/examples/mujoco/mujoco_reinforce_hl.py b/examples/mujoco/mujoco_reinforce_hl.py index f3e8821ae..5edf6fc55 100644 --- a/examples/mujoco/mujoco_reinforce_hl.py +++ b/examples/mujoco/mujoco_reinforce_hl.py @@ -9,13 +9,13 @@ from sensai.util.logging import datetime_tag from examples.mujoco.mujoco_env import MujocoEnvFactory -from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.config import OnPolicyTrainingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, - PGExperimentBuilder, + ReinforceExperimentBuilder, ) -from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear -from tianshou.highlevel.params.policy_params import PGParams +from tianshou.highlevel.params.algorithm_params import ReinforceParams +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear def main( @@ -26,42 +26,40 @@ def main( lr: float = 1e-3, gamma: float = 0.99, epoch: int = 100, - step_per_epoch: int = 30000, - step_per_collect: int = 2048, - repeat_per_collect: int = 1, + epoch_num_steps: int = 30000, + collection_step_num_env_steps: int = 2048, + update_step_num_repetitions: int = 1, batch_size: int | None = None, - training_num: int = 10, - test_num: int = 10, - rew_norm: bool = True, + num_train_envs: int = 10, + num_test_envs: int = 10, + return_scaling: bool = True, action_bound_method: Literal["clip", "tanh"] = "tanh", lr_decay: bool = True, ) -> None: log_name = os.path.join(task, "reinforce", str(experiment_config.seed), datetime_tag()) - sampling_config = SamplingConfig( - num_epochs=epoch, - step_per_epoch=step_per_epoch, + training_config = OnPolicyTrainingConfig( + max_epochs=epoch, + epoch_num_steps=epoch_num_steps, batch_size=batch_size, - num_train_envs=training_num, - num_test_envs=test_num, + num_train_envs=num_train_envs, + num_test_envs=num_test_envs, buffer_size=buffer_size, - step_per_collect=step_per_collect, - repeat_per_collect=repeat_per_collect, + collection_step_num_env_steps=collection_step_num_env_steps, + update_step_num_repetitions=update_step_num_repetitions, ) env_factory = MujocoEnvFactory(task, obs_norm=True) experiment = ( - PGExperimentBuilder(env_factory, experiment_config, sampling_config) - .with_pg_params( - PGParams( - discount_factor=gamma, + ReinforceExperimentBuilder(env_factory, experiment_config, training_config) + .with_reinforce_params( + ReinforceParams( + gamma=gamma, action_bound_method=action_bound_method, - reward_normalization=rew_norm, + return_standardization=return_scaling, lr=lr, - lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config) - if lr_decay - else None, + lr_scheduler=LRSchedulerFactoryFactoryLinear(training_config) if lr_decay else None, ), ) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index 151237580..16db8ebf5 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -9,37 +9,39 @@ import torch from mujoco_env import make_mujoco_env +from tianshou.algorithm import SAC +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import SACPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Ant-v4") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=1000000) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256]) - parser.add_argument("--actor-lr", type=float, default=1e-3) - parser.add_argument("--critic-lr", type=float, default=1e-3) + parser.add_argument("--buffer_size", type=int, default=1000000) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[256, 256]) + parser.add_argument("--actor_lr", type=float, default=1e-3) + parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--alpha", type=float, default=0.2) - parser.add_argument("--auto-alpha", default=False, action="store_true") - parser.add_argument("--alpha-lr", type=float, default=3e-4) - parser.add_argument("--start-timesteps", type=int, default=10000) + parser.add_argument("--auto_alpha", default=False, action="store_true") + parser.add_argument("--alpha_lr", type=float, default=3e-4) + parser.add_argument("--start_timesteps", type=int, default=10000) parser.add_argument("--epoch", type=int, default=200) - parser.add_argument("--step-per-epoch", type=int, default=5000) - parser.add_argument("--step-per-collect", type=int, default=1) - parser.add_argument("--update-per-step", type=int, default=1) - parser.add_argument("--n-step", type=int, default=1) - parser.add_argument("--batch-size", type=int, default=256) - parser.add_argument("--training-num", type=int, default=1) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=5000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=1) + parser.add_argument("--update_per_step", type=int, default=1) + parser.add_argument("--n_step", type=int, default=1) + parser.add_argument("--batch_size", type=int, default=256) + parser.add_argument("--num_train_envs", type=int, default=1) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -47,15 +49,15 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="mujoco.benchmark") + parser.add_argument("--wandb_project", type=str, default="mujoco.benchmark") parser.add_argument( "--watch", default=False, @@ -65,12 +67,12 @@ def get_args() -> argparse.Namespace: return parser.parse_args() -def test_sac(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, - args.training_num, - args.test_num, + args.num_train_envs, + args.num_test_envs, obs_norm=False, ) args.state_shape = env.observation_space.shape or env.observation_space.n @@ -83,43 +85,44 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = ActorProb( - net_a, - args.action_shape, - device=args.device, + net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + actor = ContinuousActorProbabilistic( + preprocess_net=net_a, + action_shape=args.action_shape, unbounded=True, conditioned_sigma=True, ).to(args.device) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic1 = Critic(net_c1, device=args.device).to(args.device) - critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic(net_c2, device=args.device).to(args.device) - critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) + critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) + critic2 = ContinuousCritic(preprocess_net=net_c2).to(args.device) + critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) if args.auto_alpha: target_entropy = -np.prod(env.action_space.shape) - log_alpha = torch.zeros(1, requires_grad=True, device=args.device) - alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) - args.alpha = (target_entropy, log_alpha, alpha_optim) + log_alpha = 0.0 + alpha_optim = AdamOptimizerFactory(lr=args.alpha_lr) + args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim).to(args.device) - policy: SACPolicy = SACPolicy( + policy = SACPolicy( actor=actor, - actor_optim=actor_optim, + action_space=env.action_space, + ) + algorithm: SAC = SAC( + policy=policy, + policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, critic2=critic2, @@ -127,23 +130,22 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: tau=args.tau, gamma=args.gamma, alpha=args.alpha, - estimation_step=args.n_step, - action_space=env.action_space, + n_step_return_horizon=args.n_step, ) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector buffer: VectorReplayBuffer | ReplayBuffer - if args.training_num > 1: + if args.num_train_envs > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs) train_collector.reset() train_collector.collect(n_step=args.start_timesteps, random=True) @@ -168,33 +170,34 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) if not args.watch: - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - test_in_train=False, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + save_best_fn=save_best_fn, + logger=logger, + update_step_num_gradient_steps_per_sample=args.update_per_step, + test_in_train=False, + ) + ) pprint.pprint(result) # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) if __name__ == "__main__": - test_sac() + main() diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index 5b2e1519b..84b51ed25 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -7,13 +7,13 @@ from sensai.util.logging import datetime_tag from examples.mujoco.mujoco_env import MujocoEnvFactory -from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.config import OffPolicyTrainingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, SACExperimentBuilder, ) +from tianshou.highlevel.params.algorithm_params import SACParams from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault -from tianshou.highlevel.params.policy_params import SACParams def main( @@ -30,25 +30,25 @@ def main( alpha_lr: float = 3e-4, start_timesteps: int = 10000, epoch: int = 200, - step_per_epoch: int = 5000, - step_per_collect: int = 1, + epoch_num_steps: int = 5000, + collection_step_num_env_steps: int = 1, update_per_step: int = 1, n_step: int = 1, batch_size: int = 256, - training_num: int = 1, - test_num: int = 10, + num_train_envs: int = 1, + num_test_envs: int = 10, ) -> None: log_name = os.path.join(task, "sac", str(experiment_config.seed), datetime_tag()) - sampling_config = SamplingConfig( - num_epochs=epoch, - step_per_epoch=step_per_epoch, - num_train_envs=training_num, - num_test_envs=test_num, + training_config = OffPolicyTrainingConfig( + max_epochs=epoch, + epoch_num_steps=epoch_num_steps, + num_train_envs=num_train_envs, + num_test_envs=num_test_envs, buffer_size=buffer_size, batch_size=batch_size, - step_per_collect=step_per_collect, - update_per_step=update_per_step, + collection_step_num_env_steps=collection_step_num_env_steps, + update_step_num_gradient_steps_per_sample=update_per_step, start_timesteps=start_timesteps, start_timesteps_random=True, ) @@ -56,13 +56,13 @@ def main( env_factory = MujocoEnvFactory(task, obs_norm=False) experiment = ( - SACExperimentBuilder(env_factory, experiment_config, sampling_config) + SACExperimentBuilder(env_factory, experiment_config, training_config) .with_sac_params( SACParams( tau=tau, gamma=gamma, alpha=AutoAlphaFactoryDefault(lr=alpha_lr) if auto_alpha else alpha, - estimation_step=n_step, + n_step_return_horizon=n_step, actor_lr=actor_lr, critic1_lr=critic_lr, critic2_lr=critic_lr, diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index a5e3e8cf6..030cfa05a 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -9,39 +9,41 @@ import torch from mujoco_env import make_mujoco_env +from tianshou.algorithm import TD3 +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.exploration import GaussianNoise from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import TD3Policy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import Actor, Critic +from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Ant-v4") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=1000000) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256]) - parser.add_argument("--actor-lr", type=float, default=3e-4) - parser.add_argument("--critic-lr", type=float, default=3e-4) + parser.add_argument("--buffer_size", type=int, default=1000000) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[256, 256]) + parser.add_argument("--actor_lr", type=float, default=3e-4) + parser.add_argument("--critic_lr", type=float, default=3e-4) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--tau", type=float, default=0.005) - parser.add_argument("--exploration-noise", type=float, default=0.1) - parser.add_argument("--policy-noise", type=float, default=0.2) - parser.add_argument("--noise-clip", type=float, default=0.5) - parser.add_argument("--update-actor-freq", type=int, default=2) - parser.add_argument("--start-timesteps", type=int, default=25000) + parser.add_argument("--exploration_noise", type=float, default=0.1) + parser.add_argument("--policy_noise", type=float, default=0.2) + parser.add_argument("--noise_clip", type=float, default=0.5) + parser.add_argument("--update_actor_freq", type=int, default=2) + parser.add_argument("--start_timesteps", type=int, default=25000) parser.add_argument("--epoch", type=int, default=200) - parser.add_argument("--step-per-epoch", type=int, default=5000) - parser.add_argument("--step-per-collect", type=int, default=1) - parser.add_argument("--update-per-step", type=int, default=1) - parser.add_argument("--n-step", type=int, default=1) - parser.add_argument("--batch-size", type=int, default=256) - parser.add_argument("--training-num", type=int, default=1) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=5000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=1) + parser.add_argument("--update_per_step", type=int, default=1) + parser.add_argument("--n_step", type=int, default=1) + parser.add_argument("--batch_size", type=int, default=256) + parser.add_argument("--num_train_envs", type=int, default=1) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -49,15 +51,15 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="mujoco.benchmark") + parser.add_argument("--wandb_project", type=str, default="mujoco.benchmark") parser.add_argument( "--watch", default=False, @@ -67,12 +69,12 @@ def get_args() -> argparse.Namespace: return parser.parse_args() -def test_td3(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, - args.training_num, - args.test_num, + args.num_train_envs, + args.num_test_envs, obs_norm=False, ) args.state_shape = env.observation_space.shape or env.observation_space.n @@ -88,60 +90,63 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = Actor(net_a, args.action_shape, max_action=args.max_action, device=args.device).to( + net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + actor = ContinuousActorDeterministic( + preprocess_net=net_a, action_shape=args.action_shape, max_action=args.max_action + ).to( args.device, ) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic1 = Critic(net_c1, device=args.device).to(args.device) - critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic(net_c2, device=args.device).to(args.device) - critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) + critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) + critic2 = ContinuousCritic(preprocess_net=net_c2).to(args.device) + critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) - policy: TD3Policy = TD3Policy( + policy = ContinuousDeterministicPolicy( actor=actor, - actor_optim=actor_optim, + exploration_noise=GaussianNoise(sigma=args.exploration_noise), + action_space=env.action_space, + ) + algorithm: TD3 = TD3( + policy=policy, + policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, critic2=critic2, critic2_optim=critic2_optim, tau=args.tau, gamma=args.gamma, - exploration_noise=GaussianNoise(sigma=args.exploration_noise), policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq, noise_clip=args.noise_clip, - estimation_step=args.n_step, - action_space=env.action_space, + n_step_return_horizon=args.n_step, ) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector buffer: VectorReplayBuffer | ReplayBuffer - if args.training_num > 1: + if args.num_train_envs > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs) train_collector.reset() train_collector.collect(n_step=args.start_timesteps, random=True) @@ -166,33 +171,34 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) if not args.watch: - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - test_in_train=False, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + save_best_fn=save_best_fn, + logger=logger, + update_step_num_gradient_steps_per_sample=args.update_per_step, + test_in_train=False, + ) + ) pprint.pprint(result) # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) if __name__ == "__main__": - test_td3() + main() diff --git a/examples/mujoco/mujoco_td3_hl.py b/examples/mujoco/mujoco_td3_hl.py index 8ca54d591..56898319e 100644 --- a/examples/mujoco/mujoco_td3_hl.py +++ b/examples/mujoco/mujoco_td3_hl.py @@ -8,16 +8,16 @@ from sensai.util.logging import datetime_tag from examples.mujoco.mujoco_env import MujocoEnvFactory -from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.config import OffPolicyTrainingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, TD3ExperimentBuilder, ) +from tianshou.highlevel.params.algorithm_params import TD3Params from tianshou.highlevel.params.env_param import MaxActionScaled from tianshou.highlevel.params.noise import ( MaxActionScaledGaussian, ) -from tianshou.highlevel.params.policy_params import TD3Params def main( @@ -35,25 +35,25 @@ def main( update_actor_freq: int = 2, start_timesteps: int = 25000, epoch: int = 200, - step_per_epoch: int = 5000, - step_per_collect: int = 1, - update_per_step: int = 1, + epoch_num_steps: int = 5000, + collection_step_num_env_steps: int = 1, + update_step_num_gradient_steps_per_sample: int = 1, n_step: int = 1, batch_size: int = 256, - training_num: int = 1, - test_num: int = 10, + num_train_envs: int = 1, + num_test_envs: int = 10, ) -> None: log_name = os.path.join(task, "td3", str(experiment_config.seed), datetime_tag()) - sampling_config = SamplingConfig( - num_epochs=epoch, - step_per_epoch=step_per_epoch, - num_train_envs=training_num, - num_test_envs=test_num, + training_config = OffPolicyTrainingConfig( + max_epochs=epoch, + epoch_num_steps=epoch_num_steps, + num_train_envs=num_train_envs, + num_test_envs=num_test_envs, buffer_size=buffer_size, batch_size=batch_size, - step_per_collect=step_per_collect, - update_per_step=update_per_step, + collection_step_num_env_steps=collection_step_num_env_steps, + update_step_num_gradient_steps_per_sample=update_step_num_gradient_steps_per_sample, start_timesteps=start_timesteps, start_timesteps_random=True, ) @@ -61,12 +61,12 @@ def main( env_factory = MujocoEnvFactory(task, obs_norm=False) experiment = ( - TD3ExperimentBuilder(env_factory, experiment_config, sampling_config) + TD3ExperimentBuilder(env_factory, experiment_config, training_config) .with_td3_params( TD3Params( tau=tau, gamma=gamma, - estimation_step=n_step, + n_step_return_horizon=n_step, update_actor_freq=update_actor_freq, noise_clip=MaxActionScaled(noise_clip), policy_noise=MaxActionScaled(policy_noise), diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index 9405b2440..59c6aa005 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -10,24 +10,25 @@ from mujoco_env import make_mujoco_env from torch import nn from torch.distributions import Distribution, Independent, Normal -from torch.optim.lr_scheduler import LambdaLR +from tianshou.algorithm import TRPO +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import TRPOPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OnpolicyTrainer +from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Ant-v4") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=4096) + parser.add_argument("--buffer_size", type=int, default=4096) parser.add_argument( - "--hidden-sizes", + "--hidden_sizes", type=int, nargs="*", default=[64, 64], @@ -35,40 +36,40 @@ def get_args() -> argparse.Namespace: parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=30000) - parser.add_argument("--step-per-collect", type=int, default=1024) - parser.add_argument("--repeat-per-collect", type=int, default=1) + parser.add_argument("--epoch_num_steps", type=int, default=30000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=1024) + parser.add_argument("--update_step_num_repetitions", type=int, default=1) # batch-size >> step-per-collect means calculating all data in one singe forward. - parser.add_argument("--batch-size", type=int, default=None) - parser.add_argument("--training-num", type=int, default=16) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=None) + parser.add_argument("--num_train_envs", type=int, default=16) + parser.add_argument("--num_test_envs", type=int, default=10) # trpo special - parser.add_argument("--rew-norm", type=int, default=True) - parser.add_argument("--gae-lambda", type=float, default=0.95) + parser.add_argument("--return_scaling", type=int, default=True) + parser.add_argument("--gae_lambda", type=float, default=0.95) # TODO tanh support - parser.add_argument("--bound-action-method", type=str, default="clip") - parser.add_argument("--lr-decay", type=int, default=True) + parser.add_argument("--bound_action_method", type=str, default="clip") + parser.add_argument("--lr_decay", type=int, default=True) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--norm-adv", type=int, default=1) - parser.add_argument("--optim-critic-iters", type=int, default=20) - parser.add_argument("--max-kl", type=float, default=0.01) - parser.add_argument("--backtrack-coeff", type=float, default=0.8) - parser.add_argument("--max-backtracks", type=int, default=10) + parser.add_argument("--advantage_normalization", type=int, default=1) + parser.add_argument("--optim_critic_iters", type=int, default=20) + parser.add_argument("--max_kl", type=float, default=0.01) + parser.add_argument("--backtrack_coeff", type=float, default=0.8) + parser.add_argument("--max_backtracks", type=int, default=10) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="mujoco.benchmark") + parser.add_argument("--wandb_project", type=str, default="mujoco.benchmark") parser.add_argument( "--watch", default=False, @@ -82,8 +83,8 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, - args.training_num, - args.test_num, + args.num_train_envs, + args.num_test_envs, obs_norm=True, ) args.state_shape = env.observation_space.shape or env.observation_space.n @@ -97,24 +98,21 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None: torch.manual_seed(args.seed) # model net_a = Net( - args.state_shape, + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, - device=args.device, ) - actor = ActorProb( - net_a, - args.action_shape, + actor = ContinuousActorProbabilistic( + preprocess_net=net_a, + action_shape=args.action_shape, unbounded=True, - device=args.device, ).to(args.device) net_c = Net( - args.state_shape, + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, - device=args.device, ) - critic = Critic(net_c, device=args.device).to(args.device) + critic = ContinuousCritic(preprocess_net=net_c).to(args.device) torch.nn.init.constant_(actor.sigma_param, -0.5) for m in list(actor.modules()) + list(critic.modules()): if isinstance(m, torch.nn.Linear): @@ -129,31 +127,35 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None: torch.nn.init.zeros_(m.bias) m.weight.data.copy_(0.01 * m.weight.data) - optim = torch.optim.Adam(critic.parameters(), lr=args.lr) - lr_scheduler = None + optim = AdamOptimizerFactory(lr=args.lr) if args.lr_decay: - # decay learning rate to 0 linearly - max_update_num = np.ceil(args.step_per_epoch / args.step_per_collect) * args.epoch - - lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + optim.with_lr_scheduler_factory( + LRSchedulerFactoryLinear( + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + ) + ) def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy: TRPOPolicy = TRPOPolicy( + policy = ProbabilisticActorPolicy( actor=actor, - critic=critic, - optim=optim, dist_fn=dist, - discount_factor=args.gamma, - gae_lambda=args.gae_lambda, - reward_normalization=args.rew_norm, action_scaling=True, action_bound_method=args.bound_action_method, - lr_scheduler=lr_scheduler, action_space=env.action_space, - advantage_normalization=args.norm_adv, + ) + algorithm: TRPO = TRPO( + policy=policy, + critic=critic, + optim=optim, + gamma=args.gamma, + gae_lambda=args.gae_lambda, + return_scaling=args.return_scaling, + advantage_normalization=args.advantage_normalization, optim_critic_iters=args.optim_critic_iters, max_kl=args.max_kl, backtrack_coeff=args.backtrack_coeff, @@ -163,19 +165,19 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: # load a previous policy if args.resume_path: ckpt = torch.load(args.resume_path, map_location=args.device) - policy.load_state_dict(ckpt["model"]) + algorithm.load_state_dict(ckpt["model"]) train_envs.set_obs_rms(ckpt["obs_rms"]) test_envs.set_obs_rms(ckpt["obs_rms"]) print("Loaded agent from: ", args.resume_path) # collector buffer: VectorReplayBuffer | ReplayBuffer - if args.training_num > 1: + if args.num_train_envs > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -198,32 +200,33 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: state = {"model": policy.state_dict(), "obs_rms": train_envs.get_obs_rms()} torch.save(state, os.path.join(log_path, "policy.pth")) if not args.watch: # trainer - result = OnpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - step_per_collect=args.step_per_collect, - save_best_fn=save_best_fn, - logger=logger, - test_in_train=False, - ).run() + result = algorithm.run_training( + OnPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + collection_step_num_env_steps=args.collection_step_num_env_steps, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=False, + ) + ) pprint.pprint(result) # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) diff --git a/examples/mujoco/mujoco_trpo_hl.py b/examples/mujoco/mujoco_trpo_hl.py index 4dfc39185..73c501ae8 100644 --- a/examples/mujoco/mujoco_trpo_hl.py +++ b/examples/mujoco/mujoco_trpo_hl.py @@ -9,13 +9,13 @@ from sensai.util.logging import datetime_tag from examples.mujoco.mujoco_env import MujocoEnvFactory -from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.config import OnPolicyTrainingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, TRPOExperimentBuilder, ) -from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear -from tianshou.highlevel.params.policy_params import TRPOParams +from tianshou.highlevel.params.algorithm_params import TRPOParams +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear def main( @@ -26,17 +26,17 @@ def main( lr: float = 1e-3, gamma: float = 0.99, epoch: int = 100, - step_per_epoch: int = 30000, - step_per_collect: int = 1024, - repeat_per_collect: int = 1, + epoch_num_steps: int = 30000, + collection_step_num_env_steps: int = 1024, + update_step_num_repetitions: int = 1, batch_size: int = 16, - training_num: int = 16, - test_num: int = 10, - rew_norm: bool = True, + num_train_envs: int = 16, + num_test_envs: int = 10, + return_scaling: bool = True, gae_lambda: float = 0.95, bound_action_method: Literal["clip", "tanh"] = "clip", lr_decay: bool = True, - norm_adv: bool = True, + advantage_normalization: bool = True, optim_critic_iters: int = 20, max_kl: float = 0.01, backtrack_coeff: float = 0.8, @@ -44,36 +44,34 @@ def main( ) -> None: log_name = os.path.join(task, "trpo", str(experiment_config.seed), datetime_tag()) - sampling_config = SamplingConfig( - num_epochs=epoch, - step_per_epoch=step_per_epoch, + training_config = OnPolicyTrainingConfig( + max_epochs=epoch, + epoch_num_steps=epoch_num_steps, batch_size=batch_size, - num_train_envs=training_num, - num_test_envs=test_num, + num_train_envs=num_train_envs, + num_test_envs=num_test_envs, buffer_size=buffer_size, - step_per_collect=step_per_collect, - repeat_per_collect=repeat_per_collect, + collection_step_num_env_steps=collection_step_num_env_steps, + update_step_num_repetitions=update_step_num_repetitions, ) env_factory = MujocoEnvFactory(task, obs_norm=True) experiment = ( - TRPOExperimentBuilder(env_factory, experiment_config, sampling_config) + TRPOExperimentBuilder(env_factory, experiment_config, training_config) .with_trpo_params( TRPOParams( - discount_factor=gamma, + gamma=gamma, gae_lambda=gae_lambda, action_bound_method=bound_action_method, - reward_normalization=rew_norm, - advantage_normalization=norm_adv, + return_standardization=return_scaling, + advantage_normalization=advantage_normalization, optim_critic_iters=optim_critic_iters, max_kl=max_kl, backtrack_coeff=backtrack_coeff, max_backtracks=max_backtracks, lr=lr, - lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config) - if lr_decay - else None, + lr_scheduler=LRSchedulerFactoryFactoryLinear(training_config) if lr_decay else None, ), ) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) diff --git a/examples/mujoco/plotter.py b/examples/mujoco/plotter.py index 5e2f9e016..cb64efc27 100755 --- a/examples/mujoco/plotter.py +++ b/examples/mujoco/plotter.py @@ -180,13 +180,13 @@ def plot_figure( if __name__ == "__main__": parser = argparse.ArgumentParser(description="plotter") parser.add_argument( - "--fig-length", + "--fig_length", type=int, default=6, help="matplotlib figure length (default: 6)", ) parser.add_argument( - "--fig-width", + "--fig_width", type=int, default=6, help="matplotlib figure width (default: 6)", @@ -212,7 +212,7 @@ def plot_figure( parser.add_argument("--xlabel", default="Timesteps", help="matplotlib figure xlabel") parser.add_argument("--ylabel", default="Episode Reward", help="matplotlib figure ylabel") parser.add_argument( - "--shaded-std", + "--shaded_std", action="store_true", help="shaded region corresponding to standard deviation of the group", ) @@ -227,35 +227,35 @@ def plot_figure( help="whether to share y axis within multiple sub-figures", ) parser.add_argument( - "--legend-outside", + "--legend_outside", action="store_true", help="place the legend outside of the figure", ) parser.add_argument("--xlim", type=int, default=None, help="x-axis limitation (default: None)") - parser.add_argument("--root-dir", default="./", help="root dir (default: ./)") + parser.add_argument("--root_dir", default="./", help="root dir (default: ./)") parser.add_argument( - "--file-pattern", + "--file_pattern", type=str, default=r".*/test_rew_\d+seeds.csv$", help="regular expression to determine whether or not to include target csv " "file, default to including all test_rew_{num}seeds.csv file under rootdir", ) parser.add_argument( - "--group-pattern", + "--group_pattern", type=str, default=r"(/|^)\w*?\-v(\d|$)", help="regular expression to group files in sub-figure, default to grouping " 'according to env_name dir, "" means no grouping', ) parser.add_argument( - "--legend-pattern", + "--legend_pattern", type=str, default=r".*", help="regular expression to extract legend from csv file path, default to " "using file path as legend name.", ) parser.add_argument("--show", action="store_true", help="show figure") - parser.add_argument("--output-path", type=str, help="figure save path", default="./figure.png") + parser.add_argument("--output_path", type=str, help="figure save path", default="./figure.png") parser.add_argument("--dpi", type=int, default=200, help="figure dpi (default: 200)") args = parser.parse_args() file_lists = find_all_files(args.root_dir, re.compile(args.file_pattern)) diff --git a/examples/mujoco/run_experiments.sh b/examples/mujoco/run_experiments.sh deleted file mode 100755 index b175fe7c9..000000000 --- a/examples/mujoco/run_experiments.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/bin/bash - -LOGDIR="results" -TASK=$1 -ALGO=$2 - -echo "Experiments started." -for seed in $(seq 0 9) -do - python mujoco_${ALGO}.py --task $TASK --epoch 200 --seed $seed --logdir $LOGDIR > ${TASK}_`date '+%m-%d-%H-%M-%S'`_seed_$seed.txt 2>&1 & -done -echo "Experiments ended." diff --git a/examples/mujoco/tools.py b/examples/mujoco/tools.py index e0db8162b..3777ce98c 100755 --- a/examples/mujoco/tools.py +++ b/examples/mujoco/tools.py @@ -128,11 +128,11 @@ def merge_csv( help="Re-generate all csv files instead of using existing one.", ) parser.add_argument( - "--remove-zero", + "--remove_zero", action="store_true", help="Remove the data point of env_step == 0.", ) - parser.add_argument("--root-dir", type=str) + parser.add_argument("--root_dir", type=str) args = parser.parse_args() csv_files = convert_tfevents_to_csv(args.root_dir, args.refresh) diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index 62c205076..8a5ec1ecf 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -11,60 +11,61 @@ import torch from gymnasium.spaces import Discrete -from examples.atari.atari_network import DQN -from examples.atari.atari_wrapper import make_atari_env from examples.offline.utils import load_buffer +from tianshou.algorithm import DiscreteBCQ +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.imitation.discrete_bcq import DiscreteBCQPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import DQNet +from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import DiscreteBCQPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OfflineTrainer -from tianshou.utils.net.common import ActorCritic -from tianshou.utils.net.discrete import Actor +from tianshou.trainer import OfflineTrainerParams +from tianshou.utils.net.discrete import DiscreteActor def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=1626) - parser.add_argument("--eps-test", type=float, default=0.001) + parser.add_argument("--eps_test", type=float, default=0.001) parser.add_argument("--lr", type=float, default=6.25e-5) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--n-step", type=int, default=1) - parser.add_argument("--target-update-freq", type=int, default=8000) - parser.add_argument("--unlikely-action-threshold", type=float, default=0.3) - parser.add_argument("--imitation-logits-penalty", type=float, default=0.01) + parser.add_argument("--n_step", type=int, default=1) + parser.add_argument("--target_update_freq", type=int, default=8000) + parser.add_argument("--unlikely_action_threshold", type=float, default=0.3) + parser.add_argument("--imitation_logits_penalty", type=float, default=0.01) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--update-per-epoch", type=int, default=10000) - parser.add_argument("--batch-size", type=int, default=32) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[512]) - parser.add_argument("--test-num", type=int, default=10) - parser.add_argument("--frames-stack", type=int, default=4) - parser.add_argument("--scale-obs", type=int, default=0) + parser.add_argument("--update_per_epoch", type=int, default=10000) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[512]) + parser.add_argument("--num_test_envs", type=int, default=10) + parser.add_argument("--frames_stack", type=int, default=4) + parser.add_argument("--scale_obs", type=int, default=0) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="offline_atari.benchmark") + parser.add_argument("--wandb_project", type=str, default="offline_atari.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) - parser.add_argument("--log-interval", type=int, default=100) + parser.add_argument("--log_interval", type=int, default=100) parser.add_argument( - "--load-buffer-name", + "--load_buffer_name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5", ) - parser.add_argument("--buffer-from-rl-unplugged", action="store_true", default=False) + parser.add_argument("--buffer_from_rl_unplugged", action="store_true", default=False) parser.add_argument( "--device", type=str, @@ -73,13 +74,13 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: # envs env, _, test_envs = make_atari_env( args.task, args.seed, 1, - args.test_num, + args.num_test_envs, scale=args.scale_obs, frame_stack=args.frames_stack, ) @@ -96,46 +97,45 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: assert args.state_shape is not None assert len(args.state_shape) == 3 c, h, w = args.state_shape - feature_net = DQN( - c, - h, - w, - args.action_shape, - device=args.device, + feature_net = DQNet( + c=c, + h=h, + w=w, + action_shape=args.action_shape, features_only=True, ).to(args.device) - policy_net = Actor( - feature_net, - args.action_shape, - device=args.device, + policy_net = DiscreteActor( + preprocess_net=feature_net, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, softmax_output=False, ).to(args.device) - imitation_net = Actor( - feature_net, - args.action_shape, - device=args.device, + imitation_net = DiscreteActor( + preprocess_net=feature_net, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, softmax_output=False, ).to(args.device) - actor_critic = ActorCritic(policy_net, imitation_net) - optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) - # define policy - policy: DiscreteBCQPolicy = DiscreteBCQPolicy( + optim = AdamOptimizerFactory(lr=args.lr) + # define policy and algorithm + policy = DiscreteBCQPolicy( model=policy_net, imitator=imitation_net, - optim=optim, action_space=env.action_space, - discount_factor=args.gamma, - estimation_step=args.n_step, - target_update_freq=args.target_update_freq, - eval_eps=args.eps_test, unlikely_action_threshold=args.unlikely_action_threshold, + eps_inference=args.eps_test, + ) + algorithm: DiscreteBCQ = DiscreteBCQ( + policy=policy, + optim=optim, + gamma=args.gamma, + n_step_return_horizon=args.n_step, + target_update_freq=args.target_update_freq, imitation_logits_penalty=args.imitation_logits_penalty, ) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # buffer if args.buffer_from_rl_unplugged: @@ -155,7 +155,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: print("Replay buffer size:", len(buffer), flush=True) # collector - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -178,7 +178,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: @@ -191,29 +191,30 @@ def watch() -> None: test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: watch() sys.exit(0) - result = OfflineTrainer( - policy=policy, - buffer=buffer, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.update_per_epoch, - episode_per_test=args.test_num, - batch_size=args.batch_size, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + result = algorithm.run_training( + OfflineTrainerParams( + buffer=buffer, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.update_per_epoch, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + ) + ) pprint.pprint(result) watch() if __name__ == "__main__": - test_discrete_bcq(get_args()) + main(get_args()) diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index 436145c90..620b97882 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -12,14 +12,16 @@ import torch from gymnasium.spaces import Discrete -from examples.atari.atari_network import QRDQN -from examples.atari.atari_wrapper import make_atari_env from examples.offline.utils import load_buffer +from tianshou.algorithm import DiscreteCQL +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import QRDQNet +from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import DiscreteCQLPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OfflineTrainer +from tianshou.trainer import OfflineTrainerParams from tianshou.utils.space_info import SpaceInfo @@ -27,44 +29,43 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=1626) - parser.add_argument("--eps-test", type=float, default=0.001) parser.add_argument("--lr", type=float, default=0.0001) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--num-quantiles", type=int, default=200) - parser.add_argument("--n-step", type=int, default=1) - parser.add_argument("--target-update-freq", type=int, default=500) - parser.add_argument("--min-q-weight", type=float, default=10.0) + parser.add_argument("--num_quantiles", type=int, default=200) + parser.add_argument("--n_step", type=int, default=1) + parser.add_argument("--target_update_freq", type=int, default=500) + parser.add_argument("--min_q_weight", type=float, default=10.0) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--update-per-epoch", type=int, default=10000) - parser.add_argument("--batch-size", type=int, default=32) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[512]) - parser.add_argument("--test-num", type=int, default=10) - parser.add_argument("--frames-stack", type=int, default=4) - parser.add_argument("--scale-obs", type=int, default=0) + parser.add_argument("--update_per_epoch", type=int, default=10000) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[512]) + parser.add_argument("--num_test_envs", type=int, default=10) + parser.add_argument("--frames_stack", type=int, default=4) + parser.add_argument("--scale_obs", type=int, default=0) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="offline_atari.benchmark") + parser.add_argument("--wandb_project", type=str, default="offline_atari.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) - parser.add_argument("--log-interval", type=int, default=100) + parser.add_argument("--log_interval", type=int, default=100) parser.add_argument( - "--load-buffer-name", + "--load_buffer_name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5", ) - parser.add_argument("--buffer-from-rl-unplugged", action="store_true", default=False) + parser.add_argument("--buffer_from_rl_unplugged", action="store_true", default=False) parser.add_argument( "--device", type=str, @@ -73,13 +74,13 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: # envs env, _, test_envs = make_atari_env( args.task, args.seed, 1, - args.test_num, + args.num_test_envs, scale=args.scale_obs, frame_stack=args.frames_stack, ) @@ -97,29 +98,31 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net = QRDQN( + net = QRDQNet( c=c, h=h, w=w, action_shape=args.action_shape, num_quantiles=args.num_quantiles, - device=args.device, ) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) # define policy - policy: DiscreteCQLPolicy = DiscreteCQLPolicy( + policy = QRDQNPolicy( model=net, - optim=optim, action_space=env.action_space, - discount_factor=args.gamma, + ) + algorithm: DiscreteCQL = DiscreteCQL( + policy=policy, + optim=optim, + gamma=args.gamma, num_quantiles=args.num_quantiles, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, min_q_weight=args.min_q_weight, ).to(args.device) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # buffer if args.buffer_from_rl_unplugged: @@ -139,7 +142,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: print("Replay buffer size:", len(buffer), flush=True) # collector - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -162,7 +165,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: @@ -171,33 +174,33 @@ def stop_fn(mean_rewards: float) -> bool: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.set_eps(args.eps_test) test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: watch() sys.exit(0) - result = OfflineTrainer( - policy=policy, - buffer=buffer, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.update_per_epoch, - episode_per_test=args.test_num, - batch_size=args.batch_size, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + result = algorithm.run_training( + OfflineTrainerParams( + buffer=buffer, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.update_per_epoch, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + ) + ) pprint.pprint(result) watch() if __name__ == "__main__": - test_discrete_cql(get_args()) + main(get_args()) diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index 565908559..bd2ff45a6 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -11,16 +11,17 @@ import torch from gymnasium.spaces import Discrete -from examples.atari.atari_network import DQN -from examples.atari.atari_wrapper import make_atari_env from examples.offline.utils import load_buffer +from tianshou.algorithm import DiscreteCRR +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.reinforce import DiscreteActorPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import DQNet +from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import DiscreteCRRPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OfflineTrainer -from tianshou.utils.net.common import ActorCritic -from tianshou.utils.net.discrete import Actor, Critic +from tianshou.trainer import OfflineTrainerParams +from tianshou.utils.net.discrete import DiscreteActor, DiscreteCritic from tianshou.utils.space_info import SpaceInfo @@ -30,42 +31,42 @@ def get_args() -> argparse.Namespace: parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--lr", type=float, default=0.0001) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--policy-improvement-mode", type=str, default="exp") - parser.add_argument("--ratio-upper-bound", type=float, default=20.0) + parser.add_argument("--policy_improvement_mode", type=str, default="exp") + parser.add_argument("--ratio_upper_bound", type=float, default=20.0) parser.add_argument("--beta", type=float, default=1.0) - parser.add_argument("--min-q-weight", type=float, default=10.0) - parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--min_q_weight", type=float, default=10.0) + parser.add_argument("--target_update_freq", type=int, default=500) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--update-per-epoch", type=int, default=10000) - parser.add_argument("--batch-size", type=int, default=32) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[512]) - parser.add_argument("--test-num", type=int, default=10) - parser.add_argument("--frames-stack", type=int, default=4) - parser.add_argument("--scale-obs", type=int, default=0) + parser.add_argument("--update_per_epoch", type=int, default=10000) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[512]) + parser.add_argument("--num_test_envs", type=int, default=10) + parser.add_argument("--frames_stack", type=int, default=4) + parser.add_argument("--scale_obs", type=int, default=0) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="offline_atari.benchmark") + parser.add_argument("--wandb_project", type=str, default="offline_atari.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) - parser.add_argument("--log-interval", type=int, default=100) + parser.add_argument("--log_interval", type=int, default=100) parser.add_argument( - "--load-buffer-name", + "--load_buffer_name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5", ) - parser.add_argument("--buffer-from-rl-unplugged", action="store_true", default=False) + parser.add_argument("--buffer_from_rl_unplugged", action="store_true", default=False) parser.add_argument( "--device", type=str, @@ -74,13 +75,13 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: # envs env, _, test_envs = make_atari_env( args.task, args.seed, 1, - args.test_num, + args.num_test_envs, scale=args.scale_obs, frame_stack=args.frames_stack, ) @@ -98,36 +99,35 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: assert args.state_shape is not None assert len(args.state_shape) == 3 c, h, w = args.state_shape - feature_net = DQN( - c, - h, - w, - args.action_shape, - device=args.device, + feature_net = DQNet( + c=c, + h=h, + w=w, + action_shape=args.action_shape, features_only=True, ).to(args.device) - actor = Actor( - feature_net, - args.action_shape, + actor = DiscreteActor( + preprocess_net=feature_net, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, softmax_output=False, ).to(args.device) - critic = Critic( - feature_net, + critic = DiscreteCritic( + preprocess_net=feature_net, hidden_sizes=args.hidden_sizes, last_size=int(np.prod(args.action_shape)), - device=args.device, ).to(args.device) - actor_critic = ActorCritic(actor, critic) - optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) - # define policy - policy: DiscreteCRRPolicy = DiscreteCRRPolicy( + optim = AdamOptimizerFactory(lr=args.lr) + # define policy and algorithm + policy = DiscreteActorPolicy( actor=actor, + action_space=env.action_space, + ) + algorithm: DiscreteCRR = DiscreteCRR( + policy=policy, critic=critic, optim=optim, - action_space=env.action_space, - discount_factor=args.gamma, + gamma=args.gamma, policy_improvement_mode=args.policy_improvement_mode, ratio_upper_bound=args.ratio_upper_bound, beta=args.beta, @@ -136,7 +136,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: ).to(args.device) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # buffer if args.buffer_from_rl_unplugged: @@ -156,7 +156,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: print("Replay buffer size:", len(buffer), flush=True) # collector - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -179,7 +179,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: @@ -191,29 +191,30 @@ def watch() -> None: test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: watch() sys.exit(0) - result = OfflineTrainer( - policy=policy, - buffer=buffer, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.update_per_epoch, - episode_per_test=args.test_num, - batch_size=args.batch_size, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + result = algorithm.run_training( + OfflineTrainerParams( + buffer=buffer, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.update_per_epoch, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + ) + ) pprint.pprint(result) watch() if __name__ == "__main__": - test_discrete_crr(get_args()) + main(get_args()) diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index 16b7cdec3..0819aed6c 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -10,14 +10,18 @@ import numpy as np import torch -from examples.atari.atari_network import DQN -from examples.atari.atari_wrapper import make_atari_env from examples.offline.utils import load_buffer +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.imitation.imitation_base import ( + ImitationPolicy, + OfflineImitationLearning, +) +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import DQNet +from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import ImitationPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OfflineTrainer +from tianshou.trainer import OfflineTrainerParams from tianshou.utils.space_info import SpaceInfo @@ -27,35 +31,35 @@ def get_args() -> argparse.Namespace: parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--lr", type=float, default=0.0001) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--update-per-epoch", type=int, default=10000) - parser.add_argument("--batch-size", type=int, default=32) - parser.add_argument("--test-num", type=int, default=10) - parser.add_argument("--frames-stack", type=int, default=4) - parser.add_argument("--scale-obs", type=int, default=0) + parser.add_argument("--update_per_epoch", type=int, default=10000) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--num_test_envs", type=int, default=10) + parser.add_argument("--frames_stack", type=int, default=4) + parser.add_argument("--scale_obs", type=int, default=0) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="offline_atari.benchmark") + parser.add_argument("--wandb_project", type=str, default="offline_atari.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) - parser.add_argument("--log-interval", type=int, default=100) + parser.add_argument("--log_interval", type=int, default=100) parser.add_argument( - "--load-buffer-name", + "--load_buffer_name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5", ) - parser.add_argument("--buffer-from-rl-unplugged", action="store_true", default=False) + parser.add_argument("--buffer_from_rl_unplugged", action="store_true", default=False) parser.add_argument( "--device", type=str, @@ -70,7 +74,7 @@ def test_il(args: argparse.Namespace = get_args()) -> None: args.task, args.seed, 1, - args.test_num, + args.num_test_envs, scale=args.scale_obs, frame_stack=args.frames_stack, ) @@ -87,13 +91,17 @@ def test_il(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net = DQN(c, h, w, args.action_shape, device=args.device).to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) + net = DQNet(c=c, h=h, w=w, action_shape=args.action_shape).to(args.device) + optim = AdamOptimizerFactory(lr=args.lr) # define policy - policy: ImitationPolicy = ImitationPolicy(actor=net, optim=optim, action_space=env.action_space) + policy = ImitationPolicy(actor=net, action_space=env.action_space) + algorithm: OfflineImitationLearning = OfflineImitationLearning( + policy=policy, + optim=optim, + ) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # buffer if args.buffer_from_rl_unplugged: @@ -113,7 +121,7 @@ def test_il(args: argparse.Namespace = get_args()) -> None: print("Replay buffer size:", len(buffer), flush=True) # collector - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -136,7 +144,7 @@ def test_il(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: @@ -148,25 +156,26 @@ def watch() -> None: test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: watch() sys.exit(0) - result = OfflineTrainer( - policy=policy, - buffer=buffer, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.update_per_epoch, - episode_per_test=args.test_num, - batch_size=args.batch_size, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + result = algorithm.run_training( + OfflineTrainerParams( + buffer=buffer, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.update_per_epoch, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + ) + ) pprint.pprint(result) watch() diff --git a/examples/offline/convert_rl_unplugged_atari.py b/examples/offline/convert_rl_unplugged_atari.py index 1afd721a5..d999ad330 100755 --- a/examples/offline/convert_rl_unplugged_atari.py +++ b/examples/offline/convert_rl_unplugged_atari.py @@ -266,25 +266,25 @@ def main(args: Namespace) -> None: parser = ArgumentParser(usage=__doc__) parser.add_argument("--task", required=True, help="Name of the Atari game.") parser.add_argument( - "--run-id", + "--run_id", type=int, default=1, help="Run id to download and convert. Value in [1..5].", ) parser.add_argument( - "--shard-id", + "--shard_id", type=int, default=0, help="Shard id to download and convert. Value in [0..99].", ) - parser.add_argument("--total-num-shards", type=int, default=100, help="Total number of shards.") + parser.add_argument("--total_num_shards", type=int, default=100, help="Total number of shards.") parser.add_argument( - "--dataset-dir", + "--dataset_dir", default=os.path.expanduser("~/.rl_unplugged/datasets"), help="Directory for converted hdf5 files.", ) parser.add_argument( - "--cache-dir", + "--cache_dir", default=os.path.expanduser("~/.rl_unplugged/cache"), help="Directory for downloaded original datasets.", ) diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py index 08b38ded6..1fd1c51be 100644 --- a/examples/offline/d4rl_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -11,14 +11,16 @@ from torch.utils.tensorboard import SummaryWriter from examples.offline.utils import load_buffer_d4rl +from tianshou.algorithm import BCQ +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.imitation.bcq import BCQPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats from tianshou.env import SubprocVectorEnv -from tianshou.policy import BCQPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OfflineTrainer +from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import MLP, Net -from tianshou.utils.net.continuous import VAE, Critic, Perturbation +from tianshou.utils.net.continuous import VAE, ContinuousCritic, Perturbation from tianshou.utils.space_info import SpaceInfo @@ -26,23 +28,23 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="HalfCheetah-v2") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--expert-data-task", type=str, default="halfcheetah-expert-v2") - parser.add_argument("--buffer-size", type=int, default=1000000) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256]) - parser.add_argument("--actor-lr", type=float, default=1e-3) - parser.add_argument("--critic-lr", type=float, default=1e-3) - parser.add_argument("--start-timesteps", type=int, default=10000) + parser.add_argument("--expert_data_task", type=str, default="halfcheetah-expert-v2") + parser.add_argument("--buffer_size", type=int, default=1000000) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[256, 256]) + parser.add_argument("--actor_lr", type=float, default=1e-3) + parser.add_argument("--critic_lr", type=float, default=1e-3) + parser.add_argument("--start_timesteps", type=int, default=10000) parser.add_argument("--epoch", type=int, default=200) - parser.add_argument("--step-per-epoch", type=int, default=5000) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--batch-size", type=int, default=256) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=5000) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--batch_size", type=int, default=256) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=1 / 35) - parser.add_argument("--vae-hidden-sizes", type=int, nargs="*", default=[512, 512]) + parser.add_argument("--vae_hidden_sizes", type=int, nargs="*", default=[512, 512]) # default to 2 * action_dim - parser.add_argument("--latent-dim", type=int) + parser.add_argument("--latent_dim", type=int) parser.add_argument("--gamma", default=0.99) parser.add_argument("--tau", default=0.005) # Weighting for Clipped Double Q-learning in BCQ @@ -54,15 +56,15 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="offline_d4rl.benchmark") + parser.add_argument("--wandb_project", type=str, default="offline_d4rl.benchmark") parser.add_argument( "--watch", default=False, @@ -89,7 +91,7 @@ def test_bcq() -> None: print("Max_action", args.max_action) # test_envs = gym.make(args.task) - test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -101,38 +103,34 @@ def test_bcq() -> None: input_dim=args.state_dim + args.action_dim, output_dim=args.action_dim, hidden_sizes=args.hidden_sizes, - device=args.device, ) - actor = Perturbation(net_a, max_action=args.max_action, device=args.device, phi=args.phi).to( + actor = Perturbation(preprocess_net=net_a, max_action=args.max_action, phi=args.phi).to( args.device, ) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic1 = Critic(net_c1, device=args.device).to(args.device) - critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic(net_c2, device=args.device).to(args.device) - critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) + critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) + critic2 = ContinuousCritic(preprocess_net=net_c2).to(args.device) + critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) # vae # output_dim = 0, so the last Module in the encoder is ReLU vae_encoder = MLP( input_dim=args.state_dim + args.action_dim, hidden_sizes=args.vae_hidden_sizes, - device=args.device, ) if not args.latent_dim: args.latent_dim = args.action_dim * 2 @@ -140,41 +138,41 @@ def test_bcq() -> None: input_dim=args.state_dim + args.latent_dim, output_dim=args.action_dim, hidden_sizes=args.vae_hidden_sizes, - device=args.device, ) vae = VAE( - vae_encoder, - vae_decoder, + encoder=vae_encoder, + decoder=vae_decoder, hidden_dim=args.vae_hidden_sizes[-1], latent_dim=args.latent_dim, max_action=args.max_action, - device=args.device, ).to(args.device) - vae_optim = torch.optim.Adam(vae.parameters()) + vae_optim = AdamOptimizerFactory() - policy: BCQPolicy = BCQPolicy( + policy = BCQPolicy( actor_perturbation=actor, - actor_perturbation_optim=actor_optim, + action_space=env.action_space, critic=critic1, + vae=vae, + ) + algorithm: BCQ = BCQ( + policy=policy, + actor_perturbation_optim=actor_optim, critic_optim=critic1_optim, - action_space=env.action_space, critic2=critic2, critic2_optim=critic2_optim, - vae=vae, vae_optim=vae_optim, - device=args.device, gamma=args.gamma, tau=args.tau, lmbda=args.lmbda, - ) + ).to(args.device) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -198,31 +196,32 @@ def test_bcq() -> None: ) logger.load(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def watch() -> None: if args.resume_path is None: args.resume_path = os.path.join(log_path, "policy.pth") - policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) - collector = Collector[CollectStats](policy, env) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) + collector = Collector[CollectStats](algorithm, env) collector.collect(n_episode=1, render=1 / 35) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) - # trainer - result = OfflineTrainer( - policy=policy, - buffer=replay_buffer, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - episode_per_test=args.test_num, - batch_size=args.batch_size, - save_best_fn=save_best_fn, - logger=logger, - ).run() + # train + result = algorithm.run_training( + OfflineTrainerParams( + buffer=replay_buffer, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + save_best_fn=save_best_fn, + logger=logger, + ) + ) pprint.pprint(result) else: watch() @@ -230,7 +229,7 @@ def watch() -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index 5b68edf9e..51e965193 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -11,14 +11,16 @@ from torch.utils.tensorboard import SummaryWriter from examples.offline.utils import load_buffer_d4rl +from tianshou.algorithm import CQL +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats from tianshou.env import SubprocVectorEnv -from tianshou.policy import CQLPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OfflineTrainer +from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -37,32 +39,32 @@ def get_args() -> argparse.Namespace: help="The random seed to use.", ) parser.add_argument( - "--expert-data-task", + "--expert_data_task", type=str, default="hopper-expert-v2", help="The name of the OpenAI Gym environment to use for expert data collection.", ) parser.add_argument( - "--buffer-size", + "--buffer_size", type=int, default=1000000, help="The size of the replay buffer.", ) parser.add_argument( - "--hidden-sizes", + "--hidden_sizes", type=int, nargs="*", default=[256, 256], help="The list of hidden sizes for the neural networks.", ) parser.add_argument( - "--actor-lr", + "--actor_lr", type=float, default=1e-4, help="The learning rate for the actor network.", ) parser.add_argument( - "--critic-lr", + "--critic_lr", type=float, default=3e-4, help="The learning rate for the critic network.", @@ -74,25 +76,25 @@ def get_args() -> argparse.Namespace: help="The weight of the entropy term in the loss function.", ) parser.add_argument( - "--auto-alpha", + "--auto_alpha", default=True, action="store_true", help="Whether to use automatic entropy tuning.", ) parser.add_argument( - "--alpha-lr", + "--alpha_lr", type=float, default=1e-4, help="The learning rate for the entropy tuning.", ) parser.add_argument( - "--cql-alpha-lr", + "--cql_alpha_lr", type=float, default=3e-4, help="The learning rate for the CQL entropy tuning.", ) parser.add_argument( - "--start-timesteps", + "--start_timesteps", type=int, default=10000, help="The number of timesteps before starting to train.", @@ -104,19 +106,19 @@ def get_args() -> argparse.Namespace: help="The number of epochs to train for.", ) parser.add_argument( - "--step-per-epoch", + "--epoch_num_steps", type=int, default=5000, help="The number of steps per epoch.", ) parser.add_argument( - "--n-step", + "--n_step", type=int, default=3, help="The number of steps to use for N-step TD learning.", ) parser.add_argument( - "--batch-size", + "--batch_size", type=int, default=256, help="The batch size for training.", @@ -134,13 +136,13 @@ def get_args() -> argparse.Namespace: help="The temperature for the Boltzmann policy.", ) parser.add_argument( - "--cql-weight", + "--cql_weight", type=float, default=1.0, help="The weight of the CQL loss term.", ) parser.add_argument( - "--with-lagrange", + "--with_lagrange", type=bool, default=True, help="Whether to use the Lagrange multiplier for CQL.", @@ -152,20 +154,20 @@ def get_args() -> argparse.Namespace: help="Whether to use calibration for CQL.", ) parser.add_argument( - "--lagrange-threshold", + "--lagrange_threshold", type=float, default=10.0, help="The Lagrange multiplier threshold for CQL.", ) parser.add_argument("--gamma", type=float, default=0.99, help="The discount factor") parser.add_argument( - "--eval-freq", + "--eval_freq", type=int, default=1, help="The frequency of evaluation.", ) parser.add_argument( - "--test-num", + "--num_test_envs", type=int, default=10, help="The number of episodes to evaluate for.", @@ -189,13 +191,13 @@ def get_args() -> argparse.Namespace: help="The device to train on (cpu or cuda).", ) parser.add_argument( - "--resume-path", + "--resume_path", type=str, default=None, help="The path to the checkpoint to resume from.", ) parser.add_argument( - "--resume-id", + "--resume_id", type=str, default=None, help="The ID of the checkpoint to resume from.", @@ -206,7 +208,7 @@ def get_args() -> argparse.Namespace: default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="offline_d4rl.benchmark") + parser.add_argument("--wandb_project", type=str, default="offline_d4rl.benchmark") parser.add_argument( "--watch", default=False, @@ -235,7 +237,7 @@ def test_cql() -> None: print("Max_action", args.max_action) # test_envs = gym.make(args.task) - test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -247,16 +249,14 @@ def test_cql() -> None: state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, ) - actor = ActorProb( - net_a, + actor = ContinuousActorProbabilistic( + preprocess_net=net_a, action_shape=args.action_shape, - device=args.device, unbounded=True, conditioned_sigma=True, ).to(args.device) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) # critic network net_c1 = Net( @@ -264,32 +264,33 @@ def test_cql() -> None: action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic = Critic(net_c1, device=args.device).to(args.device) - critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) - critic2 = Critic(net_c2, device=args.device).to(args.device) - critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + critic = ContinuousCritic(preprocess_net=net_c1).to(args.device) + critic_optim = AdamOptimizerFactory(lr=args.critic_lr) + critic2 = ContinuousCritic(preprocess_net=net_c2).to(args.device) + critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) if args.auto_alpha: target_entropy = -args.action_dim - log_alpha = torch.zeros(1, requires_grad=True, device=args.device) - alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) - args.alpha = (target_entropy, log_alpha, alpha_optim) + log_alpha = 0.0 + alpha_optim = AdamOptimizerFactory(lr=args.alpha_lr) + args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim).to(args.device) - policy: CQLPolicy = CQLPolicy( + policy = SACPolicy( actor=actor, - actor_optim=actor_optim, + action_space=env.action_space, + ) + algorithm: CQL = CQL( + policy=policy, + policy_optim=actor_optim, critic=critic, critic_optim=critic_optim, - action_space=env.action_space, critic2=critic2, critic2_optim=critic2_optim, calibrated=args.calibrated, @@ -303,16 +304,15 @@ def test_cql() -> None: lagrange_threshold=args.lagrange_threshold, min_action=args.min_action, max_action=args.max_action, - device=args.device, - ) + ).to(args.device) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -336,31 +336,32 @@ def test_cql() -> None: ) logger.load(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def watch() -> None: if args.resume_path is None: args.resume_path = os.path.join(log_path, "policy.pth") - policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) - collector = Collector[CollectStats](policy, env) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) + collector = Collector[CollectStats](algorithm, env) collector.collect(n_episode=1, render=1 / 35) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) - # trainer - result = OfflineTrainer( - policy=policy, - buffer=replay_buffer, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - episode_per_test=args.test_num, - batch_size=args.batch_size, - save_best_fn=save_best_fn, - logger=logger, - ).run() + # train + result = algorithm.run_training( + OfflineTrainerParams( + buffer=replay_buffer, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + save_best_fn=save_best_fn, + logger=logger, + ) + ) pprint.pprint(result) else: watch() @@ -368,7 +369,7 @@ def watch() -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index e1e71fd82..3b05e5f1a 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -11,14 +11,18 @@ from torch.utils.tensorboard import SummaryWriter from examples.offline.utils import load_buffer_d4rl +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.imitation.imitation_base import ( + ImitationPolicy, + OfflineImitationLearning, +) +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats from tianshou.env import SubprocVectorEnv -from tianshou.policy import ImitationPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OfflineTrainer +from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import Actor +from tianshou.utils.net.continuous import ContinuousActorDeterministic from tianshou.utils.space_info import SpaceInfo @@ -26,13 +30,13 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="HalfCheetah-v2") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--expert-data-task", type=str, default="halfcheetah-expert-v2") - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256]) + parser.add_argument("--expert_data_task", type=str, default="halfcheetah-expert-v2") + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[256, 256]) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--epoch", type=int, default=200) - parser.add_argument("--step-per-epoch", type=int, default=5000) - parser.add_argument("--batch-size", type=int, default=256) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=5000) + parser.add_argument("--batch_size", type=int, default=256) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=1 / 35) parser.add_argument("--gamma", default=0.99) @@ -41,15 +45,15 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="offline_d4rl.benchmark") + parser.add_argument("--wandb_project", type=str, default="offline_d4rl.benchmark") parser.add_argument( "--watch", default=False, @@ -75,7 +79,7 @@ def test_il() -> None: args.action_dim = space_info.action_info.action_dim print("Max_action", args.max_action) - test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -86,31 +90,32 @@ def test_il() -> None: state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, ) - actor = Actor( - net, + actor = ContinuousActorDeterministic( + preprocess_net=net, action_shape=args.action_shape, max_action=args.max_action, - device=args.device, ).to(args.device) - optim = torch.optim.Adam(actor.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) - policy: ImitationPolicy = ImitationPolicy( + policy = ImitationPolicy( actor=actor, - optim=optim, action_space=env.action_space, action_scaling=True, action_bound_method="clip", ) + algorithm: OfflineImitationLearning = OfflineImitationLearning( + policy=policy, + optim=optim, + ) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -134,31 +139,32 @@ def test_il() -> None: ) logger.load(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def watch() -> None: if args.resume_path is None: args.resume_path = os.path.join(log_path, "policy.pth") - policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) - collector = Collector[CollectStats](policy, env) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) + collector = Collector[CollectStats](algorithm, env) collector.collect(n_episode=1, render=1 / 35) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) - # trainer - result = OfflineTrainer( - policy=policy, - buffer=replay_buffer, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - episode_per_test=args.test_num, - batch_size=args.batch_size, - save_best_fn=save_best_fn, - logger=logger, - ).run() + # train + result = algorithm.run_training( + OfflineTrainerParams( + buffer=replay_buffer, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + save_best_fn=save_best_fn, + logger=logger, + ) + ) pprint.pprint(result) else: watch() @@ -166,7 +172,7 @@ def watch() -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index f4e8b38c2..d0a4d42ae 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -11,15 +11,17 @@ from torch.utils.tensorboard import SummaryWriter from examples.offline.utils import load_buffer_d4rl, normalize_all_obs_in_replay_buffer +from tianshou.algorithm import TD3BC +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats from tianshou.env import BaseVectorEnv, SubprocVectorEnv, VectorEnvNormObs from tianshou.exploration import GaussianNoise -from tianshou.policy import TD3BCPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OfflineTrainer +from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import Actor, Critic +from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -27,27 +29,27 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="HalfCheetah-v2") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--expert-data-task", type=str, default="halfcheetah-expert-v2") - parser.add_argument("--buffer-size", type=int, default=1000000) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256]) - parser.add_argument("--actor-lr", type=float, default=3e-4) - parser.add_argument("--critic-lr", type=float, default=3e-4) + parser.add_argument("--expert_data_task", type=str, default="halfcheetah-expert-v2") + parser.add_argument("--buffer_size", type=int, default=1000000) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[256, 256]) + parser.add_argument("--actor_lr", type=float, default=3e-4) + parser.add_argument("--critic_lr", type=float, default=3e-4) parser.add_argument("--epoch", type=int, default=200) - parser.add_argument("--step-per-epoch", type=int, default=5000) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--batch-size", type=int, default=256) + parser.add_argument("--epoch_num_steps", type=int, default=5000) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--alpha", type=float, default=2.5) - parser.add_argument("--exploration-noise", type=float, default=0.1) - parser.add_argument("--policy-noise", type=float, default=0.2) - parser.add_argument("--noise-clip", type=float, default=0.5) - parser.add_argument("--update-actor-freq", type=int, default=2) + parser.add_argument("--exploration_noise", type=float, default=0.1) + parser.add_argument("--policy_noise", type=float, default=0.2) + parser.add_argument("--noise_clip", type=float, default=0.5) + parser.add_argument("--update_actor_freq", type=int, default=2) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--norm-obs", type=int, default=1) + parser.add_argument("--norm_obs", type=int, default=1) - parser.add_argument("--eval-freq", type=int, default=1) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--eval_freq", type=int, default=1) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=1 / 35) parser.add_argument( @@ -55,15 +57,15 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="offline_d4rl.benchmark") + parser.add_argument("--wandb_project", type=str, default="offline_d4rl.benchmark") parser.add_argument( "--watch", default=False, @@ -91,7 +93,7 @@ def test_td3_bc() -> None: print("Max_action", args.max_action) test_envs: BaseVectorEnv - test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) if args.norm_obs: test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False) @@ -103,17 +105,15 @@ def test_td3_bc() -> None: # model # actor network net_a = Net( - args.state_shape, + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device, ) - actor = Actor( - net_a, + actor = ContinuousActorDeterministic( + preprocess_net=net_a, action_shape=args.action_shape, max_action=args.max_action, - device=args.device, ).to(args.device) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) # critic network net_c1 = Net( @@ -121,45 +121,46 @@ def test_td3_bc() -> None: action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic1 = Critic(net_c1, device=args.device).to(args.device) - critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic(net_c2, device=args.device).to(args.device) - critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) + critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) + critic2 = ContinuousCritic(preprocess_net=net_c2).to(args.device) + critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) - policy: TD3BCPolicy = TD3BCPolicy( + policy = ContinuousDeterministicPolicy( actor=actor, - actor_optim=actor_optim, + exploration_noise=GaussianNoise(sigma=args.exploration_noise), + action_space=env.action_space, + ) + algorithm: TD3BC = TD3BC( + policy=policy, + policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, critic2=critic2, critic2_optim=critic2_optim, tau=args.tau, gamma=args.gamma, - exploration_noise=GaussianNoise(sigma=args.exploration_noise), policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq, noise_clip=args.noise_clip, alpha=args.alpha, - estimation_step=args.n_step, - action_space=env.action_space, + n_step_return_horizon=args.n_step, ) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -183,15 +184,15 @@ def test_td3_bc() -> None: ) logger.load(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def watch() -> None: if args.resume_path is None: args.resume_path = os.path.join(log_path, "policy.pth") - policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) - collector = Collector[CollectStats](policy, env) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) + collector = Collector[CollectStats](algorithm, env) collector.collect(n_episode=1, render=1 / 35) if not args.watch: @@ -199,18 +200,19 @@ def watch() -> None: if args.norm_obs: replay_buffer, obs_rms = normalize_all_obs_in_replay_buffer(replay_buffer) test_envs.set_obs_rms(obs_rms) - # trainer - result = OfflineTrainer( - policy=policy, - buffer=replay_buffer, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - episode_per_test=args.test_num, - batch_size=args.batch_size, - save_best_fn=save_best_fn, - logger=logger, - ).run() + # train + result = algorithm.run_training( + OfflineTrainerParams( + buffer=replay_buffer, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + save_best_fn=save_best_fn, + logger=logger, + ) + ) pprint.pprint(result) else: watch() @@ -218,7 +220,7 @@ def watch() -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) diff --git a/examples/vizdoom/README.md b/examples/vizdoom/README.md index ca151f19b..3a46aaf86 100644 --- a/examples/vizdoom/README.md +++ b/examples/vizdoom/README.md @@ -39,13 +39,13 @@ D4 can reach 700+ reward. Here is the result: To evaluate an agent's performance: ```bash -python3 vizdoom_c51.py --test-num 100 --resume-path policy.pth --watch --task {D1_basic|D3_battle|D4_battle2} +python3 vizdoom_c51.py --num_test_envs 100 --resume-path policy.pth --watch --task {D1_basic|D3_battle|D4_battle2} ``` To save `.lmp` files for recording: ```bash -python3 vizdoom_c51.py --save-lmp --test-num 100 --resume-path policy.pth --watch --task {D1_basic|D3_battle|D4_battle2} +python3 vizdoom_c51.py --save-lmp --num_test_envs 100 --resume-path policy.pth --watch --task {D1_basic|D3_battle|D4_battle2} ``` it will store `lmp` file in `lmps/` directory. To watch these `lmp` files (for example, d3 lmp): diff --git a/examples/vizdoom/env.py b/examples/vizdoom/env.py index 2869acd1a..f8dcc8816 100644 --- a/examples/vizdoom/env.py +++ b/examples/vizdoom/env.py @@ -133,12 +133,12 @@ def make_vizdoom_env( res: tuple[int], save_lmp: bool = False, seed: int | None = None, - training_num: int = 10, - test_num: int = 10, + num_train_envs: int = 10, + num_test_envs: int = 10, ) -> tuple[Env, ShmemVectorEnv, ShmemVectorEnv]: cpu_count = os.cpu_count() if cpu_count is not None: - test_num = min(cpu_count - 1, test_num) + num_test_envs = min(cpu_count - 1, num_test_envs) if envpool is not None: task_id = "".join([i.capitalize() for i in task.split("_")]) + "-v1" lmp_save_dir = "lmps/" if save_lmp else "" @@ -154,7 +154,7 @@ def make_vizdoom_env( frame_skip=frame_skip, stack_num=res[0], seed=seed, - num_envs=training_num, + num_envs=num_train_envs, reward_config=reward_config, use_combined_action=True, max_episode_steps=2625, @@ -166,7 +166,7 @@ def make_vizdoom_env( stack_num=res[0], lmp_save_dir=lmp_save_dir, seed=seed, - num_envs=test_num, + num_envs=num_test_envs, reward_config=reward_config, use_combined_action=True, max_episode_steps=2625, @@ -176,10 +176,10 @@ def make_vizdoom_env( cfg_path = f"maps/{task}.cfg" env = Env(cfg_path, frame_skip, res) train_envs = ShmemVectorEnv( - [lambda: Env(cfg_path, frame_skip, res) for _ in range(training_num)], + [lambda: Env(cfg_path, frame_skip, res) for _ in range(num_train_envs)], ) test_envs = ShmemVectorEnv( - [lambda: Env(cfg_path, frame_skip, res, save_lmp) for _ in range(test_num)], + [lambda: Env(cfg_path, frame_skip, res, save_lmp) for _ in range(num_test_envs)], ) train_envs.seed(seed) test_envs.seed(seed) diff --git a/examples/vizdoom/network.py b/examples/vizdoom/network.py deleted file mode 120000 index a0c543acb..000000000 --- a/examples/vizdoom/network.py +++ /dev/null @@ -1 +0,0 @@ -../atari/atari_network.py \ No newline at end of file diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 0c23e6e8e..941a61c3c 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -7,37 +7,39 @@ import numpy as np import torch from env import make_vizdoom_env -from network import C51 +from tianshou.algorithm import C51 +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.c51 import C51Policy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import C51Net from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import C51Policy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="D1_basic") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--eps-test", type=float, default=0.005) - parser.add_argument("--eps-train", type=float, default=1.0) - parser.add_argument("--eps-train-final", type=float, default=0.05) - parser.add_argument("--buffer-size", type=int, default=2000000) + parser.add_argument("--eps_test", type=float, default=0.005) + parser.add_argument("--eps_train", type=float, default=1.0) + parser.add_argument("--eps_train_final", type=float, default=0.05) + parser.add_argument("--buffer_size", type=int, default=2000000) parser.add_argument("--lr", type=float, default=0.0001) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--num-atoms", type=int, default=51) - parser.add_argument("--v-min", type=float, default=-10.0) - parser.add_argument("--v-max", type=float, default=10.0) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--num_atoms", type=int, default=51) + parser.add_argument("--v_min", type=float, default=-10.0) + parser.add_argument("--v_max", type=float, default=10.0) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=500) parser.add_argument("--epoch", type=int, default=300) - parser.add_argument("--step-per-epoch", type=int, default=100000) - parser.add_argument("--step-per-collect", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=100000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) + parser.add_argument("--update_per_step", type=float, default=0.1) + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -45,17 +47,17 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--frames-stack", type=int, default=4) - parser.add_argument("--skip-num", type=int, default=4) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--frames_stack", type=int, default=4) + parser.add_argument("--skip_num", type=int, default=4) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="vizdoom.benchmark") + parser.add_argument("--wandb_project", type=str, default="vizdoom.benchmark") parser.add_argument( "--watch", default=False, @@ -63,12 +65,12 @@ def get_args() -> argparse.Namespace: help="watch the play of pre-trained policy only", ) parser.add_argument( - "--save-lmp", + "--save_lmp", default=False, action="store_true", help="save lmp file for replay whole episode", ) - parser.add_argument("--save-buffer-name", type=str, default=None) + parser.add_argument("--save_buffer_name", type=str, default=None) return parser.parse_args() @@ -80,8 +82,8 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: (args.frames_stack, 84, 84), args.save_lmp, args.seed, - args.training_num, - args.test_num, + args.num_train_envs, + args.num_test_envs, ) args.state_shape = env.observation_space.shape args.action_shape = env.action_space.shape or env.action_space.n @@ -92,23 +94,29 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # define model - net = C51(*args.state_shape, args.action_shape, args.num_atoms, args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) - # define policy - policy: C51Policy = C51Policy( + c, h, w = args.state_shape + net = C51Net(c=c, h=h, w=w, action_shape=args.action_shape, num_atoms=args.num_atoms) + optim = AdamOptimizerFactory(lr=args.lr) + # define policy and algorithm + policy = C51Policy( model=net, - optim=optim, - discount_factor=args.gamma, action_space=env.action_space, num_atoms=args.num_atoms, v_min=args.v_min, v_max=args.v_max, - estimation_step=args.n_step, + eps_training=args.eps_train, + eps_inference=args.eps_test, + ) + algorithm: C51 = C51( + policy=policy, + optim=optim, + gamma=args.gamma, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM @@ -120,8 +128,8 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) # collector - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -144,7 +152,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: @@ -158,17 +166,13 @@ def train_fn(epoch: int, env_step: int) -> None: eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final) else: eps = args.eps_train_final - policy.set_eps(eps) + policy.set_eps_training(eps) if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/eps": eps}) - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) - # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") @@ -179,7 +183,9 @@ def watch() -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) + collector = Collector[CollectStats]( + algorithm, test_envs, buffer, exploration_noise=True + ) result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -187,7 +193,7 @@ def watch() -> None: else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: @@ -196,25 +202,25 @@ def watch() -> None: # test train_collector and start filling replay buffer train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - test_in_train=False, - ).run() + train_collector.collect(n_step=args.batch_size * args.num_train_envs) + # train + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + train_fn=train_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_step_num_gradient_steps_per_sample=args.update_per_step, + test_in_train=False, + ) + ) pprint.pprint(result) watch() diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index 6d4f55d14..499f548a1 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -7,45 +7,50 @@ import numpy as np import torch from env import make_vizdoom_env -from network import DQN from torch.distributions import Categorical -from torch.optim.lr_scheduler import LambdaLR +from tianshou.algorithm import PPO +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelbased.icm import ICMOnPolicyWrapper +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import DQNet from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import ICMPolicy, PPOPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OnpolicyTrainer -from tianshou.utils.net.common import ActorCritic -from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule +from tianshou.trainer import OnPolicyTrainerParams +from tianshou.utils.net.discrete import ( + DiscreteActor, + DiscreteCritic, + IntrinsicCuriosityModule, +) def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="D1_basic") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--buffer_size", type=int, default=100000) parser.add_argument("--lr", type=float, default=0.00002) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=300) - parser.add_argument("--step-per-epoch", type=int, default=100000) - parser.add_argument("--step-per-collect", type=int, default=1000) - parser.add_argument("--repeat-per-collect", type=int, default=4) - parser.add_argument("--batch-size", type=int, default=256) - parser.add_argument("--hidden-size", type=int, default=512) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) - parser.add_argument("--rew-norm", type=int, default=False) - parser.add_argument("--vf-coef", type=float, default=0.5) - parser.add_argument("--ent-coef", type=float, default=0.01) - parser.add_argument("--gae-lambda", type=float, default=0.95) - parser.add_argument("--lr-decay", type=int, default=True) - parser.add_argument("--max-grad-norm", type=float, default=0.5) - parser.add_argument("--eps-clip", type=float, default=0.2) - parser.add_argument("--dual-clip", type=float, default=None) - parser.add_argument("--value-clip", type=int, default=0) - parser.add_argument("--norm-adv", type=int, default=1) - parser.add_argument("--recompute-adv", type=int, default=0) + parser.add_argument("--epoch_num_steps", type=int, default=100000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=1000) + parser.add_argument("--update_step_num_repetitions", type=int, default=4) + parser.add_argument("--batch_size", type=int, default=256) + parser.add_argument("--hidden_size", type=int, default=512) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) + parser.add_argument("--return_scaling", type=int, default=False) + parser.add_argument("--vf_coef", type=float, default=0.5) + parser.add_argument("--ent_coef", type=float, default=0.01) + parser.add_argument("--gae_lambda", type=float, default=0.95) + parser.add_argument("--lr_decay", type=int, default=True) + parser.add_argument("--max_grad_norm", type=float, default=0.5) + parser.add_argument("--eps_clip", type=float, default=0.2) + parser.add_argument("--dual_clip", type=float, default=None) + parser.add_argument("--value_clip", type=int, default=0) + parser.add_argument("--advantage_normalization", type=int, default=1) + parser.add_argument("--recompute_adv", type=int, default=0) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -53,17 +58,17 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--frames-stack", type=int, default=4) - parser.add_argument("--skip-num", type=int, default=4) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--frames_stack", type=int, default=4) + parser.add_argument("--skip_num", type=int, default=4) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="vizdoom.benchmark") + parser.add_argument("--wandb_project", type=str, default="vizdoom.benchmark") parser.add_argument( "--watch", default=False, @@ -71,26 +76,26 @@ def get_args() -> argparse.Namespace: help="watch the play of pre-trained policy only", ) parser.add_argument( - "--save-lmp", + "--save_lmp", default=False, action="store_true", help="save lmp file for replay whole episode", ) - parser.add_argument("--save-buffer-name", type=str, default=None) + parser.add_argument("--save_buffer_name", type=str, default=None) parser.add_argument( - "--icm-lr-scale", + "--icm_lr_scale", type=float, default=0.0, help="use intrinsic curiosity module with this lr scale", ) parser.add_argument( - "--icm-reward-scale", + "--icm_reward_scale", type=float, default=0.01, help="scaling factor for intrinsic curiosity reward", ) parser.add_argument( - "--icm-forward-loss-weight", + "--icm_forward_loss_weight", type=float, default=0.2, help="weight for the forward model loss in ICM", @@ -106,11 +111,11 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: (args.frames_stack, 84, 84), args.save_lmp, args.seed, - args.training_num, - args.test_num, + args.num_train_envs, + args.num_test_envs, ) args.state_shape = env.observation_space.shape - args.action_shape = env.action_space.shape or env.action_space.n + args.action_shape = env.action_space.n # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) @@ -118,77 +123,84 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # define model - net = DQN( - *args.state_shape, - args.action_shape, - device=args.device, + c, h, w = args.state_shape + net = DQNet( + c=c, + h=h, + w=w, + action_shape=args.action_shape, features_only=True, - output_dim=args.hidden_size, + output_dim_added_layer=args.hidden_size, ) - actor = Actor(net, args.action_shape, device=args.device, softmax_output=False) - critic = Critic(net, device=args.device) - optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr) + actor = DiscreteActor(preprocess_net=net, action_shape=args.action_shape, softmax_output=False) + critic = DiscreteCritic(preprocess_net=net) + optim = AdamOptimizerFactory(lr=args.lr) - lr_scheduler = None if args.lr_decay: - # decay learning rate to 0 linearly - max_update_num = np.ceil(args.step_per_epoch / args.step_per_collect) * args.epoch - - lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + optim.with_lr_scheduler_factory( + LRSchedulerFactoryLinear( + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + ) + ) - # define policy def dist(logits: torch.Tensor) -> Categorical: return Categorical(logits=logits) - policy: PPOPolicy = PPOPolicy( + # define policy and algorithm + policy = ProbabilisticActorPolicy( actor=actor, + dist_fn=dist, + action_scaling=False, + action_space=env.action_space, + ) + algorithm: PPO | ICMOnPolicyWrapper + algorithm = PPO( + policy=policy, critic=critic, optim=optim, - dist_fn=dist, - discount_factor=args.gamma, + gamma=args.gamma, gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm, vf_coef=args.vf_coef, ent_coef=args.ent_coef, - reward_normalization=args.rew_norm, - action_scaling=False, - lr_scheduler=lr_scheduler, - action_space=env.action_space, + return_scaling=args.return_scaling, eps_clip=args.eps_clip, value_clip=args.value_clip, dual_clip=args.dual_clip, - advantage_normalization=args.norm_adv, + advantage_normalization=args.advantage_normalization, recompute_advantage=args.recompute_adv, ).to(args.device) if args.icm_lr_scale > 0: - feature_net = DQN( - *args.state_shape, - args.action_shape, - device=args.device, + c, h, w = args.state_shape + feature_net = DQNet( + c=c, + h=h, + w=w, + action_shape=args.action_shape, features_only=True, - output_dim=args.hidden_size, + output_dim_added_layer=args.hidden_size, ) action_dim = np.prod(args.action_shape) feature_dim = feature_net.output_dim icm_net = IntrinsicCuriosityModule( - feature_net.net, - feature_dim, - action_dim, - device=args.device, + feature_net=feature_net.net, + feature_dim=feature_dim, + action_dim=action_dim, ) - icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) - policy: ICMPolicy = ICMPolicy( # type: ignore[no-redef] - policy=policy, + icm_optim = AdamOptimizerFactory(lr=args.lr) + algorithm = ICMOnPolicyWrapper( + wrapped_algorithm=algorithm, model=icm_net, optim=icm_optim, - action_space=env.action_space, lr_scale=args.icm_lr_scale, reward_scale=args.icm_reward_scale, forward_loss_weight=args.icm_forward_loss_weight, ).to(args.device) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM @@ -200,8 +212,8 @@ def dist(logits: torch.Tensor) -> Categorical: stack_num=args.frames_stack, ) # collector - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -224,7 +236,7 @@ def dist(logits: torch.Tensor) -> Categorical: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: @@ -245,7 +257,9 @@ def watch() -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) + collector = Collector[CollectStats]( + algorithm, test_envs, buffer, exploration_noise=True + ) result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -253,7 +267,7 @@ def watch() -> None: else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: @@ -262,23 +276,25 @@ def watch() -> None: # test train_collector and start filling replay buffer train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) - # trainer - result = OnpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - step_per_collect=args.step_per_collect, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - test_in_train=False, - ).run() + train_collector.collect(n_step=args.batch_size * args.num_train_envs) + + # train + result = algorithm.run_training( + OnPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + collection_step_num_env_steps=args.collection_step_num_env_steps, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=False, + ) + ) pprint.pprint(result) watch() diff --git a/poetry.lock b/poetry.lock index 2cd8e2ed7..c009dd5ed 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand. [[package]] name = "absl-py" @@ -6,6 +6,7 @@ version = "2.0.0" description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py." optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "absl-py-2.0.0.tar.gz", hash = "sha256:d9690211c5fcfefcdd1a45470ac2b5c5acd45241c3af71eed96bc5441746c0d5"}, {file = "absl_py-2.0.0-py3-none-any.whl", hash = "sha256:9a28abb62774ae4e8edbe2dd4c49ffcd45a6a848952a5eccc6a49f3f0fc1e2f3"}, @@ -17,6 +18,7 @@ version = "0.0.4" description = "A collection of accessible pygments styles" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "accessible-pygments-0.0.4.tar.gz", hash = "sha256:e7b57a9b15958e9601c7e9eb07a440c813283545a20973f2574a5f453d0e953e"}, {file = "accessible_pygments-0.0.4-py2.py3-none-any.whl", hash = "sha256:416c6d8c1ea1c5ad8701903a20fcedf953c6e720d64f33dc47bfb2d3f2fa4e8d"}, @@ -31,6 +33,8 @@ version = "1.3.1" description = "aiosignal: a list of registered asynchronous callbacks" optional = false python-versions = ">=3.7" +groups = ["dev"] +markers = "sys_platform != \"win32\"" files = [ {file = "aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17"}, {file = "aiosignal-1.3.1.tar.gz", hash = "sha256:54cd96e15e1649b75d6c87526a6ff0b6c1b0dd3459f43d9ca11d48c339b68cfc"}, @@ -45,6 +49,7 @@ version = "0.7.13" description = "A configurable sidebar-enabled Sphinx theme" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "alabaster-0.7.13-py3-none-any.whl", hash = "sha256:1ee19aca801bbabb5ba3f5f258e4422dfa86f82f3e9cefb0859b283cdd7f62a3"}, {file = "alabaster-0.7.13.tar.gz", hash = "sha256:a27a4a084d5e690e16e01e03ad2b2e552c61a65469419b907243193de1a84ae2"}, @@ -56,6 +61,8 @@ version = "0.8.1" description = "The Arcade Learning Environment (ALE) - a platform for AI research." optional = true python-versions = ">=3.7" +groups = ["main"] +markers = "extra == \"atari\"" files = [ {file = "ale_py-0.8.1-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:b2aa2f69a4169742800615970efe6914fa856e33eaf7fa9133c0e06a617a80e2"}, {file = "ale_py-0.8.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6f2f6b92c8fd6189654979bbf0b305dbe0ecf82176c47f244d8c1cbc36286b89"}, @@ -91,6 +98,7 @@ version = "4.0.0" description = "High level compatibility layer for multiple asynchronous event loop implementations" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "anyio-4.0.0-py3-none-any.whl", hash = "sha256:cfdb2b588b9fc25ede96d8db56ed50848b0b649dca3dd1df0b11f683bb9e0b5f"}, {file = "anyio-4.0.0.tar.gz", hash = "sha256:f7ed51751b2c2add651e5747c891b47e26d2a21be5d32d9311dfe9692f3e5d7a"}, @@ -102,7 +110,7 @@ sniffio = ">=1.1" [package.extras] doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)"] -test = ["anyio[trio]", "coverage[toml] (>=7)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17) ; python_version < \"3.12\" and platform_python_implementation == \"CPython\" and platform_system != \"Windows\""] trio = ["trio (>=0.22)"] [[package]] @@ -111,6 +119,7 @@ version = "1.4.1" description = "Handy tools for working with URLs and APIs." optional = false python-versions = ">=3.6.1" +groups = ["dev"] files = [ {file = "apeye-1.4.1-py3-none-any.whl", hash = "sha256:44e58a9104ec189bf42e76b3a7fe91e2b2879d96d48e9a77e5e32ff699c9204e"}, {file = "apeye-1.4.1.tar.gz", hash = "sha256:14ea542fad689e3bfdbda2189a354a4908e90aee4bf84c15ab75d68453d76a36"}, @@ -132,6 +141,7 @@ version = "1.1.4" description = "Core (offline) functionality for the apeye library." optional = false python-versions = ">=3.6.1" +groups = ["dev"] files = [ {file = "apeye_core-1.1.4-py3-none-any.whl", hash = "sha256:084bc696448d3ac428fece41c1f2eb08fa9d9ce1d1b2f4d43187e3def4528a60"}, {file = "apeye_core-1.1.4.tar.gz", hash = "sha256:72bb89fed3baa647cb81aa28e1d851787edcbf9573853b5d2b5f87c02f50eaf5"}, @@ -147,6 +157,8 @@ version = "0.1.3" description = "Disable App Nap on macOS >= 10.9" optional = false python-versions = "*" +groups = ["dev"] +markers = "platform_system == \"Darwin\" or sys_platform == \"darwin\"" files = [ {file = "appnope-0.1.3-py2.py3-none-any.whl", hash = "sha256:265a455292d0bd8a72453494fa24df5a11eb18373a60c7c0430889f22548605e"}, {file = "appnope-0.1.3.tar.gz", hash = "sha256:02bd91c4de869fbb1e1c50aafc4098827a7a54ab2f39d9dcba6c9547ed920e24"}, @@ -158,6 +170,8 @@ version = "5.3.1" description = "ARCH for Python" optional = true python-versions = ">=3.7" +groups = ["main"] +markers = "extra == \"eval\"" files = [ {file = "arch-5.3.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:75fa6f9386ecc2df81bcbf5d055a290a697482ca51e0b3459dab183d288993cb"}, {file = "arch-5.3.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f9c9220d331618322517e0f2b3b3529f9c51f5e5a891441da4a107fd2d6d7fce"}, @@ -197,6 +211,7 @@ version = "23.1.0" description = "Argon2 for Python" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "argon2_cffi-23.1.0-py3-none-any.whl", hash = "sha256:c670642b78ba29641818ab2e68bd4e6a78ba53b7eff7b4c3815ae16abf91c7ea"}, {file = "argon2_cffi-23.1.0.tar.gz", hash = "sha256:879c3e79a2729ce768ebb7d36d4609e3a78a4ca2ec3a9f12286ca057e3d0db08"}, @@ -217,6 +232,7 @@ version = "21.2.0" description = "Low-level CFFI bindings for Argon2" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "argon2-cffi-bindings-21.2.0.tar.gz", hash = "sha256:bb89ceffa6c791807d1305ceb77dbfacc5aa499891d2c55661c6459651fc39e3"}, {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:ccb949252cb2ab3a08c02024acb77cfb179492d5701c7cbdbfd776124d4d2367"}, @@ -254,6 +270,7 @@ version = "1.3.0" description = "Better dates & times for Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "arrow-1.3.0-py3-none-any.whl", hash = "sha256:c728b120ebc00eb84e01882a6f5e7927a53960aa990ce7dd2b10f39005a67f80"}, {file = "arrow-1.3.0.tar.gz", hash = "sha256:d4540617648cb5f895730f1ad8c82a65f2dad0166f57b75f3ca54759c4d67a85"}, @@ -273,6 +290,7 @@ version = "2.4.1" description = "Annotate AST trees with source code positions" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "asttokens-2.4.1-py2.py3-none-any.whl", hash = "sha256:051ed49c3dcae8913ea7cd08e46a606dba30b79993209636c4875bc1d637bc24"}, {file = "asttokens-2.4.1.tar.gz", hash = "sha256:b03869718ba9a6eb027e134bfdf69f38a236d681c83c160d510768af11254ba0"}, @@ -282,8 +300,8 @@ files = [ six = ">=1.12.0" [package.extras] -astroid = ["astroid (>=1,<2)", "astroid (>=2,<4)"] -test = ["astroid (>=1,<2)", "astroid (>=2,<4)", "pytest"] +astroid = ["astroid (>=1,<2) ; python_version < \"3\"", "astroid (>=2,<4) ; python_version >= \"3\""] +test = ["astroid (>=1,<2) ; python_version < \"3\"", "astroid (>=2,<4) ; python_version >= \"3\"", "pytest"] [[package]] name = "async-lru" @@ -291,6 +309,7 @@ version = "2.0.4" description = "Simple LRU cache for asyncio" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "async-lru-2.0.4.tar.gz", hash = "sha256:b8a59a5df60805ff63220b2a0c5b5393da5521b113cd5465a44eb037d81a5627"}, {file = "async_lru-2.0.4-py3-none-any.whl", hash = "sha256:ff02944ce3c288c5be660c42dbcca0742b32c3b279d6dceda655190240b99224"}, @@ -302,6 +321,7 @@ version = "23.1.0" description = "Classes Without Boilerplate" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "attrs-23.1.0-py3-none-any.whl", hash = "sha256:1f28b4522cdc2fb4256ac1a020c78acf9cba2c6b461ccd2c126f3aa8e8335d04"}, {file = "attrs-23.1.0.tar.gz", hash = "sha256:6279836d581513a26f1bf235f9acd333bc9115683f14f7e8fae46c98fc50e015"}, @@ -312,7 +332,7 @@ cov = ["attrs[tests]", "coverage[toml] (>=5.3)"] dev = ["attrs[docs,tests]", "pre-commit"] docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope-interface"] tests = ["attrs[tests-no-zope]", "zope-interface"] -tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +tests-no-zope = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.1.1) ; platform_python_implementation == \"CPython\"", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version < \"3.11\"", "pytest-xdist[psutil]"] [[package]] name = "autodocsumm" @@ -320,6 +340,7 @@ version = "0.2.11" description = "Extended sphinx autodoc including automatic autosummaries" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "autodocsumm-0.2.11-py3-none-any.whl", hash = "sha256:f1d0a623bf1ad64d979a9e23fd360d1fb1b8f869beaf3197f711552cddc174e2"}, {file = "autodocsumm-0.2.11.tar.gz", hash = "sha256:183212bd9e9f3b58a96bb21b7958ee4e06224107aa45b2fd894b61b83581b9a9"}, @@ -334,6 +355,7 @@ version = "2.0.4" description = "A tool that automatically formats Python code to conform to the PEP 8 style guide" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "autopep8-2.0.4-py2.py3-none-any.whl", hash = "sha256:067959ca4a07b24dbd5345efa8325f5f58da4298dab0dde0443d5ed765de80cb"}, {file = "autopep8-2.0.4.tar.gz", hash = "sha256:2913064abd97b3419d1cc83ea71f042cb821f87e45b9c88cad5ad3c4ea87fe0c"}, @@ -348,6 +370,8 @@ version = "0.4.2" description = "Automated installation of Atari ROMs for Gym/ALE-Py" optional = true python-versions = ">=3.6" +groups = ["main"] +markers = "extra == \"atari\"" files = [ {file = "AutoROM-0.4.2-py3-none-any.whl", hash = "sha256:719c9d363ef08391fdb7003d70df235b68f36de628d289a946c4a59a3adefa13"}, {file = "AutoROM-0.4.2.tar.gz", hash = "sha256:b426a39bc0ee3781c7791f28963a9b2e4385b6421eeaf2f368edc00c761d428a"}, @@ -368,6 +392,8 @@ version = "0.6.1" description = "Automated installation of Atari ROMs for Gym/ALE-Py" optional = true python-versions = ">=3.7" +groups = ["main"] +markers = "extra == \"atari\"" files = [ {file = "AutoROM.accept-rom-license-0.6.1.tar.gz", hash = "sha256:0c905a708d634a076f686802f672817d3585259ce3be0bde8713a4fb59e3159e"}, ] @@ -385,6 +411,7 @@ version = "2.13.1" description = "Internationalization utilities" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "Babel-2.13.1-py3-none-any.whl", hash = "sha256:7077a4984b02b6727ac10f1f7294484f737443d7e2e66c5e4380e41a3ae0b4ed"}, {file = "Babel-2.13.1.tar.gz", hash = "sha256:33e0952d7dd6374af8dbf6768cc4ddf3ccfefc244f9986d4074704f2fbd18900"}, @@ -402,6 +429,7 @@ version = "4.12.2" description = "Screen-scraping library" optional = false python-versions = ">=3.6.0" +groups = ["dev"] files = [ {file = "beautifulsoup4-4.12.2-py3-none-any.whl", hash = "sha256:bd2520ca0d9d7d12694a53d44ac482d181b4ec1888909b035a3dbf40d0f57d4a"}, {file = "beautifulsoup4-4.12.2.tar.gz", hash = "sha256:492bbc69dca35d12daac71c4db1bfff0c876c00ef4a2ffacce226d4638eb72da"}, @@ -420,6 +448,7 @@ version = "23.11.0" description = "The uncompromising code formatter." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "black-23.11.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dbea0bb8575c6b6303cc65017b46351dc5953eea5c0a59d7b7e3a2d2f433a911"}, {file = "black-23.11.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:412f56bab20ac85927f3a959230331de5614aecda1ede14b373083f62ec24e6f"}, @@ -462,6 +491,7 @@ version = "6.1.0" description = "An easy safelist-based HTML-sanitizing tool." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "bleach-6.1.0-py3-none-any.whl", hash = "sha256:3225f354cfc436b9789c66c4ee030194bee0568fbf9cbdad3bc8b5c26c5f12b6"}, {file = "bleach-6.1.0.tar.gz", hash = "sha256:0a31f1837963c41d46bbf1331b8778e1308ea0791db03cc4e7357b97cf42a8fe"}, @@ -480,6 +510,8 @@ version = "2.3.5" description = "Python Box2D" optional = true python-versions = "*" +groups = ["main"] +markers = "extra == \"box2d\"" files = [ {file = "box2d-py-2.3.5.tar.gz", hash = "sha256:b37dc38844bcd7def48a97111d2b082e4f81cca3cece7460feb3eacda0da2207"}, {file = "box2d_py-2.3.5-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:287aa54005c0644b47bf7ad72966e4068d66e56bcf8458f5b4a653ffe42a2618"}, @@ -494,6 +526,7 @@ version = "0.13.1" description = "httplib2 caching for requests" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "cachecontrol-0.13.1-py3-none-any.whl", hash = "sha256:95dedbec849f46dda3137866dc28b9d133fc9af55f5b805ab1291833e4457aa4"}, {file = "cachecontrol-0.13.1.tar.gz", hash = "sha256:f012366b79d2243a6118309ce73151bf52a38d4a5dac8ea57f09bd29087e506b"}, @@ -515,6 +548,7 @@ version = "5.3.2" description = "Extensible memoizing collections and decorators" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "cachetools-5.3.2-py3-none-any.whl", hash = "sha256:861f35a13a451f94e301ce2bec7cac63e881232ccce7ed67fab9b5df4d3beaa1"}, {file = "cachetools-5.3.2.tar.gz", hash = "sha256:086ee420196f7b2ab9ca2db2520aca326318b68fe5ba8bc4d49cca91add450f2"}, @@ -526,6 +560,7 @@ version = "2024.7.4" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.6" +groups = ["main", "dev"] files = [ {file = "certifi-2024.7.4-py3-none-any.whl", hash = "sha256:c198e21b1289c2ab85ee4e67bb4b4ef3ead0892059901a8d5b622f24a1101e90"}, {file = "certifi-2024.7.4.tar.gz", hash = "sha256:5a1e7645bc0ec61a09e26c36f6106dd4cf40c6db3a1fb6352b0244e7fb057c7b"}, @@ -537,6 +572,7 @@ version = "1.16.0" description = "Foreign Function Interface for Python calling C code." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "cffi-1.16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6b3d6606d369fc1da4fd8c357d026317fbb9c9b75d36dc16e90e84c26854b088"}, {file = "cffi-1.16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ac0f5edd2360eea2f1daa9e26a41db02dd4b0451b48f7c318e217ee092a213e9"}, @@ -601,6 +637,7 @@ version = "3.4.0" description = "Validate configuration and produce human readable error messages." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"}, {file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"}, @@ -612,6 +649,7 @@ version = "3.3.2" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." optional = false python-versions = ">=3.7.0" +groups = ["main", "dev"] files = [ {file = "charset-normalizer-3.3.2.tar.gz", hash = "sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5"}, {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3"}, @@ -711,10 +749,12 @@ version = "8.1.7" description = "Composable command line interface toolkit" optional = false python-versions = ">=3.7" +groups = ["main", "dev"] files = [ {file = "click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28"}, {file = "click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de"}, ] +markers = {main = "extra == \"atari\""} [package.dependencies] colorama = {version = "*", markers = "platform_system == \"Windows\""} @@ -725,6 +765,7 @@ version = "3.0.0" description = "Pickler class to extend the standard pickle.Pickler functionality" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "cloudpickle-3.0.0-py3-none-any.whl", hash = "sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7"}, {file = "cloudpickle-3.0.0.tar.gz", hash = "sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882"}, @@ -736,10 +777,12 @@ version = "0.4.6" description = "Cross-platform colored terminal text." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +groups = ["main", "dev"] files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +markers = {main = "platform_system == \"Windows\"", dev = "platform_system == \"Windows\" or sys_platform == \"win32\""} [[package]] name = "comm" @@ -747,6 +790,7 @@ version = "0.2.0" description = "Jupyter Python Comm implementation, for usage in ipykernel, xeus-python etc." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "comm-0.2.0-py3-none-any.whl", hash = "sha256:2da8d9ebb8dd7bfc247adaff99f24dce705638a8042b85cb995066793e391001"}, {file = "comm-0.2.0.tar.gz", hash = "sha256:a517ea2ca28931c7007a7a99c562a0fa5883cfb48963140cf642c41c948498be"}, @@ -764,6 +808,7 @@ version = "1.2.1" description = "Python library for calculating contours of 2D quadrilateral grids" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "contourpy-1.2.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bd7c23df857d488f418439686d3b10ae2fbf9bc256cd045b37a8c16575ea1040"}, {file = "contourpy-1.2.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5b9eb0ca724a241683c9685a484da9d35c872fd42756574a7cfbf58af26677fd"}, @@ -827,6 +872,7 @@ version = "7.3.2" description = "Code coverage measurement for Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "coverage-7.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d872145f3a3231a5f20fd48500274d7df222e291d90baa2026cc5152b7ce86bf"}, {file = "coverage-7.3.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:310b3bb9c91ea66d59c53fa4989f57d2436e08f18fb2f421a1b0b6b8cc7fffda"}, @@ -883,7 +929,7 @@ files = [ ] [package.extras] -toml = ["tomli"] +toml = ["tomli ; python_full_version <= \"3.11.0a6\""] [[package]] name = "cssutils" @@ -891,6 +937,7 @@ version = "2.9.0" description = "A CSS Cascading Style Sheets library for Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "cssutils-2.9.0-py3-none-any.whl", hash = "sha256:f8b013169e281c0c6083207366c5005f5dd4549055f7aba840384fb06a78745c"}, {file = "cssutils-2.9.0.tar.gz", hash = "sha256:89477b3d17d790e97b9fb4def708767061055795aae6f7c82ae32e967c9be4cd"}, @@ -898,7 +945,7 @@ files = [ [package.extras] docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] -testing = ["cssselect", "importlib-resources", "jaraco.test (>=5.1)", "lxml", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-ruff"] +testing = ["cssselect", "importlib-resources ; python_version < \"3.9\"", "jaraco.test (>=5.1)", "lxml ; python_version < \"3.11\"", "pytest (>=6)", "pytest-black (>=0.3.7) ; platform_python_implementation != \"PyPy\"", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1) ; platform_python_implementation != \"PyPy\"", "pytest-ruff"] [[package]] name = "cycler" @@ -906,6 +953,7 @@ version = "0.12.1" description = "Composable style cycles" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "cycler-0.12.1-py3-none-any.whl", hash = "sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30"}, {file = "cycler-0.12.1.tar.gz", hash = "sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c"}, @@ -921,6 +969,7 @@ version = "3.0.8" description = "The Cython compiler for writing C extensions in the Python language." optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +groups = ["main"] files = [ {file = "Cython-3.0.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a846e0a38e2b24e9a5c5dc74b0e54c6e29420d88d1dafabc99e0fc0f3e338636"}, {file = "Cython-3.0.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45523fdc2b78d79b32834cc1cc12dc2ca8967af87e22a3ee1bff20e77c7f5520"}, @@ -988,6 +1037,7 @@ version = "1.8.0" description = "An implementation of the Debug Adapter Protocol for Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "debugpy-1.8.0-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:7fb95ca78f7ac43393cd0e0f2b6deda438ec7c5e47fa5d38553340897d2fbdfb"}, {file = "debugpy-1.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef9ab7df0b9a42ed9c878afd3eaaff471fce3fa73df96022e1f5c9f8f8c87ada"}, @@ -1015,6 +1065,7 @@ version = "5.1.1" description = "Decorators for Humans" optional = false python-versions = ">=3.5" +groups = ["dev"] files = [ {file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"}, {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, @@ -1026,6 +1077,7 @@ version = "7.0.1" description = "Deep Difference and Search of any Python object/data. Recreate objects by adding adding deltas to each other." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "deepdiff-7.0.1-py3-none-any.whl", hash = "sha256:447760081918216aa4fd4ca78a4b6a848b81307b2ea94c810255334b759e1dc3"}, {file = "deepdiff-7.0.1.tar.gz", hash = "sha256:260c16f052d4badbf60351b4f77e8390bee03a0b516246f6839bc813fb429ddf"}, @@ -1044,6 +1096,7 @@ version = "0.7.1" description = "XML bomb protection for Python stdlib modules" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +groups = ["dev"] files = [ {file = "defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61"}, {file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"}, @@ -1055,6 +1108,7 @@ version = "0.3.0.post1" description = "A μ-library for constructing cascading style sheets from Python dictionaries." optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "dict2css-0.3.0.post1-py3-none-any.whl", hash = "sha256:f006a6b774c3e31869015122ae82c491fd25e7de4a75607a62aa3e798f837e0d"}, {file = "dict2css-0.3.0.post1.tar.gz", hash = "sha256:89c544c21c4ca7472c3fffb9d37d3d926f606329afdb751dc1de67a411b70719"}, @@ -1070,6 +1124,7 @@ version = "0.3.7" description = "Distribution utilities" optional = false python-versions = "*" +groups = ["main", "dev"] files = [ {file = "distlib-0.3.7-py2.py3-none-any.whl", hash = "sha256:2e24928bc811348f0feb63014e97aaae3037f2cf48712d51ae61df7fd6075057"}, {file = "distlib-0.3.7.tar.gz", hash = "sha256:9dafe54b34a028eafd95039d5e5d4851a13734540f1331060d31c9916e7147a8"}, @@ -1081,6 +1136,8 @@ version = "1.6" description = "A Python interface for Reinforcement Learning environments." optional = true python-versions = ">=3.7" +groups = ["main"] +markers = "sys_platform != \"darwin\" and extra == \"envpool\"" files = [ {file = "dm-env-1.6.tar.gz", hash = "sha256:a436eb1c654c39e0c986a516cee218bea7140b510fceff63f97eb4fcff3d93de"}, {file = "dm_env-1.6-py3-none-any.whl", hash = "sha256:0eabb6759dd453b625e041032f7ae0c1e87d4eb61b6a96b9ca586483837abf29"}, @@ -1097,6 +1154,8 @@ version = "0.1.8" description = "Tree is a library for working with nested data structures." optional = true python-versions = "*" +groups = ["main"] +markers = "sys_platform != \"darwin\" and extra == \"envpool\"" files = [ {file = "dm-tree-0.1.8.tar.gz", hash = "sha256:0fcaabbb14e7980377439e7140bd05552739ca5e515ecb3119f234acee4b9430"}, {file = "dm_tree-0.1.8-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:35cc164a79336bfcfafb47e5f297898359123bbd3330c1967f0c4994f9cf9f60"}, @@ -1152,6 +1211,7 @@ version = "0.4.0" description = "Python bindings for the docker credentials store API" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "docker-pycreds-0.4.0.tar.gz", hash = "sha256:6ce3270bcaf404cc4c3e27e4b6c70d3521deae82fb508767870fdbf772d584d4"}, {file = "docker_pycreds-0.4.0-py2.py3-none-any.whl", hash = "sha256:7266112468627868005106ec19cd0d722702d2b7d5912a28e19b826c3d37af49"}, @@ -1166,6 +1226,8 @@ version = "0.15" description = "Parse Python docstrings in reST, Google and Numpydoc format" optional = true python-versions = ">=3.6,<4.0" +groups = ["main"] +markers = "extra == \"argparse\" or extra == \"eval\"" files = [ {file = "docstring_parser-0.15-py3-none-any.whl", hash = "sha256:d1679b86250d269d06a99670924d6bce45adc00b08069dae8c47d98e89b667a9"}, {file = "docstring_parser-0.15.tar.gz", hash = "sha256:48ddc093e8b1865899956fcc03b03e66bb7240c310fac5af81814580c55bf682"}, @@ -1177,6 +1239,7 @@ version = "0.20.1" description = "Docutils -- Python Documentation Utilities" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "docutils-0.20.1-py3-none-any.whl", hash = "sha256:96f387a2c5562db4476f09f13bbab2192e764cac08ebbf3a34a95d9b1e4a59d6"}, {file = "docutils-0.20.1.tar.gz", hash = "sha256:f08a4e276c3a1583a86dce3e34aba3fe04d02bba2dd51ed16106244e8a923e3b"}, @@ -1188,6 +1251,7 @@ version = "3.7.0" description = "Helpful functions for Python 🐍 🛠️" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "domdf_python_tools-3.7.0-py3-none-any.whl", hash = "sha256:7b4d1c3bdb7402b872d43953824bf921ae2e52f893adbe5c0052a21a6efa2fe4"}, {file = "domdf_python_tools-3.7.0.tar.gz", hash = "sha256:df1af9a91649af0fb2a4e7b3a4b0a0936e4f78389dd7280dd6fd2f53a339ca71"}, @@ -1207,6 +1271,8 @@ version = "0.8.4" description = "\"C++-based high-performance parallel environment execution engine (vectorized env) for general RL environments.\"" optional = true python-versions = ">=3.7" +groups = ["main"] +markers = "sys_platform != \"darwin\" and extra == \"envpool\"" files = [ {file = "envpool-0.8.4-cp310-cp310-manylinux_2_24_x86_64.whl", hash = "sha256:9c6a1af66c8a18d798b3069e8eee4cde2e5942af22b25d058189714f2630b024"}, {file = "envpool-0.8.4-cp311-cp311-manylinux_2_24_x86_64.whl", hash = "sha256:2407294307a3e20c18787bb836a94cc0649e708b04d8a8200be674f5fc46f3b4"}, @@ -1231,13 +1297,14 @@ version = "2.0.1" description = "Get the currently executing AST node of a frame, and other information" optional = false python-versions = ">=3.5" +groups = ["dev"] files = [ {file = "executing-2.0.1-py2.py3-none-any.whl", hash = "sha256:eac49ca94516ccc753f9fb5ce82603156e590b27525a8bc32cce8ae302eb61bc"}, {file = "executing-2.0.1.tar.gz", hash = "sha256:35afe2ce3affba8ee97f2d69927fa823b08b472b7b994e36a52a964b93d16147"}, ] [package.extras] -tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich"] +tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich ; python_version >= \"3.11\""] [[package]] name = "farama-notifications" @@ -1245,28 +1312,19 @@ version = "0.0.4" description = "Notifications for all Farama Foundation maintained libraries." optional = false python-versions = "*" +groups = ["main"] files = [ {file = "Farama-Notifications-0.0.4.tar.gz", hash = "sha256:13fceff2d14314cf80703c8266462ebf3733c7d165336eee998fc58e545efd18"}, {file = "Farama_Notifications-0.0.4-py3-none-any.whl", hash = "sha256:14de931035a41961f7c056361dc7f980762a143d05791ef5794a751a2caf05ae"}, ] -[[package]] -name = "fasteners" -version = "0.19" -description = "A python package that provides useful locks" -optional = true -python-versions = ">=3.6" -files = [ - {file = "fasteners-0.19-py3-none-any.whl", hash = "sha256:758819cb5d94cdedf4e836988b74de396ceacb8e2794d21f82d131fd9ee77237"}, - {file = "fasteners-0.19.tar.gz", hash = "sha256:b4f37c3ac52d8a445af3a66bce57b33b5e90b97c696b7b984f530cf8f0ded09c"}, -] - [[package]] name = "fastjsonschema" version = "2.19.0" description = "Fastest Python implementation of JSON schema" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "fastjsonschema-2.19.0-py3-none-any.whl", hash = "sha256:b9fd1a2dd6971dbc7fee280a95bd199ae0dd9ce22beb91cc75e9c1c528a5170e"}, {file = "fastjsonschema-2.19.0.tar.gz", hash = "sha256:e25df6647e1bc4a26070b700897b07b542ec898dd4f1f6ea013e7f6a88417225"}, @@ -1281,6 +1339,7 @@ version = "3.13.1" description = "A platform independent file lock." optional = false python-versions = ">=3.8" +groups = ["main", "dev"] files = [ {file = "filelock-3.13.1-py3-none-any.whl", hash = "sha256:57dbda9b35157b05fb3e58ee91448612eb674172fab98ee235ccb0b5bee19a1c"}, {file = "filelock-3.13.1.tar.gz", hash = "sha256:521f5f56c50f8426f5e03ad3b281b490a87ef15bc6c526f168290f0c7148d44e"}, @@ -1289,7 +1348,7 @@ files = [ [package.extras] docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.24)"] testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] -typing = ["typing-extensions (>=4.8)"] +typing = ["typing-extensions (>=4.8) ; python_version < \"3.11\""] [[package]] name = "fonttools" @@ -1297,6 +1356,7 @@ version = "4.51.0" description = "Tools to manipulate font files" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "fonttools-4.51.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:84d7751f4468dd8cdd03ddada18b8b0857a5beec80bce9f435742abc9a851a74"}, {file = "fonttools-4.51.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8b4850fa2ef2cfbc1d1f689bc159ef0f45d8d83298c1425838095bf53ef46308"}, @@ -1343,18 +1403,18 @@ files = [ ] [package.extras] -all = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "fs (>=2.2.0,<3)", "lxml (>=4.0)", "lz4 (>=1.7.4.2)", "matplotlib", "munkres", "pycairo", "scipy", "skia-pathops (>=0.5.0)", "sympy", "uharfbuzz (>=0.23.0)", "unicodedata2 (>=15.1.0)", "xattr", "zopfli (>=0.1.4)"] +all = ["brotli (>=1.0.1) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\"", "fs (>=2.2.0,<3)", "lxml (>=4.0)", "lz4 (>=1.7.4.2)", "matplotlib", "munkres ; platform_python_implementation == \"PyPy\"", "pycairo", "scipy ; platform_python_implementation != \"PyPy\"", "skia-pathops (>=0.5.0)", "sympy", "uharfbuzz (>=0.23.0)", "unicodedata2 (>=15.1.0) ; python_version <= \"3.12\"", "xattr ; sys_platform == \"darwin\"", "zopfli (>=0.1.4)"] graphite = ["lz4 (>=1.7.4.2)"] -interpolatable = ["munkres", "pycairo", "scipy"] +interpolatable = ["munkres ; platform_python_implementation == \"PyPy\"", "pycairo", "scipy ; platform_python_implementation != \"PyPy\""] lxml = ["lxml (>=4.0)"] pathops = ["skia-pathops (>=0.5.0)"] plot = ["matplotlib"] repacker = ["uharfbuzz (>=0.23.0)"] symfont = ["sympy"] -type1 = ["xattr"] +type1 = ["xattr ; sys_platform == \"darwin\""] ufo = ["fs (>=2.2.0,<3)"] -unicode = ["unicodedata2 (>=15.1.0)"] -woff = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "zopfli (>=0.1.4)"] +unicode = ["unicodedata2 (>=15.1.0) ; python_version <= \"3.12\""] +woff = ["brotli (>=1.0.1) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\"", "zopfli (>=0.1.4)"] [[package]] name = "fqdn" @@ -1362,6 +1422,7 @@ version = "1.5.1" description = "Validates fully-qualified domain names against RFC 1123, so that they are acceptable to modern bowsers" optional = false python-versions = ">=2.7, !=3.0, !=3.1, !=3.2, !=3.3, !=3.4, <4" +groups = ["dev"] files = [ {file = "fqdn-1.5.1-py3-none-any.whl", hash = "sha256:3a179af3761e4df6eb2e026ff9e1a3033d3587bf980a0b1b2e1e5d08d7358014"}, {file = "fqdn-1.5.1.tar.gz", hash = "sha256:105ed3677e767fb5ca086a0c1f4bb66ebc3c100be518f0e0d755d9eae164d89f"}, @@ -1373,6 +1434,8 @@ version = "1.4.0" description = "A list-like structure which implements collections.abc.MutableSequence" optional = false python-versions = ">=3.8" +groups = ["dev"] +markers = "sys_platform != \"win32\"" files = [ {file = "frozenlist-1.4.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:764226ceef3125e53ea2cb275000e309c0aa5464d43bd72abd661e27fffc26ab"}, {file = "frozenlist-1.4.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d6484756b12f40003c6128bfcc3fa9f0d49a687e171186c2d85ec82e3758c559"}, @@ -1443,6 +1506,7 @@ version = "2023.10.0" description = "File-system specification" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "fsspec-2023.10.0-py3-none-any.whl", hash = "sha256:346a8f024efeb749d2a5fca7ba8854474b1ff9af7c3faaf636a4548781136529"}, {file = "fsspec-2023.10.0.tar.gz", hash = "sha256:330c66757591df346ad3091a53bd907e15348c2ba17d63fd54f5c39c4457d2a5"}, @@ -1478,6 +1542,7 @@ version = "4.0.11" description = "Git Object Database" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "gitdb-4.0.11-py3-none-any.whl", hash = "sha256:81a3407ddd2ee8df444cbacea00e2d038e40150acfa3001696fe0dcf1d3adfa4"}, {file = "gitdb-4.0.11.tar.gz", hash = "sha256:bf5421126136d6d0af55bc1e7c1af1c397a34f5b7bd79e776cd3e89785c2b04b"}, @@ -1492,6 +1557,7 @@ version = "3.1.41" description = "GitPython is a Python library used to interact with Git repositories" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "GitPython-3.1.41-py3-none-any.whl", hash = "sha256:c36b6634d069b3f719610175020a9aed919421c87552185b085e04fbbdb10b7c"}, {file = "GitPython-3.1.41.tar.gz", hash = "sha256:ed66e624884f76df22c8e16066d567aaa5a37d5b5fa19db2c6df6f7156db9048"}, @@ -1501,7 +1567,7 @@ files = [ gitdb = ">=4.0.1,<5" [package.extras] -test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "sumtypes"] +test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock ; python_version < \"3.8\"", "mypy", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "sumtypes"] [[package]] name = "glfw" @@ -1509,6 +1575,8 @@ version = "2.6.5" description = "A ctypes-based wrapper for GLFW3." optional = true python-versions = "*" +groups = ["main"] +markers = "extra == \"robotics\" or extra == \"mujoco\"" files = [ {file = "glfw-2.6.5-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-macosx_10_6_intel.whl", hash = "sha256:57d00367f8dc31b898a47ab22849bab9f87dff4b4c7a56d16d9a7158cda96c19"}, {file = "glfw-2.6.5-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-macosx_11_0_arm64.whl", hash = "sha256:a1a132e7d6f78ae7f32957b56de2fd996d2a416f9520adb40345cc9cf744d277"}, @@ -1530,6 +1598,7 @@ version = "2.23.4" description = "Google Authentication Library" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "google-auth-2.23.4.tar.gz", hash = "sha256:79905d6b1652187def79d491d6e23d0cbb3a21d3c7ba0dbaa9c8a01906b13ff3"}, {file = "google_auth-2.23.4-py2.py3-none-any.whl", hash = "sha256:d4bbc92fe4b8bfd2f3e8d88e5ba7085935da208ee38a134fc280e7ce682a05f2"}, @@ -1553,6 +1622,7 @@ version = "1.1.0" description = "Google Authentication Library" optional = false python-versions = ">=3.6" +groups = ["main"] files = [ {file = "google-auth-oauthlib-1.1.0.tar.gz", hash = "sha256:83ea8c3b0881e453790baff4448e8a6112ac8778d1de9da0b68010b843937afb"}, {file = "google_auth_oauthlib-1.1.0-py2.py3-none-any.whl", hash = "sha256:089c6e587d36f4803ac7e0720c045c6a8b1fd1790088b8424975b90d0ee61c12"}, @@ -1571,6 +1641,8 @@ version = "3.0.1" description = "Lightweight in-process concurrent programming" optional = false python-versions = ">=3.7" +groups = ["dev"] +markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\"" files = [ {file = "greenlet-3.0.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f89e21afe925fcfa655965ca8ea10f24773a1791400989ff32f467badfe4a064"}, {file = "greenlet-3.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28e89e232c7593d33cac35425b58950789962011cc274aa43ef8865f2e11f46d"}, @@ -1641,6 +1713,7 @@ version = "1.59.3" description = "HTTP/2-based RPC framework" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "grpcio-1.59.3-cp310-cp310-linux_armv7l.whl", hash = "sha256:aca028a6c7806e5b61e5f9f4232432c52856f7fcb98e330b20b6bc95d657bdcc"}, {file = "grpcio-1.59.3-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:19ad26a7967f7999c8960d2b9fe382dae74c55b0c508c613a6c2ba21cddf2354"}, @@ -1707,6 +1780,8 @@ version = "0.26.2" description = "Gym: A universal API for reinforcement learning environments" optional = true python-versions = ">=3.6" +groups = ["main"] +markers = "sys_platform != \"darwin\" and extra == \"envpool\"" files = [ {file = "gym-0.26.2.tar.gz", hash = "sha256:e0d882f4b54f0c65f203104c24ab8a38b039f1289986803c7d02cdbe214fbcc4"}, ] @@ -1734,6 +1809,8 @@ version = "0.0.8" description = "Notices for gym" optional = true python-versions = "*" +groups = ["main"] +markers = "sys_platform != \"darwin\" and extra == \"envpool\"" files = [ {file = "gym-notices-0.0.8.tar.gz", hash = "sha256:ad25e200487cafa369728625fe064e88ada1346618526102659b4640f2b4b911"}, {file = "gym_notices-0.0.8-py3-none-any.whl", hash = "sha256:e5f82e00823a166747b4c2a07de63b6560b1acb880638547e0cabf825a01e463"}, @@ -1745,6 +1822,7 @@ version = "0.28.1" description = "A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym)." optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "gymnasium-0.28.1-py3-none-any.whl", hash = "sha256:7bc9a5bce1022f997d1dbc152fc91d1ac977bad9cc7794cdc25437010867cabf"}, {file = "gymnasium-0.28.1.tar.gz", hash = "sha256:4c2c745808792c8f45c6e88ad0a5504774394e0c126f6e3db555e720d3da6f24"}, @@ -1776,6 +1854,8 @@ version = "1.2.3" description = "Robotics environments for the Gymnasium repo." optional = true python-versions = ">=3.8" +groups = ["main"] +markers = "extra == \"robotics\"" files = [ {file = "gymnasium-robotics-1.2.3.tar.gz", hash = "sha256:b01eb9df74c0041e559e1251442ba1a59174bfc71a1c58519724d76df803c0b6"}, {file = "gymnasium_robotics-1.2.3-py3-none-any.whl", hash = "sha256:9c3cd7bcc7ac7a0efca03d5685a01686661c7fa678e34adfe4e15044580e7180"}, @@ -1799,6 +1879,7 @@ version = "0.14.0" description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, @@ -1810,6 +1891,7 @@ version = "3.10.0" description = "Read and write HDF5 files from Python" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "h5py-3.10.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b963fb772964fc1d1563c57e4e2e874022ce11f75ddc6df1a626f42bd49ab99f"}, {file = "h5py-3.10.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:012ab448590e3c4f5a8dd0f3533255bc57f80629bf7c5054cf4c87b30085063c"}, @@ -1847,6 +1929,7 @@ version = "1.1" description = "HTML parser based on the WHATWG HTML specification" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +groups = ["dev"] files = [ {file = "html5lib-1.1-py2.py3-none-any.whl", hash = "sha256:0d78f8fde1c230e99fe37986a60526d7049ed4bf8a9fadbad5f00e22e58e041d"}, {file = "html5lib-1.1.tar.gz", hash = "sha256:b2e5b40261e20f354d198eae92afc10d750afb487ed5e50f9c4eaf07c184146f"}, @@ -1857,10 +1940,10 @@ six = ">=1.9" webencodings = "*" [package.extras] -all = ["chardet (>=2.2)", "genshi", "lxml"] +all = ["chardet (>=2.2)", "genshi", "lxml ; platform_python_implementation == \"CPython\""] chardet = ["chardet (>=2.2)"] genshi = ["genshi"] -lxml = ["lxml"] +lxml = ["lxml ; platform_python_implementation == \"CPython\""] [[package]] name = "httpcore" @@ -1868,6 +1951,7 @@ version = "1.0.5" description = "A minimal low-level HTTP client." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "httpcore-1.0.5-py3-none-any.whl", hash = "sha256:421f18bac248b25d310f3cacd198d55b8e6125c107797b609ff9b7a6ba7991b5"}, {file = "httpcore-1.0.5.tar.gz", hash = "sha256:34a38e2f9291467ee3b44e89dd52615370e152954ba21721378a87b2960f7a61"}, @@ -1889,6 +1973,7 @@ version = "0.27.2" description = "The next generation HTTP client." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "httpx-0.27.2-py3-none-any.whl", hash = "sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0"}, {file = "httpx-0.27.2.tar.gz", hash = "sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2"}, @@ -1902,7 +1987,7 @@ idna = "*" sniffio = "*" [package.extras] -brotli = ["brotli", "brotlicffi"] +brotli = ["brotli ; platform_python_implementation == \"CPython\"", "brotlicffi ; platform_python_implementation != \"CPython\""] cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] @@ -1914,6 +1999,7 @@ version = "2.5.32" description = "File identification library for Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "identify-2.5.32-py2.py3-none-any.whl", hash = "sha256:0b7656ef6cba81664b783352c73f8c24b39cf82f926f78f4550eda928e5e0545"}, {file = "identify-2.5.32.tar.gz", hash = "sha256:5d9979348ec1a21c768ae07e0a652924538e8bce67313a73cb0f681cf08ba407"}, @@ -1928,6 +2014,7 @@ version = "3.4" description = "Internationalized Domain Names in Applications (IDNA)" optional = false python-versions = ">=3.5" +groups = ["main", "dev"] files = [ {file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"}, {file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"}, @@ -1939,6 +2026,8 @@ version = "2.33.1" description = "Library for reading and writing a wide range of image, video, scientific, and volumetric data formats." optional = true python-versions = ">=3.8" +groups = ["main"] +markers = "extra == \"robotics\" or extra == \"mujoco\"" files = [ {file = "imageio-2.33.1-py3-none-any.whl", hash = "sha256:c5094c48ccf6b2e6da8b4061cd95e1209380afafcbeae4a4e280938cce227e1d"}, {file = "imageio-2.33.1.tar.gz", hash = "sha256:78722d40b137bd98f5ec7312119f8aea9ad2049f76f434748eb306b6937cc1ce"}, @@ -1971,6 +2060,7 @@ version = "1.4.1" description = "Getting image size from png/jpeg/jpeg2000/gif file" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +groups = ["dev"] files = [ {file = "imagesize-1.4.1-py2.py3-none-any.whl", hash = "sha256:0d8d18d08f840c19d0ee7ca1fd82490fdc3729b7ac93f49870406ddde8ef8d8b"}, {file = "imagesize-1.4.1.tar.gz", hash = "sha256:69150444affb9cb0d5cc5a92b3676f0b2fb7cd9ae39e947a5e11a36b4497cd4a"}, @@ -1982,6 +2072,7 @@ version = "6.8.0" description = "Read metadata from Python packages" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "importlib_metadata-6.8.0-py3-none-any.whl", hash = "sha256:3ebb78df84a805d7698245025b975d9d67053cd94c79245ba4b3eb694abe68bb"}, {file = "importlib_metadata-6.8.0.tar.gz", hash = "sha256:dbace7892d8c0c4ac1ad096662232f831d4e64f4c4545bd53016a3e9d4654743"}, @@ -1993,7 +2084,7 @@ zipp = ">=0.5" [package.extras] docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] perf = ["ipython"] -testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"] +testing = ["flufl.flake8", "importlib-resources (>=1.3) ; python_version < \"3.9\"", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7) ; platform_python_implementation != \"PyPy\"", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1) ; platform_python_implementation != \"PyPy\"", "pytest-perf (>=0.9.2)", "pytest-ruff"] [[package]] name = "importlib-resources" @@ -2001,6 +2092,8 @@ version = "6.1.1" description = "Read resources from Python packages" optional = true python-versions = ">=3.8" +groups = ["main"] +markers = "extra == \"atari\"" files = [ {file = "importlib_resources-6.1.1-py3-none-any.whl", hash = "sha256:e8bf90d8213b486f428c9c39714b920041cb02c184686a3dee24905aaa8105d6"}, {file = "importlib_resources-6.1.1.tar.gz", hash = "sha256:3893a00122eafde6894c59914446a512f728a0c1a45f9bb9b63721b6bacf0b4a"}, @@ -2008,7 +2101,7 @@ files = [ [package.extras] docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] -testing = ["pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-ruff", "zipp (>=3.17)"] +testing = ["pytest (>=6)", "pytest-black (>=0.3.7) ; platform_python_implementation != \"PyPy\"", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1) ; platform_python_implementation != \"PyPy\"", "pytest-ruff", "zipp (>=3.17)"] [[package]] name = "iniconfig" @@ -2016,6 +2109,7 @@ version = "2.0.0" description = "brain-dead simple config-ini parsing" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, @@ -2027,6 +2121,7 @@ version = "6.26.0" description = "IPython Kernel for Jupyter" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "ipykernel-6.26.0-py3-none-any.whl", hash = "sha256:3ba3dc97424b87b31bb46586b5167b3161b32d7820b9201a9e698c71e271602c"}, {file = "ipykernel-6.26.0.tar.gz", hash = "sha256:553856658eb8430bbe9653ea041a41bff63e9606fc4628873fc92a6cf3abd404"}, @@ -2060,6 +2155,7 @@ version = "8.17.2" description = "IPython: Productive Interactive Computing" optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "ipython-8.17.2-py3-none-any.whl", hash = "sha256:1e4d1d666a023e3c93585ba0d8e962867f7a111af322efff6b9c58062b3e5444"}, {file = "ipython-8.17.2.tar.gz", hash = "sha256:126bb57e1895594bb0d91ea3090bbd39384f6fe87c3d57fd558d0670f50339bb"}, @@ -2096,6 +2192,7 @@ version = "8.1.1" description = "Jupyter interactive widgets" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "ipywidgets-8.1.1-py3-none-any.whl", hash = "sha256:2b88d728656aea3bbfd05d32c747cfd0078f9d7e159cf982433b58ad717eed7f"}, {file = "ipywidgets-8.1.1.tar.gz", hash = "sha256:40211efb556adec6fa450ccc2a77d59ca44a060f4f9f136833df59c9f538e6e8"}, @@ -2117,6 +2214,7 @@ version = "20.11.0" description = "Operations with ISO 8601 durations" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "isoduration-20.11.0-py3-none-any.whl", hash = "sha256:b2904c2a4228c3d44f409c8ae8e2370eb21a26f7ac2ec5446df141dde3452042"}, {file = "isoduration-20.11.0.tar.gz", hash = "sha256:ac2f9015137935279eac671f94f89eb00584f940f5dc49462a0c4ee692ba1bd9"}, @@ -2131,6 +2229,7 @@ version = "1.0.0" description = "Common backend for Jax or Numpy." optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "jax-jumpy-1.0.0.tar.gz", hash = "sha256:195fb955cc4c2b7f0b1453e3cb1fb1c414a51a407ffac7a51e69a73cb30d59ad"}, {file = "jax_jumpy-1.0.0-py3-none-any.whl", hash = "sha256:ab7e01454bba462de3c4d098e3e585c302a8f06bc36d9182ab4e7e4aa7067c5e"}, @@ -2149,6 +2248,7 @@ version = "0.19.1" description = "An autocompletion tool for Python that can be used for text editors." optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "jedi-0.19.1-py2.py3-none-any.whl", hash = "sha256:e983c654fe5c02867aef4cdfce5a2fbb4a50adc0af145f70504238f18ef5e7e0"}, {file = "jedi-0.19.1.tar.gz", hash = "sha256:cf0496f3651bc65d7174ac1b7d043eff454892c708a87d1b683e57b569927ffd"}, @@ -2168,6 +2268,7 @@ version = "3.1.4" description = "A very fast and expressive template engine." optional = false python-versions = ">=3.7" +groups = ["main", "dev"] files = [ {file = "jinja2-3.1.4-py3-none-any.whl", hash = "sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d"}, {file = "jinja2-3.1.4.tar.gz", hash = "sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369"}, @@ -2185,6 +2286,8 @@ version = "1.4.0" description = "Lightweight pipelining with Python functions" optional = true python-versions = ">=3.8" +groups = ["main"] +markers = "extra == \"eval\"" files = [ {file = "joblib-1.4.0-py3-none-any.whl", hash = "sha256:42942470d4062537be4d54c83511186da1fc14ba354961a2114da91efa9a4ed7"}, {file = "joblib-1.4.0.tar.gz", hash = "sha256:1eb0dc091919cd384490de890cb5dfd538410a6d4b3b54eef09fb8c50b409b1c"}, @@ -2196,6 +2299,7 @@ version = "0.9.14" description = "A Python implementation of the JSON5 data format." optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "json5-0.9.14-py2.py3-none-any.whl", hash = "sha256:740c7f1b9e584a468dbb2939d8d458db3427f2c93ae2139d05f47e453eae964f"}, {file = "json5-0.9.14.tar.gz", hash = "sha256:9ed66c3a6ca3510a976a9ef9b8c0787de24802724ab1860bc0153c7fdd589b02"}, @@ -2210,6 +2314,8 @@ version = "4.27.0" description = "Implement minimal boilerplate CLIs derived from type hints and parse from command line, config files and environment variables." optional = true python-versions = ">=3.7" +groups = ["main"] +markers = "extra == \"argparse\" or extra == \"eval\"" files = [ {file = "jsonargparse-4.27.0-py3-none-any.whl", hash = "sha256:a6378bc8b7bbe38b708f090b10ea8431216e71f8b2eea1f9a4f095ae4abd0f2e"}, {file = "jsonargparse-4.27.0.tar.gz", hash = "sha256:6ac791cd7913cff34ad2dbd3ed0431f9e327af0be926332ac060bd5b13d353f2"}, @@ -2225,7 +2331,7 @@ coverage = ["jsonargparse[test-no-urls]", "pytest-cov (>=4.0.0)"] dev = ["build (>=0.10.0)", "jsonargparse[coverage]", "jsonargparse[doc]", "jsonargparse[mypy]", "jsonargparse[test]", "pre-commit (>=2.19.0)", "tox (>=3.25.0)"] doc = ["Sphinx (>=1.7.9)", "autodocsumm (>=0.1.10)", "sphinx-autodoc-typehints (>=1.19.5)", "sphinx-rtd-theme (>=1.2.2)"] fsspec = ["fsspec (>=0.8.4)"] -jsonnet = ["jsonnet (>=0.13.0)", "jsonnet-binary (>=0.17.0)"] +jsonnet = ["jsonnet (>=0.13.0) ; os_name == \"posix\"", "jsonnet-binary (>=0.17.0) ; os_name != \"posix\""] jsonschema = ["jsonschema (>=3.2.0)"] maintainer = ["bump2version (>=0.5.11)", "twine (>=4.0.2)"] omegaconf = ["omegaconf (>=2.1.1)"] @@ -2234,7 +2340,7 @@ ruyaml = ["ruyaml (>=0.20.0)"] signatures = ["docstring-parser (>=0.15)", "jsonargparse[typing-extensions]", "typeshed-client (>=2.1.0)"] test = ["attrs (>=22.2.0)", "jsonargparse[test-no-urls]", "pydantic (>=2.3.0)", "responses (>=0.12.0)", "types-PyYAML (>=6.0.11)", "types-requests (>=2.28.9)"] test-no-urls = ["pytest (>=6.2.5)", "pytest-subtests (>=0.8.0)"] -typing-extensions = ["typing-extensions (>=3.10.0.0)"] +typing-extensions = ["typing-extensions (>=3.10.0.0) ; python_version < \"3.10\""] urls = ["requests (>=2.18.4)"] [[package]] @@ -2243,6 +2349,7 @@ version = "2.4" description = "Identify specific nodes in a JSON document (RFC 6901)" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" +groups = ["dev"] files = [ {file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"}, {file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"}, @@ -2254,6 +2361,7 @@ version = "4.20.0" description = "An implementation of JSON Schema validation for Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "jsonschema-4.20.0-py3-none-any.whl", hash = "sha256:ed6231f0429ecf966f5bc8dfef245998220549cbbcf140f913b7464c52c3b6b3"}, {file = "jsonschema-4.20.0.tar.gz", hash = "sha256:4f614fd46d8d61258610998997743ec5492a648b33cf478c1ddc23ed4598a5fa"}, @@ -2283,6 +2391,7 @@ version = "2023.11.1" description = "The JSON Schema meta-schemas and vocabularies, exposed as a Registry" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "jsonschema_specifications-2023.11.1-py3-none-any.whl", hash = "sha256:f596778ab612b3fd29f72ea0d990393d0540a5aab18bf0407a46632eab540779"}, {file = "jsonschema_specifications-2023.11.1.tar.gz", hash = "sha256:c9b234904ffe02f079bf91b14d79987faa685fd4b39c377a0996954c0090b9ca"}, @@ -2297,6 +2406,7 @@ version = "1.0.0" description = "Jupyter metapackage. Install all the Jupyter components in one go." optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "jupyter-1.0.0-py2.py3-none-any.whl", hash = "sha256:5b290f93b98ffbc21c0c7e749f054b3267782166d72fa5e3ed1ed4eaf34a2b78"}, {file = "jupyter-1.0.0.tar.gz", hash = "sha256:d9dc4b3318f310e34c82951ea5d6683f67bed7def4b259fafbfe4f1beb1d8e5f"}, @@ -2317,6 +2427,7 @@ version = "1.0.0" description = "Build a book with Jupyter Notebooks and Sphinx." optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "jupyter_book-1.0.0-py3-none-any.whl", hash = "sha256:18238f1e7e1d425731e60ab509a7da878dd6db88b7d77bcfab4690361b72e1be"}, {file = "jupyter_book-1.0.0.tar.gz", hash = "sha256:539c5d0493546200d9de27bd4b5f77eaea03115f8937f825d4ff82b3801a987e"}, @@ -2354,6 +2465,7 @@ version = "0.6.1" description = "A defined interface for working with a cache of jupyter notebooks." optional = false python-versions = "~=3.8" +groups = ["dev"] files = [ {file = "jupyter-cache-0.6.1.tar.gz", hash = "sha256:26f83901143edf4af2f3ff5a91e2d2ad298e46e2cee03c8071d37a23a63ccbfc"}, {file = "jupyter_cache-0.6.1-py3-none-any.whl", hash = "sha256:2fce7d4975805c77f75bdfc1bc2e82bc538b8e5b1af27f2f5e06d55b9f996a82"}, @@ -2381,6 +2493,7 @@ version = "8.6.0" description = "Jupyter protocol implementation and client libraries" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "jupyter_client-8.6.0-py3-none-any.whl", hash = "sha256:909c474dbe62582ae62b758bca86d6518c85234bdee2d908c778db6d72f39d99"}, {file = "jupyter_client-8.6.0.tar.gz", hash = "sha256:0642244bb83b4764ae60d07e010e15f0e2d275ec4e918a8f7b80fbbef3ca60c7"}, @@ -2395,7 +2508,7 @@ traitlets = ">=5.3" [package.extras] docs = ["ipykernel", "myst-parser", "pydata-sphinx-theme", "sphinx (>=4)", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling"] -test = ["coverage", "ipykernel (>=6.14)", "mypy", "paramiko", "pre-commit", "pytest", "pytest-cov", "pytest-jupyter[client] (>=0.4.1)", "pytest-timeout"] +test = ["coverage", "ipykernel (>=6.14)", "mypy", "paramiko ; sys_platform == \"win32\"", "pre-commit", "pytest", "pytest-cov", "pytest-jupyter[client] (>=0.4.1)", "pytest-timeout"] [[package]] name = "jupyter-console" @@ -2403,6 +2516,7 @@ version = "6.6.3" description = "Jupyter terminal console" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "jupyter_console-6.6.3-py3-none-any.whl", hash = "sha256:309d33409fcc92ffdad25f0bcdf9a4a9daa61b6f341177570fdac03de5352485"}, {file = "jupyter_console-6.6.3.tar.gz", hash = "sha256:566a4bf31c87adbfadf22cdf846e3069b59a71ed5da71d6ba4d8aaad14a53539"}, @@ -2427,6 +2541,7 @@ version = "5.5.0" description = "Jupyter core package. A base package on which Jupyter projects rely." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "jupyter_core-5.5.0-py3-none-any.whl", hash = "sha256:e11e02cd8ae0a9de5c6c44abf5727df9f2581055afe00b22183f621ba3585805"}, {file = "jupyter_core-5.5.0.tar.gz", hash = "sha256:880b86053bf298a8724994f95e99b99130659022a4f7f45f563084b6223861d3"}, @@ -2447,6 +2562,7 @@ version = "0.9.0" description = "Jupyter Event System library" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "jupyter_events-0.9.0-py3-none-any.whl", hash = "sha256:d853b3c10273ff9bc8bb8b30076d65e2c9685579db736873de6c2232dde148bf"}, {file = "jupyter_events-0.9.0.tar.gz", hash = "sha256:81ad2e4bc710881ec274d31c6c50669d71bbaa5dd9d01e600b56faa85700d399"}, @@ -2472,6 +2588,7 @@ version = "2.2.2" description = "Multi-Language Server WebSocket proxy for Jupyter Notebook/Lab server" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "jupyter-lsp-2.2.2.tar.gz", hash = "sha256:256d24620542ae4bba04a50fc1f6ffe208093a07d8e697fea0a8d1b8ca1b7e5b"}, {file = "jupyter_lsp-2.2.2-py3-none-any.whl", hash = "sha256:3b95229e4168355a8c91928057c1621ac3510ba98b2a925e82ebd77f078b1aa5"}, @@ -2486,6 +2603,7 @@ version = "2.11.2" description = "The backend—i.e. core services, APIs, and REST endpoints—to Jupyter web applications." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "jupyter_server-2.11.2-py3-none-any.whl", hash = "sha256:0c548151b54bcb516ca466ec628f7f021545be137d01b5467877e87f6fff4374"}, {file = "jupyter_server-2.11.2.tar.gz", hash = "sha256:0c99f9367b0f24141e527544522430176613f9249849be80504c6d2b955004bb"}, @@ -2522,6 +2640,7 @@ version = "0.4.4" description = "A Jupyter Server Extension Providing Terminals." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "jupyter_server_terminals-0.4.4-py3-none-any.whl", hash = "sha256:75779164661cec02a8758a5311e18bb8eb70c4e86c6b699403100f1585a12a36"}, {file = "jupyter_server_terminals-0.4.4.tar.gz", hash = "sha256:57ab779797c25a7ba68e97bcfb5d7740f2b5e8a83b5e8102b10438041a7eac5d"}, @@ -2541,6 +2660,7 @@ version = "4.2.5" description = "JupyterLab computational environment" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "jupyterlab-4.2.5-py3-none-any.whl", hash = "sha256:73b6e0775d41a9fee7ee756c80f58a6bed4040869ccc21411dc559818874d321"}, {file = "jupyterlab-4.2.5.tar.gz", hash = "sha256:ae7f3a1b8cb88b4f55009ce79fa7c06f99d70cd63601ee4aa91815d054f46f75"}, @@ -2574,6 +2694,7 @@ version = "0.2.2" description = "Pygments theme using JupyterLab CSS variables" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "jupyterlab_pygments-0.2.2-py2.py3-none-any.whl", hash = "sha256:2405800db07c9f770863bcf8049a529c3dd4d3e28536638bd7c1c01d2748309f"}, {file = "jupyterlab_pygments-0.2.2.tar.gz", hash = "sha256:7405d7fde60819d905a9fa8ce89e4cd830e318cdad22a0030f7a901da705585d"}, @@ -2585,6 +2706,7 @@ version = "2.27.3" description = "A set of server components for JupyterLab and JupyterLab like applications." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "jupyterlab_server-2.27.3-py3-none-any.whl", hash = "sha256:e697488f66c3db49df675158a77b3b017520d772c6e1548c7d9bcc5df7944ee4"}, {file = "jupyterlab_server-2.27.3.tar.gz", hash = "sha256:eb36caca59e74471988f0ae25c77945610b887f777255aa21f8065def9e51ed4"}, @@ -2610,6 +2732,7 @@ version = "3.0.9" description = "Jupyter interactive widgets for JupyterLab" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "jupyterlab_widgets-3.0.9-py3-none-any.whl", hash = "sha256:3cf5bdf5b897bf3bccf1c11873aa4afd776d7430200f765e0686bd352487b58d"}, {file = "jupyterlab_widgets-3.0.9.tar.gz", hash = "sha256:6005a4e974c7beee84060fdfba341a3218495046de8ae3ec64888e5fe19fdb4c"}, @@ -2621,6 +2744,7 @@ version = "1.4.5" description = "A fast implementation of the Cassowary constraint solver" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "kiwisolver-1.4.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:05703cf211d585109fcd72207a31bb170a0f22144d68298dc5e61b3c946518af"}, {file = "kiwisolver-1.4.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:146d14bebb7f1dc4d5fbf74f8a6cb15ac42baadee8912eb84ac0b3b2a3dc6ac3"}, @@ -2734,6 +2858,7 @@ version = "2.0.1" description = "A lexer and codec to work with LaTeX code in Python." optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +groups = ["dev"] files = [ {file = "latexcodec-2.0.1-py2.py3-none-any.whl", hash = "sha256:c277a193638dc7683c4c30f6684e3db728a06efb0dc9cf346db8bd0aa6c5d271"}, {file = "latexcodec-2.0.1.tar.gz", hash = "sha256:2aa2551c373261cefe2ad3a8953a6d6533e68238d180eb4bb91d7964adb3fe9a"}, @@ -2748,6 +2873,7 @@ version = "2.0.2" description = "Links recognition library with FULL unicode support." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "linkify-it-py-2.0.2.tar.gz", hash = "sha256:19f3060727842c254c808e99d465c80c49d2c7306788140987a1a7a29b0d6ad2"}, {file = "linkify_it_py-2.0.2-py3-none-any.whl", hash = "sha256:a3a24428f6c96f27370d7fe61d2ac0be09017be5190d68d8658233171f1b6541"}, @@ -2768,6 +2894,7 @@ version = "0.43.0" description = "lightweight wrapper around basic LLVM functionality" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "llvmlite-0.43.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a289af9a1687c6cf463478f0fa8e8aa3b6fb813317b0d70bf1ed0759eab6f761"}, {file = "llvmlite-0.43.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6d4fd101f571a31acb1559ae1af30f30b1dc4b3186669f92ad780e17c81e91bc"}, @@ -2798,6 +2925,7 @@ version = "3.5.1" description = "Python implementation of John Gruber's Markdown." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "Markdown-3.5.1-py3-none-any.whl", hash = "sha256:5874b47d4ee3f0b14d764324d2c94c03ea66bee56f2d929da9f2508d65e722dc"}, {file = "Markdown-3.5.1.tar.gz", hash = "sha256:b65d7beb248dc22f2e8a31fb706d93798093c308dc1aba295aedeb9d41a813bd"}, @@ -2813,6 +2941,7 @@ version = "3.0.0" description = "Python port of markdown-it. Markdown parsing, done right!" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"}, {file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"}, @@ -2837,6 +2966,7 @@ version = "2.1.3" description = "Safely add untrusted strings to HTML/XML markup." optional = false python-versions = ">=3.7" +groups = ["main", "dev"] files = [ {file = "MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cd0f502fe016460680cd20aaa5a76d241d6f35a1c3350c474bac1273803893fa"}, {file = "MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e09031c87a1e51556fdcb46e5bd4f59dfb743061cf93c4d6831bf894f125eb57"}, @@ -2906,6 +3036,7 @@ version = "3.8.4" description = "Python plotting package" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "matplotlib-3.8.4-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:abc9d838f93583650c35eca41cfcec65b2e7cb50fd486da6f0c49b5e1ed23014"}, {file = "matplotlib-3.8.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8f65c9f002d281a6e904976007b2d46a1ee2bcea3a68a8c12dda24709ddc9106"}, @@ -2954,6 +3085,7 @@ version = "0.1.6" description = "Inline Matplotlib backend for Jupyter" optional = false python-versions = ">=3.5" +groups = ["dev"] files = [ {file = "matplotlib-inline-0.1.6.tar.gz", hash = "sha256:f887e5f10ba98e8d2b150ddcf4702c1e5f8b3a20005eb0f74bfdbd360ee6f304"}, {file = "matplotlib_inline-0.1.6-py3-none-any.whl", hash = "sha256:f1f41aab5328aa5aaea9b16d083b128102f8712542f819fe7e6a420ff581b311"}, @@ -2968,6 +3100,7 @@ version = "0.4.0" description = "Collection of plugins for markdown-it-py" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "mdit_py_plugins-0.4.0-py3-none-any.whl", hash = "sha256:b51b3bb70691f57f974e257e367107857a93b36f322a9e6d44ca5bf28ec2def9"}, {file = "mdit_py_plugins-0.4.0.tar.gz", hash = "sha256:d8ab27e9aed6c38aa716819fedfde15ca275715955f8a185a8e1cf90fb1d2c1b"}, @@ -2987,6 +3120,7 @@ version = "0.1.2" description = "Markdown URL utilities" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"}, {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, @@ -2998,6 +3132,7 @@ version = "3.0.2" description = "A sane and fast Markdown parser with useful plugins and renderers" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "mistune-3.0.2-py3-none-any.whl", hash = "sha256:71481854c30fdbc938963d3605b72501f5c10a9320ecd412c121c163a1c7d205"}, {file = "mistune-3.0.2.tar.gz", hash = "sha256:fc7f93ded930c92394ef2cb6f04a8aabab4117a91449e72dcc8dfa646a508be8"}, @@ -3009,6 +3144,7 @@ version = "1.3.0" description = "Python library for arbitrary-precision floating-point arithmetic" optional = false python-versions = "*" +groups = ["main"] files = [ {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"}, {file = "mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f"}, @@ -3017,7 +3153,7 @@ files = [ [package.extras] develop = ["codecov", "pycodestyle", "pytest (>=4.6)", "pytest-cov", "wheel"] docs = ["sphinx"] -gmpy = ["gmpy2 (>=2.1.0a4)"] +gmpy = ["gmpy2 (>=2.1.0a4) ; platform_python_implementation != \"PyPy\""] tests = ["pytest (>=4.6)"] [[package]] @@ -3026,6 +3162,7 @@ version = "1.0.7" description = "MessagePack serializer" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "msgpack-1.0.7-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:04ad6069c86e531682f9e1e71b71c1c3937d6014a7c3e9edd2aa81ad58842862"}, {file = "msgpack-1.0.7-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:cca1b62fe70d761a282496b96a5e51c44c213e410a964bdffe0928e611368329"}, @@ -3091,6 +3228,8 @@ version = "2.3.7" description = "MuJoCo Physics Simulator" optional = true python-versions = ">=3.8" +groups = ["main"] +markers = "extra == \"robotics\" or extra == \"mujoco\"" files = [ {file = "mujoco-2.3.7-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:e8714a5ff6a1561b364b7b4648d4c0c8d13e751874cf7401c309b9d23fa9598b"}, {file = "mujoco-2.3.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a934315f858a4e0c4b90a682fde519471cfdd7baa64435179da8cd20d4ae3f99"}, @@ -3125,31 +3264,13 @@ glfw = "*" numpy = "*" pyopengl = "*" -[[package]] -name = "mujoco-py" -version = "2.1.2.14" -description = "" -optional = true -python-versions = ">=3.6" -files = [ - {file = "mujoco-py-2.1.2.14.tar.gz", hash = "sha256:eb5b14485acf80a3cf8c15f4b080c6a28a9f79e68869aa696d16cbd51ea7706f"}, - {file = "mujoco_py-2.1.2.14-py3-none-any.whl", hash = "sha256:37c0b41bc0153a8a0eb3663103a67c60f65467753f74e4ff6e68b879f3e3a71f"}, -] - -[package.dependencies] -cffi = ">=1.10" -Cython = ">=0.27.2" -fasteners = ">=0.15,<1.0" -glfw = ">=1.4.0" -imageio = ">=2.1.2" -numpy = ">=1.11" - [[package]] name = "mypy" version = "1.7.0" description = "Optional static typing for Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "mypy-1.7.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5da84d7bf257fd8f66b4f759a904fd2c5a765f70d8b52dde62b521972a0a2357"}, {file = "mypy-1.7.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a3637c03f4025f6405737570d6cbfa4f1400eb3c649317634d273687a09ffc2f"}, @@ -3196,6 +3317,7 @@ version = "1.0.0" description = "Type system extensions for programs checked with the mypy type checker." optional = false python-versions = ">=3.5" +groups = ["dev"] files = [ {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, @@ -3207,6 +3329,7 @@ version = "1.0.0" description = "A Jupyter Notebook Sphinx reader built on top of the MyST markdown parser." optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "myst_nb-1.0.0-py3-none-any.whl", hash = "sha256:ee8febc6dd7d9e32bede0c66a9b962b2e2fdab697428ee9fbfd4919d82380911"}, {file = "myst_nb-1.0.0.tar.gz", hash = "sha256:9077e42a1c6b441ea55078506f83555dda5d6c816ef4930841d71d239e3e0c5e"}, @@ -3235,6 +3358,7 @@ version = "2.0.0" description = "An extended [CommonMark](https://spec.commonmark.org/) compliant parser," optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "myst_parser-2.0.0-py3-none-any.whl", hash = "sha256:7c36344ae39c8e740dad7fdabf5aa6fc4897a813083c6cc9990044eb93656b14"}, {file = "myst_parser-2.0.0.tar.gz", hash = "sha256:ea929a67a6a0b1683cdbe19b8d2e724cd7643f8aa3e7bb18dd65beac3483bead"}, @@ -3261,6 +3385,7 @@ version = "8.4.0" description = "Simple yet flexible natural sorting in Python." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "natsort-8.4.0-py3-none-any.whl", hash = "sha256:4732914fb471f56b5cce04d7bae6f164a592c7712e1c85f9ef585e197299521c"}, {file = "natsort-8.4.0.tar.gz", hash = "sha256:45312c4a0e5507593da193dedd04abb1469253b601ecaf63445ad80f0a1ea581"}, @@ -3276,6 +3401,7 @@ version = "0.7.4" description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor." optional = false python-versions = ">=3.7.0" +groups = ["dev"] files = [ {file = "nbclient-0.7.4-py3-none-any.whl", hash = "sha256:c817c0768c5ff0d60e468e017613e6eae27b6fa31e43f905addd2d24df60c125"}, {file = "nbclient-0.7.4.tar.gz", hash = "sha256:d447f0e5a4cfe79d462459aec1b3dc5c2e9152597262be8ee27f7d4c02566a0d"}, @@ -3298,6 +3424,7 @@ version = "7.11.0" description = "Converting Jupyter Notebooks" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "nbconvert-7.11.0-py3-none-any.whl", hash = "sha256:d1d417b7f34a4e38887f8da5bdfd12372adf3b80f995d57556cb0972c68909fe"}, {file = "nbconvert-7.11.0.tar.gz", hash = "sha256:abedc01cf543177ffde0bfc2a69726d5a478f6af10a332fc1bf29fcb4f0cf000"}, @@ -3335,6 +3462,7 @@ version = "5.9.2" description = "The Jupyter Notebook format" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "nbformat-5.9.2-py3-none-any.whl", hash = "sha256:1c5172d786a41b82bcfd0c23f9e6b6f072e8fb49c39250219e4acfff1efe89e9"}, {file = "nbformat-5.9.2.tar.gz", hash = "sha256:5f98b5ba1997dff175e77e0c17d5c10a96eaed2cbd1de3533d1fc35d5e111192"}, @@ -3356,6 +3484,7 @@ version = "1.7.1" description = "Run any standard Python code quality tool on a Jupyter Notebook" optional = false python-versions = ">=3.8.0" +groups = ["dev"] files = [ {file = "nbqa-1.7.1-py3-none-any.whl", hash = "sha256:77cdff622bfcf527bf260004449984edfb3624f6e065ac6bb35d64cddcdad483"}, {file = "nbqa-1.7.1.tar.gz", hash = "sha256:44f5b5000d6df438c4f1cba339e3ad80acc405e61f4500ac951fa36a177133f4"}, @@ -3376,6 +3505,7 @@ version = "0.6.1" description = "Strips outputs from Jupyter and IPython notebooks" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "nbstripout-0.6.1-py2.py3-none-any.whl", hash = "sha256:5ff6eb0debbcd656c4a64db8e082a24fabcfc753a9e8c9f6d786971e8f29e110"}, {file = "nbstripout-0.6.1.tar.gz", hash = "sha256:9065bcdd1488b386e4f3c081ffc1d48f4513a2f8d8bf4d0d9a28208c5dafe9d3"}, @@ -3390,6 +3520,7 @@ version = "1.5.8" description = "Patch asyncio to allow nested event loops" optional = false python-versions = ">=3.5" +groups = ["dev"] files = [ {file = "nest_asyncio-1.5.8-py3-none-any.whl", hash = "sha256:accda7a339a70599cb08f9dd09a67e0c2ef8d8d6f4c07f96ab203f2ae254e48d"}, {file = "nest_asyncio-1.5.8.tar.gz", hash = "sha256:25aa2ca0d2a5b5531956b9e273b45cf664cae2b145101d73b86b199978d48fdb"}, @@ -3401,6 +3532,7 @@ version = "3.2.1" description = "Python package for creating and manipulating graphs and networks" optional = false python-versions = ">=3.9" +groups = ["main", "dev"] files = [ {file = "networkx-3.2.1-py3-none-any.whl", hash = "sha256:f18c69adc97877c42332c170849c96cefa91881c99a7cb3e95b7c659ebdc1ec2"}, {file = "networkx-3.2.1.tar.gz", hash = "sha256:9f1bb5cf3409bf324e0a722c20bdb4c20ee39bf1c30ce8ae499c8502b0b5e0c6"}, @@ -3419,6 +3551,7 @@ version = "1.8.0" description = "Node.js virtual environment builder" optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*" +groups = ["dev"] files = [ {file = "nodeenv-1.8.0-py2.py3-none-any.whl", hash = "sha256:df865724bb3c3adc86b3876fa209771517b0cfe596beff01a92700e0e8be4cec"}, {file = "nodeenv-1.8.0.tar.gz", hash = "sha256:d51e0c37e64fbf47d017feac3145cdbb58836d7eee8c6f6d3b6880c5456227d2"}, @@ -3433,6 +3566,7 @@ version = "7.2.2" description = "Jupyter Notebook - A web-based notebook environment for interactive computing" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "notebook-7.2.2-py3-none-any.whl", hash = "sha256:c89264081f671bc02eec0ed470a627ed791b9156cad9285226b31611d3e9fe1c"}, {file = "notebook-7.2.2.tar.gz", hash = "sha256:2ef07d4220421623ad3fe88118d687bc0450055570cdd160814a59cf3a1c516e"}, @@ -3448,7 +3582,7 @@ tornado = ">=6.2.0" [package.extras] dev = ["hatch", "pre-commit"] docs = ["myst-parser", "nbsphinx", "pydata-sphinx-theme", "sphinx (>=1.3.6)", "sphinxcontrib-github-alt", "sphinxcontrib-spelling"] -test = ["importlib-resources (>=5.0)", "ipykernel", "jupyter-server[test] (>=2.4.0,<3)", "jupyterlab-server[test] (>=2.27.1,<3)", "nbval", "pytest (>=7.0)", "pytest-console-scripts", "pytest-timeout", "pytest-tornasync", "requests"] +test = ["importlib-resources (>=5.0) ; python_version < \"3.10\"", "ipykernel", "jupyter-server[test] (>=2.4.0,<3)", "jupyterlab-server[test] (>=2.27.1,<3)", "nbval", "pytest (>=7.0)", "pytest-console-scripts", "pytest-timeout", "pytest-tornasync", "requests"] [[package]] name = "notebook-shim" @@ -3456,6 +3590,7 @@ version = "0.2.3" description = "A shim layer for notebook traits and config" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "notebook_shim-0.2.3-py3-none-any.whl", hash = "sha256:a83496a43341c1674b093bfcebf0fe8e74cbe7eda5fd2bbc56f8e39e1486c0c7"}, {file = "notebook_shim-0.2.3.tar.gz", hash = "sha256:f69388ac283ae008cd506dda10d0288b09a017d822d5e8c7129a152cbd3ce7e9"}, @@ -3473,6 +3608,7 @@ version = "0.60.0" description = "compiling Python code using LLVM" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "numba-0.60.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5d761de835cd38fb400d2c26bb103a2726f548dc30368853121d66201672e651"}, {file = "numba-0.60.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:159e618ef213fba758837f9837fb402bbe65326e60ba0633dbe6c7f274d42c1b"}, @@ -3507,6 +3643,7 @@ version = "1.24.4" description = "Fundamental package for array computing in Python" optional = false python-versions = ">=3.8" +groups = ["main", "dev"] files = [ {file = "numpy-1.24.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0bfb52d2169d58c1cdb8cc1f16989101639b34c7d3ce60ed70b19c63eba0b64"}, {file = "numpy-1.24.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ed094d4f0c177b1b8e7aa9cba7d6ceed51c0e569a5318ac0ca9a090680a6a1b1"}, @@ -3544,6 +3681,8 @@ version = "12.1.3.1" description = "CUBLAS native runtime libraries" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728"}, {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906"}, @@ -3555,6 +3694,8 @@ version = "12.1.105" description = "CUDA profiling tools runtime libs." optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e"}, {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4"}, @@ -3566,6 +3707,8 @@ version = "12.1.105" description = "NVRTC native runtime libraries" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2"}, {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed"}, @@ -3577,6 +3720,8 @@ version = "12.1.105" description = "CUDA Runtime native Libraries" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40"}, {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344"}, @@ -3588,6 +3733,8 @@ version = "8.9.2.26" description = "cuDNN runtime libraries" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl", hash = "sha256:5ccb288774fdfb07a7e7025ffec286971c06d8d7b4fb162525334616d7629ff9"}, ] @@ -3601,6 +3748,8 @@ version = "11.0.2.54" description = "CUFFT native runtime libraries" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56"}, {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253"}, @@ -3612,6 +3761,8 @@ version = "10.3.2.106" description = "CURAND native runtime libraries" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0"}, {file = "nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a"}, @@ -3623,6 +3774,8 @@ version = "11.4.5.107" description = "CUDA solver native runtime libraries" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd"}, {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5"}, @@ -3639,6 +3792,8 @@ version = "12.1.0.106" description = "CUSPARSE native runtime libraries" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c"}, {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a"}, @@ -3653,6 +3808,8 @@ version = "2.18.1" description = "NVIDIA Collective Communication Library (NCCL) Runtime" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_nccl_cu12-2.18.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:1a6c4acefcbebfa6de320f412bf7866de856e786e0462326ba1bac40de0b5e71"}, ] @@ -3663,6 +3820,8 @@ version = "12.3.101" description = "Nvidia JIT LTO Library" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux1_x86_64.whl", hash = "sha256:64335a8088e2b9d196ae8665430bc6a2b7e6ef2eb877a9c735c804bd4ff6467c"}, {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux2014_aarch64.whl", hash = "sha256:211a63e7b30a9d62f1a853e19928fbb1a750e3f17a13a3d1f98ff0ced19478dd"}, @@ -3675,6 +3834,8 @@ version = "12.1.105" description = "NVIDIA Tools Extension" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5"}, {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"}, @@ -3686,6 +3847,7 @@ version = "3.2.2" description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic" optional = false python-versions = ">=3.6" +groups = ["main"] files = [ {file = "oauthlib-3.2.2-py3-none-any.whl", hash = "sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca"}, {file = "oauthlib-3.2.2.tar.gz", hash = "sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918"}, @@ -3702,6 +3864,8 @@ version = "4.8.1.78" description = "Wrapper package for OpenCV python bindings." optional = true python-versions = ">=3.6" +groups = ["main"] +markers = "extra == \"atari\"" files = [ {file = "opencv-python-4.8.1.78.tar.gz", hash = "sha256:cc7adbbcd1112877a39274106cb2752e04984bc01a031162952e97450d6117f6"}, {file = "opencv_python-4.8.1.78-cp37-abi3-macosx_10_16_x86_64.whl", hash = "sha256:91d5f6f5209dc2635d496f6b8ca6573ecdad051a09e6b5de4c399b8e673c60da"}, @@ -3721,6 +3885,8 @@ version = "0.10.0" description = "Optimized PyTree Utilities." optional = true python-versions = ">=3.7" +groups = ["main"] +markers = "sys_platform != \"darwin\" and extra == \"envpool\"" files = [ {file = "optree-0.10.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ac2c0fa383f504f03887a0c0ffcb6a4187c43c8c99c32f52ff14e7eae2c8c69b"}, {file = "optree-0.10.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8fa16b16203938b7a9caa4603998d0968b408f7f3a1a9f7f84763802daf1cff0"}, @@ -3781,6 +3947,7 @@ version = "4.1.0" description = "An OrderedSet is a custom MutableSet that remembers its order, so that every" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "ordered-set-4.1.0.tar.gz", hash = "sha256:694a8e44c87657c59292ede72891eb91d34131f6531463aab3009191c77364a8"}, {file = "ordered_set-4.1.0-py3-none-any.whl", hash = "sha256:046e1132c71fcf3330438a539928932caf51ddbc582496833e23de611de14562"}, @@ -3795,6 +3962,7 @@ version = "7.4.0" description = "A decorator to automatically detect mismatch when overriding a method." optional = false python-versions = ">=3.6" +groups = ["main", "dev"] files = [ {file = "overrides-7.4.0-py3-none-any.whl", hash = "sha256:3ad24583f86d6d7a49049695efe9933e67ba62f0c7625d53c59fa832ce4b8b7d"}, {file = "overrides-7.4.0.tar.gz", hash = "sha256:9502a3cca51f4fac40b5feca985b6703a5c1f6ad815588a7ca9e285b9dca6757"}, @@ -3806,6 +3974,7 @@ version = "23.2" description = "Core utilities for Python packages" optional = false python-versions = ">=3.7" +groups = ["main", "dev"] files = [ {file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"}, {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"}, @@ -3817,6 +3986,7 @@ version = "2.1.0" description = "Powerful data structures for data analysis, time series, and statistics" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "pandas-2.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:40dd20439ff94f1b2ed55b393ecee9cb6f3b08104c2c40b0cb7186a2f0046242"}, {file = "pandas-2.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d4f38e4fedeba580285eaac7ede4f686c6701a9e618d8a857b138a126d067f2f"}, @@ -3875,6 +4045,7 @@ version = "1.5.0" description = "Utilities for writing pandoc filters in python" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +groups = ["dev"] files = [ {file = "pandocfilters-1.5.0-py2.py3-none-any.whl", hash = "sha256:33aae3f25fd1a026079f5d27bdd52496f0e0803b3469282162bafdcbdf6ef14f"}, {file = "pandocfilters-1.5.0.tar.gz", hash = "sha256:0b679503337d233b4339a817bfc8c50064e2eff681314376a47cb582305a7a38"}, @@ -3886,6 +4057,7 @@ version = "0.8.3" description = "A Python Parser" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "parso-0.8.3-py2.py3-none-any.whl", hash = "sha256:c001d4636cd3aecdaf33cbb40aebb59b094be2a74c556778ef5576c175e19e75"}, {file = "parso-0.8.3.tar.gz", hash = "sha256:8c07be290bb59f03588915921e29e8a50002acaf2cdc5fa0e0114f91709fafa0"}, @@ -3901,6 +4073,7 @@ version = "0.2.1" description = "Bring colors to your terminal." optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +groups = ["dev"] files = [ {file = "pastel-0.2.1-py2.py3-none-any.whl", hash = "sha256:4349225fcdf6c2bb34d483e523475de5bb04a5c10ef711263452cb37d7dd4364"}, {file = "pastel-0.2.1.tar.gz", hash = "sha256:e6581ac04e973cac858828c6202c1e1e81fee1dc7de7683f3e1ffe0bfd8a573d"}, @@ -3912,6 +4085,7 @@ version = "0.11.2" description = "Utility library for gitignore style pattern matching of file paths." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "pathspec-0.11.2-py3-none-any.whl", hash = "sha256:1d6ed233af05e679efb96b1851550ea95bbb64b7c490b0f5aa52996c11e92a20"}, {file = "pathspec-0.11.2.tar.gz", hash = "sha256:e0d8d0ac2f12da61956eb2306b69f9469b42f4deb0f3cb6ed47b9cce9996ced3"}, @@ -3923,6 +4097,7 @@ version = "0.1.2" description = "File system general utilities" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "pathtools-0.1.2.tar.gz", hash = "sha256:7c35c5421a39bb82e58018febd90e3b6e5db34c5443aaaf742b3f33d4655f1c0"}, ] @@ -3933,6 +4108,8 @@ version = "0.5.6" description = "A Python package for describing statistical models and for building design matrices." optional = true python-versions = "*" +groups = ["main"] +markers = "extra == \"eval\"" files = [ {file = "patsy-0.5.6-py2.py3-none-any.whl", hash = "sha256:19056886fd8fa71863fa32f0eb090267f21fb74be00f19f5c70b2e9d76c883c6"}, {file = "patsy-0.5.6.tar.gz", hash = "sha256:95c6d47a7222535f84bff7f63d7303f2e297747a598db89cf5c67f0c0c7d2cdb"}, @@ -3951,6 +4128,7 @@ version = "1.24.2" description = "Gymnasium for multi-agent reinforcement learning." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "pettingzoo-1.24.2-py3-none-any.whl", hash = "sha256:00268cf990d243654c2bbbbf8c88322c12b041dc0a879b74747f14ee8aa93dd6"}, {file = "pettingzoo-1.24.2.tar.gz", hash = "sha256:0a5856d47de78ab20feddfdac4940959dc892f6becc92107247b1c3a210c0984"}, @@ -3976,6 +4154,8 @@ version = "4.8.0" description = "Pexpect allows easy control of interactive console applications." optional = false python-versions = "*" +groups = ["dev"] +markers = "sys_platform != \"win32\"" files = [ {file = "pexpect-4.8.0-py2.py3-none-any.whl", hash = "sha256:0b48a55dcb3c05f3329815901ea4fc1537514d6ba867a152b581d69ae3710937"}, {file = "pexpect-4.8.0.tar.gz", hash = "sha256:fc65a43959d153d0114afe13997d439c22823a27cefceb5ff35c2178c6784c0c"}, @@ -3990,6 +4170,7 @@ version = "10.2.0" description = "Python Imaging Library (Fork)" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "pillow-10.2.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:7823bdd049099efa16e4246bdf15e5a13dbb18a51b68fa06d6c1d4d8b99a796e"}, {file = "pillow-10.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:83b2021f2ade7d1ed556bc50a399127d7fb245e725aa0113ebd05cfe88aaf588"}, @@ -4066,7 +4247,7 @@ docs = ["furo", "olefile", "sphinx (>=2.4)", "sphinx-copybutton", "sphinx-inline fpx = ["olefile"] mic = ["olefile"] tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"] -typing = ["typing-extensions"] +typing = ["typing-extensions ; python_version < \"3.10\""] xmp = ["defusedxml"] [[package]] @@ -4075,6 +4256,8 @@ version = "2.6.2" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." optional = false python-versions = ">=3.7" +groups = ["main", "dev"] +markers = "sys_platform == \"win32\"" files = [ {file = "platformdirs-2.6.2-py3-none-any.whl", hash = "sha256:83c8f6d04389165de7c9b6f0c682439697887bca0aa2f1c87ef1826be3584490"}, {file = "platformdirs-2.6.2.tar.gz", hash = "sha256:e1fea1fe471b9ff8332e229df3cb7de4f53eeea4998d3b6bfff542115e998bd2"}, @@ -4090,6 +4273,8 @@ version = "3.11.0" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." optional = false python-versions = ">=3.7" +groups = ["main", "dev"] +markers = "sys_platform != \"win32\"" files = [ {file = "platformdirs-3.11.0-py3-none-any.whl", hash = "sha256:e9d171d00af68be50e9202731309c4e658fd8bc76f55c11c7dd760d023bda68e"}, {file = "platformdirs-3.11.0.tar.gz", hash = "sha256:cf8ee52a3afdb965072dcc652433e0c7e3e40cf5ea1477cd4b3b1d2eb75495b3"}, @@ -4105,6 +4290,7 @@ version = "1.3.0" description = "plugin and hook calling mechanisms for python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "pluggy-1.3.0-py3-none-any.whl", hash = "sha256:d89c696a773f8bd377d18e5ecda92b7a3793cbe66c87060a6fb58c7b6e1061f7"}, {file = "pluggy-1.3.0.tar.gz", hash = "sha256:cf61ae8f126ac6f7c451172cf30e3e43d3ca77615509771b3a984a0730651e12"}, @@ -4120,6 +4306,7 @@ version = "0.20.0" description = "A task runner that works well with poetry." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "poethepoet-0.20.0-py3-none-any.whl", hash = "sha256:cb37be15f3895ccc65ddf188c2e3d8fb79e26cc9d469a6098cb1c6f994659f6f"}, {file = "poethepoet-0.20.0.tar.gz", hash = "sha256:ca5a2a955f52dfb0a53fad3c989ef0b69ce3d5ec0f6bfa9b1da1f9e32d262e20"}, @@ -4138,6 +4325,7 @@ version = "3.5.0" description = "A framework for managing and maintaining multi-language pre-commit hooks." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "pre_commit-3.5.0-py2.py3-none-any.whl", hash = "sha256:841dc9aef25daba9a0238cd27984041fa0467b4199fc4852e27950664919f660"}, {file = "pre_commit-3.5.0.tar.gz", hash = "sha256:5804465c675b659b0862f07907f96295d490822a450c4c40e747d0b1c6ebcb32"}, @@ -4156,6 +4344,7 @@ version = "0.18.0" description = "Python client for the Prometheus monitoring system." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "prometheus_client-0.18.0-py3-none-any.whl", hash = "sha256:8de3ae2755f890826f4b6479e5571d4f74ac17a81345fe69a6778fdb92579184"}, {file = "prometheus_client-0.18.0.tar.gz", hash = "sha256:35f7a8c22139e2bb7ca5a698e92d38145bc8dc74c1c0bf56f25cca886a764e17"}, @@ -4170,6 +4359,7 @@ version = "2.3" description = "Promises/A+ implementation for Python" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "promise-2.3.tar.gz", hash = "sha256:dfd18337c523ba4b6a58801c164c1904a9d4d1b1747c7d5dbf45b693a49d93d0"}, ] @@ -4186,6 +4376,7 @@ version = "3.0.41" description = "Library for building powerful interactive command lines in Python" optional = false python-versions = ">=3.7.0" +groups = ["dev"] files = [ {file = "prompt_toolkit-3.0.41-py3-none-any.whl", hash = "sha256:f36fe301fafb7470e86aaf90f036eef600a3210be4decf461a5b1ca8403d3cb2"}, {file = "prompt_toolkit-3.0.41.tar.gz", hash = "sha256:941367d97fc815548822aa26c2a269fdc4eb21e9ec05fc5d447cf09bad5d75f0"}, @@ -4200,6 +4391,8 @@ version = "1.6.4" description = "A decorator for caching properties in classes (forked from cached-property)." optional = true python-versions = ">= 3.5" +groups = ["main"] +markers = "extra == \"eval\"" files = [ {file = "property-cached-1.6.4.zip", hash = "sha256:3e9c4ef1ed3653909147510481d7df62a3cfb483461a6986a6f1dcd09b2ebb73"}, {file = "property_cached-1.6.4-py2.py3-none-any.whl", hash = "sha256:135fc059ec969c1646424a0db15e7fbe1b5f8c36c0006d0b3c91ba568c11e7d8"}, @@ -4211,6 +4404,7 @@ version = "3.20.3" description = "Protocol Buffers" optional = false python-versions = ">=3.7" +groups = ["main", "dev"] files = [ {file = "protobuf-3.20.3-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:f4bd856d702e5b0d96a00ec6b307b0f51c1982c2bf9c0052cf9019e9a544ba99"}, {file = "protobuf-3.20.3-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:9aae4406ea63d825636cc11ffb34ad3379335803216ee3a856787bcf5ccc751e"}, @@ -4242,6 +4436,7 @@ version = "5.9.6" description = "Cross-platform lib for process and system monitoring in Python." optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" +groups = ["dev"] files = [ {file = "psutil-5.9.6-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:fb8a697f11b0f5994550555fcfe3e69799e5b060c8ecf9e2f75c69302cc35c0d"}, {file = "psutil-5.9.6-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:91ecd2d9c00db9817a4b4192107cf6954addb5d9d67a969a4f436dbc9200f88c"}, @@ -4262,7 +4457,7 @@ files = [ ] [package.extras] -test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] +test = ["enum34 ; python_version <= \"3.4\"", "ipaddress ; python_version < \"3.0\"", "mock ; python_version < \"3.0\"", "pywin32 ; sys_platform == \"win32\"", "wmi ; sys_platform == \"win32\""] [[package]] name = "ptyprocess" @@ -4270,6 +4465,8 @@ version = "0.7.0" description = "Run a subprocess in a pseudo terminal" optional = false python-versions = "*" +groups = ["dev"] +markers = "os_name != \"nt\" or sys_platform != \"win32\"" files = [ {file = "ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35"}, {file = "ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220"}, @@ -4281,6 +4478,7 @@ version = "0.2.2" description = "Safely evaluate AST nodes without side effects" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "pure_eval-0.2.2-py3-none-any.whl", hash = "sha256:01eaab343580944bc56080ebe0a674b39ec44a945e6d09ba7db3cb8cec289350"}, {file = "pure_eval-0.2.2.tar.gz", hash = "sha256:2b45320af6dfaa1750f543d714b6d1c520a1688dec6fd24d339063ce0aaa9ac3"}, @@ -4295,6 +4493,7 @@ version = "0.5.0" description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +groups = ["main"] files = [ {file = "pyasn1-0.5.0-py2.py3-none-any.whl", hash = "sha256:87a2121042a1ac9358cabcaf1d07680ff97ee6404333bacca15f76aa8ad01a57"}, {file = "pyasn1-0.5.0.tar.gz", hash = "sha256:97b7290ca68e62a832558ec3976f15cbf911bf5d7c7039d8b861c2a0ece69fde"}, @@ -4306,6 +4505,7 @@ version = "0.3.0" description = "A collection of ASN.1-based protocols modules" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +groups = ["main"] files = [ {file = "pyasn1_modules-0.3.0-py2.py3-none-any.whl", hash = "sha256:d3ccd6ed470d9ffbc716be08bd90efbd44d0734bc9303818f7336070984a162d"}, {file = "pyasn1_modules-0.3.0.tar.gz", hash = "sha256:5bd01446b736eb9d31512a30d46c1ac3395d676c6f3cafa4c03eb54b9925631c"}, @@ -4320,6 +4520,7 @@ version = "0.24.0" description = "A BibTeX-compatible bibliography processor in Python" optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*" +groups = ["dev"] files = [ {file = "pybtex-0.24.0-py2.py3-none-any.whl", hash = "sha256:e1e0c8c69998452fea90e9179aa2a98ab103f3eed894405b7264e517cc2fcc0f"}, {file = "pybtex-0.24.0.tar.gz", hash = "sha256:818eae35b61733e5c007c3fcd2cfb75ed1bc8b4173c1f70b56cc4c0802d34755"}, @@ -4339,6 +4540,7 @@ version = "1.0.3" description = "A docutils backend for pybtex." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "pybtex-docutils-1.0.3.tar.gz", hash = "sha256:3a7ebdf92b593e00e8c1c538aa9a20bca5d92d84231124715acc964d51d93c6b"}, {file = "pybtex_docutils-1.0.3-py3-none-any.whl", hash = "sha256:8fd290d2ae48e32fcb54d86b0efb8d573198653c7e2447d5bec5847095f430b9"}, @@ -4354,6 +4556,8 @@ version = "3.2.5" description = "Official Python Interface for the Bullet Physics SDK specialized for Robotics Simulation and Reinforcement Learning" optional = true python-versions = "*" +groups = ["main"] +markers = "extra == \"pybullet\"" files = [ {file = "pybullet-3.2.5-cp27-cp27m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:4970aec0dd968924f6b1820655a20f80650da2f85ba38b641937c9701a8a2b14"}, {file = "pybullet-3.2.5-cp27-cp27mu-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:b64e4523a11d03729035e0a5baa0ce4d2ca58de8d0a242c0b91e8253781b24c4"}, @@ -4371,6 +4575,7 @@ version = "2.11.1" description = "Python style guide checker" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "pycodestyle-2.11.1-py2.py3-none-any.whl", hash = "sha256:44fe31000b2d866f2e41841b18528a505fbd7fef9017b04eff4e2648a0fadc67"}, {file = "pycodestyle-2.11.1.tar.gz", hash = "sha256:41ba0e7afc9752dfb53ced5489e89f8186be00e599e712660695b7a75ff2663f"}, @@ -4382,6 +4587,7 @@ version = "2.21" description = "C parser in Python" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +groups = ["dev"] files = [ {file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"}, {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"}, @@ -4393,6 +4599,7 @@ version = "0.14.3" description = "Bootstrap-based Sphinx theme from the PyData community" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "pydata_sphinx_theme-0.14.3-py3-none-any.whl", hash = "sha256:b7e40cd75a20449adfe2d7525be379b9fe92f6d31e5233e449fa34ddcd4398d9"}, {file = "pydata_sphinx_theme-0.14.3.tar.gz", hash = "sha256:bd474f347895f3fc5b6ce87390af64330ee54f11ebf9660d5bc3f87d532d4e5c"}, @@ -4420,6 +4627,7 @@ version = "3.2.2" description = "Python bindings for the Enchant spellchecking system" optional = false python-versions = ">=3.5" +groups = ["dev"] files = [ {file = "pyenchant-3.2.2-py3-none-any.whl", hash = "sha256:5facc821ece957208a81423af7d6ec7810dad29697cb0d77aae81e4e11c8e5a6"}, {file = "pyenchant-3.2.2-py3-none-win32.whl", hash = "sha256:5a636832987eaf26efe971968f4d1b78e81f62bca2bde0a9da210c7de43c3bce"}, @@ -4433,6 +4641,7 @@ version = "2.5.2" description = "Python Game Development" optional = false python-versions = ">=3.6" +groups = ["main", "dev"] files = [ {file = "pygame-2.5.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a0769eb628c818761755eb0a0ca8216b95270ea8cbcbc82227e39ac9644643da"}, {file = "pygame-2.5.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ed9a3d98adafa0805ccbaaff5d2996a2b5795381285d8437a4a5d248dbd12b4a"}, @@ -4499,13 +4708,14 @@ version = "2.17.1" description = "Pygments is a syntax highlighting package written in Python." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "pygments-2.17.1-py3-none-any.whl", hash = "sha256:1b37f1b1e1bff2af52ecaf28cc601e2ef7077000b227a0675da25aef85784bc4"}, {file = "pygments-2.17.1.tar.gz", hash = "sha256:e45a0e74bf9c530f564ca81b8952343be986a29f6afe7f5ad95c5f06b7bdf5e8"}, ] [package.extras] -plugins = ["importlib-metadata"] +plugins = ["importlib-metadata ; python_version < \"3.8\""] windows-terminal = ["colorama (>=0.4.6)"] [[package]] @@ -4514,6 +4724,7 @@ version = "6.6.0" description = "Pymunk is a easy-to-use pythonic 2d physics library" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "pymunk-6.6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6da50dd97683337a290110d594fad07a75153d2d837b570ef972478d739c33f8"}, {file = "pymunk-6.6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:bcd7d16a2b4d51d45d6780a701f65c8d5b36fdf545c3f4738910da41e2a9c4ee"}, @@ -4585,6 +4796,8 @@ version = "3.1.7" description = "Standard OpenGL bindings for Python" optional = true python-versions = "*" +groups = ["main"] +markers = "extra == \"robotics\" or extra == \"mujoco\"" files = [ {file = "PyOpenGL-3.1.7-py3-none-any.whl", hash = "sha256:a6ab19cf290df6101aaf7470843a9c46207789855746399d0af92521a0a92b7a"}, {file = "PyOpenGL-3.1.7.tar.gz", hash = "sha256:eef31a3888e6984fd4d8e6c9961b184c9813ca82604d37fe3da80eb000a76c86"}, @@ -4596,6 +4809,7 @@ version = "3.1.2" description = "pyparsing module - Classes and methods to define and execute parsing grammars" optional = false python-versions = ">=3.6.8" +groups = ["main"] files = [ {file = "pyparsing-3.1.2-py3-none-any.whl", hash = "sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742"}, {file = "pyparsing-3.1.2.tar.gz", hash = "sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad"}, @@ -4610,6 +4824,7 @@ version = "7.4.3" description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "pytest-7.4.3-py3-none-any.whl", hash = "sha256:0d009c083ea859a71b76adf7c1d502e4bc170b80a8ef002da5806527b9591fac"}, {file = "pytest-7.4.3.tar.gz", hash = "sha256:d989d136982de4e3b29dabcc838ad581c64e8ed52c11fbe86ddebd9da0818cd5"}, @@ -4630,6 +4845,7 @@ version = "4.1.0" description = "Pytest plugin for measuring coverage." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "pytest-cov-4.1.0.tar.gz", hash = "sha256:3904b13dfbfec47f003b8e77fd5b589cd11904a21ddf1ab38a64f204d6a10ef6"}, {file = "pytest_cov-4.1.0-py3-none-any.whl", hash = "sha256:6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a"}, @@ -4648,6 +4864,7 @@ version = "2.8.2" description = "Extensions to the standard Python datetime module" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +groups = ["main", "dev"] files = [ {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, @@ -4662,6 +4879,7 @@ version = "2.0.7" description = "A python library adding a json log formatter" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "python-json-logger-2.0.7.tar.gz", hash = "sha256:23e7ec02d34237c5aa1e29a070193a4ea87583bb4e7f8fd06d3de8264c4b2e1c"}, {file = "python_json_logger-2.0.7-py3-none-any.whl", hash = "sha256:f380b826a991ebbe3de4d897aeec42760035ac760345e57b812938dc8b35e2bd"}, @@ -4673,6 +4891,7 @@ version = "2024.1" description = "World timezone definitions, modern and historical" optional = false python-versions = "*" +groups = ["main"] files = [ {file = "pytz-2024.1-py2.py3-none-any.whl", hash = "sha256:328171f4e3623139da4983451950b28e95ac706e13f3f2630a879749e7a8b319"}, {file = "pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812"}, @@ -4684,6 +4903,8 @@ version = "306" description = "Python for Window Extensions" optional = false python-versions = "*" +groups = ["dev"] +markers = "sys_platform == \"win32\" and platform_python_implementation != \"PyPy\"" files = [ {file = "pywin32-306-cp310-cp310-win32.whl", hash = "sha256:06d3420a5155ba65f0b72f2699b5bacf3109f36acbe8923765c22938a69dfc8d"}, {file = "pywin32-306-cp310-cp310-win_amd64.whl", hash = "sha256:84f4471dbca1887ea3803d8848a1616429ac94a4a8d05f4bc9c5dcfd42ca99c8"}, @@ -4707,6 +4928,8 @@ version = "2.0.12" description = "Pseudo terminal support for Windows from Python." optional = false python-versions = ">=3.8" +groups = ["dev"] +markers = "os_name == \"nt\"" files = [ {file = "pywinpty-2.0.12-cp310-none-win_amd64.whl", hash = "sha256:21319cd1d7c8844fb2c970fb3a55a3db5543f112ff9cfcd623746b9c47501575"}, {file = "pywinpty-2.0.12-cp311-none-win_amd64.whl", hash = "sha256:853985a8f48f4731a716653170cd735da36ffbdc79dcb4c7b7140bce11d8c722"}, @@ -4722,6 +4945,7 @@ version = "6.0.1" description = "YAML parser and emitter for Python" optional = false python-versions = ">=3.6" +groups = ["main", "dev"] files = [ {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"}, {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"}, @@ -4775,6 +4999,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, ] +markers = {main = "extra == \"argparse\" or extra == \"eval\""} [[package]] name = "pyzmq" @@ -4782,6 +5007,7 @@ version = "25.1.1" description = "Python bindings for 0MQ" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "pyzmq-25.1.1-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:381469297409c5adf9a0e884c5eb5186ed33137badcbbb0560b86e910a2f1e76"}, {file = "pyzmq-25.1.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:955215ed0604dac5b01907424dfa28b40f2b2292d6493445dd34d0dfa72586a8"}, @@ -4887,6 +5113,7 @@ version = "5.5.1" description = "Jupyter Qt console" optional = false python-versions = ">= 3.8" +groups = ["dev"] files = [ {file = "qtconsole-5.5.1-py3-none-any.whl", hash = "sha256:8c75fa3e9b4ed884880ff7cea90a1b67451219279ec33deaee1d59e3df1a5d2b"}, {file = "qtconsole-5.5.1.tar.gz", hash = "sha256:a0e806c6951db9490628e4df80caec9669b65149c7ba40f9bf033c025a5b56bc"}, @@ -4912,6 +5139,7 @@ version = "2.4.1" description = "Provides an abstraction layer on top of the various Qt bindings (PyQt5/6 and PySide2/6)." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "QtPy-2.4.1-py3-none-any.whl", hash = "sha256:1c1d8c4fa2c884ae742b069151b0abe15b3f70491f3972698c683b8e38de839b"}, {file = "QtPy-2.4.1.tar.gz", hash = "sha256:a5a15ffd519550a1361bdc56ffc07fda56a6af7292f17c7b395d4083af632987"}, @@ -4929,6 +5157,8 @@ version = "2.8.0" description = "Ray provides a simple, universal API for building distributed applications." optional = false python-versions = "*" +groups = ["dev"] +markers = "sys_platform != \"win32\"" files = [ {file = "ray-2.8.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:34e0676a0dfa277efa688bccd83ecb7a799bc03078e5b1f1aa747fe9263175a8"}, {file = "ray-2.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:72c696c1b784c55f0ad107d55bb58ecef5d368176765cf44fed87e714538d708"}, @@ -4966,16 +5196,16 @@ pyyaml = "*" requests = "*" [package.extras] -air = ["aiohttp (>=3.7)", "aiohttp-cors", "aiorwlock", "colorful", "fastapi", "fsspec", "gpustat (>=1.0.0)", "grpcio (>=1.32.0)", "grpcio (>=1.42.0)", "numpy (>=1.20)", "opencensus", "pandas", "pandas (>=1.3)", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pyarrow (>=6.0.1)", "pydantic (<2)", "requests", "smart-open", "starlette", "tensorboardX (>=1.9)", "uvicorn[standard]", "virtualenv (>=20.0.24,<20.21.1)", "watchfiles"] -all = ["aiohttp (>=3.7)", "aiohttp-cors", "aiorwlock", "colorful", "dm-tree", "fastapi", "fsspec", "gpustat (>=1.0.0)", "grpcio (!=1.56.0)", "grpcio (>=1.32.0)", "grpcio (>=1.42.0)", "gymnasium (==0.28.1)", "lz4", "numpy (>=1.20)", "opencensus", "opentelemetry-api", "opentelemetry-exporter-otlp", "opentelemetry-sdk", "pandas", "pandas (>=1.3)", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pyarrow (>=6.0.1)", "pydantic (<2)", "pyyaml", "ray-cpp (==2.8.0)", "requests", "rich", "scikit-image", "scipy", "smart-open", "starlette", "tensorboardX (>=1.9)", "typer", "uvicorn[standard]", "virtualenv (>=20.0.24,<20.21.1)", "watchfiles"] +air = ["aiohttp (>=3.7)", "aiohttp-cors", "aiorwlock", "colorful", "fastapi", "fsspec", "gpustat (>=1.0.0)", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "numpy (>=1.20)", "opencensus", "pandas", "pandas (>=1.3)", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pyarrow (>=6.0.1)", "pydantic (<2)", "requests", "smart-open", "starlette", "tensorboardX (>=1.9)", "uvicorn[standard]", "virtualenv (>=20.0.24,<20.21.1)", "watchfiles"] +all = ["aiohttp (>=3.7)", "aiohttp-cors", "aiorwlock", "colorful", "dm-tree", "fastapi", "fsspec", "gpustat (>=1.0.0)", "grpcio (!=1.56.0)", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "gymnasium (==0.28.1)", "lz4", "numpy (>=1.20)", "opencensus", "opentelemetry-api", "opentelemetry-exporter-otlp", "opentelemetry-sdk", "pandas", "pandas (>=1.3)", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pyarrow (>=6.0.1)", "pydantic (<2)", "pyyaml", "ray-cpp (==2.8.0)", "requests", "rich", "scikit-image", "scipy", "smart-open", "starlette", "tensorboardX (>=1.9)", "typer", "uvicorn[standard]", "virtualenv (>=20.0.24,<20.21.1)", "watchfiles"] client = ["grpcio (!=1.56.0)"] cpp = ["ray-cpp (==2.8.0)"] data = ["fsspec", "numpy (>=1.20)", "pandas (>=1.3)", "pyarrow (>=6.0.1)"] -default = ["aiohttp (>=3.7)", "aiohttp-cors", "colorful", "gpustat (>=1.0.0)", "grpcio (>=1.32.0)", "grpcio (>=1.42.0)", "opencensus", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pydantic (<2)", "requests", "smart-open", "virtualenv (>=20.0.24,<20.21.1)"] +default = ["aiohttp (>=3.7)", "aiohttp-cors", "colorful", "gpustat (>=1.0.0)", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "opencensus", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pydantic (<2)", "requests", "smart-open", "virtualenv (>=20.0.24,<20.21.1)"] observability = ["opentelemetry-api", "opentelemetry-exporter-otlp", "opentelemetry-sdk"] rllib = ["dm-tree", "fsspec", "gymnasium (==0.28.1)", "lz4", "pandas", "pyarrow (>=6.0.1)", "pyyaml", "requests", "rich", "scikit-image", "scipy", "tensorboardX (>=1.9)", "typer"] -serve = ["aiohttp (>=3.7)", "aiohttp-cors", "aiorwlock", "colorful", "fastapi", "gpustat (>=1.0.0)", "grpcio (>=1.32.0)", "grpcio (>=1.42.0)", "opencensus", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pydantic (<2)", "requests", "smart-open", "starlette", "uvicorn[standard]", "virtualenv (>=20.0.24,<20.21.1)", "watchfiles"] -serve-grpc = ["aiohttp (>=3.7)", "aiohttp-cors", "aiorwlock", "colorful", "fastapi", "gpustat (>=1.0.0)", "grpcio (>=1.32.0)", "grpcio (>=1.42.0)", "opencensus", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pydantic (<2)", "requests", "smart-open", "starlette", "uvicorn[standard]", "virtualenv (>=20.0.24,<20.21.1)", "watchfiles"] +serve = ["aiohttp (>=3.7)", "aiohttp-cors", "aiorwlock", "colorful", "fastapi", "gpustat (>=1.0.0)", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "opencensus", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pydantic (<2)", "requests", "smart-open", "starlette", "uvicorn[standard]", "virtualenv (>=20.0.24,<20.21.1)", "watchfiles"] +serve-grpc = ["aiohttp (>=3.7)", "aiohttp-cors", "aiorwlock", "colorful", "fastapi", "gpustat (>=1.0.0)", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "opencensus", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pydantic (<2)", "requests", "smart-open", "starlette", "uvicorn[standard]", "virtualenv (>=20.0.24,<20.21.1)", "watchfiles"] train = ["fsspec", "pandas", "pyarrow (>=6.0.1)", "requests", "tensorboardX (>=1.9)"] tune = ["fsspec", "pandas", "pyarrow (>=6.0.1)", "requests", "tensorboardX (>=1.9)"] @@ -4985,6 +5215,7 @@ version = "0.31.0" description = "JSON Referencing + Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "referencing-0.31.0-py3-none-any.whl", hash = "sha256:381b11e53dd93babb55696c71cf42aef2d36b8a150c49bf0bc301e36d536c882"}, {file = "referencing-0.31.0.tar.gz", hash = "sha256:cc28f2c88fbe7b961a7817a0abc034c09a1e36358f82fedb4ffdf29a25398863"}, @@ -5000,6 +5231,7 @@ version = "2.32.0" description = "Python HTTP for Humans." optional = false python-versions = ">=3.8" +groups = ["main", "dev"] files = [ {file = "requests-2.32.0-py3-none-any.whl", hash = "sha256:f2c3881dddb70d056c5bd7600a4fae312b2a300e39be6a118d30b90bd27262b5"}, {file = "requests-2.32.0.tar.gz", hash = "sha256:fa5490319474c82ef1d2c9bc459d3652e3ae4ef4c4ebdd18a21145a47ca4b6b8"}, @@ -5021,6 +5253,7 @@ version = "1.3.1" description = "OAuthlib authentication support for Requests." optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +groups = ["main"] files = [ {file = "requests-oauthlib-1.3.1.tar.gz", hash = "sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a"}, {file = "requests_oauthlib-1.3.1-py2.py3-none-any.whl", hash = "sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5"}, @@ -5039,6 +5272,7 @@ version = "0.1.4" description = "A pure python RFC3339 validator" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +groups = ["dev"] files = [ {file = "rfc3339_validator-0.1.4-py2.py3-none-any.whl", hash = "sha256:24f6ec1eda14ef823da9e36ec7113124b39c04d50a4d3d3a3c2859577e7791fa"}, {file = "rfc3339_validator-0.1.4.tar.gz", hash = "sha256:138a2abdf93304ad60530167e51d2dfb9549521a836871b88d7f4695d0022f6b"}, @@ -5053,6 +5287,7 @@ version = "0.1.1" description = "Pure python rfc3986 validator" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +groups = ["dev"] files = [ {file = "rfc3986_validator-0.1.1-py2.py3-none-any.whl", hash = "sha256:2f235c432ef459970b4306369336b9d5dbdda31b510ca1e327636e01f528bfa9"}, {file = "rfc3986_validator-0.1.1.tar.gz", hash = "sha256:3d44bde7921b3b9ec3ae4e3adca370438eccebc676456449b145d533b240d055"}, @@ -5064,6 +5299,8 @@ version = "1.2.0" description = "rliable: Reliable evaluation on reinforcement learning and machine learning benchmarks." optional = true python-versions = "*" +groups = ["main"] +markers = "extra == \"eval\"" files = [ {file = "rliable-1.2.0.tar.gz", hash = "sha256:72789d9147d7c56e6efa812f9dffedcef44993a866ec08d75506ac7c1fe69cd5"}, ] @@ -5082,6 +5319,7 @@ version = "0.13.0" description = "Python bindings to Rust's persistent data structures (rpds)" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "rpds_py-0.13.0-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:1758197cc8d7ff383c07405f188253535b4aa7fa745cbc54d221ae84b18e0702"}, {file = "rpds_py-0.13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:715df74cbcef4387d623c917f295352127f4b3e0388038d68fa577b4e4c6e540"}, @@ -5190,6 +5428,7 @@ version = "4.9" description = "Pure-Python RSA implementation" optional = false python-versions = ">=3.6,<4" +groups = ["main"] files = [ {file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"}, {file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"}, @@ -5204,6 +5443,7 @@ version = "0.18.5" description = "ruamel.yaml is a YAML parser/emitter that supports roundtrip preservation of comments, seq/map flow style, and map key order" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "ruamel.yaml-0.18.5-py3-none-any.whl", hash = "sha256:a013ac02f99a69cdd6277d9664689eb1acba07069f912823177c5eced21a6ada"}, {file = "ruamel.yaml-0.18.5.tar.gz", hash = "sha256:61917e3a35a569c1133a8f772e1226961bf5a1198bea7e23f06a0841dea1ab0e"}, @@ -5222,6 +5462,8 @@ version = "0.2.8" description = "C version of reader, parser and emitter for ruamel.yaml derived from libyaml" optional = false python-versions = ">=3.6" +groups = ["dev"] +markers = "platform_python_implementation == \"CPython\" and python_version < \"3.13\"" files = [ {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:b42169467c42b692c19cf539c38d4602069d8c1505e97b86387fcf7afb766e1d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:07238db9cbdf8fc1e9de2489a4f68474e70dffcb32232db7c08fa61ca0c7c462"}, @@ -5281,6 +5523,7 @@ version = "0.0.285" description = "An extremely fast Python linter, written in Rust." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "ruff-0.0.285-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:72a3a0936369b986b0e959f9090206ed3c18f9e5e439ea5b8e6867c6707aded5"}, {file = "ruff-0.0.285-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:0d9ab6ad16742eb78919e0fba09f914f042409df40ad63423c34bb20d350162a"}, @@ -5307,6 +5550,7 @@ version = "1.11.4" description = "Fundamental algorithms for scientific computing in Python" optional = false python-versions = ">=3.9" +groups = ["main", "dev"] files = [ {file = "scipy-1.11.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bc9a714581f561af0848e6b69947fda0614915f072dfd14142ed1bfe1b806710"}, {file = "scipy-1.11.4-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:cf00bd2b1b0211888d4dc75656c0412213a8b25e80d73898083f402b50f47e41"}, @@ -5349,6 +5593,8 @@ version = "0.13.2" description = "Statistical data visualization" optional = true python-versions = ">=3.8" +groups = ["main"] +markers = "extra == \"eval\"" files = [ {file = "seaborn-0.13.2-py3-none-any.whl", hash = "sha256:636f8336facf092165e27924f223d3c62ca560b1f2bb5dff7ab7fad265361987"}, {file = "seaborn-0.13.2.tar.gz", hash = "sha256:93e60a40988f4d65e9f4885df477e2fdaff6b73a9ded434c1ab356dd57eefff7"}, @@ -5370,15 +5616,16 @@ version = "1.8.2" description = "Send file to trash natively under Mac OS X, Windows and Linux" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" +groups = ["dev"] files = [ {file = "Send2Trash-1.8.2-py3-none-any.whl", hash = "sha256:a384719d99c07ce1eefd6905d2decb6f8b7ed054025bb0e618919f945de4f679"}, {file = "Send2Trash-1.8.2.tar.gz", hash = "sha256:c132d59fa44b9ca2b1699af5c86f57ce9f4c5eb56629d5d55fbb7a35f84e2312"}, ] [package.extras] -nativelib = ["pyobjc-framework-Cocoa", "pywin32"] -objc = ["pyobjc-framework-Cocoa"] -win32 = ["pywin32"] +nativelib = ["pyobjc-framework-Cocoa ; sys_platform == \"darwin\"", "pywin32 ; sys_platform == \"win32\""] +objc = ["pyobjc-framework-Cocoa ; sys_platform == \"darwin\""] +win32 = ["pywin32 ; sys_platform == \"win32\""] [[package]] name = "sensai-utils" @@ -5386,6 +5633,7 @@ version = "1.4.0" description = "Utilities from sensAI, the Python library for sensible AI" optional = false python-versions = "*" +groups = ["main"] files = [ {file = "sensai_utils-1.4.0-py3-none-any.whl", hash = "sha256:ed6fc57552620e43b33cf364ea0bc0fd7df39391069dd7b621b113ef55547507"}, {file = "sensai_utils-1.4.0.tar.gz", hash = "sha256:2d32bdcc91fd1428c5cae0181e98623142d2d5f7e115e23d585a842dd9dc59ba"}, @@ -5400,6 +5648,7 @@ version = "2.8.0" description = "Python client for Sentry (https://sentry.io)" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "sentry_sdk-2.8.0-py2.py3-none-any.whl", hash = "sha256:6051562d2cfa8087bb8b4b8b79dc44690f8a054762a29c07e22588b1f619bfb5"}, {file = "sentry_sdk-2.8.0.tar.gz", hash = "sha256:aa4314f877d9cd9add5a0c9ba18e3f27f99f7de835ce36bd150e48a41c7c646f"}, @@ -5450,6 +5699,7 @@ version = "1.3.3" description = "A Python module to customize the process title" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "setproctitle-1.3.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:897a73208da48db41e687225f355ce993167079eda1260ba5e13c4e53be7f754"}, {file = "setproctitle-1.3.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8c331e91a14ba4076f88c29c777ad6b58639530ed5b24b5564b5ed2fd7a95452"}, @@ -5550,6 +5800,7 @@ version = "68.2.2" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" +groups = ["main", "dev"] files = [ {file = "setuptools-68.2.2-py3-none-any.whl", hash = "sha256:b454a35605876da60632df1a60f736524eb73cc47bbc9f3f1ef1b644de74fd2a"}, {file = "setuptools-68.2.2.tar.gz", hash = "sha256:4ac1475276d2f1c48684874089fefcd83bd7162ddaafb81fac866ba0db282a87"}, @@ -5557,7 +5808,7 @@ files = [ [package.extras] docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] -testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] +testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21) ; python_version >= \"3.9\" and sys_platform != \"cygwin\"", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pytest (>=6)", "pytest-black (>=0.3.7) ; platform_python_implementation != \"PyPy\"", "pytest-checkdocs (>=2.4)", "pytest-cov ; platform_python_implementation != \"PyPy\"", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1) ; platform_python_implementation != \"PyPy\"", "pytest-perf ; sys_platform != \"cygwin\"", "pytest-ruff ; sys_platform != \"cygwin\"", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.1)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] [[package]] @@ -5566,6 +5817,8 @@ version = "0.2.1" description = "API for converting popular non-gymnasium environments to a gymnasium compatible environment." optional = true python-versions = ">=3.7" +groups = ["main"] +markers = "extra == \"atari\"" files = [ {file = "Shimmy-0.2.1-py3-none-any.whl", hash = "sha256:2d7d21c4ca679a64bb452e6a4232c6b0f5dba7589f5420454ddc1f0634334334"}, {file = "Shimmy-0.2.1.tar.gz", hash = "sha256:7b96915445ee5488dcb19ccf52ce5581d6f00cc5cf0e0dff36b16cd65bffcb75"}, @@ -5590,6 +5843,7 @@ version = "1.0.11" description = "A generator library for concise, unambiguous and URL-safe UUIDs." optional = false python-versions = ">=3.5" +groups = ["dev"] files = [ {file = "shortuuid-1.0.11-py3-none-any.whl", hash = "sha256:27ea8f28b1bd0bf8f15057a3ece57275d2059d2b0bb02854f02189962c13b6aa"}, {file = "shortuuid-1.0.11.tar.gz", hash = "sha256:fc75f2615914815a8e4cb1501b3a513745cb66ef0fd5fc6fb9f8c3fa3481f789"}, @@ -5601,6 +5855,7 @@ version = "1.16.0" description = "Python 2 and 3 compatibility utilities" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" +groups = ["main", "dev"] files = [ {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, @@ -5612,6 +5867,7 @@ version = "5.0.1" description = "A pure Python implementation of a sliding window memory map manager" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "smmap-5.0.1-py3-none-any.whl", hash = "sha256:e6d8668fa5f93e706934a62d7b4db19c8d9eb8cf2adbb75ef1b675aa332b69da"}, {file = "smmap-5.0.1.tar.gz", hash = "sha256:dceeb6c0028fdb6734471eb07c0cd2aae706ccaecab45965ee83f11c8d3b1f62"}, @@ -5623,6 +5879,7 @@ version = "1.3.0" description = "Sniff out which async library your code is running under" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "sniffio-1.3.0-py3-none-any.whl", hash = "sha256:eecefdce1e5bbfb7ad2eeaabf7c1eeb404d7757c379bd1f7e5cce9d8bf425384"}, {file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"}, @@ -5634,6 +5891,7 @@ version = "2.2.0" description = "This package provides 29 stemmers for 28 languages generated from Snowball algorithms." optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "snowballstemmer-2.2.0-py2.py3-none-any.whl", hash = "sha256:c8e1716e83cc398ae16824e5572ae04e0d9fc2c6b985fb0f900f5f0c96ecba1a"}, {file = "snowballstemmer-2.2.0.tar.gz", hash = "sha256:09b16deb8547d3412ad7b590689584cd0fe25ec8db3be37788be3810cbf19cb1"}, @@ -5645,6 +5903,7 @@ version = "2.5" description = "A modern CSS selector implementation for Beautiful Soup." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "soupsieve-2.5-py3-none-any.whl", hash = "sha256:eaa337ff55a1579b6549dc679565eac1e3d000563bcb1c8ab0d0fefbc0c2cdc7"}, {file = "soupsieve-2.5.tar.gz", hash = "sha256:5663d5a7b3bfaeee0bc4372e7fc48f9cff4940b3eec54a6451cc5299f1097690"}, @@ -5656,6 +5915,7 @@ version = "7.2.6" description = "Python documentation generator" optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "sphinx-7.2.6-py3-none-any.whl", hash = "sha256:1e09160a40b956dc623c910118fa636da93bd3ca0b9876a7b3df90f07d691560"}, {file = "sphinx-7.2.6.tar.gz", hash = "sha256:9a5160e1ea90688d5963ba09a2dcd8bdd526620edbb65c328728f1b2228d5ab5"}, @@ -5690,6 +5950,7 @@ version = "1.19.1" description = "Type hints (PEP 484) support for the Sphinx autodoc extension" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "sphinx_autodoc_typehints-1.19.1-py3-none-any.whl", hash = "sha256:9be46aeeb1b315eb5df1f3a7cb262149895d16c7d7dcd77b92513c3c3a1e85e6"}, {file = "sphinx_autodoc_typehints-1.19.1.tar.gz", hash = "sha256:6c841db55e0e9be0483ff3962a2152b60e79306f4288d8c4e7e86ac84486a5ea"}, @@ -5700,7 +5961,7 @@ Sphinx = ">=4.5" [package.extras] testing = ["covdefaults (>=2.2)", "coverage (>=6.3)", "diff-cover (>=6.4)", "nptyping (>=2.1.2)", "pytest (>=7.1)", "pytest-cov (>=3)", "sphobjinv (>=2)", "typing-extensions (>=4.1)"] -type-comments = ["typed-ast (>=1.5.2)"] +type-comments = ["typed-ast (>=1.5.2) ; python_version < \"3.8\""] [[package]] name = "sphinx-book-theme" @@ -5708,6 +5969,7 @@ version = "1.1.0" description = "A clean book theme for scientific explanations and documentation with Sphinx" optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "sphinx_book_theme-1.1.0-py3-none-any.whl", hash = "sha256:088bc69d65fab8446adb8691ed61687f71bf7504c9740af68bc78cf936a26112"}, {file = "sphinx_book_theme-1.1.0.tar.gz", hash = "sha256:ad4f92998e53e24751ecd0978d3eb79fdaa59692f005b1b286ecdd6146ebc9c1"}, @@ -5728,6 +5990,7 @@ version = "0.0.3" description = "Add comments and annotation to your documentation." optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "sphinx-comments-0.0.3.tar.gz", hash = "sha256:00170afff27019fad08e421da1ae49c681831fb2759786f07c826e89ac94cf21"}, {file = "sphinx_comments-0.0.3-py3-none-any.whl", hash = "sha256:1e879b4e9bfa641467f83e3441ac4629225fc57c29995177d043252530c21d00"}, @@ -5747,6 +6010,7 @@ version = "0.5.2" description = "Add a copy button to each of your code cells." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "sphinx-copybutton-0.5.2.tar.gz", hash = "sha256:4cf17c82fb9646d1bc9ca92ac280813a3b605d8c421225fd9913154103ee1fbd"}, {file = "sphinx_copybutton-0.5.2-py3-none-any.whl", hash = "sha256:fb543fd386d917746c9a2c50360c7905b605726b9355cd26e9974857afeae06e"}, @@ -5765,6 +6029,7 @@ version = "0.5.0" description = "A sphinx extension for designing beautiful, view size responsive web components." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "sphinx_design-0.5.0-py3-none-any.whl", hash = "sha256:1af1267b4cea2eedd6724614f19dcc88fe2e15aff65d06b2f6252cee9c4f4c1e"}, {file = "sphinx_design-0.5.0.tar.gz", hash = "sha256:e8e513acea6f92d15c6de3b34e954458f245b8e761b45b63950f65373352ab00"}, @@ -5788,6 +6053,7 @@ version = "1.0.1" description = "A sphinx extension that allows the site-map to be defined in a single YAML file." optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "sphinx_external_toc-1.0.1-py3-none-any.whl", hash = "sha256:d9e02d50731dee9697c1887e4f8b361e7b86d38241f0e66bd5a9f4096779646f"}, {file = "sphinx_external_toc-1.0.1.tar.gz", hash = "sha256:a7d2c63cc47ec688546443b28bc4ef466121827ef3dc7bb509de354bad4ea2e0"}, @@ -5809,6 +6075,7 @@ version = "0.2.0.post1" description = "Patches Jinja2 v3 to restore compatibility with earlier Sphinx versions." optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "sphinx_jinja2_compat-0.2.0.post1-py3-none-any.whl", hash = "sha256:f9d329174bdde8db19dc12c62528367196eb2f6b46c91754eca604acd0c0f6ad"}, {file = "sphinx_jinja2_compat-0.2.0.post1.tar.gz", hash = "sha256:974289a12a9f402108dead621e9c15f7004e945d5cfcaea8d6419e94d3fa95a3"}, @@ -5824,6 +6091,7 @@ version = "1.0.0" description = "Latex specific features for jupyter book" optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "sphinx_jupyterbook_latex-1.0.0-py3-none-any.whl", hash = "sha256:e0cd3e9e1c5af69136434e21a533343fdf013475c410a414d5b7b4922b4f3891"}, {file = "sphinx_jupyterbook_latex-1.0.0.tar.gz", hash = "sha256:f54c6674c13f1616f9a93443e98b9b5353f9fdda8e39b6ec552ccf0b3e5ffb62"}, @@ -5845,6 +6113,7 @@ version = "0.1.3" description = "Supporting continuous HTML section numbering" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "sphinx-multitoc-numbering-0.1.3.tar.gz", hash = "sha256:c9607671ac511236fa5d61a7491c1031e700e8d498c9d2418e6c61d1251209ae"}, {file = "sphinx_multitoc_numbering-0.1.3-py3-none-any.whl", hash = "sha256:33d2e707a9b2b8ad636b3d4302e658a008025106fe0474046c651144c26d8514"}, @@ -5864,6 +6133,7 @@ version = "1.5.0" description = "Sphinx directive to add unselectable prompt" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "sphinx_prompt-1.5.0-py3-none-any.whl", hash = "sha256:fa4e90d8088b5a996c76087d701fc7e31175f8b9dc4aab03a507e45051067162"}, ] @@ -5878,6 +6148,7 @@ version = "3.4.5" description = "Tabbed views for Sphinx" optional = false python-versions = "~=3.7" +groups = ["dev"] files = [ {file = "sphinx-tabs-3.4.5.tar.gz", hash = "sha256:ba9d0c1e3e37aaadd4b5678449eb08176770e0fc227e769b6ce747df3ceea531"}, {file = "sphinx_tabs-3.4.5-py3-none-any.whl", hash = "sha256:92cc9473e2ecf1828ca3f6617d0efc0aa8acb06b08c56ba29d1413f2f0f6cf09"}, @@ -5898,6 +6169,7 @@ version = "0.3.1" description = "Integrate interactive code blocks into your documentation with Thebe and Binder." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "sphinx_thebe-0.3.1-py3-none-any.whl", hash = "sha256:e7e7edee9f0d601c76bc70156c471e114939484b111dd8e74fe47ac88baffc52"}, {file = "sphinx_thebe-0.3.1.tar.gz", hash = "sha256:576047f45560e82f64aa5f15200b1eb094dcfe1c5b8f531a8a65bd208e25a493"}, @@ -5917,6 +6189,7 @@ version = "0.3.2" description = "Toggle page content and collapse admonitions in Sphinx." optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "sphinx-togglebutton-0.3.2.tar.gz", hash = "sha256:ab0c8b366427b01e4c89802d5d078472c427fa6e9d12d521c34fa0442559dc7a"}, {file = "sphinx_togglebutton-0.3.2-py3-none-any.whl", hash = "sha256:9647ba7874b7d1e2d43413d8497153a85edc6ac95a3fea9a75ef9c1e08aaae2b"}, @@ -5937,6 +6210,7 @@ version = "3.5.0" description = "Box of handy tools for Sphinx 🧰 📔" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "sphinx_toolbox-3.5.0-py3-none-any.whl", hash = "sha256:20dfd3566717db6f2da7a400a54dc4b946f064fb31250fa44802d54cfb9b8a03"}, {file = "sphinx_toolbox-3.5.0.tar.gz", hash = "sha256:e5b5a7153f1997572d71a06aaf6cec225483492ec2c60097a84f15aad6df18b7"}, @@ -5971,6 +6245,7 @@ version = "1.0.7" description = "sphinxcontrib-applehelp is a Sphinx extension which outputs Apple help books" optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "sphinxcontrib_applehelp-1.0.7-py3-none-any.whl", hash = "sha256:094c4d56209d1734e7d252f6e0b3ccc090bd52ee56807a5d9315b19c122ab15d"}, {file = "sphinxcontrib_applehelp-1.0.7.tar.gz", hash = "sha256:39fdc8d762d33b01a7d8f026a3b7d71563ea3b72787d5f00ad8465bd9d6dfbfa"}, @@ -5989,6 +6264,7 @@ version = "2.5.0" description = "Sphinx extension for BibTeX style citations." optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "sphinxcontrib-bibtex-2.5.0.tar.gz", hash = "sha256:71b42e5db0e2e284f243875326bf9936aa9a763282277d75048826fef5b00eaa"}, {file = "sphinxcontrib_bibtex-2.5.0-py3-none-any.whl", hash = "sha256:748f726eaca6efff7731012103417ef130ecdcc09501b4d0c54283bf5f059f76"}, @@ -6006,6 +6282,7 @@ version = "1.0.5" description = "sphinxcontrib-devhelp is a sphinx extension which outputs Devhelp documents" optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "sphinxcontrib_devhelp-1.0.5-py3-none-any.whl", hash = "sha256:fe8009aed765188f08fcaadbb3ea0d90ce8ae2d76710b7e29ea7d047177dae2f"}, {file = "sphinxcontrib_devhelp-1.0.5.tar.gz", hash = "sha256:63b41e0d38207ca40ebbeabcf4d8e51f76c03e78cd61abe118cf4435c73d4212"}, @@ -6024,6 +6301,7 @@ version = "2.0.4" description = "sphinxcontrib-htmlhelp is a sphinx extension which renders HTML help files" optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "sphinxcontrib_htmlhelp-2.0.4-py3-none-any.whl", hash = "sha256:8001661c077a73c29beaf4a79968d0726103c5605e27db92b9ebed8bab1359e9"}, {file = "sphinxcontrib_htmlhelp-2.0.4.tar.gz", hash = "sha256:6c26a118a05b76000738429b724a0568dbde5b72391a688577da08f11891092a"}, @@ -6042,6 +6320,7 @@ version = "1.0.1" description = "A sphinx extension which renders display math in HTML via JavaScript" optional = false python-versions = ">=3.5" +groups = ["dev"] files = [ {file = "sphinxcontrib-jsmath-1.0.1.tar.gz", hash = "sha256:a9925e4a4587247ed2191a22df5f6970656cb8ca2bd6284309578f2153e0c4b8"}, {file = "sphinxcontrib_jsmath-1.0.1-py2.py3-none-any.whl", hash = "sha256:2ec2eaebfb78f3f2078e73666b1415417a116cc848b72e5172e596c871103178"}, @@ -6056,6 +6335,7 @@ version = "1.0.6" description = "sphinxcontrib-qthelp is a sphinx extension which outputs QtHelp documents" optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "sphinxcontrib_qthelp-1.0.6-py3-none-any.whl", hash = "sha256:bf76886ee7470b934e363da7a954ea2825650013d367728588732c7350f49ea4"}, {file = "sphinxcontrib_qthelp-1.0.6.tar.gz", hash = "sha256:62b9d1a186ab7f5ee3356d906f648cacb7a6bdb94d201ee7adf26db55092982d"}, @@ -6074,6 +6354,7 @@ version = "1.1.9" description = "sphinxcontrib-serializinghtml is a sphinx extension which outputs \"serialized\" HTML files (json and pickle)" optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "sphinxcontrib_serializinghtml-1.1.9-py3-none-any.whl", hash = "sha256:9b36e503703ff04f20e9675771df105e58aa029cfcbc23b8ed716019b7416ae1"}, {file = "sphinxcontrib_serializinghtml-1.1.9.tar.gz", hash = "sha256:0c64ff898339e1fac29abd2bf5f11078f3ec413cfe9c046d3120d7ca65530b54"}, @@ -6092,6 +6373,7 @@ version = "8.0.0" description = "Sphinx spelling extension" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "sphinxcontrib-spelling-8.0.0.tar.gz", hash = "sha256:199d0a16902ad80c387c2966dc9eb10f565b1fb15ccce17210402db7c2443e5c"}, {file = "sphinxcontrib_spelling-8.0.0-py3-none-any.whl", hash = "sha256:b27e0a16aef00bcfc888a6490dc3f16651f901dc475446c6882834278c8dc7b3"}, @@ -6110,6 +6392,7 @@ version = "2.0.23" description = "Database Abstraction Library" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "SQLAlchemy-2.0.23-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:638c2c0b6b4661a4fd264f6fb804eccd392745c5887f9317feb64bb7cb03b3ea"}, {file = "SQLAlchemy-2.0.23-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e3b5036aa326dc2df50cba3c958e29b291a80f604b1afa4c8ce73e78e1c9f01d"}, @@ -6197,6 +6480,7 @@ version = "0.6.3" description = "Extract data from python stack frames and tracebacks for informative displays" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695"}, {file = "stack_data-0.6.3.tar.gz", hash = "sha256:836a778de4fec4dcd1dcd89ed8abff8a221f58308462e1c4aa2a3cf30148f0b9"}, @@ -6216,6 +6500,8 @@ version = "0.14.0" description = "Statistical computations and models for Python" optional = true python-versions = ">=3.8" +groups = ["main"] +markers = "extra == \"eval\"" files = [ {file = "statsmodels-0.14.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:16bfe0c96a53b20fa19067e3b6bd2f1d39e30d4891ea0d7bc20734a0ae95942d"}, {file = "statsmodels-0.14.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5a6a0a1a06ff79be8aa89c8494b33903442859add133f0dda1daf37c3c71682e"}, @@ -6257,7 +6543,7 @@ scipy = ">=1.4,<1.9.2 || >1.9.2" [package.extras] build = ["cython (>=0.29.26)"] -develop = ["colorama", "cython (>=0.29.26)", "cython (>=0.29.28,<3.0.0)", "flake8", "isort", "joblib", "matplotlib (>=3)", "oldest-supported-numpy (>=2022.4.18)", "pytest (>=7.0.1,<7.1.0)", "pytest-randomly", "pytest-xdist", "pywinpty", "setuptools-scm[toml] (>=7.0.0,<7.1.0)"] +develop = ["colorama", "cython (>=0.29.26)", "cython (>=0.29.28,<3.0.0)", "flake8", "isort", "joblib", "matplotlib (>=3)", "oldest-supported-numpy (>=2022.4.18)", "pytest (>=7.0.1,<7.1.0)", "pytest-randomly", "pytest-xdist", "pywinpty ; os_name == \"nt\"", "setuptools-scm[toml] (>=7.0.0,<7.1.0)"] docs = ["ipykernel", "jupyter-client", "matplotlib", "nbconvert", "nbformat", "numpydoc", "pandas-datareader", "sphinx"] [[package]] @@ -6266,6 +6552,8 @@ version = "4.2.0" description = "SWIG is a software development tool that connects programs written in C and C++ with a variety of high-level programming languages." optional = true python-versions = "*" +groups = ["main"] +markers = "extra == \"box2d\"" files = [ {file = "swig-4.2.0-py2.py3-none-macosx_10_9_universal2.whl", hash = "sha256:71bf282fb30aa179b870e29c8f4fe16b3404e8562377061f85d57a2ec1571d7c"}, {file = "swig-4.2.0-py2.py3-none-macosx_10_9_x86_64.whl", hash = "sha256:071c7a3af61c2c69d1e911c5428479a4536a8103623276847d8e55350da8cf05"}, @@ -6291,6 +6579,7 @@ version = "1.12" description = "Computer algebra system (CAS) in Python" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"}, {file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"}, @@ -6305,6 +6594,7 @@ version = "0.9.0" description = "Pretty-print tabular data" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f"}, {file = "tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c"}, @@ -6319,6 +6609,7 @@ version = "2.15.1" description = "TensorBoard lets you watch Tensors Flow" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "tensorboard-2.15.1-py3-none-any.whl", hash = "sha256:c46c1d1cf13a458c429868a78b2531d8ff5f682058d69ec0840b0bc7a38f1c0f"}, ] @@ -6343,6 +6634,7 @@ version = "0.7.2" description = "Fast data loading for TensorBoard" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb"}, {file = "tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60"}, @@ -6355,6 +6647,7 @@ version = "0.18.0" description = "Tornado websocket backend for the Xterm.js Javascript terminal emulator library." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "terminado-0.18.0-py3-none-any.whl", hash = "sha256:87b0d96642d0fe5f5abd7783857b9cab167f221a39ff98e3b9619a788a3c0f2e"}, {file = "terminado-0.18.0.tar.gz", hash = "sha256:1ea08a89b835dd1b8c0c900d92848147cef2537243361b2e3f4dc15df9b6fded"}, @@ -6376,6 +6669,7 @@ version = "1.2.1" description = "A tiny CSS parser" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "tinycss2-1.2.1-py3-none-any.whl", hash = "sha256:2b80a96d41e7c3914b8cda8bc7f705a4d9c49275616e886103dd839dfc847847"}, {file = "tinycss2-1.2.1.tar.gz", hash = "sha256:8cff3a8f066c2ec677c06dbc7b45619804a6938478d9d73c284b29d14ecb0627"}, @@ -6394,6 +6688,7 @@ version = "5.2.0" description = "A wrapper around the stdlib `tokenize` which roundtrips." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "tokenize_rt-5.2.0-py2.py3-none-any.whl", hash = "sha256:b79d41a65cfec71285433511b50271b05da3584a1da144a0752e9c621a285289"}, {file = "tokenize_rt-5.2.0.tar.gz", hash = "sha256:9fe80f8a5c1edad2d3ede0f37481cc0cc1538a2f442c9c2f9e4feacd2792d054"}, @@ -6405,6 +6700,7 @@ version = "2.0.1" description = "A lil' TOML parser" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, @@ -6416,6 +6712,7 @@ version = "2.1.1" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" optional = false python-versions = ">=3.8.0" +groups = ["main"] files = [ {file = "torch-2.1.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:5ebc43f5355a9b7be813392b3fb0133991f0380f6f0fcc8218d5468dc45d1071"}, {file = "torch-2.1.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:84fefd63356416c0cd20578637ccdbb82164993400ed17b57c951dd6376dcee8"}, @@ -6469,6 +6766,7 @@ version = "6.4.2" description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "tornado-6.4.2-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:e828cce1123e9e44ae2a50a9de3055497ab1d0aeb440c5ac23064d9e44880da1"}, {file = "tornado-6.4.2-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:072ce12ada169c5b00b7d92a99ba089447ccc993ea2143c9ede887e0937aa803"}, @@ -6489,6 +6787,7 @@ version = "4.66.3" description = "Fast, Extensible Progress Meter" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "tqdm-4.66.3-py3-none-any.whl", hash = "sha256:4f41d54107ff9a223dca80b53efe4fb654c67efaba7f47bada3ee9d50e05bd53"}, {file = "tqdm-4.66.3.tar.gz", hash = "sha256:23097a41eba115ba99ecae40d06444c15d1c0c698d527a01c6c8bd1c5d0647e5"}, @@ -6509,6 +6808,7 @@ version = "5.13.0" description = "Traitlets Python configuration system" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "traitlets-5.13.0-py3-none-any.whl", hash = "sha256:baf991e61542da48fe8aef8b779a9ea0aa38d8a54166ee250d5af5ecf4486619"}, {file = "traitlets-5.13.0.tar.gz", hash = "sha256:9b232b9430c8f57288c1024b34a8f0251ddcc47268927367a0dd3eeaca40deb5"}, @@ -6524,6 +6824,8 @@ version = "2.1.0" description = "A language and compiler for custom Deep Learning operations" optional = false python-versions = "*" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "triton-2.1.0-0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:66439923a30d5d48399b08a9eae10370f6c261a5ec864a64983bae63152d39d7"}, {file = "triton-2.1.0-0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:919b06453f0033ea52c13eaf7833de0e57db3178d23d4e04f9fc71c4f2c32bf8"}, @@ -6549,6 +6851,8 @@ version = "4.24.0.4" description = "Typing stubs for protobuf" optional = true python-versions = ">=3.7" +groups = ["main"] +markers = "sys_platform != \"darwin\" and extra == \"envpool\"" files = [ {file = "types-protobuf-4.24.0.4.tar.gz", hash = "sha256:57ab42cb171dfdba2c74bb5b50c250478538cc3c5ed95b8b368929ad0c9f90a5"}, {file = "types_protobuf-4.24.0.4-py3-none-any.whl", hash = "sha256:131ab7d0cbc9e444bc89c994141327dcce7bcaeded72b1acb72a94827eb9c7af"}, @@ -6560,6 +6864,7 @@ version = "2.8.19.14" description = "Typing stubs for python-dateutil" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "types-python-dateutil-2.8.19.14.tar.gz", hash = "sha256:1f4f10ac98bb8b16ade9dbee3518d9ace017821d94b057a425b069f834737f4b"}, {file = "types_python_dateutil-2.8.19.14-py3-none-any.whl", hash = "sha256:f977b8de27787639986b4e28963263fd0e5158942b3ecef91b9335c130cb1ce9"}, @@ -6571,6 +6876,7 @@ version = "2.31.0.20240311" description = "Typing stubs for requests" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "types-requests-2.31.0.20240311.tar.gz", hash = "sha256:b1c1b66abfb7fa79aae09097a811c4aa97130eb8831c60e47aee4ca344731ca5"}, {file = "types_requests-2.31.0.20240311-py3-none-any.whl", hash = "sha256:47872893d65a38e282ee9f277a4ee50d1b28bd592040df7d1fdaffdf3779937d"}, @@ -6585,6 +6891,7 @@ version = "0.9.0.20240106" description = "Typing stubs for tabulate" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "types-tabulate-0.9.0.20240106.tar.gz", hash = "sha256:c9b6db10dd7fcf55bd1712dd3537f86ddce72a08fd62bb1af4338c7096ce947e"}, {file = "types_tabulate-0.9.0.20240106-py3-none-any.whl", hash = "sha256:0378b7b6fe0ccb4986299496d027a6d4c218298ecad67199bbd0e2d7e9d335a1"}, @@ -6596,6 +6903,7 @@ version = "4.8.0" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" +groups = ["main", "dev"] files = [ {file = "typing_extensions-4.8.0-py3-none-any.whl", hash = "sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0"}, {file = "typing_extensions-4.8.0.tar.gz", hash = "sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef"}, @@ -6607,6 +6915,7 @@ version = "2024.1" description = "Provider of IANA time zone data" optional = false python-versions = ">=2" +groups = ["main"] files = [ {file = "tzdata-2024.1-py2.py3-none-any.whl", hash = "sha256:9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252"}, {file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"}, @@ -6618,6 +6927,7 @@ version = "1.0.2" description = "Micro subset of unicode data files for linkify-it-py projects." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "uc-micro-py-1.0.2.tar.gz", hash = "sha256:30ae2ac9c49f39ac6dce743bd187fcd2b574b16ca095fa74cd9396795c954c54"}, {file = "uc_micro_py-1.0.2-py3-none-any.whl", hash = "sha256:8c9110c309db9d9e87302e2f4ad2c3152770930d88ab385cd544e7a7e75f3de0"}, @@ -6632,6 +6942,7 @@ version = "1.3.0" description = "RFC 6570 URI Template Processor" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "uri-template-1.3.0.tar.gz", hash = "sha256:0e00f8eb65e18c7de20d595a14336e9f337ead580c70934141624b6d1ffdacc7"}, {file = "uri_template-1.3.0-py3-none-any.whl", hash = "sha256:a44a133ea12d44a0c0f06d7d42a52d71282e77e2f937d8abd5655b8d56fc1363"}, @@ -6646,13 +6957,14 @@ version = "2.2.2" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.8" +groups = ["main", "dev"] files = [ {file = "urllib3-2.2.2-py3-none-any.whl", hash = "sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472"}, {file = "urllib3-2.2.2.tar.gz", hash = "sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168"}, ] [package.extras] -brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] +brotli = ["brotli (>=1.0.9) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\""] h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] @@ -6663,6 +6975,8 @@ version = "20.16.3" description = "Virtual Python Environment builder" optional = false python-versions = ">=3.6" +groups = ["main", "dev"] +markers = "sys_platform == \"win32\"" files = [ {file = "virtualenv-20.16.3-py2.py3-none-any.whl", hash = "sha256:4193b7bc8a6cd23e4eb251ac64f29b4398ab2c233531e66e40b19a6b7b0d30c1"}, {file = "virtualenv-20.16.3.tar.gz", hash = "sha256:d86ea0bb50e06252d79e6c241507cb904fcd66090c3271381372d6221a3970f9"}, @@ -6683,6 +6997,8 @@ version = "20.24.6" description = "Virtual Python Environment builder" optional = false python-versions = ">=3.7" +groups = ["main", "dev"] +markers = "sys_platform != \"win32\"" files = [ {file = "virtualenv-20.24.6-py3-none-any.whl", hash = "sha256:520d056652454c5098a00c0f073611ccbea4c79089331f60bf9d7ba247bb7381"}, {file = "virtualenv-20.24.6.tar.gz", hash = "sha256:02ece4f56fbf939dbbc33c0715159951d6bf14aaf5457b092e4548e1382455af"}, @@ -6695,7 +7011,7 @@ platformdirs = ">=3.9.1,<4" [package.extras] docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] -test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"] +test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8) ; platform_python_implementation == \"PyPy\"", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10) ; platform_python_implementation == \"CPython\""] [[package]] name = "vizdoom" @@ -6703,6 +7019,8 @@ version = "1.2.2" description = "ViZDoom is Doom-based AI Research Platform for Reinforcement Learning from Raw Visual Information." optional = true python-versions = "*" +groups = ["main"] +markers = "extra == \"vizdoom\"" files = [ {file = "vizdoom-1.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3e2f478e1728702f17b828de0e7ee6bf0e2809c1786ce21f69ce00e4a4da82e0"}, {file = "vizdoom-1.2.2-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:49180ed13d30109bcd99b38e6b923c5bd74e6bb364add8d46beb5cdf7405fe10"}, @@ -6738,6 +7056,7 @@ version = "0.12.21" description = "A CLI and library for interacting with the Weights and Biases API." optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "wandb-0.12.21-py2.py3-none-any.whl", hash = "sha256:150842447d355d90dc7f368b824951a625e5b2d1be355a00e99b11b73728bc1f"}, {file = "wandb-0.12.21.tar.gz", hash = "sha256:1975ff88c5024923c3321c93cfefb8d9b871543c0b009f34001bf0f31e444b04"}, @@ -6776,6 +7095,7 @@ version = "0.2.10" description = "Measures the displayed width of unicode strings in a terminal" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "wcwidth-0.2.10-py2.py3-none-any.whl", hash = "sha256:aec5179002dd0f0d40c456026e74a729661c9d468e1ed64405e3a6c2176ca36f"}, {file = "wcwidth-0.2.10.tar.gz", hash = "sha256:390c7454101092a6a5e43baad8f83de615463af459201709556b6e4b1c861f97"}, @@ -6787,6 +7107,7 @@ version = "1.13" description = "A library for working with the color formats defined by HTML and CSS." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "webcolors-1.13-py3-none-any.whl", hash = "sha256:29bc7e8752c0a1bd4a1f03c14d6e6a72e93d82193738fa860cbff59d0fcc11bf"}, {file = "webcolors-1.13.tar.gz", hash = "sha256:c225b674c83fa923be93d235330ce0300373d02885cef23238813b0d5668304a"}, @@ -6802,6 +7123,7 @@ version = "0.5.1" description = "Character encoding aliases for legacy web content" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "webencodings-0.5.1-py2.py3-none-any.whl", hash = "sha256:a0af1213f3c2226497a97e2b3aa01a7e4bee4f403f95be16fc9acd2947514a78"}, {file = "webencodings-0.5.1.tar.gz", hash = "sha256:b36a1c245f2d304965eb4e0a82848379241dc04b865afcc4aab16748587e1923"}, @@ -6813,6 +7135,7 @@ version = "1.6.4" description = "WebSocket client for Python with low level API options" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "websocket-client-1.6.4.tar.gz", hash = "sha256:b3324019b3c28572086c4a319f91d1dcd44e6e11cd340232978c684a7650d0df"}, {file = "websocket_client-1.6.4-py3-none-any.whl", hash = "sha256:084072e0a7f5f347ef2ac3d8698a5e0b4ffbfcab607628cadabc650fc9a83a24"}, @@ -6829,6 +7152,7 @@ version = "3.0.6" description = "The comprehensive WSGI web application library." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "werkzeug-3.0.6-py3-none-any.whl", hash = "sha256:1bc0c2310d2fbb07b1dd1105eba2f7af72f322e1e455f2f93c993bee8c8a5f17"}, {file = "werkzeug-3.0.6.tar.gz", hash = "sha256:a8dd59d4de28ca70471a34cba79bed5f7ef2e036a76b3ab0835474246eb41f8d"}, @@ -6846,6 +7170,7 @@ version = "0.41.3" description = "A built-package format for Python" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "wheel-0.41.3-py3-none-any.whl", hash = "sha256:488609bc63a29322326e05560731bf7bfea8e48ad646e1f5e40d366607de0942"}, {file = "wheel-0.41.3.tar.gz", hash = "sha256:4d4987ce51a49370ea65c0bfd2234e8ce80a12780820d9dc462597a6e60d0841"}, @@ -6860,6 +7185,7 @@ version = "4.0.9" description = "Jupyter interactive widgets for Jupyter Notebook" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "widgetsnbextension-4.0.9-py3-none-any.whl", hash = "sha256:91452ca8445beb805792f206e560c1769284267a30ceb1cec9f5bcc887d15175"}, {file = "widgetsnbextension-4.0.9.tar.gz", hash = "sha256:3c1f5e46dc1166dfd40a42d685e6a51396fd34ff878742a3e47c6f0cc4a2a385"}, @@ -6871,6 +7197,7 @@ version = "3.19.1" description = "Backport of pathlib-compatible object wrapper for zip files" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "zipp-3.19.1-py3-none-any.whl", hash = "sha256:2828e64edb5386ea6a52e7ba7cdb17bb30a73a858f5eb6eb93d8d36f5ea26091"}, {file = "zipp-3.19.1.tar.gz", hash = "sha256:35427f6d5594f4acf82d25541438348c26736fa9b3afa2754bcd63cdb99d8e8f"}, @@ -6888,12 +7215,11 @@ classic-control = ["pygame"] envpool = ["envpool"] eval = ["docstring-parser", "joblib", "jsonargparse", "rliable", "scipy"] mujoco = ["imageio", "mujoco"] -mujoco-py = ["cython", "mujoco-py"] pybullet = ["pybullet"] robotics = ["gymnasium-robotics"] vizdoom = ["vizdoom"] [metadata] -lock-version = "2.0" +lock-version = "2.1" python-versions = "^3.11" -content-hash = "bff3f4f8cc0d8196ea162a799472c7179486109d30968aa7d1b96b40016a459f" +content-hash = "575f58bac92d215908d074f946b8593cbefaf83f965beed396253d8d3f38eea7" diff --git a/pyproject.toml b/pyproject.toml index 8e9a5ad30..2b76846b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "tianshou" -version = "1.2.0" +version = "2.0.0b1" description = "A Library for Deep Reinforcement Learning" authors = ["TSAIL "] license = "MIT" @@ -64,7 +64,6 @@ joblib = { version = "*", optional = true } jsonargparse = {version = "^4.24.1", optional = true} # we need <3 b/c of https://github.com/Farama-Foundation/Gymnasium/issues/749 mujoco = { version = ">=2.1.5, <3", optional = true } -mujoco-py = { version = ">=2.1,<2.2", optional = true } opencv_python = { version = "*", optional = true } pybullet = { version = "*", optional = true } pygame = { version = ">=2.1.3", optional = true } @@ -79,7 +78,6 @@ atari = ["ale-py", "autorom", "opencv-python", "shimmy"] box2d = ["box2d-py", "pygame", "swig"] classic_control = ["pygame"] mujoco = ["mujoco", "imageio"] -mujoco_py = ["mujoco-py", "cython"] pybullet = ["pybullet"] envpool = ["envpool"] robotics = ["gymnasium-robotics"] @@ -182,6 +180,7 @@ ignore = [ "B027", # empty and non-abstract method in abstract class "D404", # It's fine to start with "This" in docstrings "D407", "D408", "D409", # Ruff rules for underlines under 'Example:' and so clash with Sphinx + "COM812", # missing trailing comma: With this enabled, re-application of "poe format" chain can cause additional commas and subsequent reformatting "B023", # forbids function using loop variable without explicit binding ] unfixable = [ @@ -204,6 +203,7 @@ max-complexity = 20 "test/**" = ["D103"] "docs/**" = ["D103"] "examples/**" = ["D103"] +"__init__.py" = ["F401"] # do not remove "unused" imports (F401) from __init__.py files [tool.poetry-sort] move-optionals-to-bottom = true @@ -212,7 +212,7 @@ move-optionals-to-bottom = true PYDEVD_DISABLE_FILE_VALIDATION="1" # keep relevant parts in sync with pre-commit [tool.poe.tasks] # https://github.com/nat-n/poethepoet -test = "pytest test --cov=tianshou --cov-report=xml --cov-report=term-missing --durations=0 -v --color=yes" +test = "pytest test" test-reduced = "pytest test/base test/continuous --cov=tianshou --durations=0 -v --color=yes" _black_check = "black --check ." _ruff_check = "ruff check ." @@ -229,9 +229,9 @@ _autogen_rst = "python docs/autogen_rst.py" _sphinx_build = "sphinx-build -b html docs docs/_build -W --keep-going" _jb_generate_toc = "python docs/create_toc.py" _jb_generate_config = "jupyter-book config sphinx docs/" -doc-clean = "rm -rf docs/_build" +doc-clean = "rm -rf docs/_build docs/03_api" doc-generate-files = ["_autogen_rst", "_jb_generate_toc", "_jb_generate_config"] -doc-build = ["doc-generate-files", "_sphinx_build"] +doc-build = ["doc-clean", "doc-generate-files", "_sphinx_build"] _mypy = "mypy tianshou test examples" _mypy_nb = "nbqa mypy docs" type-check = ["_mypy", "_mypy_nb"] diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 1e4769bbf..b88aa1fca 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -7,6 +7,7 @@ import pytest import tqdm +from tianshou.algorithm.algorithm_base import Policy, episode_mc_return_to_go from tianshou.data import ( AsyncCollector, Batch, @@ -25,8 +26,6 @@ ) from tianshou.data.types import ObsBatchProtocol, RolloutBatchProtocol from tianshou.env import DummyVectorEnv, SubprocVectorEnv -from tianshou.policy import BasePolicy, TrainingStats -from tianshou.policy.base import episode_mc_return_to_go try: import envpool @@ -34,7 +33,7 @@ envpool = None -class MaxActionPolicy(BasePolicy): +class MaxActionPolicy(Policy): def __init__( self, action_space: gym.spaces.Space | None = None, @@ -80,9 +79,6 @@ def forward( action_shape = self.action_shape if self.action_shape else len(batch.obs) return Batch(act=np.ones(action_shape), state=state) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TrainingStats: - raise NotImplementedError - @pytest.fixture() def collector_with_single_env() -> Collector[CollectStats]: diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py index 39cf9d3ea..32289bfec 100644 --- a/test/base/test_env_finite.py +++ b/test/base/test_env_finite.py @@ -11,16 +11,15 @@ from gymnasium.spaces import Box from torch.utils.data import DataLoader, Dataset, DistributedSampler +from tianshou.algorithm.algorithm_base import Policy from tianshou.data import Batch, Collector, CollectStats from tianshou.data.types import ( ActBatchProtocol, BatchProtocol, ObsBatchProtocol, - RolloutBatchProtocol, ) from tianshou.env import BaseVectorEnv, DummyVectorEnv, SubprocVectorEnv from tianshou.env.utils import ENV_TYPE, gym_new_venv_step_type -from tianshou.policy import BasePolicy class DummyDataset(Dataset): @@ -204,7 +203,7 @@ class FiniteSubprocVectorEnv(FiniteVectorEnv, SubprocVectorEnv): pass -class AnyPolicy(BasePolicy): +class DummyPolicy(Policy): def __init__(self) -> None: super().__init__(action_space=Box(-1, 1, (1,))) @@ -216,9 +215,6 @@ def forward( ) -> ActBatchProtocol: return cast(ActBatchProtocol, Batch(act=np.stack([1] * len(batch)))) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> None: - pass - def _finite_env_factory(dataset: Dataset, num_replicas: int, rank: int) -> Callable[[], FiniteEnv]: return lambda: FiniteEnv(dataset, num_replicas, rank) @@ -247,7 +243,7 @@ def validate(self) -> None: def test_finite_dummy_vector_env() -> None: dataset = DummyDataset(100) envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)]) - policy = AnyPolicy() + policy = DummyPolicy() test_collector = Collector[CollectStats](policy, envs, exploration_noise=True) test_collector.reset() @@ -263,7 +259,7 @@ def test_finite_dummy_vector_env() -> None: def test_finite_subproc_vector_env() -> None: dataset = DummyDataset(100) envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)]) - policy = AnyPolicy() + policy = DummyPolicy() test_collector = Collector[CollectStats](policy, envs, exploration_noise=True) test_collector.reset() diff --git a/test/base/test_policy.py b/test/base/test_policy.py index b918194da..940ffd0b2 100644 --- a/test/base/test_policy.py +++ b/test/base/test_policy.py @@ -4,12 +4,17 @@ import torch from torch.distributions import Categorical, Distribution, Independent, Normal +from tianshou.algorithm import PPO +from tianshou.algorithm.algorithm_base import ( + RandomActionPolicy, + episode_mc_return_to_go, +) +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Batch -from tianshou.policy import BasePolicy, PPOPolicy -from tianshou.policy.base import RandomActionPolicy, episode_mc_return_to_go -from tianshou.utils.net.common import ActorCritic, Net -from tianshou.utils.net.continuous import ActorProb, Critic -from tianshou.utils.net.discrete import Actor +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic +from tianshou.utils.net.discrete import DiscreteActor obs_shape = (5,) @@ -26,14 +31,16 @@ def test_calculate_discounted_returns() -> None: @pytest.fixture(params=["continuous", "discrete"]) -def policy(request: pytest.FixtureRequest) -> PPOPolicy: +def algorithm(request: pytest.FixtureRequest) -> PPO: action_type = request.param action_space: gym.spaces.Box | gym.spaces.Discrete - actor: Actor | ActorProb + actor: DiscreteActor | ContinuousActorProbabilistic if action_type == "continuous": action_space = gym.spaces.Box(low=-1, high=1, shape=(3,)) - actor = ActorProb( - Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.shape), + actor = ContinuousActorProbabilistic( + preprocess_net=Net( + state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.shape + ), action_shape=action_space.shape, ) @@ -43,36 +50,41 @@ def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: elif action_type == "discrete": action_space = gym.spaces.Discrete(3) - actor = Actor( - Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.n), + actor = DiscreteActor( + preprocess_net=Net( + state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.n + ), action_shape=action_space.n, ) dist_fn = Categorical else: raise ValueError(f"Unknown action type: {action_type}") - critic = Critic( - Net(obs_shape, hidden_sizes=[64, 64]), + critic = ContinuousCritic( + preprocess_net=Net(state_shape=obs_shape, hidden_sizes=[64, 64]), ) - actor_critic = ActorCritic(actor, critic) - optim = torch.optim.Adam(actor_critic.parameters(), lr=1e-3) + optim = AdamOptimizerFactory(lr=1e-3) - policy: BasePolicy - policy = PPOPolicy( + algorithm: PPO + policy = ProbabilisticActorPolicy( actor=actor, - critic=critic, dist_fn=dist_fn, - optim=optim, action_space=action_space, action_scaling=False, ) - policy.eval() - return policy + algorithm = PPO( + policy=policy, + critic=critic, + optim=optim, + ) + algorithm.eval() + return algorithm class TestPolicyBasics: - def test_get_action(self, policy: PPOPolicy) -> None: + def test_get_action(self, algorithm: PPO) -> None: + policy = algorithm.policy policy.is_within_training_step = False sample_obs = torch.randn(obs_shape) policy.deterministic_eval = False diff --git a/test/base/test_returns.py b/test/base/test_returns.py index 078893113..1e2b00dd2 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -3,9 +3,9 @@ import numpy as np import torch +from tianshou.algorithm import Algorithm from tianshou.data import Batch, ReplayBuffer, to_numpy from tianshou.data.types import RolloutBatchProtocol -from tianshou.policy import BasePolicy def compute_episodic_return_base(batch: Batch, gamma: float) -> Batch: @@ -21,7 +21,7 @@ def compute_episodic_return_base(batch: Batch, gamma: float) -> Batch: def test_episodic_returns(size: int = 2560) -> None: - fn = BasePolicy.compute_episodic_return + fn = Algorithm.compute_episodic_return buf = ReplayBuffer(20) batch = cast( RolloutBatchProtocol, @@ -215,7 +215,7 @@ def test_nstep_returns(size: int = 10000) -> None: # done: [ 0, 1, 0, 1, 0, 0, 0, 1, 0, 0] # test nstep = 1 returns = to_numpy( - BasePolicy.compute_nstep_return(batch, buf, indices, target_q_fn, gamma=0.1, n_step=1) + Algorithm.compute_nstep_return(batch, buf, indices, target_q_fn, gamma=0.1, n_step=1) .pop("returns") .reshape(-1), ) @@ -223,7 +223,7 @@ def test_nstep_returns(size: int = 10000) -> None: r_ = compute_nstep_return_base(1, 0.1, buf, indices) assert np.allclose(returns, r_), (r_, returns) returns_multidim = to_numpy( - BasePolicy.compute_nstep_return( + Algorithm.compute_nstep_return( batch, buf, indices, @@ -235,7 +235,7 @@ def test_nstep_returns(size: int = 10000) -> None: assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 2 returns = to_numpy( - BasePolicy.compute_nstep_return(batch, buf, indices, target_q_fn, gamma=0.1, n_step=2) + Algorithm.compute_nstep_return(batch, buf, indices, target_q_fn, gamma=0.1, n_step=2) .pop("returns") .reshape(-1), ) @@ -243,7 +243,7 @@ def test_nstep_returns(size: int = 10000) -> None: r_ = compute_nstep_return_base(2, 0.1, buf, indices) assert np.allclose(returns, r_) returns_multidim = to_numpy( - BasePolicy.compute_nstep_return( + Algorithm.compute_nstep_return( batch, buf, indices, @@ -255,7 +255,7 @@ def test_nstep_returns(size: int = 10000) -> None: assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 10 returns = to_numpy( - BasePolicy.compute_nstep_return(batch, buf, indices, target_q_fn, gamma=0.1, n_step=10) + Algorithm.compute_nstep_return(batch, buf, indices, target_q_fn, gamma=0.1, n_step=10) .pop("returns") .reshape(-1), ) @@ -263,7 +263,7 @@ def test_nstep_returns(size: int = 10000) -> None: r_ = compute_nstep_return_base(10, 0.1, buf, indices) assert np.allclose(returns, r_) returns_multidim = to_numpy( - BasePolicy.compute_nstep_return( + Algorithm.compute_nstep_return( batch, buf, indices, @@ -297,7 +297,7 @@ def test_nstep_returns_with_timelimit(size: int = 10000) -> None: # done: [ 0, 1, 0, 1, 0, 0, 0, 1, 0, 0] # test nstep = 1 returns = to_numpy( - BasePolicy.compute_nstep_return(batch, buf, indices, target_q_fn, gamma=0.1, n_step=1) + Algorithm.compute_nstep_return(batch, buf, indices, target_q_fn, gamma=0.1, n_step=1) .pop("returns") .reshape(-1), ) @@ -305,7 +305,7 @@ def test_nstep_returns_with_timelimit(size: int = 10000) -> None: r_ = compute_nstep_return_base(1, 0.1, buf, indices) assert np.allclose(returns, r_), (r_, returns) returns_multidim = to_numpy( - BasePolicy.compute_nstep_return( + Algorithm.compute_nstep_return( batch, buf, indices, @@ -317,7 +317,7 @@ def test_nstep_returns_with_timelimit(size: int = 10000) -> None: assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 2 returns = to_numpy( - BasePolicy.compute_nstep_return(batch, buf, indices, target_q_fn, gamma=0.1, n_step=2) + Algorithm.compute_nstep_return(batch, buf, indices, target_q_fn, gamma=0.1, n_step=2) .pop("returns") .reshape(-1), ) @@ -325,7 +325,7 @@ def test_nstep_returns_with_timelimit(size: int = 10000) -> None: r_ = compute_nstep_return_base(2, 0.1, buf, indices) assert np.allclose(returns, r_) returns_multidim = to_numpy( - BasePolicy.compute_nstep_return( + Algorithm.compute_nstep_return( batch, buf, indices, @@ -337,7 +337,7 @@ def test_nstep_returns_with_timelimit(size: int = 10000) -> None: assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 10 returns = to_numpy( - BasePolicy.compute_nstep_return(batch, buf, indices, target_q_fn, gamma=0.1, n_step=10) + Algorithm.compute_nstep_return(batch, buf, indices, target_q_fn, gamma=0.1, n_step=10) .pop("returns") .reshape(-1), ) @@ -345,7 +345,7 @@ def test_nstep_returns_with_timelimit(size: int = 10000) -> None: r_ = compute_nstep_return_base(10, 0.1, buf, indices) assert np.allclose(returns, r_) returns_multidim = to_numpy( - BasePolicy.compute_nstep_return( + Algorithm.compute_nstep_return( batch, buf, indices, diff --git a/test/base/test_stats.py b/test/base/test_stats.py index 821152e83..2b5630465 100644 --- a/test/base/test_stats.py +++ b/test/base/test_stats.py @@ -5,9 +5,9 @@ import torch from torch.distributions import Categorical, Normal +from tianshou.algorithm.algorithm_base import TrainingStats, TrainingStatsWrapper from tianshou.data import Batch, CollectStats from tianshou.data.collector import CollectStepBatchProtocol, get_stddev_from_dist -from tianshou.policy.base import TrainingStats, TrainingStatsWrapper class DummyTrainingStatsWrapper(TrainingStatsWrapper): diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 8e44ad57b..6c992a165 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -8,7 +8,7 @@ from torch import nn from tianshou.exploration import GaussianNoise, OUNoise -from tianshou.utils import MovAvg, MultipleLRSchedulers, RunningMeanStd +from tianshou.utils import MovAvg, RunningMeanStd from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic from tianshou.utils.torch_utils import create_uniform_action_dist, torch_train_mode @@ -52,10 +52,10 @@ def test_net() -> None: bsz = 64 # MLP data = torch.rand([bsz, 3]) - mlp = MLP(3, 6, hidden_sizes=[128]) + mlp = MLP(input_dim=3, output_dim=6, hidden_sizes=[128]) assert list(mlp(data).shape) == [bsz, 6] # output == 0 and len(hidden_sizes) == 0 means identity model - mlp = MLP(6, 0) + mlp = MLP(input_dim=6, output_dim=0) assert data.shape == mlp(data).shape # common net state_shape = (10, 2) @@ -63,8 +63,8 @@ def test_net() -> None: data = torch.rand([bsz, *state_shape]) expect_output_shape = [bsz, *action_shape] net = Net( - state_shape, - action_shape, + state_shape=state_shape, + action_shape=action_shape, hidden_sizes=[128, 128], norm_layer=torch.nn.LayerNorm, activation=None, @@ -74,20 +74,20 @@ def test_net() -> None: assert str(net).count("ReLU") == 0 Q_param = V_param = {"hidden_sizes": [128, 128]} net = Net( - state_shape, - action_shape, + state_shape=state_shape, + action_shape=action_shape, hidden_sizes=[128, 128], dueling_param=(Q_param, V_param), ) assert list(net(data)[0].shape) == expect_output_shape # concat - net = Net(state_shape, action_shape, hidden_sizes=[128], concat=True) + net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128], concat=True) data = torch.rand([bsz, int(np.prod(state_shape)) + int(np.prod(action_shape))]) expect_output_shape = [bsz, 128] assert list(net(data)[0].shape) == expect_output_shape net = Net( - state_shape, - action_shape, + state_shape=state_shape, + action_shape=action_shape, hidden_sizes=[128], concat=True, dueling_param=(Q_param, V_param), @@ -96,49 +96,16 @@ def test_net() -> None: # recurrent actor/critic data = torch.rand([bsz, *state_shape]).flatten(1) expect_output_shape = [bsz, *action_shape] - net = RecurrentActorProb(3, state_shape, action_shape) + net = RecurrentActorProb(layer_num=3, state_shape=state_shape, action_shape=action_shape) mu, sigma = net(data)[0] assert mu.shape == sigma.shape assert list(mu.shape) == [bsz, 5] - net = RecurrentCritic(3, state_shape, action_shape) + net = RecurrentCritic(layer_num=3, state_shape=state_shape, action_shape=action_shape) data = torch.rand([bsz, 8, int(np.prod(state_shape))]) act = torch.rand(expect_output_shape) assert list(net(data, act).shape) == [bsz, 1] -def test_lr_schedulers() -> None: - initial_lr_1 = 10.0 - step_size_1 = 1 - gamma_1 = 0.5 - net_1 = torch.nn.Linear(2, 3) - optim_1 = torch.optim.Adam(net_1.parameters(), lr=initial_lr_1) - sched_1 = torch.optim.lr_scheduler.StepLR(optim_1, step_size=step_size_1, gamma=gamma_1) - - initial_lr_2 = 5.0 - step_size_2 = 2 - gamma_2 = 0.3 - net_2 = torch.nn.Linear(3, 2) - optim_2 = torch.optim.Adam(net_2.parameters(), lr=initial_lr_2) - sched_2 = torch.optim.lr_scheduler.StepLR(optim_2, step_size=step_size_2, gamma=gamma_2) - schedulers = MultipleLRSchedulers(sched_1, sched_2) - for _ in range(10): - loss_1 = (torch.ones((1, 3)) - net_1(torch.ones((1, 2)))).sum() - optim_1.zero_grad() - loss_1.backward() - optim_1.step() - loss_2 = (torch.ones((1, 2)) - net_2(torch.ones((1, 3)))).sum() - optim_2.zero_grad() - loss_2.backward() - optim_2.step() - schedulers.step() - assert optim_1.state_dict()["param_groups"][0]["lr"] == ( - initial_lr_1 * gamma_1 ** (10 // step_size_1) - ) - assert optim_2.state_dict()["param_groups"][0]["lr"] == ( - initial_lr_2 * gamma_2 ** (10 // step_size_2) - ) - - def test_in_eval_mode() -> None: module = nn.Linear(3, 4) module.train() diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 1569c0df1..325b4ce23 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -7,41 +7,43 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import DDPG +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.exploration import GaussianNoise -from tianshou.policy import DDPGPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import Actor, Critic +from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=20000) - parser.add_argument("--actor-lr", type=float, default=1e-4) - parser.add_argument("--critic-lr", type=float, default=1e-3) + parser.add_argument("--buffer_size", type=int, default=20000) + parser.add_argument("--actor_lr", type=float, default=1e-4) + parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--tau", type=float, default=0.005) - parser.add_argument("--exploration-noise", type=float, default=0.1) + parser.add_argument("--exploration_noise", type=float, default=0.1) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=20000) - parser.add_argument("--step-per-collect", type=int, default=8) - parser.add_argument("--update-per-step", type=float, default=0.125) - parser.add_argument("--batch-size", type=int, default=128) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128]) - parser.add_argument("--training-num", type=int, default=8) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--epoch_num_steps", type=int, default=20000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=8) + parser.add_argument("--update_per_step", type=float, default=0.125) + parser.add_argument("--batch_size", type=int, default=128) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128]) + parser.add_argument("--num_train_envs", type=int, default=8) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--rew-norm", action="store_true", default=False) - parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--return_scaling", action="store_true", default=False) + parser.add_argument("--n_step", type=int, default=3) parser.add_argument( "--device", type=str, @@ -62,76 +64,82 @@ def test_ddpg(args: argparse.Namespace = get_args(), enable_assertions: bool = T args.task, env.spec.reward_threshold if env.spec else None, ) - # you can also use tianshou.env.SubprocVectorEnv - # train_envs = gym.make(args.task) - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) + # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = Actor(net, args.action_shape, max_action=args.max_action, device=args.device).to( + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + actor = ContinuousActorDeterministic( + preprocess_net=net, action_shape=args.action_shape, max_action=args.max_action + ).to( args.device, ) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic = Critic(net, device=args.device).to(args.device) - critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) - policy: DDPGPolicy = DDPGPolicy( + critic = ContinuousCritic(preprocess_net=net).to(args.device) + critic_optim = AdamOptimizerFactory(lr=args.critic_lr) + policy = ContinuousDeterministicPolicy( actor=actor, - actor_optim=actor_optim, + exploration_noise=GaussianNoise(sigma=args.exploration_noise), + action_space=env.action_space, + ) + policy_optim = AdamOptimizerFactory(lr=args.actor_lr) + algorithm: DDPG = DDPG( + policy=policy, + policy_optim=policy_optim, critic=critic, critic_optim=critic_optim, tau=args.tau, gamma=args.gamma, - exploration_noise=GaussianNoise(sigma=args.exploration_noise), - estimation_step=args.n_step, - action_space=env.action_space, + n_step_return_horizon=args.n_step, ) + # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # log log_path = os.path.join(args.logdir, args.task, "ddpg") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - update_per_step=args.update_per_step, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + update_step_num_gradient_steps_per_sample=args.update_per_step, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=True, + ) + ) if enable_assertions: assert stop_fn(result.best_reward) diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index 3b413eec4..e2ce35cd0 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -9,34 +9,37 @@ from torch.distributions import Distribution, Independent, Normal from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import NPG +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import NPGPolicy -from tianshou.policy.base import BasePolicy -from tianshou.policy.modelfree.npg import NPGTrainingStats -from tianshou.trainer import OnpolicyTrainer +from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--buffer-size", type=int, default=50000) + parser.add_argument("--buffer_size", type=int, default=50000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.95) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=50000) - parser.add_argument("--step-per-collect", type=int, default=2048) - parser.add_argument("--repeat-per-collect", type=int, default=2) # theoretically it should be 1 - parser.add_argument("--batch-size", type=int, default=99999) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--training-num", type=int, default=16) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=50000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=2048) + parser.add_argument( + "--update_step_num_repetitions", type=int, default=2 + ) # theoretically it should be 1 + parser.add_argument("--batch_size", type=int, default=99999) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--num_train_envs", type=int, default=16) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -45,11 +48,11 @@ def get_args() -> argparse.Namespace: default="cuda" if torch.cuda.is_available() else "cpu", ) # npg special - parser.add_argument("--gae-lambda", type=float, default=0.95) - parser.add_argument("--rew-norm", type=int, default=1) - parser.add_argument("--norm-adv", type=int, default=1) - parser.add_argument("--optim-critic-iters", type=int, default=5) - parser.add_argument("--actor-step-size", type=float, default=0.5) + parser.add_argument("--gae_lambda", type=float, default=0.95) + parser.add_argument("--return_scaling", type=int, default=1) + parser.add_argument("--advantage_normalization", type=int, default=1) + parser.add_argument("--optim_critic_iters", type=int, default=5) + parser.add_argument("--trust_region_size", type=float, default=0.5) return parser.parse_known_args()[0] @@ -67,39 +70,37 @@ def test_npg(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr args.task, env.spec.reward_threshold if env.spec else None, ) - # you can also use tianshou.env.SubprocVectorEnv - # train_envs = gym.make(args.task) - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) + # model net = Net( - args.state_shape, + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, - device=args.device, ) - actor = ActorProb(net, args.action_shape, unbounded=True, device=args.device).to(args.device) - critic = Critic( - Net( - args.state_shape, + actor = ContinuousActorProbabilistic( + preprocess_net=net, action_shape=args.action_shape, unbounded=True + ).to(args.device) + critic = ContinuousCritic( + preprocess_net=Net( + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device, activation=nn.Tanh, ), - device=args.device, ).to(args.device) + # orthogonal initialization for m in list(actor.modules()) + list(critic.modules()): if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) - optim = torch.optim.Adam(critic.parameters(), lr=args.lr) # replace DiagGuassian with Independent(Normal) which is equivalent # pass *logits to be consistent with policy.forward @@ -107,53 +108,58 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy: NPGPolicy[NPGTrainingStats] = NPGPolicy( + policy = ProbabilisticActorPolicy( actor=actor, - critic=critic, - optim=optim, dist_fn=dist, - discount_factor=args.gamma, - reward_normalization=args.rew_norm, - advantage_normalization=args.norm_adv, - gae_lambda=args.gae_lambda, action_space=env.action_space, - optim_critic_iters=args.optim_critic_iters, - actor_step_size=args.actor_step_size, deterministic_eval=True, ) + algorithm: NPG = NPG( + policy=policy, + critic=critic, + optim=AdamOptimizerFactory(lr=args.lr), + gamma=args.gamma, + return_scaling=args.return_scaling, + advantage_normalization=args.advantage_normalization, + gae_lambda=args.gae_lambda, + optim_critic_iters=args.optim_critic_iters, + trust_region_size=args.trust_region_size, + ) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), ) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # log log_path = os.path.join(args.logdir, args.task, "npg") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer - result = OnpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - step_per_collect=args.step_per_collect, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + result = algorithm.run_training( + OnPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + collection_step_num_env_steps=args.collection_step_num_env_steps, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=True, + ) + ) if enable_assertions: assert stop_fn(result.best_reward) diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index cbb8544ab..686fa86d7 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -8,34 +8,35 @@ from torch.distributions import Distribution, Independent, Normal from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import PPO +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import PPOPolicy -from tianshou.policy.base import BasePolicy -from tianshou.policy.modelfree.ppo import PPOTrainingStats -from tianshou.trainer import OnpolicyTrainer +from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.95) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=150000) - parser.add_argument("--episode-per-collect", type=int, default=16) - parser.add_argument("--repeat-per-collect", type=int, default=2) - parser.add_argument("--batch-size", type=int, default=128) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--training-num", type=int, default=16) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--epoch_num_steps", type=int, default=150000) + parser.add_argument("--collection_step_num_episodes", type=int, default=16) + parser.add_argument("--update_step_num_repetitions", type=int, default=2) + parser.add_argument("--batch_size", type=int, default=128) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--num_train_envs", type=int, default=16) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -44,18 +45,18 @@ def get_args() -> argparse.Namespace: default="cuda" if torch.cuda.is_available() else "cpu", ) # ppo special - parser.add_argument("--vf-coef", type=float, default=0.25) - parser.add_argument("--ent-coef", type=float, default=0.0) - parser.add_argument("--eps-clip", type=float, default=0.2) - parser.add_argument("--max-grad-norm", type=float, default=0.5) - parser.add_argument("--gae-lambda", type=float, default=0.95) - parser.add_argument("--rew-norm", type=int, default=1) - parser.add_argument("--dual-clip", type=float, default=None) - parser.add_argument("--value-clip", type=int, default=1) - parser.add_argument("--norm-adv", type=int, default=1) - parser.add_argument("--recompute-adv", type=int, default=0) + parser.add_argument("--vf_coef", type=float, default=0.25) + parser.add_argument("--ent_coef", type=float, default=0.0) + parser.add_argument("--eps_clip", type=float, default=0.2) + parser.add_argument("--max_grad_norm", type=float, default=0.5) + parser.add_argument("--gae_lambda", type=float, default=0.95) + parser.add_argument("--return_scaling", type=int, default=1) + parser.add_argument("--dual_clip", type=float, default=None) + parser.add_argument("--value_clip", type=int, default=1) + parser.add_argument("--advantage_normalization", type=int, default=1) + parser.add_argument("--recompute_adv", type=int, default=0) parser.add_argument("--resume", action="store_true") - parser.add_argument("--save-interval", type=int, default=4) + parser.add_argument("--save_interval", type=int, default=4) return parser.parse_known_args()[0] @@ -73,22 +74,22 @@ def test_ppo(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr args.task, env.spec.reward_threshold if env.spec else None, ) - # you can also use tianshou.env.SubprocVectorEnv - # train_envs = gym.make(args.task) - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) + # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = ActorProb(net, args.action_shape, unbounded=True, device=args.device).to(args.device) - critic = Critic( - Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device), - device=args.device, + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + actor = ContinuousActorProbabilistic( + preprocess_net=net, action_shape=args.action_shape, unbounded=True + ).to(args.device) + critic = ContinuousCritic( + preprocess_net=Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes), ).to(args.device) actor_critic = ActorCritic(actor, critic) # orthogonal initialization @@ -96,7 +97,7 @@ def test_ppo(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) - optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) # replace DiagGuassian with Independent(Normal) which is equivalent # pass *logits to be consistent with policy.forward @@ -104,37 +105,40 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy: PPOPolicy[PPOTrainingStats] = PPOPolicy( + policy = ProbabilisticActorPolicy( actor=actor, + dist_fn=dist, + action_space=env.action_space, + ) + algorithm: PPO = PPO( + policy=policy, critic=critic, optim=optim, - dist_fn=dist, - discount_factor=args.gamma, + gamma=args.gamma, max_grad_norm=args.max_grad_norm, eps_clip=args.eps_clip, vf_coef=args.vf_coef, ent_coef=args.ent_coef, - reward_normalization=args.rew_norm, - advantage_normalization=args.norm_adv, + return_scaling=args.return_scaling, + advantage_normalization=args.advantage_normalization, recompute_advantage=args.recompute_adv, dual_clip=args.dual_clip, value_clip=args.value_clip, gae_lambda=args.gae_lambda, - action_space=env.action_space, ) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), ) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # log log_path = os.path.join(args.logdir, args.task, "ppo") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: @@ -146,10 +150,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # Example: saving by epoch num # ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") torch.save( - { - "model": policy.state_dict(), - "optim": optim.state_dict(), - }, + algorithm.state_dict(), ckpt_path, ) return ckpt_path @@ -160,29 +161,31 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: ckpt_path = os.path.join(log_path, "checkpoint.pth") if os.path.exists(ckpt_path): checkpoint = torch.load(ckpt_path, map_location=args.device) - policy.load_state_dict(checkpoint["model"]) - optim.load_state_dict(checkpoint["optim"]) + algorithm.load_state_dict(checkpoint) print("Successfully restore policy and optim.") else: print("Fail to restore policy and optim.") - # trainer - result = OnpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - episode_per_collect=args.episode_per_collect, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - resume_from_log=args.resume, - save_checkpoint_fn=save_checkpoint_fn, - ).run() + # train + result = algorithm.run_training( + OnPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + collection_step_num_episodes=args.collection_step_num_episodes, + collection_step_num_env_steps=None, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + resume_from_log=args.resume, + save_checkpoint_fn=save_checkpoint_fn, + test_in_train=True, + ) + ) if enable_assertions: assert stop_fn(result.best_reward) diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index 24a7f420c..5a2538033 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -8,43 +8,46 @@ import torch.nn as nn from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import REDQ +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.redq import REDQPolicy +from tianshou.algorithm.modelfree.sac import AutoAlpha +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import REDQPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import EnsembleLinear, Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=20000) - parser.add_argument("--ensemble-size", type=int, default=4) - parser.add_argument("--subset-size", type=int, default=2) - parser.add_argument("--actor-lr", type=float, default=1e-4) - parser.add_argument("--critic-lr", type=float, default=1e-3) + parser.add_argument("--buffer_size", type=int, default=20000) + parser.add_argument("--ensemble_size", type=int, default=4) + parser.add_argument("--subset_size", type=int, default=2) + parser.add_argument("--actor_lr", type=float, default=1e-4) + parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--alpha", type=float, default=0.2) - parser.add_argument("--auto-alpha", action="store_true", default=False) - parser.add_argument("--alpha-lr", type=float, default=3e-4) - parser.add_argument("--start-timesteps", type=int, default=1000) + parser.add_argument("--auto_alpha", action="store_true", default=False) + parser.add_argument("--alpha_lr", type=float, default=3e-4) + parser.add_argument("--start_timesteps", type=int, default=1000) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=5000) - parser.add_argument("--step-per-collect", type=int, default=1) - parser.add_argument("--update-per-step", type=int, default=3) - parser.add_argument("--n-step", type=int, default=1) - parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument("--target-mode", type=str, choices=("min", "mean"), default="min") - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--training-num", type=int, default=8) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--epoch_num_steps", type=int, default=5000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=1) + parser.add_argument("--update_per_step", type=int, default=3) + parser.add_argument("--n_step", type=int, default=1) + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--target_mode", type=str, choices=("min", "mean"), default="min") + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--num_train_envs", type=int, default=8) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -61,7 +64,6 @@ def test_redq(args: argparse.Namespace = get_args(), enable_assertions: bool = T space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape - args.max_action = space_info.action_info.max_action if args.reward_threshold is None: default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250} args.reward_threshold = default_reward_threshold.get( @@ -70,24 +72,23 @@ def test_redq(args: argparse.Namespace = get_args(), enable_assertions: bool = T ) # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = ActorProb( - net, - args.action_shape, - device=args.device, + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + actor = ContinuousActorProbabilistic( + preprocess_net=net, + action_shape=args.action_shape, unbounded=True, conditioned_sigma=True, ).to(args.device) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) def linear(x: int, y: int) -> nn.Module: return EnsembleLinear(args.ensemble_size, x, y) @@ -97,24 +98,27 @@ def linear(x: int, y: int) -> nn.Module: action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, linear_layer=linear, ) - critic = Critic(net_c, device=args.device, linear_layer=linear, flatten_input=False).to( + critic = ContinuousCritic(preprocess_net=net_c, linear_layer=linear, flatten_input=False).to( args.device, ) - critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) + critic_optim = AdamOptimizerFactory(lr=args.critic_lr) action_dim = space_info.action_info.action_dim if args.auto_alpha: target_entropy = -action_dim - log_alpha = torch.zeros(1, requires_grad=True, device=args.device) - alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) - args.alpha = (target_entropy, log_alpha, alpha_optim) + log_alpha = 0.0 + alpha_optim = AdamOptimizerFactory(lr=args.alpha_lr) + args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim).to(args.device) - policy: REDQPolicy = REDQPolicy( + policy = REDQPolicy( actor=actor, - actor_optim=actor_optim, + action_space=env.action_space, + ) + algorithm: REDQ = REDQ( + policy=policy, + policy_optim=actor_optim, critic=critic, critic_optim=critic_optim, ensemble_size=args.ensemble_size, @@ -122,19 +126,18 @@ def linear(x: int, y: int) -> nn.Module: tau=args.tau, gamma=args.gamma, alpha=args.alpha, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, actor_delay=args.update_per_step, target_mode=args.target_mode, - action_space=env.action_space, ) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) train_collector.reset() train_collector.collect(n_step=args.start_timesteps, random=True) # log @@ -142,27 +145,29 @@ def linear(x: int, y: int) -> nn.Module: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - update_per_step=args.update_per_step, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + update_step_num_gradient_steps_per_sample=args.update_per_step, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=True, + ) + ) if enable_assertions: assert stop_fn(result.best_reward) diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 1d4fc06fe..5f68630fc 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -7,14 +7,21 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import SAC, OffPolicyImitationLearning +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.imitation.imitation_base import ImitationPolicy +from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import ImitationPolicy, SACPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import Actor, ActorProb, Critic +from tianshou.utils.net.continuous import ( + ContinuousActorDeterministic, + ContinuousActorProbabilistic, + ContinuousCritic, +) from tianshou.utils.space_info import SpaceInfo try: @@ -26,30 +33,30 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--buffer-size", type=int, default=20000) - parser.add_argument("--actor-lr", type=float, default=1e-3) - parser.add_argument("--critic-lr", type=float, default=1e-3) - parser.add_argument("--il-lr", type=float, default=1e-3) + parser.add_argument("--buffer_size", type=int, default=20000) + parser.add_argument("--actor_lr", type=float, default=1e-3) + parser.add_argument("--critic_lr", type=float, default=1e-3) + parser.add_argument("--il_lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--alpha", type=float, default=0.2) - parser.add_argument("--auto-alpha", type=int, default=1) - parser.add_argument("--alpha-lr", type=float, default=3e-4) + parser.add_argument("--auto_alpha", type=int, default=1) + parser.add_argument("--alpha_lr", type=float, default=3e-4) parser.add_argument("--epoch", type=int, default=10) - parser.add_argument("--step-per-epoch", type=int, default=24000) - parser.add_argument("--il-step-per-epoch", type=int, default=500) - parser.add_argument("--step-per-collect", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=128) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128]) - parser.add_argument("--imitation-hidden-sizes", type=int, nargs="*", default=[128, 128]) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--epoch_num_steps", type=int, default=24000) + parser.add_argument("--il_step_per_epoch", type=int, default=500) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) + parser.add_argument("--update_per_step", type=float, default=0.1) + parser.add_argument("--batch_size", type=int, default=128) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128]) + parser.add_argument("--imitation_hidden_sizes", type=int, nargs="*", default=[128, 128]) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--n_step", type=int, default=3) parser.add_argument( "--device", type=str, @@ -64,11 +71,11 @@ def test_sac_with_il( skip_il: bool = False, ) -> None: # if you want to use python vector env, please refer to other test scripts - # train_envs = env = envpool.make_gymnasium(args.task, num_envs=args.training_num, seed=args.seed) - # test_envs = envpool.make_gymnasium(args.task, num_envs=args.test_num, seed=args.seed) + # train_envs = env = envpool.make_gymnasium(args.task, num_envs=args.num_train_envs, seed=args.seed) + # test_envs = envpool.make_gymnasium(args.task, num_envs=args.num_test_envs, seed=args.seed) env = gym.make(args.task) - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape @@ -79,45 +86,48 @@ def test_sac_with_il( args.task, env.spec.reward_threshold if env.spec else None, ) - # you can also use tianshou.env.SubprocVectorEnv + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) - test_envs.seed(args.seed + args.training_num) + test_envs.seed(args.seed + args.num_train_envs) + # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = ActorProb(net, args.action_shape, device=args.device, unbounded=True).to(args.device) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + actor = ContinuousActorProbabilistic( + preprocess_net=net, action_shape=args.action_shape, unbounded=True + ).to(args.device) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic1 = Critic(net_c1, device=args.device).to(args.device) - critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) + critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic2 = Critic(net_c2, device=args.device).to(args.device) - critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) - + critic2 = ContinuousCritic(preprocess_net=net_c2).to(args.device) + critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) action_dim = space_info.action_info.action_dim if args.auto_alpha: target_entropy = -action_dim - log_alpha = torch.zeros(1, requires_grad=True, device=args.device) - alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) - args.alpha = (target_entropy, log_alpha, alpha_optim) - - policy: SACPolicy = SACPolicy( + log_alpha = 0.0 + alpha_optim = AdamOptimizerFactory(lr=args.alpha_lr) + args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim) + policy = SACPolicy( actor=actor, - actor_optim=actor_optim, + action_space=env.action_space, + ) + algorithm: SAC = SAC( + policy=policy, + policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, critic2=critic2, @@ -125,44 +135,45 @@ def test_sac_with_il( tau=args.tau, gamma=args.gamma, alpha=args.alpha, - estimation_step=args.n_step, - action_space=env.action_space, + n_step_return_horizon=args.n_step, ) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # train_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, "sac") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - update_per_step=args.update_per_step, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + update_step_num_gradient_steps_per_sample=args.update_per_step, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=True, + ) + ) if enable_assertions: assert stop_fn(result.best_reward) @@ -174,45 +185,48 @@ def stop_fn(mean_rewards: float) -> bool: if args.task.startswith("Pendulum"): args.reward_threshold -= 50 # lower the goal il_net = Net( - args.state_shape, + state_shape=args.state_shape, hidden_sizes=args.imitation_hidden_sizes, - device=args.device, ) - il_actor = Actor( - il_net, - args.action_shape, + il_actor = ContinuousActorDeterministic( + preprocess_net=il_net, + action_shape=args.action_shape, max_action=args.max_action, - device=args.device, ).to(args.device) - optim = torch.optim.Adam(il_actor.parameters(), lr=args.il_lr) - il_policy: ImitationPolicy = ImitationPolicy( + optim = AdamOptimizerFactory(lr=args.il_lr) + il_policy = ImitationPolicy( actor=il_actor, - optim=optim, action_space=env.action_space, action_scaling=True, action_bound_method="clip", ) + il_algorithm: OffPolicyImitationLearning = OffPolicyImitationLearning( + policy=il_policy, + optim=optim, + ) il_test_env = gym.make(args.task) - il_test_env.reset(seed=args.seed + args.training_num + args.test_num) + il_test_env.reset(seed=args.seed + args.num_train_envs + args.num_test_envs) il_test_collector = Collector[CollectStats]( - il_policy, - # envpool.make_gymnasium(args.task, num_envs=args.test_num, seed=args.seed), + il_algorithm, + # envpool.make_gymnasium(args.task, num_envs=args.num_test_envs, seed=args.seed), il_test_env, ) train_collector.reset() - result = OffpolicyTrainer( - policy=il_policy, - train_collector=train_collector, - test_collector=il_test_collector, - max_epoch=args.epoch, - step_per_epoch=args.il_step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + result = il_algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=il_test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=True, + ) + ) if enable_assertions: assert stop_fn(result.best_reward) diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 82fcce0fb..df532aed4 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -7,43 +7,45 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import TD3 +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.exploration import GaussianNoise -from tianshou.policy import TD3Policy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import Actor, Critic +from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--buffer-size", type=int, default=20000) - parser.add_argument("--actor-lr", type=float, default=1e-4) - parser.add_argument("--critic-lr", type=float, default=1e-3) + parser.add_argument("--buffer_size", type=int, default=20000) + parser.add_argument("--actor_lr", type=float, default=1e-4) + parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--tau", type=float, default=0.005) - parser.add_argument("--exploration-noise", type=float, default=0.1) - parser.add_argument("--policy-noise", type=float, default=0.2) - parser.add_argument("--noise-clip", type=float, default=0.5) - parser.add_argument("--update-actor-freq", type=int, default=2) + parser.add_argument("--exploration_noise", type=float, default=0.1) + parser.add_argument("--policy_noise", type=float, default=0.2) + parser.add_argument("--noise_clip", type=float, default=0.5) + parser.add_argument("--update_actor_freq", type=int, default=2) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=20000) - parser.add_argument("--step-per-collect", type=int, default=8) - parser.add_argument("--update-per-step", type=float, default=0.125) - parser.add_argument("--batch-size", type=int, default=128) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128]) - parser.add_argument("--training-num", type=int, default=8) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--epoch_num_steps", type=int, default=20000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=8) + parser.add_argument("--update_per_step", type=float, default=0.125) + parser.add_argument("--batch_size", type=int, default=128) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128]) + parser.add_argument("--num_train_envs", type=int, default=8) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--n_step", type=int, default=3) parser.add_argument( "--device", type=str, @@ -66,89 +68,94 @@ def test_td3(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr ) # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = Actor(net, args.action_shape, max_action=args.max_action, device=args.device).to( + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + actor = ContinuousActorDeterministic( + preprocess_net=net, action_shape=args.action_shape, max_action=args.max_action + ).to( args.device, ) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic1 = Critic(net_c1, device=args.device).to(args.device) - critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) + critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic2 = Critic(net_c2, device=args.device).to(args.device) - critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) - policy: TD3Policy = TD3Policy( + critic2 = ContinuousCritic(preprocess_net=net_c2).to(args.device) + critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) + policy = ContinuousDeterministicPolicy( actor=actor, - actor_optim=actor_optim, + action_space=env.action_space, + exploration_noise=GaussianNoise(sigma=args.exploration_noise), + ) + algorithm: TD3 = TD3( + policy=policy, + policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, critic2=critic2, critic2_optim=critic2_optim, tau=args.tau, gamma=args.gamma, - exploration_noise=GaussianNoise(sigma=args.exploration_noise), policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq, noise_clip=args.noise_clip, - estimation_step=args.n_step, - action_space=env.action_space, + n_step_return_horizon=args.n_step, ) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # train_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, "td3") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - # Iterator trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - update_per_step=args.update_per_step, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + update_step_num_gradient_steps_per_sample=args.update_per_step, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=True, + ) + ) if enable_assertions: assert stop_fn(result.best_reward) diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 321e351f3..b5e24ad30 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -9,33 +9,37 @@ from torch.distributions import Distribution, Independent, Normal from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import TRPO +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import TRPOPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OnpolicyTrainer +from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--buffer-size", type=int, default=50000) + parser.add_argument("--buffer_size", type=int, default=50000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.95) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=50000) - parser.add_argument("--step-per-collect", type=int, default=2048) - parser.add_argument("--repeat-per-collect", type=int, default=2) # theoretically it should be 1 - parser.add_argument("--batch-size", type=int, default=99999) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--training-num", type=int, default=16) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=50000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=2048) + parser.add_argument( + "--update_step_num_repetitions", type=int, default=2 + ) # theoretically it should be 1 + parser.add_argument("--batch_size", type=int, default=99999) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--num_train_envs", type=int, default=16) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -44,13 +48,13 @@ def get_args() -> argparse.Namespace: default="cuda" if torch.cuda.is_available() else "cpu", ) # trpo special - parser.add_argument("--gae-lambda", type=float, default=0.95) - parser.add_argument("--rew-norm", type=int, default=1) - parser.add_argument("--norm-adv", type=int, default=1) - parser.add_argument("--optim-critic-iters", type=int, default=5) - parser.add_argument("--max-kl", type=float, default=0.005) - parser.add_argument("--backtrack-coeff", type=float, default=0.8) - parser.add_argument("--max-backtracks", type=int, default=10) + parser.add_argument("--gae_lambda", type=float, default=0.95) + parser.add_argument("--return_scaling", type=int, default=1) + parser.add_argument("--advantage_normalization", type=int, default=1) + parser.add_argument("--optim_critic_iters", type=int, default=5) + parser.add_argument("--max_kl", type=float, default=0.005) + parser.add_argument("--backtrack_coeff", type=float, default=0.8) + parser.add_argument("--max_backtracks", type=int, default=10) return parser.parse_known_args()[0] @@ -68,9 +72,9 @@ def test_trpo(args: argparse.Namespace = get_args(), enable_assertions: bool = T ) # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -78,27 +82,26 @@ def test_trpo(args: argparse.Namespace = get_args(), enable_assertions: bool = T test_envs.seed(args.seed) # model net = Net( - args.state_shape, + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, - device=args.device, ) - actor = ActorProb(net, args.action_shape, unbounded=True, device=args.device).to(args.device) - critic = Critic( - Net( - args.state_shape, + actor = ContinuousActorProbabilistic( + preprocess_net=net, action_shape=args.action_shape, unbounded=True + ).to(args.device) + critic = ContinuousCritic( + preprocess_net=Net( + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device, activation=nn.Tanh, ), - device=args.device, ).to(args.device) # orthogonal initialization for m in list(actor.modules()) + list(critic.modules()): if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) - optim = torch.optim.Adam(critic.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) # replace DiagGuassian with Independent(Normal) which is equivalent # pass *logits to be consistent with policy.forward @@ -106,16 +109,19 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy: TRPOPolicy = TRPOPolicy( + policy = ProbabilisticActorPolicy( actor=actor, + dist_fn=dist, + action_space=env.action_space, + ) + algorithm: TRPO = TRPO( + policy=policy, critic=critic, optim=optim, - dist_fn=dist, - discount_factor=args.gamma, - reward_normalization=args.rew_norm, - advantage_normalization=args.norm_adv, + gamma=args.gamma, + return_scaling=args.return_scaling, + advantage_normalization=args.advantage_normalization, gae_lambda=args.gae_lambda, - action_space=env.action_space, optim_critic_iters=args.optim_critic_iters, max_kl=args.max_kl, backtrack_coeff=args.backtrack_coeff, @@ -123,37 +129,39 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: ) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), ) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # log log_path = os.path.join(args.logdir, args.task, "trpo") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - # trainer - result = OnpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - step_per_collect=args.step_per_collect, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + # train + result = algorithm.run_training( + OnPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + collection_step_num_env_steps=args.collection_step_num_env_steps, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=True, + ) + ) if enable_assertions: assert stop_fn(result.best_reward) diff --git a/test/determinism_test.py b/test/determinism_test.py index 828825660..6a5deb566 100644 --- a/test/determinism_test.py +++ b/test/determinism_test.py @@ -71,7 +71,7 @@ def __init__( :param args: the arguments to be passed to the main function (some of which are overridden for the test) :param is_offline: whether the algorithm being tested is an offline algorithm and therefore - does not configure the number of training environments (`training_num`) + does not configure the number of training environments (`num_train_envs`) :param ignored_messages: message fragments to ignore in the trace log (if any) """ self.determinism_test = TraceDeterminismTest( @@ -89,11 +89,11 @@ def set(attr: str, value: Any) -> None: setattr(args, attr, value) set("epoch", 3) - set("step_per_epoch", 100) + set("epoch_num_steps", 100) set("device", "cpu") if not is_offline: - set("training_num", 1) - set("test_num", 1) + set("num_train_envs", 1) + set("num_test_envs", 1) self.args = args self.main_fn = main_fn diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index cf0f4ef7c..e865b0c7c 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -8,14 +8,16 @@ from gymnasium.spaces import Box from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import A2C, Algorithm, OffPolicyImitationLearning +from tianshou.algorithm.imitation.imitation_base import ImitationPolicy +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import A2CPolicy, ImitationPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OffpolicyTrainer, OnpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams, OnPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import ActorCritic, Net -from tianshou.utils.net.discrete import Actor, Critic +from tianshou.utils.net.common import Net +from tianshou.utils.net.discrete import DiscreteActor, DiscreteCritic try: import envpool @@ -26,24 +28,24 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) - parser.add_argument("--il-lr", type=float, default=1e-3) + parser.add_argument("--il_lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.9) parser.add_argument("--epoch", type=int, default=10) - parser.add_argument("--step-per-epoch", type=int, default=50000) - parser.add_argument("--il-step-per-epoch", type=int, default=1000) - parser.add_argument("--episode-per-collect", type=int, default=16) - parser.add_argument("--step-per-collect", type=int, default=16) - parser.add_argument("--update-per-step", type=float, default=1 / 16) - parser.add_argument("--repeat-per-collect", type=int, default=1) - parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--imitation-hidden-sizes", type=int, nargs="*", default=[128]) - parser.add_argument("--training-num", type=int, default=16) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--epoch_num_steps", type=int, default=50000) + parser.add_argument("--il_step_per_epoch", type=int, default=1000) + parser.add_argument("--collection_step_num_episodes", type=int, default=16) + parser.add_argument("--collection_step_num_env_steps", type=int, default=16) + parser.add_argument("--update_per_step", type=float, default=1 / 16) + parser.add_argument("--update_step_num_repetitions", type=int, default=1) + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--imitation_hidden_sizes", type=int, nargs="*", default=[128]) + parser.add_argument("--num_train_envs", type=int, default=16) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -52,11 +54,11 @@ def get_args() -> argparse.Namespace: default="cuda" if torch.cuda.is_available() else "cpu", ) # a2c special - parser.add_argument("--vf-coef", type=float, default=0.5) - parser.add_argument("--ent-coef", type=float, default=0.0) - parser.add_argument("--max-grad-norm", type=float, default=None) - parser.add_argument("--gae-lambda", type=float, default=1.0) - parser.add_argument("--rew-norm", action="store_true", default=False) + parser.add_argument("--vf_coef", type=float, default=0.5) + parser.add_argument("--ent_coef", type=float, default=0.0) + parser.add_argument("--max_grad_norm", type=float, default=None) + parser.add_argument("--gae_lambda", type=float, default=1.0) + parser.add_argument("--return_scaling", action="store_true", default=False) return parser.parse_known_args()[0] @@ -73,19 +75,21 @@ def test_a2c_with_il( train_envs = env = envpool.make( args.task, env_type="gymnasium", - num_envs=args.training_num, + num_envs=args.num_train_envs, seed=args.seed, ) test_envs = envpool.make( args.task, env_type="gymnasium", - num_envs=args.test_num, + num_envs=args.num_test_envs, seed=args.seed, ) else: env = gym.make(args.task) - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + train_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.num_train_envs)] + ) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) train_envs.seed(args.seed) test_envs.seed(args.seed) args.state_shape = env.observation_space.shape or env.observation_space.n @@ -94,61 +98,66 @@ def test_a2c_with_il( default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = Actor(net, args.action_shape, device=args.device).to(args.device) - critic = Critic(net, device=args.device).to(args.device) - optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + actor = DiscreteActor(preprocess_net=net, action_shape=args.action_shape).to(args.device) + critic = DiscreteCritic(preprocess_net=net).to(args.device) + optim = AdamOptimizerFactory(lr=args.lr) dist = torch.distributions.Categorical - policy: BasePolicy - policy = A2CPolicy( + policy = ProbabilisticActorPolicy( actor=actor, - critic=critic, - optim=optim, dist_fn=dist, action_scaling=isinstance(env.action_space, Box), - discount_factor=args.gamma, + action_space=env.action_space, + ) + algorithm: A2C = A2C( + policy=policy, + critic=critic, + optim=optim, + gamma=args.gamma, gae_lambda=args.gae_lambda, vf_coef=args.vf_coef, ent_coef=args.ent_coef, max_grad_norm=args.max_grad_norm, - reward_normalization=args.rew_norm, - action_space=env.action_space, + return_scaling=args.return_scaling, ) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), ) train_collector.reset() - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) test_collector.reset() # log log_path = os.path.join(args.logdir, args.task, "a2c") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer - result = OnpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - episode_per_collect=args.episode_per_collect, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + result = algorithm.run_training( + OnPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + collection_step_num_episodes=args.collection_step_num_episodes, + collection_step_num_env_steps=None, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=True, + ) + ) if enable_assertions: assert stop_fn(result.best_reward) @@ -159,45 +168,50 @@ def stop_fn(mean_rewards: float) -> bool: # here we define an imitation collector with a trivial policy # if args.task == 'CartPole-v1': # env.spec.reward_threshold = 190 # lower the goal - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = Actor(net, args.action_shape, device=args.device).to(args.device) - optim = torch.optim.Adam(actor.parameters(), lr=args.il_lr) - il_policy: ImitationPolicy = ImitationPolicy( + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + actor = DiscreteActor(preprocess_net=net, action_shape=args.action_shape).to(args.device) + optim = AdamOptimizerFactory(lr=args.il_lr) + il_policy = ImitationPolicy( actor=actor, - optim=optim, action_space=env.action_space, ) + il_algorithm: OffPolicyImitationLearning = OffPolicyImitationLearning( + policy=il_policy, + optim=optim, + ) if envpool is not None: il_env = envpool.make( args.task, env_type="gymnasium", - num_envs=args.test_num, + num_envs=args.num_test_envs, seed=args.seed, ) else: il_env = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)], + [lambda: gym.make(args.task) for _ in range(args.num_test_envs)], ) il_env.seed(args.seed) il_test_collector = Collector[CollectStats]( - il_policy, + il_algorithm, il_env, ) train_collector.reset() - result = OffpolicyTrainer( - policy=il_policy, - train_collector=train_collector, - test_collector=il_test_collector, - max_epoch=args.epoch, - step_per_epoch=args.il_step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + result = il_algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=il_test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=True, + ) + ) if enable_assertions: assert stop_fn(result.best_reward) diff --git a/test/discrete/test_bdq.py b/test/discrete/test_bdqn.py similarity index 51% rename from test/discrete/test_bdq.py rename to test/discrete/test_bdqn.py index 719e10d86..ef766707c 100644 --- a/test/discrete/test_bdq.py +++ b/test/discrete/test_bdqn.py @@ -5,39 +5,42 @@ import numpy as np import torch +from tianshou.algorithm import BDQN +from tianshou.algorithm.modelfree.bdqn import BDQNPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import ContinuousToDiscrete, DummyVectorEnv -from tianshou.policy import BranchingDQNPolicy -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import BranchingNet +from tianshou.utils.torch_utils import policy_within_training_step def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() # task parser.add_argument("--task", type=str, default="Pendulum-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) # network architecture - parser.add_argument("--common-hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--action-hidden-sizes", type=int, nargs="*", default=[64]) - parser.add_argument("--value-hidden-sizes", type=int, nargs="*", default=[64]) - parser.add_argument("--action-per-branch", type=int, default=40) + parser.add_argument("--common_hidden_sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--action_hidden_sizes", type=int, nargs="*", default=[64]) + parser.add_argument("--value_hidden_sizes", type=int, nargs="*", default=[64]) + parser.add_argument("--action_per_branch", type=int, default=40) # training hyperparameters parser.add_argument("--seed", type=int, default=1626) - parser.add_argument("--eps-test", type=float, default=0.01) - parser.add_argument("--eps-train", type=float, default=0.76) - parser.add_argument("--eps-decay", type=float, default=1e-4) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--eps_test", type=float, default=0.01) + parser.add_argument("--eps_train", type=float, default=0.76) + parser.add_argument("--eps_decay", type=float, default=1e-4) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.9) - parser.add_argument("--target-update-freq", type=int, default=200) + parser.add_argument("--target_update_freq", type=int, default=200) parser.add_argument("--epoch", type=int, default=10) - parser.add_argument("--step-per-epoch", type=int, default=80000) - parser.add_argument("--step-per-collect", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=128) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=80000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) + parser.add_argument("--update_per_step", type=float, default=0.1) + parser.add_argument("--batch_size", type=int, default=128) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -48,7 +51,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_bdq(args: argparse.Namespace = get_args()) -> None: +def test_bdq(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) env = ContinuousToDiscrete(env, args.action_per_branch) @@ -73,13 +76,13 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: train_envs = DummyVectorEnv( [ lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch) - for _ in range(args.training_num) + for _ in range(args.num_train_envs) ], ) test_envs = DummyVectorEnv( [ lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch) - for _ in range(args.test_num) + for _ in range(args.num_test_envs) ], ) @@ -90,63 +93,68 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model net = BranchingNet( - args.state_shape, - args.num_branches, - args.action_per_branch, - args.common_hidden_sizes, - args.value_hidden_sizes, - args.action_hidden_sizes, - device=args.device, + state_shape=args.state_shape, + num_branches=args.num_branches, + action_per_branch=args.action_per_branch, + common_hidden_sizes=args.common_hidden_sizes, + value_hidden_sizes=args.value_hidden_sizes, + action_hidden_sizes=args.action_hidden_sizes, ).to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: BranchingDQNPolicy = BranchingDQNPolicy( + optim = AdamOptimizerFactory(lr=args.lr) + policy = BDQNPolicy( model=net, - optim=optim, - discount_factor=args.gamma, action_space=env.action_space, # type: ignore[arg-type] # TODO: should `BranchingDQNPolicy` support also `MultiDiscrete` action spaces? + eps_training=args.eps_train, + eps_inference=args.eps_test, + ) + algorithm: BDQN = BDQN( + policy=policy, + optim=optim, + gamma=args.gamma, target_update_freq=args.target_update_freq, ) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, - VectorReplayBuffer(args.buffer_size, args.training_num), + VectorReplayBuffer(args.buffer_size, args.num_train_envs), exploration_noise=True, ) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=False) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=False) # initial data collection - policy.set_eps(args.eps_train) - train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + with policy_within_training_step(policy): + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.num_train_envs) def train_fn(epoch: int, env_step: int) -> None: # exp decay eps = max(args.eps_train * (1 - args.eps_decay) ** env_step, args.eps_test) - policy.set_eps(eps) - - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) + policy.set_eps_training(eps) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - # trainer - OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - update_per_step=args.update_per_step, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + update_step_num_gradient_steps_per_sample=args.update_per_step, + train_fn=train_fn, + stop_fn=stop_fn, + test_in_train=True, + ) + ) + + if enable_assertions: + assert stop_fn(result.best_reward) def test_bdq_determinism() -> None: - main_fn = lambda args: test_bdq(args) + main_fn = lambda args: test_bdq(args, enable_assertions=False) AlgorithmDeterminismTest("discrete_bdq", main_fn, get_args()).run() diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 2876c4406..5e0977e6d 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -8,6 +8,10 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import C51 +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.c51 import C51Policy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, @@ -16,40 +20,39 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import C51Policy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo +from tianshou.utils.torch_utils import policy_within_training_step def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) - parser.add_argument("--eps-test", type=float, default=0.05) - parser.add_argument("--eps-train", type=float, default=0.1) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--eps_test", type=float, default=0.05) + parser.add_argument("--eps_train", type=float, default=0.1) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.9) - parser.add_argument("--num-atoms", type=int, default=51) - parser.add_argument("--v-min", type=float, default=-10.0) - parser.add_argument("--v-max", type=float, default=10.0) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--num_atoms", type=int, default=51) + parser.add_argument("--v_min", type=float, default=-10.0) + parser.add_argument("--v_max", type=float, default=10.0) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=10) - parser.add_argument("--step-per-epoch", type=int, default=8000) - parser.add_argument("--step-per-collect", type=int, default=8) - parser.add_argument("--update-per-step", type=float, default=0.125) - parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128, 128, 128]) - parser.add_argument("--training-num", type=int, default=8) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--epoch_num_steps", type=int, default=8000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=8) + parser.add_argument("--update_per_step", type=float, default=0.125) + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128, 128, 128]) + parser.add_argument("--num_train_envs", type=int, default=8) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--prioritized-replay", action="store_true", default=False) + parser.add_argument("--prioritized_replay", action="store_true", default=False) parser.add_argument("--alpha", type=float, default=0.6) parser.add_argument("--beta", type=float, default=0.4) parser.add_argument("--resume", action="store_true") @@ -58,7 +61,7 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--save-interval", type=int, default=4) + parser.add_argument("--save_interval", type=int, default=4) return parser.parse_known_args()[0] @@ -74,35 +77,39 @@ def test_c51(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr args.task, env.spec.reward_threshold if env.spec else None, ) - # train_envs = gym.make(args.task) - # you can also use tianshou.env.SubprocVectorEnv - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) + # model net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, softmax=True, num_atoms=args.num_atoms, ) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: C51Policy = C51Policy( + optim = AdamOptimizerFactory(lr=args.lr) + policy = C51Policy( model=net, - optim=optim, action_space=env.action_space, - discount_factor=args.gamma, + observation_space=env.observation_space, num_atoms=args.num_atoms, v_min=args.v_min, v_max=args.v_max, - estimation_step=args.n_step, + eps_training=args.eps_train, + eps_inference=args.eps_test, + ) + algorithm: C51 = C51( + policy=policy, + optim=optim, + gamma=args.gamma, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) @@ -118,22 +125,22 @@ def test_c51(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) - # collector - train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + # collectors + train_collector = Collector[CollectStats](algorithm, train_envs, buf, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # initial data collection - policy.set_eps(args.eps_train) - train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + with policy_within_training_step(policy): + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.num_train_envs) - # log + # logger log_path = os.path.join(args.logdir, args.task, "c51") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) - def save_best_fn(policy: BasePolicy) -> None: - torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) + def save_best_fn(algorithm: Algorithm) -> None: + torch.save(algorithm.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold @@ -141,15 +148,12 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: - policy.set_eps(args.eps_train) + policy.set_eps_training(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) - policy.set_eps(eps) + policy.set_eps_training(eps) else: - policy.set_eps(0.1 * args.eps_train) - - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) + policy.set_eps_training(0.1 * args.eps_train) def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html @@ -157,10 +161,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # Example: saving by epoch num # ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") torch.save( - { - "model": policy.state_dict(), - "optim": optim.state_dict(), - }, + algorithm.state_dict(), ckpt_path, ) buffer_path = os.path.join(log_path, "train_buffer.pkl") @@ -174,8 +175,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: ckpt_path = os.path.join(log_path, "checkpoint.pth") if os.path.exists(ckpt_path): checkpoint = torch.load(ckpt_path, map_location=args.device) - policy.load_state_dict(checkpoint["model"]) - policy.optim.load_state_dict(checkpoint["optim"]) + algorithm.load_state_dict(checkpoint) print("Successfully restore policy and optim.") else: print("Fail to restore policy and optim.") @@ -187,25 +187,26 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: else: print("Fail to restore buffer.") - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - update_per_step=args.update_per_step, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - resume_from_log=args.resume, - save_checkpoint_fn=save_checkpoint_fn, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + update_step_num_gradient_steps_per_sample=args.update_per_step, + train_fn=train_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + resume_from_log=args.resume, + save_checkpoint_fn=save_checkpoint_fn, + test_in_train=True, + ) + ) if enable_assertions: assert stop_fn(result.best_reward) diff --git a/test/discrete/test_sac.py b/test/discrete/test_discrete_sac.py similarity index 51% rename from test/discrete/test_sac.py rename to test/discrete/test_discrete_sac.py index 9e6a08dc9..ffd59afa5 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_discrete_sac.py @@ -7,42 +7,46 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import DiscreteSAC +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.discrete_sac import ( + DiscreteSACPolicy, +) +from tianshou.algorithm.modelfree.sac import AutoAlpha +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import DiscreteSACPolicy -from tianshou.policy.base import BasePolicy -from tianshou.policy.modelfree.discrete_sac import DiscreteSACTrainingStats -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.discrete import Actor, Critic +from tianshou.utils.net.discrete import DiscreteActor, DiscreteCritic from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--buffer-size", type=int, default=20000) - parser.add_argument("--actor-lr", type=float, default=1e-4) - parser.add_argument("--critic-lr", type=float, default=1e-3) - parser.add_argument("--alpha-lr", type=float, default=3e-4) + parser.add_argument("--buffer_size", type=int, default=20000) + parser.add_argument("--actor_lr", type=float, default=1e-4) + parser.add_argument("--critic_lr", type=float, default=1e-3) + parser.add_argument("--alpha_lr", type=float, default=3e-4) parser.add_argument("--gamma", type=float, default=0.95) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--alpha", type=float, default=0.05) - parser.add_argument("--auto-alpha", action="store_true", default=False) + parser.add_argument("--auto_alpha", action="store_true", default=False) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=10000) - parser.add_argument("--step-per-collect", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--epoch_num_steps", type=int, default=10000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) + parser.add_argument("--update_per_step", type=float, default=0.1) + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--n_step", type=int, default=3) parser.add_argument( "--device", type=str, @@ -69,8 +73,8 @@ def test_discrete_sac( env.spec.reward_threshold if env.spec else None, ) - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -79,71 +83,77 @@ def test_discrete_sac( # model obs_dim = space_info.observation_info.obs_dim action_dim = space_info.action_info.action_dim - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = Actor(net, args.action_shape, softmax_output=False, device=args.device).to(args.device) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net_c1 = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - critic1 = Critic(net_c1, last_size=action_dim, device=args.device).to(args.device) - critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - net_c2 = Net(obs_dim, hidden_sizes=args.hidden_sizes, device=args.device) - critic2 = Critic(net_c2, last_size=action_dim, device=args.device).to(args.device) - critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + actor = DiscreteActor( + preprocess_net=net, action_shape=args.action_shape, softmax_output=False + ).to(args.device) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) + net_c1 = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + critic1 = DiscreteCritic(preprocess_net=net_c1, last_size=action_dim).to(args.device) + critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) + net_c2 = Net(state_shape=obs_dim, hidden_sizes=args.hidden_sizes) + critic2 = DiscreteCritic(preprocess_net=net_c2, last_size=action_dim).to(args.device) + critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) # better not to use auto alpha in CartPole if args.auto_alpha: target_entropy = 0.98 * np.log(action_dim) - log_alpha = torch.zeros(1, requires_grad=True, device=args.device) - alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) - args.alpha = (target_entropy, log_alpha, alpha_optim) + log_alpha = 0.0 + alpha_optim = AdamOptimizerFactory(lr=args.alpha_lr) + args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim).to(args.device) - policy: DiscreteSACPolicy[DiscreteSACTrainingStats] = DiscreteSACPolicy( + policy = DiscreteSACPolicy( actor=actor, - actor_optim=actor_optim, - critic=critic1, action_space=env.action_space, + ) + algorithm = DiscreteSAC( + policy=policy, + policy_optim=actor_optim, + critic=critic1, critic_optim=critic1_optim, critic2=critic2, critic2_optim=critic2_optim, tau=args.tau, gamma=args.gamma, alpha=args.alpha, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, ) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), ) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # train_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, "discrete_sac") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - test_in_train=False, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_step_num_gradient_steps_per_sample=args.update_per_step, + test_in_train=False, + ) + ) if enable_assertions: assert stop_fn(result.best_reward) @@ -151,4 +161,9 @@ def stop_fn(mean_rewards: float) -> bool: def test_discrete_sac_determinism() -> None: main_fn = lambda args: test_discrete_sac(args, enable_assertions=False) - AlgorithmDeterminismTest("discrete_sac", main_fn, get_args()).run() + ignored_messages = [ + "Params[actor_old]", # actor_old only present in v1 (due to flawed inheritance) + ] + AlgorithmDeterminismTest( + "discrete_sac", main_fn, get_args(), ignored_messages=ignored_messages + ).run() diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 5c706e884..092763e5b 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -7,6 +7,10 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import DQN +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, @@ -15,37 +19,36 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import DQNPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo +from tianshou.utils.torch_utils import policy_within_training_step def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) - parser.add_argument("--eps-test", type=float, default=0.05) - parser.add_argument("--eps-train", type=float, default=0.1) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--eps_test", type=float, default=0.05) + parser.add_argument("--eps_train", type=float, default=0.1) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.9) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=20) - parser.add_argument("--step-per-epoch", type=int, default=10000) - parser.add_argument("--step-per-collect", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128, 128, 128]) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--epoch_num_steps", type=int, default=10000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) + parser.add_argument("--update_per_step", type=float, default=0.1) + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128, 128, 128]) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--prioritized-replay", action="store_true", default=False) + parser.add_argument("--prioritized_replay", action="store_true", default=False) parser.add_argument("--alpha", type=float, default=0.6) parser.add_argument("--beta", type=float, default=0.4) parser.add_argument( @@ -68,34 +71,39 @@ def test_dqn(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr args.task, env.spec.reward_threshold if env.spec else None, ) - # train_envs = gym.make(args.task) - # you can also use tianshou.env.SubprocVectorEnv - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) + # Q_param = V_param = {"hidden_sizes": [128]} # model net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, # dueling=(Q_param, V_param), ).to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: DQNPolicy = DQNPolicy( + optim = AdamOptimizerFactory(lr=args.lr) + policy = DiscreteQLearningPolicy( model=net, + action_space=env.action_space, + observation_space=env.observation_space, + eps_training=args.eps_train, + eps_inference=args.eps_test, + ) + algorithm: DQN = DQN( + policy=policy, optim=optim, - discount_factor=args.gamma, - estimation_step=args.n_step, + gamma=args.gamma, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, - action_space=env.action_space, ) + # buffer buf: ReplayBuffer if args.prioritized_replay: @@ -108,21 +116,21 @@ def test_dqn(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) - # collector - train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + # collectors + train_collector = Collector[CollectStats](algorithm, train_envs, buf, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # initial data collection - policy.set_eps(args.eps_train) - train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + with policy_within_training_step(policy): + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.num_train_envs) - # log + # logger log_path = os.path.join(args.logdir, args.task, "dqn") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: @@ -131,33 +139,31 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: - policy.set_eps(args.eps_train) + policy.set_eps_training(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) - policy.set_eps(eps) + policy.set_eps_training(eps) else: - policy.set_eps(0.1 * args.eps_train) - - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) - - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - update_per_step=args.update_per_step, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + policy.set_eps_training(0.1 * args.eps_train) + + # train + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + update_step_num_gradient_steps_per_sample=args.update_per_step, + train_fn=train_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=True, + ) + ) if enable_assertions: assert stop_fn(result.best_reward) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index d3fcdbd89..89b6185f8 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -7,37 +7,40 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import DQN +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import DQNPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Recurrent from tianshou.utils.space_info import SpaceInfo +from tianshou.utils.torch_utils import policy_within_training_step def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--eps-test", type=float, default=0.05) - parser.add_argument("--eps-train", type=float, default=0.1) - parser.add_argument("--buffer-size", type=int, default=20000) - parser.add_argument("--stack-num", type=int, default=4) + parser.add_argument("--eps_test", type=float, default=0.05) + parser.add_argument("--eps_train", type=float, default=0.1) + parser.add_argument("--buffer_size", type=int, default=20000) + parser.add_argument("--stack_num", type=int, default=4) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.95) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=20000) - parser.add_argument("--update-per-step", type=float, default=1 / 16) - parser.add_argument("--step-per-collect", type=int, default=16) - parser.add_argument("--batch-size", type=int, default=128) - parser.add_argument("--layer-num", type=int, default=2) - parser.add_argument("--training-num", type=int, default=16) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--epoch_num_steps", type=int, default=20000) + parser.add_argument("--update_per_step", type=float, default=1 / 16) + parser.add_argument("--collection_step_num_env_steps", type=int, default=16) + parser.add_argument("--batch_size", type=int, default=128) + parser.add_argument("--layer_num", type=int, default=2) + parser.add_argument("--num_train_envs", type=int, default=16) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -62,25 +65,32 @@ def test_drqn(args: argparse.Namespace = get_args(), enable_assertions: bool = T ) # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Recurrent(args.layer_num, args.state_shape, args.action_shape, args.device).to( + net = Recurrent( + layer_num=args.layer_num, state_shape=args.state_shape, action_shape=args.action_shape + ).to( args.device, ) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: DQNPolicy = DQNPolicy( + optim = AdamOptimizerFactory(lr=args.lr) + policy = DiscreteQLearningPolicy( model=net, - optim=optim, - discount_factor=args.gamma, - estimation_step=args.n_step, action_space=env.action_space, + eps_training=args.eps_train, + eps_inference=args.eps_test, + ) + algorithm: DQN = DQN( + policy=policy, + optim=optim, + gamma=args.gamma, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ) @@ -91,50 +101,43 @@ def test_drqn(args: argparse.Namespace = get_args(), enable_assertions: bool = T stack_num=args.stack_num, ignore_obs_next=True, ) - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) # the stack_num is for RNN training: sample framestack obs - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # initial data collection - policy.set_eps(args.eps_train) - train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + with policy_within_training_step(policy): + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # log log_path = os.path.join(args.logdir, args.task, "drqn") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - def train_fn(epoch: int, env_step: int) -> None: - policy.set_eps(args.eps_train) - - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) - - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - update_per_step=args.update_per_step, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + update_step_num_gradient_steps_per_sample=args.update_per_step, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=True, + ) + ) if enable_assertions: assert stop_fn(result.best_reward) diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 2e56b4e38..19fb5768b 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -7,6 +7,10 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import FQF +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.fqf import FQFPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory, RMSpropOptimizerFactory from tianshou.data import ( Collector, CollectStats, @@ -15,42 +19,41 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import FQFPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction from tianshou.utils.space_info import SpaceInfo +from tianshou.utils.torch_utils import policy_within_training_step def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--eps-test", type=float, default=0.05) - parser.add_argument("--eps-train", type=float, default=0.1) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--eps_test", type=float, default=0.05) + parser.add_argument("--eps_train", type=float, default=0.1) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=3e-3) - parser.add_argument("--fraction-lr", type=float, default=2.5e-9) + parser.add_argument("--fraction_lr", type=float, default=2.5e-9) parser.add_argument("--gamma", type=float, default=0.9) - parser.add_argument("--num-fractions", type=int, default=32) - parser.add_argument("--num-cosines", type=int, default=64) - parser.add_argument("--ent-coef", type=float, default=10.0) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--num_fractions", type=int, default=32) + parser.add_argument("--num_cosines", type=int, default=64) + parser.add_argument("--ent_coef", type=float, default=10.0) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=10) - parser.add_argument("--step-per-epoch", type=int, default=10000) - parser.add_argument("--step-per-collect", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64, 64]) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--epoch_num_steps", type=int, default=10000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) + parser.add_argument("--update_per_step", type=float, default=0.1) + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64, 64]) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--prioritized-replay", action="store_true", default=False) + parser.add_argument("--prioritized_replay", action="store_true", default=False) parser.add_argument("--alpha", type=float, default=0.6) parser.add_argument("--beta", type=float, default=0.4) parser.add_argument( @@ -73,46 +76,49 @@ def test_fqf(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr args.task, env.spec.reward_threshold if env.spec else None, ) - # train_envs = gym.make(args.task) - # you can also use tianshou.env.SubprocVectorEnv - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) + # model feature_net = Net( - args.state_shape, - args.hidden_sizes[-1], + state_shape=args.state_shape, + action_shape=args.hidden_sizes[-1], hidden_sizes=args.hidden_sizes[:-1], - device=args.device, softmax=False, ) net = FullQuantileFunction( - feature_net, - args.action_shape, - args.hidden_sizes, + preprocess_net=feature_net, + action_shape=args.action_shape, + hidden_sizes=args.hidden_sizes, num_cosines=args.num_cosines, - device=args.device, ) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim) - fraction_optim = torch.optim.RMSprop(fraction_net.parameters(), lr=args.fraction_lr) - policy: FQFPolicy = FQFPolicy( + fraction_optim = RMSpropOptimizerFactory(lr=args.fraction_lr) + policy = FQFPolicy( model=net, - optim=optim, fraction_model=fraction_net, - fraction_optim=fraction_optim, action_space=env.action_space, - discount_factor=args.gamma, + eps_training=args.eps_train, + eps_inference=args.eps_test, + ) + algorithm: FQF = FQF( + policy=policy, + optim=optim, + fraction_optim=fraction_optim, + gamma=args.gamma, num_fractions=args.num_fractions, ent_coef=args.ent_coef, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) + # buffer buf: ReplayBuffer if args.prioritized_replay: @@ -125,21 +131,21 @@ def test_fqf(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) - # collector - train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + # collectors + train_collector = Collector[CollectStats](algorithm, train_envs, buf, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # initial data collection - policy.set_eps(args.eps_train) - train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + with policy_within_training_step(policy): + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.num_train_envs) - # log + # logger log_path = os.path.join(args.logdir, args.task, "fqf") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: @@ -148,33 +154,31 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: - policy.set_eps(args.eps_train) + policy.set_eps_training(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) - policy.set_eps(eps) + policy.set_eps_training(eps) else: - policy.set_eps(0.1 * args.eps_train) - - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) - - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - ).run() + policy.set_eps_training(0.1 * args.eps_train) + + # train + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + train_fn=train_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_step_num_gradient_steps_per_sample=args.update_per_step, + test_in_train=True, + ) + ) if enable_assertions: assert stop_fn(result.best_reward) diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index 57bf28e73..0dadeb6bd 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -7,6 +7,10 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import IQN +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.iqn import IQNPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, @@ -15,42 +19,41 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import IQNPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import ImplicitQuantileNetwork from tianshou.utils.space_info import SpaceInfo +from tianshou.utils.torch_utils import policy_within_training_step def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--eps-test", type=float, default=0.05) - parser.add_argument("--eps-train", type=float, default=0.1) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--eps_test", type=float, default=0.05) + parser.add_argument("--eps_train", type=float, default=0.1) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=3e-3) parser.add_argument("--gamma", type=float, default=0.9) - parser.add_argument("--sample-size", type=int, default=32) - parser.add_argument("--online-sample-size", type=int, default=8) - parser.add_argument("--target-sample-size", type=int, default=8) - parser.add_argument("--num-cosines", type=int, default=64) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--sample_size", type=int, default=32) + parser.add_argument("--online_sample_size", type=int, default=8) + parser.add_argument("--target_sample_size", type=int, default=8) + parser.add_argument("--num_cosines", type=int, default=64) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=10) - parser.add_argument("--step-per-epoch", type=int, default=10000) - parser.add_argument("--step-per-collect", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64, 64]) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--epoch_num_steps", type=int, default=10000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) + parser.add_argument("--update_per_step", type=float, default=0.1) + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64, 64]) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--prioritized-replay", action="store_true", default=False) + parser.add_argument("--prioritized_replay", action="store_true", default=False) parser.add_argument("--alpha", type=float, default=0.6) parser.add_argument("--beta", type=float, default=0.4) parser.add_argument( @@ -73,40 +76,42 @@ def test_iqn(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr args.task, env.spec.reward_threshold if env.spec else None, ) - # train_envs = gym.make(args.task) - # you can also use tianshou.env.SubprocVectorEnv - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) + # model feature_net = Net( - args.state_shape, - args.hidden_sizes[-1], + state_shape=args.state_shape, + action_shape=args.hidden_sizes[-1], hidden_sizes=args.hidden_sizes[:-1], - device=args.device, softmax=False, ) net = ImplicitQuantileNetwork( - feature_net, - args.action_shape, + preprocess_net=feature_net, + action_shape=args.action_shape, num_cosines=args.num_cosines, - device=args.device, ) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: IQNPolicy = IQNPolicy( + optim = AdamOptimizerFactory(lr=args.lr) + policy = IQNPolicy( model=net, - optim=optim, action_space=env.action_space, - discount_factor=args.gamma, sample_size=args.sample_size, online_sample_size=args.online_sample_size, target_sample_size=args.target_sample_size, - estimation_step=args.n_step, + eps_training=args.eps_train, + eps_inference=args.eps_test, + ) + algorithm: IQN = IQN( + policy=policy, + optim=optim, + gamma=args.gamma, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) @@ -122,21 +127,21 @@ def test_iqn(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) - # collector - train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + # collectors + train_collector = Collector[CollectStats](algorithm, train_envs, buf, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # initial data collection - policy.set_eps(args.eps_train) - train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + with policy_within_training_step(policy): + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.num_train_envs) - # log + # logger log_path = os.path.join(args.logdir, args.task, "iqn") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: @@ -145,33 +150,31 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: - policy.set_eps(args.eps_train) + policy.set_eps_training(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) - policy.set_eps(eps) + policy.set_eps_training(eps) else: - policy.set_eps(0.1 * args.eps_train) - - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) - - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - ).run() + policy.set_eps_training(0.1 * args.eps_train) + + # train + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + train_fn=train_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_step_num_gradient_steps_per_sample=args.update_per_step, + test_in_train=True, + ) + ) if enable_assertions: assert stop_fn(result.best_reward) diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo_discrete.py similarity index 52% rename from test/discrete/test_ppo.py rename to test/discrete/test_ppo_discrete.py index 4226caf8f..cb1e31c9f 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo_discrete.py @@ -5,38 +5,43 @@ import gymnasium as gym import numpy as np import torch -import torch.nn as nn -from gymnasium.spaces import Box from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import PPO +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.reinforce import DiscreteActorPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import PPOPolicy -from tianshou.policy.base import BasePolicy -from tianshou.policy.modelfree.ppo import PPOTrainingStats -from tianshou.trainer import OnpolicyTrainer +from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import ActorCritic, DataParallelNet, Net -from tianshou.utils.net.discrete import Actor, Critic +from tianshou.utils.net.common import ( + ActionReprNet, + ActionReprNetDataParallelWrapper, + ActorCritic, + DataParallelNet, + Net, +) +from tianshou.utils.net.discrete import DiscreteActor, DiscreteCritic from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=10) - parser.add_argument("--step-per-epoch", type=int, default=50000) - parser.add_argument("--step-per-collect", type=int, default=2000) - parser.add_argument("--repeat-per-collect", type=int, default=10) - parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--training-num", type=int, default=20) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--epoch_num_steps", type=int, default=50000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=2000) + parser.add_argument("--update_step_num_repetitions", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--num_train_envs", type=int, default=20) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -45,16 +50,16 @@ def get_args() -> argparse.Namespace: default="cuda" if torch.cuda.is_available() else "cpu", ) # ppo special - parser.add_argument("--vf-coef", type=float, default=0.5) - parser.add_argument("--ent-coef", type=float, default=0.0) - parser.add_argument("--eps-clip", type=float, default=0.2) - parser.add_argument("--max-grad-norm", type=float, default=0.5) - parser.add_argument("--gae-lambda", type=float, default=0.95) - parser.add_argument("--rew-norm", type=int, default=0) - parser.add_argument("--norm-adv", type=int, default=0) - parser.add_argument("--recompute-adv", type=int, default=0) - parser.add_argument("--dual-clip", type=float, default=None) - parser.add_argument("--value-clip", type=int, default=0) + parser.add_argument("--vf_coef", type=float, default=0.5) + parser.add_argument("--ent_coef", type=float, default=0.0) + parser.add_argument("--eps_clip", type=float, default=0.2) + parser.add_argument("--max_grad_norm", type=float, default=0.5) + parser.add_argument("--gae_lambda", type=float, default=0.95) + parser.add_argument("--return_scaling", type=int, default=0) + parser.add_argument("--advantage_normalization", type=int, default=0) + parser.add_argument("--recompute_adv", type=int, default=0) + parser.add_argument("--dual_clip", type=float, default=None) + parser.add_argument("--value_clip", type=int, default=0) return parser.parse_known_args()[0] @@ -71,85 +76,91 @@ def test_ppo(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr ) # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor: nn.Module - critic: nn.Module + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + critic: DiscreteCritic | DataParallelNet + actor: ActionReprNet if torch.cuda.is_available(): - actor = DataParallelNet(Actor(net, args.action_shape, device=args.device).to(args.device)) - critic = DataParallelNet(Critic(net, device=args.device).to(args.device)) + actor = ActionReprNetDataParallelWrapper( + DiscreteActor(preprocess_net=net, action_shape=args.action_shape).to(args.device) + ) + critic = DataParallelNet(DiscreteCritic(preprocess_net=net).to(args.device)) else: - actor = Actor(net, args.action_shape, device=args.device).to(args.device) - critic = Critic(net, device=args.device).to(args.device) + actor = DiscreteActor(preprocess_net=net, action_shape=args.action_shape).to(args.device) + critic = DiscreteCritic(preprocess_net=net).to(args.device) actor_critic = ActorCritic(actor, critic) # orthogonal initialization for m in actor_critic.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) - optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) dist = torch.distributions.Categorical - policy: PPOPolicy[PPOTrainingStats] = PPOPolicy( + policy = DiscreteActorPolicy( actor=actor, + dist_fn=dist, + action_space=env.action_space, + deterministic_eval=True, + ) + algorithm: PPO = PPO( + policy=policy, critic=critic, optim=optim, - dist_fn=dist, - action_scaling=isinstance(env.action_space, Box), - discount_factor=args.gamma, + gamma=args.gamma, max_grad_norm=args.max_grad_norm, eps_clip=args.eps_clip, vf_coef=args.vf_coef, ent_coef=args.ent_coef, gae_lambda=args.gae_lambda, - reward_normalization=args.rew_norm, + return_scaling=args.return_scaling, dual_clip=args.dual_clip, value_clip=args.value_clip, - action_space=env.action_space, - deterministic_eval=True, - advantage_normalization=args.norm_adv, + advantage_normalization=args.advantage_normalization, recompute_advantage=args.recompute_adv, ) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), ) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # log log_path = os.path.join(args.logdir, args.task, "ppo") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer - result = OnpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - step_per_collect=args.step_per_collect, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + result = algorithm.run_training( + OnPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + collection_step_num_env_steps=args.collection_step_num_env_steps, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=True, + ) + ) if enable_assertions: assert stop_fn(result.best_reward) diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index afa2592c4..f44562b2b 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -7,6 +7,10 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import QRDQN +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, @@ -14,39 +18,37 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import QRDQNPolicy -from tianshou.policy.base import BasePolicy -from tianshou.policy.modelfree.qrdqn import QRDQNTrainingStats -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo +from tianshou.utils.torch_utils import policy_within_training_step def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--eps-test", type=float, default=0.05) - parser.add_argument("--eps-train", type=float, default=0.1) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--eps_test", type=float, default=0.05) + parser.add_argument("--eps_train", type=float, default=0.1) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.9) - parser.add_argument("--num-quantiles", type=int, default=200) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--num_quantiles", type=int, default=200) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=10) - parser.add_argument("--step-per-epoch", type=int, default=10000) - parser.add_argument("--step-per-collect", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128, 128, 128]) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--epoch_num_steps", type=int, default=10000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) + parser.add_argument("--update_per_step", type=float, default=0.1) + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128, 128, 128]) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--prioritized-replay", action="store_true", default=False) + parser.add_argument("--prioritized_replay", action="store_true", default=False) parser.add_argument("--alpha", type=float, default=0.6) parser.add_argument("--beta", type=float, default=0.4) parser.add_argument( @@ -75,9 +77,9 @@ def test_qrdqn(args: argparse.Namespace = get_args(), enable_assertions: bool = ) # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -88,20 +90,26 @@ def test_qrdqn(args: argparse.Namespace = get_args(), enable_assertions: bool = state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, softmax=False, num_atoms=args.num_quantiles, ) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: QRDQNPolicy[QRDQNTrainingStats] = QRDQNPolicy( + optim = AdamOptimizerFactory(lr=args.lr) + policy = QRDQNPolicy( model=net, - optim=optim, action_space=env.action_space, - discount_factor=args.gamma, + observation_space=env.observation_space, + eps_training=args.eps_train, + eps_inference=args.eps_test, + ) + algorithm: QRDQN = QRDQN( + policy=policy, + optim=optim, + gamma=args.gamma, num_quantiles=args.num_quantiles, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) + # buffer buf: PrioritizedVectorReplayBuffer | VectorReplayBuffer if args.prioritized_replay: @@ -114,22 +122,22 @@ def test_qrdqn(args: argparse.Namespace = get_args(), enable_assertions: bool = else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) - # collector - train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + # collectors + train_collector = Collector[CollectStats](algorithm, train_envs, buf, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # initial data collection - policy.set_eps(args.eps_train) - train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + with policy_within_training_step(policy): + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.num_train_envs) - # log + # logger log_path = os.path.join(args.logdir, args.task, "qrdqn") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: - torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) + def save_best_fn(algo: Algorithm) -> None: + torch.save(algo.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold @@ -137,33 +145,31 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: - policy.set_eps(args.eps_train) + policy.set_eps_training(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) - policy.set_eps(eps) + policy.set_eps_training(eps) else: - policy.set_eps(0.1 * args.eps_train) - - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) - - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - ).run() + policy.set_eps_training(0.1 * args.eps_train) + + # train + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + train_fn=train_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_step_num_gradient_steps_per_sample=args.update_per_step, + test_in_train=True, + ) + ) if enable_assertions: assert stop_fn(result.best_reward) diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index 92d10b06a..4666d2299 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -8,6 +8,10 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import RainbowDQN +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.c51 import C51Policy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, @@ -15,53 +19,51 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import RainbowPolicy -from tianshou.policy.base import BasePolicy -from tianshou.policy.modelfree.rainbow import RainbowTrainingStats -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import NoisyLinear from tianshou.utils.space_info import SpaceInfo +from tianshou.utils.torch_utils import policy_within_training_step def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) - parser.add_argument("--eps-test", type=float, default=0.05) - parser.add_argument("--eps-train", type=float, default=0.1) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--eps_test", type=float, default=0.05) + parser.add_argument("--eps_train", type=float, default=0.1) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.9) - parser.add_argument("--num-atoms", type=int, default=51) - parser.add_argument("--v-min", type=float, default=-10.0) - parser.add_argument("--v-max", type=float, default=10.0) - parser.add_argument("--noisy-std", type=float, default=0.1) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--num_atoms", type=int, default=51) + parser.add_argument("--v_min", type=float, default=-10.0) + parser.add_argument("--v_max", type=float, default=10.0) + parser.add_argument("--noisy_std", type=float, default=0.1) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=10) - parser.add_argument("--step-per-epoch", type=int, default=8000) - parser.add_argument("--step-per-collect", type=int, default=8) - parser.add_argument("--update-per-step", type=float, default=0.125) - parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128, 128, 128]) - parser.add_argument("--training-num", type=int, default=8) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--epoch_num_steps", type=int, default=8000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=8) + parser.add_argument("--update_per_step", type=float, default=0.125) + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128, 128, 128]) + parser.add_argument("--num_train_envs", type=int, default=8) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--prioritized-replay", action="store_true", default=False) + parser.add_argument("--prioritized_replay", action="store_true", default=False) parser.add_argument("--alpha", type=float, default=0.6) parser.add_argument("--beta", type=float, default=0.4) - parser.add_argument("--beta-final", type=float, default=1.0) + parser.add_argument("--beta_final", type=float, default=1.0) parser.add_argument("--resume", action="store_true") parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--save-interval", type=int, default=4) + parser.add_argument("--save_interval", type=int, default=4) return parser.parse_known_args()[0] @@ -81,41 +83,45 @@ def test_rainbow(args: argparse.Namespace = get_args(), enable_assertions: bool ) # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) - # model - def noisy_linear(x: int, y: int) -> NoisyLinear: return NoisyLinear(x, y, args.noisy_std) + # model net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, softmax=True, num_atoms=args.num_atoms, dueling_param=({"linear_layer": noisy_linear}, {"linear_layer": noisy_linear}), ) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: RainbowPolicy[RainbowTrainingStats] = RainbowPolicy( + optim = AdamOptimizerFactory(lr=args.lr) + policy = C51Policy( model=net, - optim=optim, - discount_factor=args.gamma, action_space=env.action_space, num_atoms=args.num_atoms, v_min=args.v_min, v_max=args.v_max, - estimation_step=args.n_step, + eps_training=args.eps_train, + eps_inference=args.eps_test, + ) + algorithm: RainbowDQN = RainbowDQN( + policy=policy, + optim=optim, + gamma=args.gamma, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) + # buffer buf: PrioritizedVectorReplayBuffer | VectorReplayBuffer if args.prioritized_replay: @@ -129,21 +135,21 @@ def noisy_linear(x: int, y: int) -> NoisyLinear: else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) - # collector - train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + # collectors + train_collector = Collector[CollectStats](algorithm, train_envs, buf, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # initial data collection - policy.set_eps(args.eps_train) - train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + with policy_within_training_step(policy): + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.num_train_envs) - # log + # logger log_path = os.path.join(args.logdir, args.task, "rainbow") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: @@ -152,12 +158,12 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: # eps annealing, just a demo if env_step <= 10000: - policy.set_eps(args.eps_train) + policy.set_eps_training(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) - policy.set_eps(eps) + policy.set_eps_training(eps) else: - policy.set_eps(0.1 * args.eps_train) + policy.set_eps_training(0.1 * args.eps_train) # beta annealing, just a demo if args.prioritized_replay: if env_step <= 10000: @@ -168,19 +174,13 @@ def train_fn(epoch: int, env_step: int) -> None: beta = args.beta_final buf.set_beta(beta) - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) - def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html ckpt_path = os.path.join(log_path, "checkpoint.pth") # Example: saving by epoch num # ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") torch.save( - { - "model": policy.state_dict(), - "optim": optim.state_dict(), - }, + algorithm.state_dict(), ckpt_path, ) buffer_path = os.path.join(log_path, "train_buffer.pkl") @@ -194,8 +194,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: ckpt_path = os.path.join(log_path, "checkpoint.pth") if os.path.exists(ckpt_path): checkpoint = torch.load(ckpt_path, map_location=args.device) - policy.load_state_dict(checkpoint["model"]) - policy.optim.load_state_dict(checkpoint["optim"]) + algorithm.load_state_dict(checkpoint) print("Successfully restore policy and optim.") else: print("Fail to restore policy and optim.") @@ -207,25 +206,26 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: else: print("Fail to restore buffer.") - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - update_per_step=args.update_per_step, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - resume_from_log=args.resume, - save_checkpoint_fn=save_checkpoint_fn, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + update_step_num_gradient_steps_per_sample=args.update_per_step, + train_fn=train_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + resume_from_log=args.resume, + save_checkpoint_fn=save_checkpoint_fn, + test_in_train=True, + ) + ) if enable_assertions: assert stop_fn(result.best_reward) diff --git a/test/discrete/test_pg.py b/test/discrete/test_reinforce.py similarity index 59% rename from test/discrete/test_pg.py rename to test/discrete/test_reinforce.py index a4fb28300..df0e1cf53 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_reinforce.py @@ -8,11 +8,13 @@ from gymnasium.spaces import Box from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import Reinforce +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import PGPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OnpolicyTrainer +from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo @@ -21,22 +23,22 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.95) parser.add_argument("--epoch", type=int, default=10) - parser.add_argument("--step-per-epoch", type=int, default=40000) - parser.add_argument("--episode-per-collect", type=int, default=8) - parser.add_argument("--repeat-per-collect", type=int, default=2) - parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--training-num", type=int, default=8) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--epoch_num_steps", type=int, default=40000) + parser.add_argument("--collection_step_num_episodes", type=int, default=8) + parser.add_argument("--update_step_num_repetitions", type=int, default=2) + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--num_train_envs", type=int, default=8) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--rew-norm", type=int, default=1) + parser.add_argument("--return_scaling", type=int, default=1) parser.add_argument( "--device", type=str, @@ -45,7 +47,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_pg(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: +def test_reinforce(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape @@ -56,78 +58,83 @@ def test_pg(args: argparse.Namespace = get_args(), enable_assertions: bool = Tru args.task, env.spec.reward_threshold if env.spec else None, ) - # train_envs = gym.make(args.task) - # you can also use tianshou.env.SubprocVectorEnv - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) + # model net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, softmax=True, ).to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) dist_fn = torch.distributions.Categorical - policy: PGPolicy = PGPolicy( + policy = ProbabilisticActorPolicy( actor=net, - optim=optim, dist_fn=dist_fn, - discount_factor=args.gamma, action_space=env.action_space, action_scaling=isinstance(env.action_space, Box), - reward_normalization=args.rew_norm, + ) + algorithm: Reinforce = Reinforce( + policy=policy, + optim=optim, + gamma=args.gamma, + return_standardization=args.return_scaling, ) for m in net.modules(): if isinstance(m, torch.nn.Linear): # orthogonal initialization torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) torch.nn.init.zeros_(m.bias) + # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), ) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) + # log log_path = os.path.join(args.logdir, args.task, "pg") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: - torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) + def save_best_fn(algorithm: Algorithm) -> None: + torch.save(algorithm.policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - # trainer - result = OnpolicyTrainer( - policy=policy, + # train + training_config = OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, - episode_per_collect=args.episode_per_collect, + collection_step_num_episodes=args.collection_step_num_episodes, + collection_step_num_env_steps=None, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, - ).run() + test_in_train=True, + ) + result = algorithm.run_training(training_config) if enable_assertions: assert stop_fn(result.best_reward) -def test_pg_determinism() -> None: - main_fn = lambda args: test_pg(args, enable_assertions=False) - AlgorithmDeterminismTest("discrete_pg", main_fn, get_args()).run() +def test_reinforce_determinism() -> None: + main_fn = lambda args: test_reinforce(args, enable_assertions=False) + AlgorithmDeterminismTest("discrete_reinforce", main_fn, get_args()).run() diff --git a/test/highlevel/test_experiment_builder.py b/test/highlevel/test_experiment_builder.py index 5e61ac832..b02438787 100644 --- a/test/highlevel/test_experiment_builder.py +++ b/test/highlevel/test_experiment_builder.py @@ -2,7 +2,10 @@ import pytest -from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.config import ( + OffPolicyTrainingConfig, + OnPolicyTrainingConfig, +) from tianshou.highlevel.experiment import ( A2CExperimentBuilder, DDPGExperimentBuilder, @@ -11,15 +14,42 @@ ExperimentBuilder, ExperimentConfig, IQNExperimentBuilder, - PGExperimentBuilder, + OffPolicyExperimentBuilder, + OnPolicyExperimentBuilder, PPOExperimentBuilder, REDQExperimentBuilder, + ReinforceExperimentBuilder, SACExperimentBuilder, TD3ExperimentBuilder, TRPOExperimentBuilder, ) +def create_training_config( + builder_cls: type[ExperimentBuilder], + num_epochs: int = 1, + epoch_num_steps: int = 100, + num_train_envs: int = 2, + num_test_envs: int = 2, +) -> OffPolicyTrainingConfig | OnPolicyTrainingConfig: + if issubclass(builder_cls, OffPolicyExperimentBuilder): + return OffPolicyTrainingConfig( + max_epochs=num_epochs, + epoch_num_steps=epoch_num_steps, + num_train_envs=num_train_envs, + num_test_envs=num_test_envs, + ) + elif issubclass(builder_cls, OnPolicyExperimentBuilder): + return OnPolicyTrainingConfig( + max_epochs=num_epochs, + epoch_num_steps=epoch_num_steps, + num_train_envs=num_train_envs, + num_test_envs=num_test_envs, + ) + else: + raise ValueError + + @pytest.mark.parametrize( "builder_cls", [ @@ -31,14 +61,15 @@ # NPGExperimentBuilder, # TODO test fails non-deterministically REDQExperimentBuilder, TRPOExperimentBuilder, - PGExperimentBuilder, + ReinforceExperimentBuilder, ], ) def test_experiment_builder_continuous_default_params(builder_cls: type[ExperimentBuilder]) -> None: env_factory = ContinuousTestEnvFactory() - sampling_config = SamplingConfig( + training_config = create_training_config( + builder_cls, num_epochs=1, - step_per_epoch=100, + epoch_num_steps=100, num_train_envs=2, num_test_envs=2, ) @@ -46,7 +77,7 @@ def test_experiment_builder_continuous_default_params(builder_cls: type[Experime builder = builder_cls( experiment_config=experiment_config, env_factory=env_factory, - sampling_config=sampling_config, + training_config=training_config, ) experiment = builder.build() experiment.run(run_name="test") @@ -56,7 +87,7 @@ def test_experiment_builder_continuous_default_params(builder_cls: type[Experime @pytest.mark.parametrize( "builder_cls", [ - PGExperimentBuilder, + ReinforceExperimentBuilder, PPOExperimentBuilder, A2CExperimentBuilder, DQNExperimentBuilder, @@ -66,16 +97,17 @@ def test_experiment_builder_continuous_default_params(builder_cls: type[Experime ) def test_experiment_builder_discrete_default_params(builder_cls: type[ExperimentBuilder]) -> None: env_factory = DiscreteTestEnvFactory() - sampling_config = SamplingConfig( + training_config = create_training_config( + builder_cls, num_epochs=1, - step_per_epoch=100, + epoch_num_steps=100, num_train_envs=2, num_test_envs=2, ) builder = builder_cls( experiment_config=ExperimentConfig(persistence_enabled=False), env_factory=env_factory, - sampling_config=sampling_config, + training_config=training_config, ) experiment = builder.build() experiment.run(run_name="test") diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index c108e7c0f..e343fd586 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -6,6 +6,9 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import DQN, Algorithm, ICMOffPolicyWrapper +from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, @@ -13,10 +16,7 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import DQNPolicy, ICMPolicy -from tianshou.policy.base import BasePolicy -from tianshou.policy.modelfree.dqn import DQNTrainingStats -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.discrete import IntrinsicCuriosityModule @@ -26,26 +26,26 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) - parser.add_argument("--eps-test", type=float, default=0.05) - parser.add_argument("--eps-train", type=float, default=0.1) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--eps_test", type=float, default=0.05) + parser.add_argument("--eps_train", type=float, default=0.1) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.9) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=20) - parser.add_argument("--step-per-epoch", type=int, default=10000) - parser.add_argument("--step-per-collect", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128, 128, 128]) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--epoch_num_steps", type=int, default=10000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) + parser.add_argument("--update_per_step", type=float, default=0.1) + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128, 128, 128]) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--prioritized-replay", action="store_true", default=False) + parser.add_argument("--prioritized_replay", action="store_true", default=False) parser.add_argument("--alpha", type=float, default=0.6) parser.add_argument("--beta", type=float, default=0.4) parser.add_argument( @@ -54,19 +54,19 @@ def get_args() -> argparse.Namespace: default="cuda" if torch.cuda.is_available() else "cpu", ) parser.add_argument( - "--lr-scale", + "--lr_scale", type=float, default=1.0, help="use intrinsic curiosity module with this lr scale", ) parser.add_argument( - "--reward-scale", + "--reward_scale", type=float, default=0.01, help="scaling factor for intrinsic curiosity reward", ) parser.add_argument( - "--forward-loss-weight", + "--forward_loss_weight", type=float, default=0.2, help="weight for the forward model loss in ICM", @@ -88,60 +88,63 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: args.task, env.spec.reward_threshold if env.spec else None, ) - # train_envs = gym.make(args.task) - # you can also use tianshou.env.SubprocVectorEnv - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) + # Q_param = V_param = {"hidden_sizes": [128]} # model net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, # dueling=(Q_param, V_param), ).to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: DQNPolicy[DQNTrainingStats] = DQNPolicy( + optim = AdamOptimizerFactory(lr=args.lr) + policy = DiscreteQLearningPolicy( model=net, - optim=optim, action_space=env.action_space, - discount_factor=args.gamma, - estimation_step=args.n_step, + eps_training=args.eps_train, + eps_inference=args.eps_test, + ) + algorithm: DQN = DQN( + policy=policy, + optim=optim, + gamma=args.gamma, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ) + + # ICM wrapper feature_dim = args.hidden_sizes[-1] obs_dim = space_info.observation_info.obs_dim feature_net = MLP( - obs_dim, + input_dim=obs_dim, output_dim=feature_dim, hidden_sizes=args.hidden_sizes[:-1], - device=args.device, ) action_dim = space_info.action_info.action_dim icm_net = IntrinsicCuriosityModule( - feature_net, - feature_dim, - action_dim, + feature_net=feature_net, + feature_dim=feature_dim, + action_dim=action_dim, hidden_sizes=args.hidden_sizes[-1:], - device=args.device, ).to(args.device) - icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) - policy: ICMPolicy = ICMPolicy( - policy=policy, + icm_optim = AdamOptimizerFactory(lr=args.lr) + icm_algorithm = ICMOffPolicyWrapper( + wrapped_algorithm=algorithm, model=icm_net, optim=icm_optim, - action_space=env.action_space, lr_scale=args.lr_scale, reward_scale=args.reward_scale, forward_loss_weight=args.forward_loss_weight, ) + # buffer buf: PrioritizedVectorReplayBuffer | VectorReplayBuffer if args.prioritized_replay: @@ -153,18 +156,21 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) + # collector - train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) - # policy.set_eps(1) + train_collector = Collector[CollectStats]( + icm_algorithm, train_envs, buf, exploration_noise=True + ) + test_collector = Collector[CollectStats](icm_algorithm, test_envs, exploration_noise=True) train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) + # log - log_path = os.path.join(args.logdir, args.task, "dqn_icm") + log_path = str(os.path.join(args.logdir, args.task, "dqn_icm")) writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: @@ -173,31 +179,29 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: - policy.set_eps(args.eps_train) + policy.set_eps_training(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) - policy.set_eps(eps) + policy.set_eps_training(eps) else: - policy.set_eps(0.1 * args.eps_train) + policy.set_eps_training(0.1 * args.eps_train) - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) - - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - update_per_step=args.update_per_step, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + # train + result = icm_algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + update_step_num_gradient_steps_per_sample=args.update_per_step, + train_fn=train_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=True, + ) + ) assert stop_fn(result.best_reward) diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index 7d3780960..1d99bf05c 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -7,34 +7,40 @@ from gymnasium.spaces import Box from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import PPO +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelbased.icm import ICMOnPolicyWrapper +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import ICMPolicy, PPOPolicy -from tianshou.policy.base import BasePolicy -from tianshou.policy.modelfree.ppo import PPOTrainingStats -from tianshou.trainer import OnpolicyTrainer +from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import MLP, ActorCritic, Net -from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule +from tianshou.utils.net.discrete import ( + DiscreteActor, + DiscreteCritic, + IntrinsicCuriosityModule, +) from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=10) - parser.add_argument("--step-per-epoch", type=int, default=50000) - parser.add_argument("--step-per-collect", type=int, default=2000) - parser.add_argument("--repeat-per-collect", type=int, default=10) - parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--training-num", type=int, default=20) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--epoch_num_steps", type=int, default=50000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=2000) + parser.add_argument("--update_step_num_repetitions", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--num_train_envs", type=int, default=20) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -43,30 +49,30 @@ def get_args() -> argparse.Namespace: default="cuda" if torch.cuda.is_available() else "cpu", ) # ppo special - parser.add_argument("--vf-coef", type=float, default=0.5) - parser.add_argument("--ent-coef", type=float, default=0.0) - parser.add_argument("--eps-clip", type=float, default=0.2) - parser.add_argument("--max-grad-norm", type=float, default=0.5) - parser.add_argument("--gae-lambda", type=float, default=0.95) - parser.add_argument("--rew-norm", type=int, default=0) - parser.add_argument("--norm-adv", type=int, default=0) - parser.add_argument("--recompute-adv", type=int, default=0) - parser.add_argument("--dual-clip", type=float, default=None) - parser.add_argument("--value-clip", type=int, default=0) + parser.add_argument("--vf_coef", type=float, default=0.5) + parser.add_argument("--ent_coef", type=float, default=0.0) + parser.add_argument("--eps_clip", type=float, default=0.2) + parser.add_argument("--max_grad_norm", type=float, default=0.5) + parser.add_argument("--gae_lambda", type=float, default=0.95) + parser.add_argument("--return_scaling", type=int, default=0) + parser.add_argument("--advantage_normalization", type=int, default=0) + parser.add_argument("--recompute_adv", type=int, default=0) + parser.add_argument("--dual_clip", type=float, default=None) + parser.add_argument("--value_clip", type=int, default=0) parser.add_argument( - "--lr-scale", + "--lr_scale", type=float, default=1.0, help="use intrinsic curiosity module with this lr scale", ) parser.add_argument( - "--reward-scale", + "--reward_scale", type=float, default=0.01, help="scaling factor for intrinsic curiosity reward", ) parser.add_argument( - "--forward-loss-weight", + "--forward_loss_weight", type=float, default=0.2, help="weight for the forward model loss in ICM", @@ -87,104 +93,113 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: args.task, env.spec.reward_threshold if env.spec else None, ) - # train_envs = gym.make(args.task) - # you can also use tianshou.env.SubprocVectorEnv - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) + # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = Actor(net, args.action_shape, device=args.device).to(args.device) - critic = Critic(net, device=args.device).to(args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + actor = DiscreteActor(preprocess_net=net, action_shape=args.action_shape).to(args.device) + critic = DiscreteCritic(preprocess_net=net).to(args.device) actor_critic = ActorCritic(actor, critic) + # orthogonal initialization for m in actor_critic.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) - optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) + + # base algorithm: PPO + optim = AdamOptimizerFactory(lr=args.lr) dist = torch.distributions.Categorical - policy: PPOPolicy[PPOTrainingStats] = PPOPolicy( + policy = ProbabilisticActorPolicy( actor=actor, - critic=critic, - optim=optim, dist_fn=dist, action_scaling=isinstance(env.action_space, Box), - discount_factor=args.gamma, + action_space=env.action_space, + deterministic_eval=True, + ) + algorithm: PPO = PPO( + policy=policy, + critic=critic, + optim=optim, + gamma=args.gamma, max_grad_norm=args.max_grad_norm, eps_clip=args.eps_clip, vf_coef=args.vf_coef, ent_coef=args.ent_coef, gae_lambda=args.gae_lambda, - reward_normalization=args.rew_norm, + return_scaling=args.return_scaling, dual_clip=args.dual_clip, value_clip=args.value_clip, - action_space=env.action_space, - deterministic_eval=True, - advantage_normalization=args.norm_adv, + advantage_normalization=args.advantage_normalization, recompute_advantage=args.recompute_adv, ) + + # ICM wrapper feature_dim = args.hidden_sizes[-1] feature_net = MLP( - space_info.observation_info.obs_dim, + input_dim=space_info.observation_info.obs_dim, output_dim=feature_dim, hidden_sizes=args.hidden_sizes[:-1], - device=args.device, ) action_dim = space_info.action_info.action_dim icm_net = IntrinsicCuriosityModule( - feature_net, - feature_dim, - action_dim, + feature_net=feature_net, + feature_dim=feature_dim, + action_dim=action_dim, hidden_sizes=args.hidden_sizes[-1:], - device=args.device, ).to(args.device) - icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) - policy = ICMPolicy( - policy=policy, + icm_optim = AdamOptimizerFactory(lr=args.lr) + icm_algorithm = ICMOnPolicyWrapper( + wrapped_algorithm=algorithm, model=icm_net, optim=icm_optim, - action_space=env.action_space, lr_scale=args.lr_scale, reward_scale=args.reward_scale, forward_loss_weight=args.forward_loss_weight, ) + # collector train_collector = Collector[CollectStats]( - policy, + icm_algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), ) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](icm_algorithm, test_envs) + # log log_path = os.path.join(args.logdir, args.task, "ppo_icm") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: - torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) + def save_best_fn(alg: Algorithm) -> None: + torch.save(alg.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - # trainer - result = OnpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - step_per_collect=args.step_per_collect, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + # train + result = icm_algorithm.run_training( + OnPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + collection_step_num_env_steps=args.collection_step_num_env_steps, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=True, + ) + ) assert stop_fn(result.best_reward) diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index e79977381..0a3efe8fc 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -6,9 +6,10 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import PSRL +from tianshou.algorithm.modelbased.psrl import PSRLPolicy from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.policy import PSRLPolicy -from tianshou.trainer import OnpolicyTrainer +from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import LazyLogger, TensorboardLogger, WandbLogger try: @@ -20,21 +21,21 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="NChain-v0") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--buffer-size", type=int, default=50000) + parser.add_argument("--buffer_size", type=int, default=50000) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=1000) - parser.add_argument("--episode-per-collect", type=int, default=1) - parser.add_argument("--training-num", type=int, default=1) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=1000) + parser.add_argument("--collection_step_num_episodes", type=int, default=1) + parser.add_argument("--num_train_envs", type=int, default=1) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--rew-mean-prior", type=float, default=0.0) - parser.add_argument("--rew-std-prior", type=float, default=1.0) + parser.add_argument("--rew_mean_prior", type=float, default=0.0) + parser.add_argument("--rew_std_prior", type=float, default=1.0) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--eps", type=float, default=0.01) - parser.add_argument("--add-done-loop", action="store_true", default=False) + parser.add_argument("--add_done_loop", action="store_true", default=False) parser.add_argument( "--logger", type=str, @@ -49,43 +50,51 @@ def get_args() -> argparse.Namespace: reason="EnvPool is not installed. If on linux, please install it (e.g. as poetry extra)", ) def test_psrl(args: argparse.Namespace = get_args()) -> None: - # if you want to use python vector env, please refer to other test scripts - train_envs = env = envpool.make_gymnasium(args.task, num_envs=args.training_num, seed=args.seed) - test_envs = envpool.make_gymnasium(args.task, num_envs=args.test_num, seed=args.seed) + train_envs = env = envpool.make_gymnasium( + args.task, num_envs=args.num_train_envs, seed=args.seed + ) + test_envs = envpool.make_gymnasium(args.task, num_envs=args.num_test_envs, seed=args.seed) if args.reward_threshold is None: default_reward_threshold = {"NChain-v0": 3400} args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) print("reward threshold:", args.reward_threshold) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) + # model n_action = args.action_shape n_state = args.state_shape trans_count_prior = np.ones((n_state, n_action, n_state)) rew_mean_prior = np.full((n_state, n_action), args.rew_mean_prior) rew_std_prior = np.full((n_state, n_action), args.rew_std_prior) - policy: PSRLPolicy = PSRLPolicy( + policy = PSRLPolicy( trans_count_prior=trans_count_prior, rew_mean_prior=rew_mean_prior, rew_std_prior=rew_std_prior, action_space=env.action_space, discount_factor=args.gamma, epsilon=args.eps, + ) + algorithm: PSRL = PSRL( + policy=policy, add_done_loop=args.add_done_loop, ) + # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) train_collector.reset() - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) test_collector.reset() + # Logger log_path = os.path.join(args.logdir, args.task, "psrl") writer = SummaryWriter(log_path) @@ -103,19 +112,22 @@ def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold train_collector.collect(n_step=args.buffer_size, random=True) - # trainer, test it without logger - result = OnpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=1, - episode_per_test=args.test_num, - batch_size=0, - episode_per_collect=args.episode_per_collect, - stop_fn=stop_fn, - logger=logger, - test_in_train=False, - ).run() + + # train + result = algorithm.run_training( + OnPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=1, + test_step_num_episodes=args.num_test_envs, + batch_size=0, + collection_step_num_episodes=args.collection_step_num_episodes, + collection_step_num_env_steps=None, + stop_fn=stop_fn, + logger=logger, + test_in_train=False, + ) + ) assert result.best_reward >= args.reward_threshold diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index bee9063ea..aed2b9428 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -7,6 +7,10 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import QRDQN +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, @@ -14,10 +18,7 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import QRDQNPolicy -from tianshou.policy.base import BasePolicy -from tianshou.policy.modelfree.qrdqn import QRDQNTrainingStats -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo @@ -30,30 +31,30 @@ def expert_file_name() -> str: def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--eps-test", type=float, default=0.05) - parser.add_argument("--eps-train", type=float, default=0.1) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--eps_test", type=float, default=0.05) + parser.add_argument("--eps_train", type=float, default=0.1) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.9) - parser.add_argument("--num-quantiles", type=int, default=200) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--num_quantiles", type=int, default=200) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=10) - parser.add_argument("--step-per-epoch", type=int, default=10000) - parser.add_argument("--step-per-collect", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128, 128, 128]) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--epoch_num_steps", type=int, default=10000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) + parser.add_argument("--update_per_step", type=float, default=0.1) + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128, 128, 128]) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--prioritized-replay", action="store_true", default=False) + parser.add_argument("--prioritized_replay", action="store_true", default=False) parser.add_argument("--alpha", type=float, default=0.6) parser.add_argument("--beta", type=float, default=0.4) - parser.add_argument("--save-buffer-name", type=str, default=expert_file_name()) + parser.add_argument("--save_buffer_name", type=str, default=expert_file_name()) parser.add_argument( "--device", type=str, @@ -79,9 +80,9 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: ) # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -92,18 +93,22 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, softmax=False, num_atoms=args.num_quantiles, ) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: QRDQNPolicy[QRDQNTrainingStats] = QRDQNPolicy( + optim = AdamOptimizerFactory(lr=args.lr) + policy = QRDQNPolicy( model=net, - optim=optim, action_space=env.action_space, - discount_factor=args.gamma, + eps_training=args.eps_train, + eps_inference=args.eps_test, + ) + algorithm: QRDQN = QRDQN( + policy=policy, + optim=optim, + gamma=args.gamma, num_quantiles=args.num_quantiles, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) # buffer @@ -118,18 +123,17 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector - train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) + train_collector = Collector[CollectStats](algorithm, train_envs, buf, exploration_noise=True) train_collector.reset() - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) test_collector.reset() - # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # log log_path = os.path.join(args.logdir, args.task, "qrdqn") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: @@ -138,39 +142,37 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: - policy.set_eps(args.eps_train) + policy.set_eps_training(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) - policy.set_eps(eps) + policy.set_eps_training(eps) else: - policy.set_eps(0.1 * args.eps_train) - - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) - - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - ).run() + policy.set_eps_training(0.1 * args.eps_train) + + # train + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + train_fn=train_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_step_num_gradient_steps_per_sample=args.update_per_step, + test_in_train=True, + ) + ) assert stop_fn(result.best_reward) # save buffer in pickle format, for imitation learning unittest buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs)) - policy.set_eps(0.2) - collector = Collector[CollectStats](policy, test_envs, buf, exploration_noise=True) + policy.set_eps_inference(0.2) + collector = Collector[CollectStats](algorithm, test_envs, buf, exploration_noise=True) collector.reset() collector_stats = collector.collect(n_step=args.buffer_size) if args.save_buffer_name.endswith(".hdf5"): diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index 614ee388f..8e7b56b4f 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -7,15 +7,16 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import SAC +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy, SACTrainingStats +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import SACPolicy -from tianshou.policy.base import BasePolicy -from tianshou.policy.modelfree.sac import SACTrainingStats -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -26,19 +27,19 @@ def expert_file_name() -> str: def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=20000) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128]) - parser.add_argument("--actor-lr", type=float, default=1e-3) - parser.add_argument("--critic-lr", type=float, default=1e-3) + parser.add_argument("--buffer_size", type=int, default=20000) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128]) + parser.add_argument("--actor_lr", type=float, default=1e-3) + parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--epoch", type=int, default=7) - parser.add_argument("--step-per-epoch", type=int, default=8000) - parser.add_argument("--batch-size", type=int, default=256) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) - parser.add_argument("--step-per-collect", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.125) + parser.add_argument("--epoch_num_steps", type=int, default=8000) + parser.add_argument("--batch_size", type=int, default=256) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) + parser.add_argument("--update_per_step", type=float, default=0.125) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--gamma", default=0.99) @@ -48,7 +49,7 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) parser.add_argument( "--watch", default=False, @@ -57,10 +58,10 @@ def get_args() -> argparse.Namespace: ) # sac: parser.add_argument("--alpha", type=float, default=0.2) - parser.add_argument("--auto-alpha", type=int, default=1) - parser.add_argument("--alpha-lr", type=float, default=3e-4) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--save-buffer-name", type=str, default=expert_file_name()) + parser.add_argument("--auto_alpha", type=int, default=1) + parser.add_argument("--alpha_lr", type=float, default=3e-4) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--save_buffer_name", type=str, default=expert_file_name()) return parser.parse_known_args()[0] @@ -72,7 +73,6 @@ def gather_data() -> VectorReplayBuffer: space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape - args.max_action = space_info.action_info.max_action if args.reward_threshold is None: default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250} @@ -82,82 +82,85 @@ def gather_data() -> VectorReplayBuffer: ) # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = ActorProb( - net, - args.action_shape, - device=args.device, + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + actor = ContinuousActorProbabilistic( + preprocess_net=net, + action_shape=args.action_shape, unbounded=True, ).to(args.device) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic = Critic(net_c, device=args.device).to(args.device) - critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) + critic = ContinuousCritic(preprocess_net=net_c).to(args.device) + critic_optim = AdamOptimizerFactory(lr=args.critic_lr) action_dim = space_info.action_info.action_dim if args.auto_alpha: target_entropy = -action_dim - log_alpha = torch.zeros(1, requires_grad=True, device=args.device) - alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) - args.alpha = (target_entropy, log_alpha, alpha_optim) + log_alpha = 0.0 + alpha_optim = AdamOptimizerFactory(lr=args.alpha_lr) + args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim).to(args.device) - policy: SACPolicy[SACTrainingStats] = SACPolicy( + policy = SACPolicy( actor=actor, - actor_optim=actor_optim, + action_space=env.action_space, + ) + algorithm: SAC[SACTrainingStats] = SAC( + policy=policy, + policy_optim=actor_optim, critic=critic, critic_optim=critic_optim, tau=args.tau, gamma=args.gamma, alpha=args.alpha, - estimation_step=args.n_step, - action_space=env.action_space, + n_step_return_horizon=args.n_step, ) # collector buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs) # train_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, "sac") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer - OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - update_per_step=args.update_per_step, - save_best_fn=save_best_fn, - stop_fn=stop_fn, - logger=logger, - ).run() + algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + update_step_num_gradient_steps_per_sample=args.update_per_step, + save_best_fn=save_best_fn, + stop_fn=stop_fn, + logger=logger, + test_in_train=True, + ) + ) train_collector.reset() collector_stats = train_collector.collect(n_step=args.buffer_size) print(collector_stats) diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index c409fdb3f..a33442a49 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -10,33 +10,34 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import BCQ, Algorithm +from tianshou.algorithm.imitation.bcq import BCQPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import BasePolicy, BCQPolicy -from tianshou.policy.imitation.bcq import BCQTrainingStats -from tianshou.trainer import OfflineTrainer +from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import MLP, Net -from tianshou.utils.net.continuous import VAE, Critic, Perturbation +from tianshou.utils.net.continuous import VAE, ContinuousCritic, Perturbation from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64]) - parser.add_argument("--actor-lr", type=float, default=1e-3) - parser.add_argument("--critic-lr", type=float, default=1e-3) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64]) + parser.add_argument("--actor_lr", type=float, default=1e-3) + parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=500) - parser.add_argument("--batch-size", type=int, default=32) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=500) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=1 / 35) - parser.add_argument("--vae-hidden-sizes", type=int, nargs="*", default=[32, 32]) + parser.add_argument("--vae_hidden_sizes", type=int, nargs="*", default=[32, 32]) # default to 2 * action_dim parser.add_argument("--latent_dim", type=int, default=None) parser.add_argument("--gamma", default=0.99) @@ -50,15 +51,15 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) - parser.add_argument("--load-buffer-name", type=str, default=expert_file_name()) - parser.add_argument("--show-progress", action="store_true") + parser.add_argument("--load_buffer_name", type=str, default=expert_file_name()) + parser.add_argument("--show_progress", action="store_true") return parser.parse_known_args()[0] @@ -90,41 +91,40 @@ def test_bcq(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr ) # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) - # model # perturbation network net_a = MLP( input_dim=args.state_dim + args.action_dim, output_dim=args.action_dim, hidden_sizes=args.hidden_sizes, - device=args.device, ) - actor = Perturbation(net_a, max_action=args.max_action, device=args.device, phi=args.phi).to( + actor_perturbation = Perturbation( + preprocess_net=net_a, max_action=args.max_action, phi=args.phi + ).to( args.device, ) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic = Critic(net_c, device=args.device).to(args.device) - critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) + critic = ContinuousCritic(preprocess_net=net_c).to(args.device) + critic_optim = AdamOptimizerFactory(lr=args.critic_lr) # vae # output_dim = 0, so the last Module in the encoder is ReLU vae_encoder = MLP( input_dim=args.state_dim + args.action_dim, hidden_sizes=args.vae_hidden_sizes, - device=args.device, ) if not args.latent_dim: args.latent_dim = args.action_dim * 2 @@ -132,42 +132,43 @@ def test_bcq(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr input_dim=args.state_dim + args.latent_dim, output_dim=args.action_dim, hidden_sizes=args.vae_hidden_sizes, - device=args.device, ) vae = VAE( - vae_encoder, - vae_decoder, + encoder=vae_encoder, + decoder=vae_decoder, hidden_dim=args.vae_hidden_sizes[-1], latent_dim=args.latent_dim, max_action=args.max_action, - device=args.device, ).to(args.device) - vae_optim = torch.optim.Adam(vae.parameters()) + vae_optim = AdamOptimizerFactory() - policy: BCQPolicy[BCQTrainingStats] = BCQPolicy( - actor_perturbation=actor, - actor_perturbation_optim=actor_optim, + policy = BCQPolicy( + actor_perturbation=actor_perturbation, critic=critic, - critic_optim=critic_optim, vae=vae, - vae_optim=vae_optim, action_space=env.action_space, - device=args.device, + ) + algorithm = BCQ( + policy=policy, + actor_perturbation_optim=actor_optim, + critic_optim=critic_optim, + vae_optim=vae_optim, gamma=args.gamma, tau=args.tau, lmbda=args.lmbda, - ) + ).to(args.device) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector # buffer has been gathered # train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs) - # log + test_collector = Collector[CollectStats](algorithm, test_envs) + + # logger t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_bcq' log_path = os.path.join(args.logdir, args.task, "bcq", log_file) @@ -175,33 +176,34 @@ def test_bcq(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr writer.add_text("args", str(args)) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold def watch() -> None: - policy.load_state_dict( + algorithm.load_state_dict( torch.load(os.path.join(log_path, "policy.pth"), map_location=torch.device("cpu")), ) - collector = Collector[CollectStats](policy, env) + collector = Collector[CollectStats](algorithm, env) collector.collect(n_episode=1, render=1 / 35) - # trainer - result = OfflineTrainer( - policy=policy, - buffer=buffer, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - episode_per_test=args.test_num, - batch_size=args.batch_size, - save_best_fn=save_best_fn, - stop_fn=stop_fn, - logger=logger, - show_progress=args.show_progress, - ).run() + # train + result = algorithm.run_training( + OfflineTrainerParams( + buffer=buffer, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + save_best_fn=save_best_fn, + stop_fn=stop_fn, + logger=logger, + show_progress=args.show_progress, + ) + ) if enable_assertions: assert stop_fn(result.best_reward) diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index ea8b6ac11..d320b2bfb 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -10,44 +10,45 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import CQL, Algorithm +from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import BasePolicy, CQLPolicy -from tianshou.policy.imitation.cql import CQLTrainingStats -from tianshou.trainer import OfflineTrainer +from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--actor-lr", type=float, default=1e-3) - parser.add_argument("--critic-lr", type=float, default=1e-3) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--actor_lr", type=float, default=1e-3) + parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--alpha", type=float, default=0.2) - parser.add_argument("--auto-alpha", default=True, action="store_true") - parser.add_argument("--alpha-lr", type=float, default=1e-3) - parser.add_argument("--cql-alpha-lr", type=float, default=1e-3) - parser.add_argument("--start-timesteps", type=int, default=10000) + parser.add_argument("--auto_alpha", default=True, action="store_true") + parser.add_argument("--alpha_lr", type=float, default=1e-3) + parser.add_argument("--cql_alpha_lr", type=float, default=1e-3) + parser.add_argument("--start_timesteps", type=int, default=10000) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=500) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--epoch_num_steps", type=int, default=500) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--temperature", type=float, default=1.0) - parser.add_argument("--cql-weight", type=float, default=1.0) - parser.add_argument("--with-lagrange", type=bool, default=True) - parser.add_argument("--lagrange-threshold", type=float, default=10.0) + parser.add_argument("--cql_weight", type=float, default=1.0) + parser.add_argument("--with_lagrange", type=bool, default=True) + parser.add_argument("--lagrange_threshold", type=float, default=10.0) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--eval-freq", type=int, default=1) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--eval_freq", type=int, default=1) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=1 / 35) parser.add_argument( @@ -55,14 +56,14 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) - parser.add_argument("--load-buffer-name", type=str, default=expert_file_name()) + parser.add_argument("--load_buffer_name", type=str, default=expert_file_name()) return parser.parse_known_args()[0] @@ -96,7 +97,7 @@ def test_cql(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr ) # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -108,16 +109,14 @@ def test_cql(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, ) - actor = ActorProb( - net_a, + actor = ContinuousActorProbabilistic( + preprocess_net=net_a, action_shape=args.action_shape, - device=args.device, unbounded=True, conditioned_sigma=True, ).to(args.device) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) # critic network net_c = Net( @@ -125,26 +124,28 @@ def test_cql(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic = Critic(net_c, device=args.device).to(args.device) - critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) + critic = ContinuousCritic(preprocess_net=net_c).to(args.device) + critic_optim = AdamOptimizerFactory(lr=args.critic_lr) if args.auto_alpha: - target_entropy = -np.prod(args.action_shape) - log_alpha = torch.zeros(1, requires_grad=True, device=args.device) - alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) - args.alpha = (target_entropy, log_alpha, alpha_optim) + target_entropy = float(-np.prod(args.action_shape)) + log_alpha = 0.0 + alpha_optim = AdamOptimizerFactory(lr=args.alpha_lr) + args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim) - policy: CQLPolicy[CQLTrainingStats] = CQLPolicy( + policy = SACPolicy( actor=actor, - actor_optim=actor_optim, - critic=critic, - critic_optim=critic_optim, # CQL seems to perform better without action scaling # TODO: investigate why action_scaling=False, action_space=env.action_space, + ) + algorithm = CQL( + policy=policy, + policy_optim=actor_optim, + critic=critic, + critic_optim=critic_optim, cql_alpha_lr=args.cql_alpha_lr, cql_weight=args.cql_weight, tau=args.tau, @@ -155,18 +156,17 @@ def test_cql(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr lagrange_threshold=args.lagrange_threshold, min_action=args.min_action, max_action=args.max_action, - device=args.device, ) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector # buffer has been gathered # train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # log t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_cql' @@ -175,29 +175,29 @@ def test_cql(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr writer.add_text("args", str(args)) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer - trainer = OfflineTrainer( - policy=policy, - buffer=buffer, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - episode_per_test=args.test_num, - batch_size=args.batch_size, - save_best_fn=save_best_fn, - stop_fn=stop_fn, - logger=logger, + result = algorithm.run_training( + OfflineTrainerParams( + buffer=buffer, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + save_best_fn=save_best_fn, + stop_fn=stop_fn, + logger=logger, + ) ) - stats = trainer.run() if enable_assertions: - assert stop_fn(stats.best_reward) + assert stop_fn(result.best_reward) def test_cql_determinism() -> None: diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index f7d34ba50..5c48ea017 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -9,6 +9,9 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import Algorithm, DiscreteBCQ +from tianshou.algorithm.imitation.discrete_bcq import DiscreteBCQPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, @@ -16,41 +19,40 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import BasePolicy, DiscreteBCQPolicy -from tianshou.trainer import OfflineTrainer +from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import ActorCritic, Net -from tianshou.utils.net.discrete import Actor +from tianshou.utils.net.common import Net +from tianshou.utils.net.discrete import DiscreteActor from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) - parser.add_argument("--eps-test", type=float, default=0.001) + parser.add_argument("--eps_test", type=float, default=0.001) parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=320) - parser.add_argument("--unlikely-action-threshold", type=float, default=0.6) - parser.add_argument("--imitation-logits-penalty", type=float, default=0.01) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=320) + parser.add_argument("--unlikely_action_threshold", type=float, default=0.6) + parser.add_argument("--imitation_logits_penalty", type=float, default=0.01) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=2000) - parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--epoch_num_steps", type=int, default=2000) + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--load-buffer-name", type=str, default=expert_file_name()) + parser.add_argument("--load_buffer_name", type=str, default=expert_file_name()) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) parser.add_argument("--resume", action="store_true") - parser.add_argument("--save-interval", type=int, default=4) + parser.add_argument("--save_interval", type=int, default=4) return parser.parse_known_args()[0] @@ -70,40 +72,42 @@ def test_discrete_bcq( args.task, env.spec.reward_threshold if env.spec else None, ) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) + # model - net = Net(state_shape=args.state_shape, action_shape=args.hidden_sizes[0], device=args.device) - policy_net = Actor( - net, - args.action_shape, + net = Net(state_shape=args.state_shape, action_shape=args.hidden_sizes[0]) + policy_net = DiscreteActor( + preprocess_net=net, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, ).to(args.device) - imitation_net = Actor( - net, - args.action_shape, + imitation_net = DiscreteActor( + preprocess_net=net, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, ).to(args.device) - actor_critic = ActorCritic(policy_net, imitation_net) - optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) - - policy: DiscreteBCQPolicy = DiscreteBCQPolicy( + optim = AdamOptimizerFactory(lr=args.lr) + policy = DiscreteBCQPolicy( model=policy_net, imitator=imitation_net, - optim=optim, action_space=env.action_space, - discount_factor=args.gamma, - estimation_step=args.n_step, - target_update_freq=args.target_update_freq, - eval_eps=args.eps_test, unlikely_action_threshold=args.unlikely_action_threshold, + eps_inference=args.eps_test, + ) + algorithm: DiscreteBCQ = DiscreteBCQ( + policy=policy, + optim=optim, + gamma=args.gamma, + n_step_return_horizon=args.n_step, + target_update_freq=args.target_update_freq, imitation_logits_penalty=args.imitation_logits_penalty, ) + # buffer buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): @@ -116,13 +120,14 @@ def test_discrete_bcq( buffer = gather_data() # collector - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) + # logger log_path = os.path.join(args.logdir, args.task, "discrete_bcq") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: @@ -134,10 +139,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # Example: saving by epoch num # ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") torch.save( - { - "model": policy.state_dict(), - "optim": optim.state_dict(), - }, + algorithm.state_dict(), ckpt_path, ) return ckpt_path @@ -148,26 +150,27 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: ckpt_path = os.path.join(log_path, "checkpoint.pth") if os.path.exists(ckpt_path): checkpoint = torch.load(ckpt_path, map_location=args.device) - policy.load_state_dict(checkpoint["model"]) - optim.load_state_dict(checkpoint["optim"]) + algorithm.load_state_dict(checkpoint) print("Successfully restore policy and optim.") else: print("Fail to restore policy and optim.") - result = OfflineTrainer( - policy=policy, - buffer=buffer, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - episode_per_test=args.test_num, - batch_size=args.batch_size, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - resume_from_log=args.resume, - save_checkpoint_fn=save_checkpoint_fn, - ).run() + # train + result = algorithm.run_training( + OfflineTrainerParams( + buffer=buffer, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + resume_from_log=args.resume, + save_checkpoint_fn=save_checkpoint_fn, + ) + ) if enable_assertions: assert stop_fn(result.best_reward) diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index 373b9a074..12c0af017 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -9,6 +9,9 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import Algorithm, DiscreteCQL +from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, @@ -16,8 +19,7 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import BasePolicy, DiscreteCQLPolicy -from tianshou.trainer import OfflineTrainer +from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo @@ -26,23 +28,23 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) - parser.add_argument("--eps-test", type=float, default=0.001) + parser.add_argument("--eps_test", type=float, default=0.001) parser.add_argument("--lr", type=float, default=3e-3) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--num-quantiles", type=int, default=200) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=500) - parser.add_argument("--min-q-weight", type=float, default=10.0) + parser.add_argument("--num_quantiles", type=int, default=200) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=500) + parser.add_argument("--min_q_weight", type=float, default=10.0) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=1000) - parser.add_argument("--batch-size", type=int, default=32) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64]) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--epoch_num_steps", type=int, default=1000) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64]) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--load-buffer-name", type=str, default=expert_file_name()) + parser.add_argument("--load_buffer_name", type=str, default=expert_file_name()) parser.add_argument( "--device", type=str, @@ -67,32 +69,37 @@ def test_discrete_cql( args.task, env.spec.reward_threshold if env.spec else None, ) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) + # model net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, softmax=False, num_atoms=args.num_quantiles, ) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) - policy: DiscreteCQLPolicy = DiscreteCQLPolicy( + policy = QRDQNPolicy( model=net, - optim=optim, action_space=env.action_space, - discount_factor=args.gamma, + ) + algorithm: DiscreteCQL = DiscreteCQL( + policy=policy, + optim=optim, + gamma=args.gamma, num_quantiles=args.num_quantiles, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, min_q_weight=args.min_q_weight, ).to(args.device) + # buffer buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): @@ -105,30 +112,32 @@ def test_discrete_cql( buffer = gather_data() # collector - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) log_path = os.path.join(args.logdir, args.task, "discrete_cql") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - result = OfflineTrainer( - policy=policy, - buffer=buffer, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - episode_per_test=args.test_num, - batch_size=args.batch_size, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + # train + result = algorithm.run_training( + OfflineTrainerParams( + buffer=buffer, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + ) + ) if enable_assertions: assert stop_fn(result.best_reward) diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index 3593cf206..b547cc3d5 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -9,6 +9,9 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import Algorithm, DiscreteCRR +from tianshou.algorithm.modelfree.reinforce import DiscreteActorPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, @@ -16,31 +19,30 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import BasePolicy, DiscreteCRRPolicy -from tianshou.trainer import OfflineTrainer +from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import ActorCritic, Net -from tianshou.utils.net.discrete import Actor, Critic +from tianshou.utils.net.common import Net +from tianshou.utils.net.discrete import DiscreteActor, DiscreteCritic from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--lr", type=float, default=7e-4) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=1000) - parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--epoch_num_steps", type=int, default=1000) + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--load-buffer-name", type=str, default=expert_file_name()) + parser.add_argument("--load_buffer_name", type=str, default=expert_file_name()) parser.add_argument( "--device", type=str, @@ -65,38 +67,40 @@ def test_discrete_crr( args.task, env.spec.reward_threshold if env.spec else None, ) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) - # model - net = Net(state_shape=args.state_shape, action_shape=args.hidden_sizes[0], device=args.device) - actor = Actor( + + # model and algorithm + net = Net(state_shape=args.state_shape, action_shape=args.hidden_sizes[0]) + actor = DiscreteActor( preprocess_net=net, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, softmax_output=False, ) action_dim = space_info.action_info.action_dim - critic = Critic( - net, + critic = DiscreteCritic( + preprocess_net=net, hidden_sizes=args.hidden_sizes, last_size=action_dim, - device=args.device, ) - actor_critic = ActorCritic(actor, critic) - optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) - - policy: DiscreteCRRPolicy = DiscreteCRRPolicy( + optim = AdamOptimizerFactory(lr=args.lr) + policy = DiscreteActorPolicy( actor=actor, + action_space=env.action_space, + ) + algorithm: DiscreteCRR = DiscreteCRR( + policy=policy, critic=critic, optim=optim, - action_space=env.action_space, - discount_factor=args.gamma, + gamma=args.gamma, target_update_freq=args.target_update_freq, ).to(args.device) + # buffer buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): @@ -109,30 +113,32 @@ def test_discrete_crr( buffer = gather_data() # collector - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) log_path = os.path.join(args.logdir, args.task, "discrete_crr") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - result = OfflineTrainer( - policy=policy, - buffer=buffer, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - episode_per_test=args.test_num, - batch_size=args.batch_size, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + # train + result = algorithm.run_training( + OfflineTrainerParams( + buffer=buffer, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + ) + ) if enable_assertions: assert stop_fn(result.best_reward) diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index 98a6b6c48..a54fb5d3d 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -10,34 +10,36 @@ from torch.distributions import Distribution, Independent, Normal from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import GAIL, Algorithm +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import BasePolicy, GAILPolicy -from tianshou.trainer import OnpolicyTrainer +from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) - parser.add_argument("--disc-lr", type=float, default=5e-4) + parser.add_argument("--disc_lr", type=float, default=5e-4) parser.add_argument("--gamma", type=float, default=0.95) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=150000) - parser.add_argument("--episode-per-collect", type=int, default=16) - parser.add_argument("--repeat-per-collect", type=int, default=2) - parser.add_argument("--disc-update-num", type=int, default=2) - parser.add_argument("--batch-size", type=int, default=128) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--training-num", type=int, default=16) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--epoch_num_steps", type=int, default=150000) + parser.add_argument("--collection_step_num_episodes", type=int, default=16) + parser.add_argument("--update_step_num_repetitions", type=int, default=2) + parser.add_argument("--disc_update_num", type=int, default=2) + parser.add_argument("--batch_size", type=int, default=128) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--num_train_envs", type=int, default=16) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -46,19 +48,19 @@ def get_args() -> argparse.Namespace: default="cuda" if torch.cuda.is_available() else "cpu", ) # ppo special - parser.add_argument("--vf-coef", type=float, default=0.25) - parser.add_argument("--ent-coef", type=float, default=0.0) - parser.add_argument("--eps-clip", type=float, default=0.2) - parser.add_argument("--max-grad-norm", type=float, default=0.5) - parser.add_argument("--gae-lambda", type=float, default=0.95) - parser.add_argument("--rew-norm", type=int, default=1) - parser.add_argument("--dual-clip", type=float, default=None) - parser.add_argument("--value-clip", type=int, default=1) - parser.add_argument("--norm-adv", type=int, default=1) - parser.add_argument("--recompute-adv", type=int, default=0) + parser.add_argument("--vf_coef", type=float, default=0.25) + parser.add_argument("--ent_coef", type=float, default=0.0) + parser.add_argument("--eps_clip", type=float, default=0.2) + parser.add_argument("--max_grad_norm", type=float, default=0.5) + parser.add_argument("--gae_lambda", type=float, default=0.95) + parser.add_argument("--return_scaling", type=int, default=1) + parser.add_argument("--dual_clip", type=float, default=None) + parser.add_argument("--value_clip", type=int, default=1) + parser.add_argument("--advantage_normalization", type=int, default=1) + parser.add_argument("--recompute_adv", type=int, default=0) parser.add_argument("--resume", action="store_true") - parser.add_argument("--save-interval", type=int, default=4) - parser.add_argument("--load-buffer-name", type=str, default=expert_file_name()) + parser.add_argument("--save_interval", type=int, default=4) + parser.add_argument("--load_buffer_name", type=str, default=expert_file_name()) return parser.parse_known_args()[0] @@ -82,29 +84,24 @@ def test_gail(args: argparse.Namespace = get_args(), enable_assertions: bool = T args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape args.max_action = space_info.action_info.max_action - # you can also use tianshou.env.SubprocVectorEnv - # train_envs = gym.make(args.task) - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = ActorProb( + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + actor = ContinuousActorProbabilistic( preprocess_net=net, action_shape=args.action_shape, max_action=args.max_action, - device=args.device, ).to( args.device, ) - critic = Critic( - Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device), - device=args.device, + critic = ContinuousCritic( + preprocess_net=Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes), ).to(args.device) actor_critic = ActorCritic(actor, critic) # orthogonal initialization @@ -112,25 +109,23 @@ def test_gail(args: argparse.Namespace = get_args(), enable_assertions: bool = T if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) - optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) # discriminator - disc_net = Critic( - Net( + disc_net = ContinuousCritic( + preprocess_net=Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, activation=torch.nn.Tanh, - device=args.device, concat=True, ), - device=args.device, ).to(args.device) for m in disc_net.modules(): if isinstance(m, torch.nn.Linear): # orthogonal initialization torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) torch.nn.init.zeros_(m.bias) - disc_optim = torch.optim.Adam(disc_net.parameters(), lr=args.disc_lr) + disc_optim = AdamOptimizerFactory(lr=args.disc_lr) # replace DiagGuassian with Independent(Normal) which is equivalent # pass *logits to be consistent with policy.forward @@ -138,41 +133,44 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy: GAILPolicy = GAILPolicy( + policy = ProbabilisticActorPolicy( actor=actor, + dist_fn=dist, + action_space=env.action_space, + ) + algorithm: GAIL = GAIL( + policy=policy, critic=critic, optim=optim, - dist_fn=dist, expert_buffer=buffer, disc_net=disc_net, disc_optim=disc_optim, disc_update_num=args.disc_update_num, - discount_factor=args.gamma, + gamma=args.gamma, max_grad_norm=args.max_grad_norm, eps_clip=args.eps_clip, vf_coef=args.vf_coef, ent_coef=args.ent_coef, - reward_normalization=args.rew_norm, - advantage_normalization=args.norm_adv, + return_scaling=args.return_scaling, + advantage_normalization=args.advantage_normalization, recompute_advantage=args.recompute_adv, dual_clip=args.dual_clip, value_clip=args.value_clip, gae_lambda=args.gae_lambda, - action_space=env.action_space, ) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), ) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # log log_path = os.path.join(args.logdir, args.task, "gail") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: @@ -184,10 +182,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # Example: saving by epoch num # ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") torch.save( - { - "model": policy.state_dict(), - "optim": optim.state_dict(), - }, + algorithm.state_dict(), ckpt_path, ) return ckpt_path @@ -198,29 +193,31 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: ckpt_path = os.path.join(log_path, "checkpoint.pth") if os.path.exists(ckpt_path): checkpoint = torch.load(ckpt_path, map_location=args.device) - policy.load_state_dict(checkpoint["model"]) - optim.load_state_dict(checkpoint["optim"]) + algorithm.load_state_dict(checkpoint) print("Successfully restore policy and optim.") else: print("Fail to restore policy and optim.") # trainer - result = OnpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - episode_per_collect=args.episode_per_collect, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - resume_from_log=args.resume, - save_checkpoint_fn=save_checkpoint_fn, - ).run() + result = algorithm.run_training( + OnPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + collection_step_num_episodes=args.collection_step_num_episodes, + collection_step_num_env_steps=None, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + resume_from_log=args.resume, + save_checkpoint_fn=save_checkpoint_fn, + test_in_train=True, + ) + ) if enable_assertions: assert stop_fn(result.best_reward) diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index 40b529c68..dfe4d6b70 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -10,40 +10,42 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import TD3BC +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.exploration import GaussianNoise -from tianshou.policy import TD3BCPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OfflineTrainer +from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import Actor, Critic +from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--actor-lr", type=float, default=1e-3) - parser.add_argument("--critic-lr", type=float, default=1e-3) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--actor_lr", type=float, default=1e-3) + parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=500) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--epoch_num_steps", type=int, default=500) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--alpha", type=float, default=2.5) - parser.add_argument("--exploration-noise", type=float, default=0.1) - parser.add_argument("--policy-noise", type=float, default=0.2) - parser.add_argument("--noise-clip", type=float, default=0.5) - parser.add_argument("--update-actor-freq", type=int, default=2) + parser.add_argument("--exploration_noise", type=float, default=0.1) + parser.add_argument("--policy_noise", type=float, default=0.2) + parser.add_argument("--noise_clip", type=float, default=0.5) + parser.add_argument("--update_actor_freq", type=int, default=2) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--eval-freq", type=int, default=1) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--eval_freq", type=int, default=1) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=1 / 35) parser.add_argument( @@ -51,14 +53,14 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) - parser.add_argument("--load-buffer-name", type=str, default=expert_file_name()) + parser.add_argument("--load_buffer_name", type=str, default=expert_file_name()) return parser.parse_known_args()[0] @@ -86,76 +88,76 @@ def test_td3_bc(args: argparse.Namespace = get_args(), enable_assertions: bool = args.state_dim = space_info.action_info.action_dim args.action_dim = space_info.observation_info.obs_dim - # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) - # model # actor network net_a = Net( - args.state_shape, + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device, ) - actor = Actor( - net_a, + actor = ContinuousActorDeterministic( + preprocess_net=net_a, action_shape=args.action_shape, max_action=args.max_action, - device=args.device, ).to(args.device) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) - # critic network + # critic networks net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic1 = Critic(net_c1, device=args.device).to(args.device) - critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic(net_c2, device=args.device).to(args.device) - critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) + critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) + critic2 = ContinuousCritic(preprocess_net=net_c2).to(args.device) + critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) - policy: TD3BCPolicy = TD3BCPolicy( + # policy and algorithm + policy = ContinuousDeterministicPolicy( actor=actor, - actor_optim=actor_optim, + action_space=env.action_space, + exploration_noise=GaussianNoise(sigma=args.exploration_noise), + ) + algorithm: TD3BC = TD3BC( + policy=policy, + policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, critic2=critic2, critic2_optim=critic2_optim, tau=args.tau, gamma=args.gamma, - exploration_noise=GaussianNoise(sigma=args.exploration_noise), policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq, noise_clip=args.noise_clip, alpha=args.alpha, - estimation_step=args.n_step, - action_space=env.action_space, + n_step_return_horizon=args.n_step, ) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector # buffer has been gathered # train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs) - # log + test_collector = Collector[CollectStats](algorithm, test_envs) + + # logger t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_td3_bc' log_path = os.path.join(args.logdir, args.task, "td3_bc", log_file) @@ -163,29 +165,29 @@ def test_td3_bc(args: argparse.Namespace = get_args(), enable_assertions: bool = writer.add_text("args", str(args)) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - # trainer - trainer = OfflineTrainer( - policy=policy, - buffer=buffer, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - episode_per_test=args.test_num, - batch_size=args.batch_size, - save_best_fn=save_best_fn, - stop_fn=stop_fn, - logger=logger, + # train + result = algorithm.run_training( + OfflineTrainerParams( + buffer=buffer, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + save_best_fn=save_best_fn, + stop_fn=stop_fn, + logger=logger, + ) ) - stats = trainer.run() if enable_assertions: - assert stop_fn(stats.best_reward) + assert stop_fn(result.best_reward) def test_td3_bc_determinism() -> None: diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index 0cf269d4f..6b6f6edcd 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -8,11 +8,14 @@ from pettingzoo.butterfly import pistonball_v6 from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import DQN, Algorithm, MultiAgentOffPolicyAlgorithm +from tianshou.algorithm.algorithm_base import OffPolicyAlgorithm +from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, InfoStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv -from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -20,9 +23,9 @@ def get_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() parser.add_argument("--seed", type=int, default=1626) - parser.add_argument("--eps-test", type=float, default=0.05) - parser.add_argument("--eps-train", type=float, default=0.1) - parser.add_argument("--buffer-size", type=int, default=2000) + parser.add_argument("--eps_test", type=float, default=0.05) + parser.add_argument("--eps_train", type=float, default=0.1) + parser.add_argument("--buffer_size", type=int, default=2000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument( "--gamma", @@ -31,21 +34,21 @@ def get_parser() -> argparse.ArgumentParser: help="a smaller gamma favors earlier win", ) parser.add_argument( - "--n-pistons", + "--n_pistons", type=int, default=3, help="Number of pistons(agents) in the env", ) - parser.add_argument("--n-step", type=int, default=100) - parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--n_step", type=int, default=100) + parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=3) - parser.add_argument("--step-per-epoch", type=int, default=500) - parser.add_argument("--step-per-collect", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=100) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=500) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) + parser.add_argument("--update_per_step", type=float, default=0.1) + parser.add_argument("--batch_size", type=int, default=100) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) @@ -74,9 +77,9 @@ def get_env(args: argparse.Namespace = get_args()) -> PettingZooEnv: def get_agents( args: argparse.Namespace = get_args(), - agents: list[BasePolicy] | None = None, + agents: list[OffPolicyAlgorithm] | None = None, optims: list[torch.optim.Optimizer] | None = None, -) -> tuple[BasePolicy, list[torch.optim.Optimizer] | None, list]: +) -> tuple[Algorithm, list[torch.optim.Optimizer] | None, list]: env = get_env() observation_space = ( env.observation_space["observation"] @@ -85,8 +88,11 @@ def get_agents( ) args.state_shape = observation_space.shape or int(observation_space.n) args.action_shape = env.action_space.shape or int(env.action_space.n) - if agents is None: - agents = [] + + if agents is not None: + algorithms = agents + else: + algorithms = [] optims = [] for _ in range(args.n_pistons): # model @@ -94,101 +100,96 @@ def get_agents( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, ).to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) - agent: DQNPolicy = DQNPolicy( + optim = AdamOptimizerFactory(lr=args.lr) + policy = DiscreteQLearningPolicy( model=net, - optim=optim, action_space=env.action_space, - discount_factor=args.gamma, - estimation_step=args.n_step, + eps_training=args.eps_train, + eps_inference=args.eps_test, + ) + agent: DQN = DQN( + policy=policy, + optim=optim, + gamma=args.gamma, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ) - agents.append(agent) + algorithms.append(agent) optims.append(optim) - policy = MultiAgentPolicyManager(policies=agents, env=env) - return policy, optims, env.agents + ma_algorithm = MultiAgentOffPolicyAlgorithm(algorithms=algorithms, env=env) + return ma_algorithm, optims, env.agents def train_agent( args: argparse.Namespace = get_args(), - agents: list[BasePolicy] | None = None, + agents: list[OffPolicyAlgorithm] | None = None, optims: list[torch.optim.Optimizer] | None = None, -) -> tuple[InfoStats, BasePolicy]: - train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)]) - test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)]) +) -> tuple[InfoStats, Algorithm]: + train_envs = DummyVectorEnv([get_env for _ in range(args.num_train_envs)]) + test_envs = DummyVectorEnv([get_env for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) - policy, optim, agents = get_agents(args, agents=agents, optims=optims) + marl_algorithm, optim, agents = get_agents(args, agents=agents, optims=optims) # collector train_collector = Collector[CollectStats]( - policy, + marl_algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](marl_algorithm, test_envs, exploration_noise=True) train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # log log_path = os.path.join(args.logdir, "pistonball", "dqn") writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: pass def stop_fn(mean_rewards: float) -> bool: return False - def train_fn(epoch: int, env_step: int) -> None: - [agent.set_eps(args.eps_train) for agent in policy.policies.values()] - - def test_fn(epoch: int, env_step: int | None) -> None: - [agent.set_eps(args.eps_test) for agent in policy.policies.values()] - def reward_metric(rews: np.ndarray) -> np.ndarray: return rews[:, 0] # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - update_per_step=args.update_per_step, - logger=logger, - test_in_train=False, - reward_metric=reward_metric, - ).run() - - return result, policy - - -def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = None) -> None: + result = marl_algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + update_step_num_gradient_steps_per_sample=args.update_per_step, + logger=logger, + test_in_train=False, + multi_agent_return_reduction=reward_metric, + ) + ) + return result, marl_algorithm + + +def watch(args: argparse.Namespace = get_args(), policy: Algorithm | None = None) -> None: env = DummyVectorEnv([get_env]) if not policy: warnings.warn( "watching random agents, as loading pre-trained policies is currently not supported", ) policy, _, _ = get_agents(args) - [agent.set_eps(args.eps_test) for agent in policy.policies.values()] collector = Collector[CollectStats](policy, env, exploration_noise=True) result = collector.collect(n_episode=1, render=args.render) result.pprint_asdict() diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 7beb92fde..5f0fb5839 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -11,17 +11,22 @@ from torch.distributions import Distribution, Independent, Normal from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import PPO, Algorithm +from tianshou.algorithm.algorithm_base import OnPolicyAlgorithm +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy +from tianshou.algorithm.multiagent.marl import MultiAgentOnPolicyAlgorithm +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.data.stats import InfoStats from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv -from tianshou.policy import BasePolicy, MultiAgentPolicyManager, PPOPolicy -from tianshou.trainer import OnpolicyTrainer +from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.common import ModuleWithVectorOutput +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic -class DQN(nn.Module): +class DQNet(ModuleWithVectorOutput): """Reference: Human-level control through deep reinforcement learning. For advanced usage (how to customize the network), please refer to @@ -35,12 +40,7 @@ def __init__( w: int, device: str | int | torch.device = "cpu", ) -> None: - super().__init__() - self.device = device - self.c = c - self.h = h - self.w = w - self.net = nn.Sequential( + net = nn.Sequential( nn.Conv2d(c, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True), nn.Conv2d(32, 64, kernel_size=4, stride=2), @@ -50,7 +50,13 @@ def __init__( nn.Flatten(), ) with torch.no_grad(): - self.output_dim = np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:]) + output_dim = np.prod(net(torch.zeros(1, c, h, w)).shape[1:]) + super().__init__(int(output_dim)) + self.device = device + self.c = c + self.h = h + self.w = w + self.net = net def forward( self, @@ -68,9 +74,9 @@ def forward( def get_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() parser.add_argument("--seed", type=int, default=1626) - parser.add_argument("--eps-test", type=float, default=0.05) - parser.add_argument("--eps-train", type=float, default=0.1) - parser.add_argument("--buffer-size", type=int, default=2000) + parser.add_argument("--eps_test", type=float, default=0.05) + parser.add_argument("--eps_train", type=float, default=0.1) + parser.add_argument("--buffer_size", type=int, default=2000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument( "--gamma", @@ -79,23 +85,23 @@ def get_parser() -> argparse.ArgumentParser: help="a smaller gamma favors earlier win", ) parser.add_argument( - "--n-pistons", + "--n_pistons", type=int, default=3, help="Number of pistons(agents) in the env", ) - parser.add_argument("--n-step", type=int, default=100) - parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--n_step", type=int, default=100) + parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=500) - parser.add_argument("--step-per-collect", type=int, default=10) - parser.add_argument("--episode-per-collect", type=int, default=16) - parser.add_argument("--repeat-per-collect", type=int, default=2) - parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=32) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=500) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) + parser.add_argument("--collection_step_num_episodes", type=int, default=16) + parser.add_argument("--update_step_num_repetitions", type=int, default=2) + parser.add_argument("--update_per_step", type=float, default=0.1) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument( @@ -110,18 +116,18 @@ def get_parser() -> argparse.ArgumentParser: default="cuda" if torch.cuda.is_available() else "cpu", ) # ppo special - parser.add_argument("--vf-coef", type=float, default=0.25) - parser.add_argument("--ent-coef", type=float, default=0.0) - parser.add_argument("--eps-clip", type=float, default=0.2) - parser.add_argument("--max-grad-norm", type=float, default=0.5) - parser.add_argument("--gae-lambda", type=float, default=0.95) - parser.add_argument("--rew-norm", type=int, default=1) - parser.add_argument("--dual-clip", type=float, default=None) - parser.add_argument("--value-clip", type=int, default=1) - parser.add_argument("--norm-adv", type=int, default=1) - parser.add_argument("--recompute-adv", type=int, default=0) + parser.add_argument("--vf_coef", type=float, default=0.25) + parser.add_argument("--ent_coef", type=float, default=0.0) + parser.add_argument("--eps_clip", type=float, default=0.2) + parser.add_argument("--max_grad_norm", type=float, default=0.5) + parser.add_argument("--gae_lambda", type=float, default=0.95) + parser.add_argument("--return_scaling", type=int, default=1) + parser.add_argument("--dual_clip", type=float, default=None) + parser.add_argument("--value_clip", type=int, default=1) + parser.add_argument("--advantage_normalization", type=int, default=1) + parser.add_argument("--recompute_adv", type=int, default=0) parser.add_argument("--resume", action="store_true") - parser.add_argument("--save-interval", type=int, default=4) + parser.add_argument("--save_interval", type=int, default=4) parser.add_argument("--render", type=float, default=0.0) return parser @@ -138,9 +144,9 @@ def get_env(args: argparse.Namespace = get_args()) -> PettingZooEnv: def get_agents( args: argparse.Namespace = get_args(), - agents: list[BasePolicy] | None = None, + agents: list[OnPolicyAlgorithm] | None = None, optims: list[torch.optim.Optimizer] | None = None, -) -> tuple[BasePolicy, list[torch.optim.Optimizer] | None, list]: +) -> tuple[Algorithm, list[torch.optim.Optimizer] | None, list]: env = get_env() observation_space = ( env.observation_space["observation"] @@ -151,104 +157,108 @@ def get_agents( args.action_shape = env.action_space.shape or env.action_space.n args.max_action = env.action_space.high[0] - if agents is None: - agents = [] + if agents is not None: + algorithms = agents + else: + algorithms = [] optims = [] for _ in range(args.n_pistons): # model - net = DQN( + net = DQNet( observation_space.shape[2], observation_space.shape[1], observation_space.shape[0], device=args.device, ).to(args.device) - actor = ActorProb( - net, - args.action_shape, + actor = ContinuousActorProbabilistic( + preprocess_net=net, + action_shape=args.action_shape, max_action=args.max_action, - device=args.device, ).to(args.device) - net2 = DQN( + net2 = DQNet( observation_space.shape[2], observation_space.shape[1], observation_space.shape[0], device=args.device, ).to(args.device) - critic = Critic(net2, device=args.device).to(args.device) + critic = ContinuousCritic(preprocess_net=net2).to(args.device) for m in set(actor.modules()).union(critic.modules()): if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) - optim = torch.optim.Adam(set(actor.parameters()).union(critic.parameters()), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - agent: PPOPolicy = PPOPolicy( + policy = ProbabilisticActorPolicy( actor=actor, + dist_fn=dist, + action_space=env.action_space, + action_scaling=True, + action_bound_method="clip", + ) + algorithm: PPO = PPO( + policy=policy, critic=critic, optim=optim, - dist_fn=dist, - discount_factor=args.gamma, + gamma=args.gamma, max_grad_norm=args.max_grad_norm, eps_clip=args.eps_clip, vf_coef=args.vf_coef, ent_coef=args.ent_coef, - reward_normalization=args.rew_norm, - advantage_normalization=args.norm_adv, + return_scaling=args.return_scaling, + advantage_normalization=args.advantage_normalization, recompute_advantage=args.recompute_adv, # dual_clip=args.dual_clip, # dual clip cause monotonically increasing log_std :) value_clip=args.value_clip, gae_lambda=args.gae_lambda, - action_space=env.action_space, ) - agents.append(agent) + algorithms.append(algorithm) optims.append(optim) - policy = MultiAgentPolicyManager( - policies=agents, + ma_algorithm = MultiAgentOnPolicyAlgorithm( + algorithms=algorithms, env=env, - action_scaling=True, - action_bound_method="clip", ) - return policy, optims, env.agents + return ma_algorithm, optims, env.agents def train_agent( args: argparse.Namespace = get_args(), - agents: list[BasePolicy] | None = None, + agents: list[OnPolicyAlgorithm] | None = None, optims: list[torch.optim.Optimizer] | None = None, -) -> tuple[InfoStats, BasePolicy]: - train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)]) - test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)]) +) -> tuple[InfoStats, Algorithm]: + train_envs = DummyVectorEnv([get_env for _ in range(args.num_train_envs)]) + test_envs = DummyVectorEnv([get_env for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) - policy, optim, agents = get_agents(args, agents=agents, optims=optims) + marl_algorithm, optim, agents = get_agents(args, agents=agents, optims=optims) # collector train_collector = Collector[CollectStats]( - policy, + marl_algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=False, # True ) - test_collector = Collector[CollectStats](policy, test_envs) - # train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) + test_collector = Collector[CollectStats](marl_algorithm, test_envs) + # train_collector.collect(n_step=args.batch_size * args.num_train_envs, reset_before_collect=True) # log log_path = os.path.join(args.logdir, "pistonball", "dqn") writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: pass def stop_fn(mean_rewards: float) -> bool: @@ -257,27 +267,30 @@ def stop_fn(mean_rewards: float) -> bool: def reward_metric(rews: np.ndarray) -> np.ndarray: return rews[:, 0] - # trainer - result = OnpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - episode_per_collect=args.episode_per_collect, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - resume_from_log=args.resume, - ).run() - - return result, policy - - -def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = None) -> None: + # train + result = marl_algorithm.run_training( + OnPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + collection_step_num_episodes=args.collection_step_num_episodes, + collection_step_num_env_steps=None, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + resume_from_log=args.resume, + test_in_train=True, + ) + ) + + return result, marl_algorithm + + +def watch(args: argparse.Namespace = get_args(), policy: Algorithm | None = None) -> None: env = DummyVectorEnv([get_env]) if not policy: warnings.warn( diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index 9e74c003e..de9cdeccc 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -9,17 +9,20 @@ from pettingzoo.classic import tictactoe_v3 from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import ( + DQN, + Algorithm, + MARLRandomDiscreteMaskedOffPolicyAlgorithm, + MultiAgentOffPolicyAlgorithm, +) +from tianshou.algorithm.algorithm_base import OffPolicyAlgorithm +from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory, OptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.data.stats import InfoStats from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv -from tianshou.policy import ( - BasePolicy, - DQNPolicy, - MARLRandomPolicy, - MultiAgentPolicyManager, -) -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -31,9 +34,9 @@ def get_env(render_mode: str | None = None) -> PettingZooEnv: def get_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() parser.add_argument("--seed", type=int, default=1626) - parser.add_argument("--eps-test", type=float, default=0.05) - parser.add_argument("--eps-train", type=float, default=0.1) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--eps_test", type=float, default=0.05) + parser.add_argument("--eps_train", type=float, default=0.1) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument( "--gamma", @@ -41,20 +44,20 @@ def get_parser() -> argparse.ArgumentParser: default=0.9, help="a smaller gamma favors earlier win", ) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=50) - parser.add_argument("--step-per-epoch", type=int, default=1000) - parser.add_argument("--step-per-collect", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128, 128, 128]) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=1000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) + parser.add_argument("--update_per_step", type=float, default=0.1) + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128, 128, 128]) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.1) parser.add_argument( - "--win-rate", + "--win_rate", type=float, default=0.6, help="the expected winning rate: Optimal policy can get 0.7", @@ -66,19 +69,19 @@ def get_parser() -> argparse.ArgumentParser: help="no training, watch the play of pre-trained models", ) parser.add_argument( - "--agent-id", + "--agent_id", type=int, default=2, help="the learned agent plays as the agent_id-th player. Choices are 1 and 2.", ) parser.add_argument( - "--resume-path", + "--resume_path", type=str, default="", help="the path of agent pth file for resuming from a pre-trained agent", ) parser.add_argument( - "--opponent-path", + "--opponent_path", type=str, default="", help="the path of opponent agent pth file for resuming from a pre-trained agent", @@ -98,10 +101,10 @@ def get_args() -> argparse.Namespace: def get_agents( args: argparse.Namespace = get_args(), - agent_learn: BasePolicy | None = None, - agent_opponent: BasePolicy | None = None, - optim: torch.optim.Optimizer | None = None, -) -> tuple[BasePolicy, torch.optim.Optimizer | None, list]: + agent_learn: OffPolicyAlgorithm | None = None, + agent_opponent: OffPolicyAlgorithm | None = None, + optim: OptimizerFactory | None = None, +) -> tuple[MultiAgentOffPolicyAlgorithm, torch.optim.Optimizer | None, list]: env = get_env() observation_space = ( env.observation_space.spaces["observation"] @@ -116,16 +119,20 @@ def get_agents( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, ).to(args.device) if optim is None: - optim = torch.optim.Adam(net.parameters(), lr=args.lr) - agent_learn = DQNPolicy( + optim = AdamOptimizerFactory(lr=args.lr) + algorithm = DiscreteQLearningPolicy( model=net, - optim=optim, action_space=env.action_space, - estimation_step=args.n_step, - discount_factor=args.gamma, + eps_training=args.eps_train, + eps_inference=args.eps_test, + ) + agent_learn = DQN( + policy=algorithm, + optim=optim, + n_step_return_horizon=args.n_step, + gamma=args.gamma, target_update_freq=args.target_update_freq, ) if args.resume_path: @@ -136,31 +143,33 @@ def get_agents( agent_opponent = deepcopy(agent_learn) agent_opponent.load_state_dict(torch.load(args.opponent_path)) else: - agent_opponent = MARLRandomPolicy(action_space=env.action_space) + agent_opponent = MARLRandomDiscreteMaskedOffPolicyAlgorithm( + action_space=env.action_space + ) if args.agent_id == 1: agents = [agent_learn, agent_opponent] else: agents = [agent_opponent, agent_learn] - policy = MultiAgentPolicyManager(policies=agents, env=env) - return policy, optim, env.agents + ma_algorithm = MultiAgentOffPolicyAlgorithm(algorithms=agents, env=env) + return ma_algorithm, optim, env.agents def train_agent( args: argparse.Namespace = get_args(), - agent_learn: BasePolicy | None = None, - agent_opponent: BasePolicy | None = None, - optim: torch.optim.Optimizer | None = None, -) -> tuple[InfoStats, BasePolicy]: - train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)]) - test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)]) + agent_learn: OffPolicyAlgorithm | None = None, + agent_opponent: OffPolicyAlgorithm | None = None, + optim: OptimizerFactory | None = None, +) -> tuple[InfoStats, OffPolicyAlgorithm]: + train_envs = DummyVectorEnv([get_env for _ in range(args.num_train_envs)]) + test_envs = DummyVectorEnv([get_env for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) - policy, optim, agents = get_agents( + marl_algorithm, optim, agents = get_agents( args, agent_learn=agent_learn, agent_opponent=agent_opponent, @@ -169,71 +178,64 @@ def train_agent( # collector train_collector = Collector[CollectStats]( - policy, + marl_algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) - # policy.set_eps(1) + test_collector = Collector[CollectStats](marl_algorithm, test_envs, exploration_noise=True) train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # log log_path = os.path.join(args.logdir, "tic_tac_toe", "dqn") writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + player_agent_id = agents[args.agent_id - 1] + + def save_best_fn(policy: Algorithm) -> None: if hasattr(args, "model_save_path"): model_save_path = args.model_save_path else: model_save_path = os.path.join(args.logdir, "tic_tac_toe", "dqn", "policy.pth") - torch.save(policy.policies[agents[args.agent_id - 1]].state_dict(), model_save_path) + torch.save(policy.get_algorithm(player_agent_id).state_dict(), model_save_path) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.win_rate - def train_fn(epoch: int, env_step: int) -> None: - policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_train) - - def test_fn(epoch: int, env_step: int | None) -> None: - policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test) - def reward_metric(rews: np.ndarray) -> np.ndarray: return rews[:, args.agent_id - 1] # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - update_per_step=args.update_per_step, - logger=logger, - test_in_train=False, - reward_metric=reward_metric, - ).run() - - return result, policy.policies[agents[args.agent_id - 1]] + result = marl_algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=args.epoch, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, + test_step_num_episodes=args.num_test_envs, + batch_size=args.batch_size, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + update_step_num_gradient_steps_per_sample=args.update_per_step, + logger=logger, + test_in_train=False, + multi_agent_return_reduction=reward_metric, + ) + ) + + return result, marl_algorithm.get_algorithm(player_agent_id) def watch( args: argparse.Namespace = get_args(), - agent_learn: BasePolicy | None = None, - agent_opponent: BasePolicy | None = None, + agent_learn: OffPolicyAlgorithm | None = None, + agent_opponent: OffPolicyAlgorithm | None = None, ) -> None: env = DummyVectorEnv([partial(get_env, render_mode="human")]) policy, optim, agents = get_agents(args, agent_learn=agent_learn, agent_opponent=agent_opponent) - policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test) collector = Collector[CollectStats](policy, env, exploration_noise=True) result = collector.collect(n_episode=1, render=args.render, reset_before_collect=True) result.pprint_asdict() diff --git a/tianshou/__init__.py b/tianshou/__init__.py index 9a3d557f9..3c719b1eb 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -1,12 +1,27 @@ -from tianshou import data, env, exploration, policy, trainer, utils +# isort: skip_file +# NOTE: Import order is important to avoid circular import errors! +from tianshou import data, env, exploration, algorithm, trainer, utils + +__version__ = "2.0.0b1" + + +def _register_log_config_callback() -> None: + from sensai.util import logging + + def configure() -> None: + logging.getLogger("numba").setLevel(logging.INFO) + + logging.set_configure_callback(configure) + + +_register_log_config_callback() -__version__ = "1.2.0" __all__ = [ "env", "data", "utils", - "policy", + "algorithm", "trainer", "exploration", ] diff --git a/tianshou/algorithm/__init__.py b/tianshou/algorithm/__init__.py new file mode 100644 index 000000000..6ce54f60b --- /dev/null +++ b/tianshou/algorithm/__init__.py @@ -0,0 +1,35 @@ +"""Algorithm package.""" +# isort:skip_file + +from tianshou.algorithm.algorithm_base import Algorithm, TrainingStats +from tianshou.algorithm.modelfree.reinforce import Reinforce +from tianshou.algorithm.modelfree.dqn import DQN +from tianshou.algorithm.modelfree.ddpg import DDPG + +from tianshou.algorithm.random import MARLRandomDiscreteMaskedOffPolicyAlgorithm +from tianshou.algorithm.modelfree.bdqn import BDQN +from tianshou.algorithm.modelfree.c51 import C51 +from tianshou.algorithm.modelfree.rainbow import RainbowDQN +from tianshou.algorithm.modelfree.qrdqn import QRDQN +from tianshou.algorithm.modelfree.iqn import IQN +from tianshou.algorithm.modelfree.fqf import FQF +from tianshou.algorithm.modelfree.a2c import A2C +from tianshou.algorithm.modelfree.npg import NPG +from tianshou.algorithm.modelfree.ppo import PPO +from tianshou.algorithm.modelfree.trpo import TRPO +from tianshou.algorithm.modelfree.td3 import TD3 +from tianshou.algorithm.modelfree.sac import SAC +from tianshou.algorithm.modelfree.redq import REDQ +from tianshou.algorithm.modelfree.discrete_sac import DiscreteSAC +from tianshou.algorithm.imitation.imitation_base import OffPolicyImitationLearning +from tianshou.algorithm.imitation.bcq import BCQ +from tianshou.algorithm.imitation.cql import CQL +from tianshou.algorithm.imitation.td3_bc import TD3BC +from tianshou.algorithm.imitation.discrete_bcq import DiscreteBCQ +from tianshou.algorithm.imitation.discrete_cql import DiscreteCQL +from tianshou.algorithm.imitation.discrete_crr import DiscreteCRR +from tianshou.algorithm.imitation.gail import GAIL +from tianshou.algorithm.modelbased.psrl import PSRL +from tianshou.algorithm.modelbased.icm import ICMOffPolicyWrapper +from tianshou.algorithm.modelbased.icm import ICMOnPolicyWrapper +from tianshou.algorithm.multiagent.marl import MultiAgentOffPolicyAlgorithm diff --git a/tianshou/policy/base.py b/tianshou/algorithm/algorithm_base.py similarity index 52% rename from tianshou/policy/base.py rename to tianshou/algorithm/algorithm_base.py index ced0043d1..50884d646 100644 --- a/tianshou/policy/base.py +++ b/tianshou/algorithm/algorithm_base.py @@ -1,9 +1,9 @@ import logging import time from abc import ABC, abstractmethod -from collections.abc import Callable +from collections.abc import Callable, Mapping from dataclasses import dataclass, field -from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast import gymnasium as gym import numpy as np @@ -13,11 +13,17 @@ from numpy.typing import ArrayLike from overrides import override from sensai.util.hash import pickle_hash +from sensai.util.helper import mark_used from torch import nn +from torch.nn.modules.module import ( + _IncompatibleKeys, # we have to do this since we override load_state_dict +) +from torch.optim.lr_scheduler import LRScheduler +from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_numpy, to_torch_as from tianshou.data.batch import Batch, BatchProtocol, TArr -from tianshou.data.buffer.base import TBuffer +from tianshou.data.buffer.buffer_base import TBuffer from tianshou.data.types import ( ActBatchProtocol, ActStateBatchProtocol, @@ -25,15 +31,33 @@ ObsBatchProtocol, RolloutBatchProtocol, ) -from tianshou.utils import MultipleLRSchedulers from tianshou.utils.determinism import TraceLogger +from tianshou.utils.lagged_network import ( + EvalModeModuleWrapper, + LaggedNetworkCollection, +) from tianshou.utils.net.common import RandomActor from tianshou.utils.print import DataclassPPrintMixin from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode +if TYPE_CHECKING: + from tianshou.data.stats import InfoStats + from tianshou.trainer import ( + OfflineTrainer, + OfflineTrainerParams, + OffPolicyTrainer, + OffPolicyTrainerParams, + OnPolicyTrainer, + OnPolicyTrainerParams, + Trainer, + TrainerParams, + ) + + mark_used(TrainerParams) + logger = logging.getLogger(__name__) -TLearningRateScheduler: TypeAlias = torch.optim.lr_scheduler.LRScheduler | MultipleLRSchedulers +TArrOrActBatch = TypeVar("TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") @dataclass(kw_only=True) @@ -132,74 +156,45 @@ def __setattr__(self, name: str, value: Any) -> None: super().__setattr__(name, value) -TTrainingStats = TypeVar("TTrainingStats", bound=TrainingStats) - - -class BasePolicy(nn.Module, Generic[TTrainingStats], ABC): - """The base class for any RL policy. - - Tianshou aims to modularize RL algorithms. It comes into several classes of - policies in Tianshou. All policy classes must inherit from - :class:`~tianshou.policy.BasePolicy`. - - A policy class typically has the following parts: - - * :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including \ - coping the target network and so on; - * :meth:`~tianshou.policy.BasePolicy.forward`: compute action with given \ - observation; - * :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the \ - replay buffer (this function can interact with replay buffer); - * :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given batch of \ - data. - * :meth:`~tianshou.policy.BasePolicy.post_process_fn`: update the replay buffer \ - from the learning process (e.g., prioritized replay buffer needs to update \ - the weight); - * :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training, \ - i.e., `process_fn -> learn -> post_process_fn`. - - Most of the policy needs a neural network to predict the action and an - optimizer to optimize the policy. The rules of self-defined networks are: - - 1. Input: observation "obs" (may be a ``numpy.ndarray``, a ``torch.Tensor``, a \ - dict or any others), hidden state "state" (for RNN usage), and other information \ - "info" provided by the environment. - 2. Output: some "logits", the next hidden state "state", and the intermediate \ - result during policy forwarding procedure "policy". The "logits" could be a tuple \ - instead of a ``torch.Tensor``. It depends on how the policy process the network \ - output. For example, in PPO, the return of the network might be \ - ``(mu, sigma), state`` for Gaussian policy. The "policy" can be a Batch of \ - torch.Tensor or other things, which will be stored in the replay buffer, and can \ - be accessed in the policy update process (e.g. in "policy.learn()", the \ - "batch.policy" is what you need). - - Since :class:`~tianshou.policy.BasePolicy` inherits ``torch.nn.Module``, you can - use :class:`~tianshou.policy.BasePolicy` almost the same as ``torch.nn.Module``, - for instance, loading and saving the model: - :: - - torch.save(policy.state_dict(), "policy.pth") - policy.load_state_dict(torch.load("policy.pth")) - - :param action_space: Env's action_space. - :param observation_space: Env's observation space. TODO: appears unused... - :param action_scaling: if True, scale the action from [-1, 1] to the range - of action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - Only used if the action_space is continuous. - :param lr_scheduler: if not None, will be called in `policy.update()`. - """ +class Policy(nn.Module, ABC): + """Represents a policy, which provides the fundamental mapping from observations to actions.""" def __init__( self, - *, action_space: gym.Space, - # TODO: does the policy actually need the observation space? observation_space: gym.Space | None = None, action_scaling: bool = False, action_bound_method: Literal["clip", "tanh"] | None = "clip", - lr_scheduler: TLearningRateScheduler | None = None, - ) -> None: + ): + """ + :param action_space: the environment's action_space. + :param observation_space: the environment's observation space. + :param action_scaling: flag indicating whether, for continuous action spaces, actions + should be scaled from the standard neural network output range [-1, 1] to the + environment's action space range [action_space.low, action_space.high]. + This applies to continuous action spaces only (gym.spaces.Box) and has no effect + for discrete spaces. + When enabled, policy outputs are expected to be in the normalized range [-1, 1] + (after bounding), and are then linearly transformed to the actual required range. + This improves neural network training stability, allows the same algorithm to work + across environments with different action ranges, and standardizes exploration + strategies. + Should be disabled if the actor model already produces outputs in the correct range. + :param action_bound_method: the method used for bounding actions in continuous action spaces + to the range [-1, 1] before scaling them to the environment's action space (provided + that `action_scaling` is enabled). + This applies to continuous action spaces only (`gym.spaces.Box`) and should be set to None + for discrete spaces. + When set to "clip", actions exceeding the [-1, 1] range are simply clipped to this + range. When set to "tanh", a hyperbolic tangent function is applied, which smoothly + constrains outputs to [-1, 1] while preserving gradients. + The choice of bounding method affects both training dynamics and exploration behavior. + Clipping provides hard boundaries but may create plateau regions in the gradient + landscape, while tanh provides smoother transitions but can compress sensitivity + near the boundaries. + Should be set to None if the actor model inherently produces bounded outputs. + Typically used together with `action_scaling=True`. + """ allowed_action_bound_methods = ("clip", "tanh") if ( action_bound_method is not None @@ -214,7 +209,6 @@ def __init__( f"action_scaling can only be True when action_space is Box but " f"got: {action_space}", ) - super().__init__() self.observation_space = observation_space self.action_space = action_space @@ -226,10 +220,8 @@ def __init__( raise ValueError(f"Unsupported action space: {action_space}.") self._action_type = cast(Literal["discrete", "continuous"], action_type) self.agent_id = 0 - self.updating = False self.action_scaling = action_scaling self.action_bound_method = action_bound_method - self.lr_scheduler = lr_scheduler self.is_within_training_step = False """ flag indicating whether we are currently within a training step, @@ -243,120 +235,14 @@ def __init__( This flag should normally remain False and should be set to True only by the algorithm which performs training steps. This is done automatically by the Trainer classes. If a policy is used outside of a Trainer, - the user should ensure that this flag is set correctly before calling update or learn. + the user should ensure that this flag is set correctly. """ self._compile() - def __setstate__(self, state: dict[str, Any]) -> None: - # TODO Use setstate function once merged - if "is_within_training_step" not in state: - state["is_within_training_step"] = False - self.__dict__ = state - @property def action_type(self) -> Literal["discrete", "continuous"]: return self._action_type - def set_agent_id(self, agent_id: int) -> None: - """Set self.agent_id = agent_id, for MARL.""" - self.agent_id = agent_id - - # TODO: needed, since for most of offline algorithm, the algorithm itself doesn't - # have a method to add noise to action. - # So we add the default behavior here. It's a little messy, maybe one can - # find a better way to do this. - - _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") - - def exploration_noise( - self, - act: _TArrOrActBatch, - batch: ObsBatchProtocol, - ) -> _TArrOrActBatch: - """Modify the action from policy.forward with exploration noise. - - NOTE: currently does not add any noise! Needs to be overridden by subclasses - to actually do something. - - :param act: a data batch or numpy.ndarray which is the action taken by - policy.forward. - :param batch: the input batch for policy.forward, kept for advanced usage. - :return: action in the same form of input "act" but with added exploration - noise. - """ - return act - - def soft_update(self, tgt: nn.Module, src: nn.Module, tau: float) -> None: - """Softly update the parameters of target module towards the parameters of source module.""" - for tgt_param, src_param in zip(tgt.parameters(), src.parameters(), strict=True): - tgt_param.data.copy_(tau * src_param.data + (1 - tau) * tgt_param.data) - - def compute_action( - self, - obs: ArrayLike, - info: dict[str, Any] | None = None, - state: dict | BatchProtocol | np.ndarray | None = None, - ) -> np.ndarray | int: - """Get action as int (for discrete env's) or array (for continuous ones) from an env's observation and info. - - :param obs: observation from the gym's env. - :param info: information given by the gym's env. - :param state: the hidden state of RNN policy, used for recurrent policy. - :return: action as int (for discrete env's) or array (for continuous ones). - """ - obs = np.array(obs) # convert array-like to array (e.g. LazyFrames) - obs = obs[None, :] # add batch dimension - obs_batch = cast(ObsBatchProtocol, Batch(obs=obs, info=info)) - act = self.forward(obs_batch, state=state).act.squeeze() - if isinstance(act, torch.Tensor): - act = act.detach().cpu().numpy() - act = self.map_action(act) - if isinstance(self.action_space, Discrete): - # could be an array of shape (), easier to just convert to int - act = int(act) # type: ignore - return act - - @abstractmethod - def forward( - self, - batch: ObsBatchProtocol, - state: dict | BatchProtocol | np.ndarray | None = None, - **kwargs: Any, - ) -> ActBatchProtocol | ActStateBatchProtocol: # TODO: make consistent typing - """Compute action over the given batch data. - - :return: A :class:`~tianshou.data.Batch` which MUST have the following keys: - - * ``act`` a numpy.ndarray or a torch.Tensor, the action over \ - given batch data. - * ``state`` a dict, a numpy.ndarray or a torch.Tensor, the \ - internal state of the policy, ``None`` as default. - - Other keys are user-defined. It depends on the algorithm. For example, - :: - - # some code - return Batch(logits=..., act=..., state=None, dist=...) - - The keyword ``policy`` is reserved and the corresponding data will be - stored into the replay buffer. For instance, - :: - - # some code - return Batch(..., policy=Batch(log_prob=dist.log_prob(act))) - # and in the sampled data batch, you can directly use - # batch.policy.log_prob to get your data. - - .. note:: - - In continuous action space, you should do another step "map_action" to get - the real action: - :: - - act = policy(batch).act # doesn't map to the target action range - act = policy.map_action(act, batch) - """ - @staticmethod def _action_to_numpy(act: TArr) -> np.ndarray: act = to_numpy(act) # NOTE: to_numpy could confusingly also return a Batch @@ -405,7 +291,7 @@ def map_action_inverse( self, act: TArr, ) -> np.ndarray: - """Inverse operation to :meth:`~tianshou.policy.BasePolicy.map_action`. + """Inverse operation to :meth:`map_action`. This function is called in :meth:`~tianshou.data.Collector.collect` for random initial steps. It scales [action_space.low, action_space.high] to @@ -429,19 +315,232 @@ def map_action_inverse( return act - def process_buffer(self, buffer: TBuffer) -> TBuffer: - """Pre-process the replay buffer, e.g., to add new keys. + def compute_action( + self, + obs: ArrayLike, + info: dict[str, Any] | None = None, + state: dict | BatchProtocol | np.ndarray | None = None, + ) -> np.ndarray | int: + """Get action as int (for discrete env's) or array (for continuous ones) from an env's observation and info. + + :param obs: observation from the gym's env. + :param info: information given by the gym's env. + :param state: the hidden state of RNN policy, used for recurrent policy. + :return: action as int (for discrete env's) or array (for continuous ones). + """ + obs = np.array(obs) # convert array-like to array (e.g. LazyFrames) + obs = obs[None, :] # add batch dimension + obs_batch = cast(ObsBatchProtocol, Batch(obs=obs, info=info)) + act = self.forward(obs_batch, state=state).act.squeeze() + if isinstance(act, torch.Tensor): + act = act.detach().cpu().numpy() + act = self.map_action(act) + if isinstance(self.action_space, Discrete): + # could be an array of shape (), easier to just convert to int + act = int(act) # type: ignore + return act + + @staticmethod + def _compile() -> None: + f64 = np.array([0, 1], dtype=np.float64) + f32 = np.array([0, 1], dtype=np.float32) + b = np.array([False, True], dtype=np.bool_) + i64 = np.array([[0, 1]], dtype=np.int64) + _gae(f64, f64, f64, b, 0.1, 0.1) + _gae(f32, f32, f64, b, 0.1, 0.1) + _nstep_return(f64, b, f32.reshape(-1, 1), i64, 0.1, 1) - Used in BaseTrainer initialization method, usually used by offline trainers. + _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") - Note: this will only be called once, when the trainer is initialized! - If the buffer is empty by then, there will be nothing to process. - This method is meant to be overridden by policies which will be trained - offline at some stage, e.g., in a pre-training step. + def add_exploration_noise( + self, + act: _TArrOrActBatch, + batch: ObsBatchProtocol, + ) -> _TArrOrActBatch: + """(Optionally) adds noise to an actions computed by the policy's forward method for + exploration purposes. + + NOTE: The base implementation does not add any noise, but subclasses can override + this method to add appropriate mechanisms for adding noise. + + :param act: a data batch or numpy.ndarray containing actions computed by the policy's + forward method. + :param batch: the corresponding input batch that was passed to forward; provided for + advanced usage. + :return: actions in the same format as the input `act` but with added exploration + noise (if implemented - otherwise returns `act` unchanged). """ - return buffer + return act + + +class LaggedNetworkAlgorithmMixin(ABC): + """ + Base class for an algorithm mixin which adds support for lagged networks (target networks) whose weights + are updated periodically. + """ + + def __init__(self) -> None: + self._lagged_networks = LaggedNetworkCollection() + + def _add_lagged_network(self, src: torch.nn.Module) -> EvalModeModuleWrapper: + """ + Adds a lagged network to the collection, returning the target network, which + is forced to eval mode. The target network is a copy of the source network, + which, however, supports only the forward method (hence the type torch.nn.Module); + attribute access is not supported. + + :param src: the source network whose parameters are to be copied to the target network + :return: the target network, which supports only the forward method and is forced to eval mode + """ + return self._lagged_networks.add_lagged_network(src) + + @abstractmethod + def _update_lagged_network_weights(self) -> None: + pass + + +class LaggedNetworkFullUpdateAlgorithmMixin(LaggedNetworkAlgorithmMixin): + """ + Algorithm mixin which adds support for lagged networks (target networks) where weights + are updated by fully copying the weights of the source network to the target network. + """ + + def _update_lagged_network_weights(self) -> None: + self._lagged_networks.full_parameter_update() + + +class LaggedNetworkPolyakUpdateAlgorithmMixin(LaggedNetworkAlgorithmMixin): + """ + Algorithm mixin which adds support for lagged networks (target networks) where weights + are updated via Polyak averaging (soft update using a convex combination of the parameters + of the source and target networks with weight `tau` and `1-tau` respectively). + """ + + def __init__(self, tau: float) -> None: + """ + :param tau: the fraction with which to use the source network's parameters, the inverse `1-tau` being + the fraction with which to retain the target network's parameters. + """ + super().__init__() + self.tau = tau + + def _update_lagged_network_weights(self) -> None: + self._lagged_networks.polyak_parameter_update(self.tau) + + +TPolicy = TypeVar("TPolicy", bound=Policy) +TTrainerParams = TypeVar("TTrainerParams", bound="TrainerParams") - def process_fn( + +class Algorithm(torch.nn.Module, Generic[TPolicy, TTrainerParams], ABC): + """ + The base class for reinforcement learning algorithms in Tianshou. + + An algorithm critically defines how to update the parameters of neural networks + based on a batch data, optionally applying pre-processing and post-processing to the data. + The actual update step is highly algorithm-specific and thus is defined in subclasses. + """ + + _STATE_DICT_KEY_OPTIMIZERS = "_optimizers" + + def __init__( + self, + *, + policy: TPolicy, + ) -> None: + """:param policy: the policy""" + super().__init__() + self.policy: TPolicy = policy + self.lr_schedulers: list[LRScheduler] = [] + self._optimizers: list["Algorithm.Optimizer"] = [] + """ + list of optimizers associated with the algorithm (created via `_create_optimizer`), + whose states will be returned when calling `state_dict` and which will be restored + when calling `load_state_dict` accordingly + """ + + class Optimizer: + """Wrapper for a torch optimizer that optionally performs gradient clipping.""" + + def __init__( + self, + optim: torch.optim.Optimizer, + module: torch.nn.Module, + max_grad_norm: float | None = None, + ) -> None: + """ + :param optim: the optimizer + :param module: the module whose parameters are being affected by `optim` + :param max_grad_norm: the maximum L2 norm threshold for gradient clipping. + When not None, gradients will be rescaled using to ensure their L2 norm does not + exceed this value. This prevents exploding gradients and stabilizes training by + limiting the magnitude of parameter updates. + Set to None to disable gradient clipping. + """ + super().__init__() + self._optim = optim + self._module = module + self._max_grad_norm = max_grad_norm + + def step( + self, loss: torch.Tensor, retain_graph: bool | None = None, create_graph: bool = False + ) -> None: + """Performs an optimizer step, optionally applying gradient clipping (if configured at construction). + + :param loss: the loss to backpropagate + :param retain_graph: passed on to `backward` + :param create_graph: passed on to `backward` + """ + self._optim.zero_grad() + loss.backward(retain_graph=retain_graph, create_graph=create_graph) + if self._max_grad_norm is not None: + nn.utils.clip_grad_norm_(self._module.parameters(), max_norm=self._max_grad_norm) + self._optim.step() + + def state_dict(self) -> dict: + """Returns the `state_dict` of the wrapped optimizer.""" + return self._optim.state_dict() + + def load_state_dict(self, state_dict: dict) -> None: + """Loads the given `state_dict` into the wrapped optimizer.""" + self._optim.load_state_dict(state_dict) + + def _create_optimizer( + self, + module: torch.nn.Module, + factory: OptimizerFactory, + max_grad_norm: float | None = None, + ) -> Optimizer: + optimizer, lr_scheduler = factory.create_instances(module) + if lr_scheduler is not None: + self.lr_schedulers.append(lr_scheduler) + optim = self.Optimizer(optimizer, module, max_grad_norm=max_grad_norm) + self._optimizers.append(optim) + return optim + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): # type: ignore + d = super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars) + + # add optimizer states + opt_key = prefix + self._STATE_DICT_KEY_OPTIMIZERS + assert opt_key not in d + d[opt_key] = [o.state_dict() for o in self._optimizers] + + return d + + def load_state_dict( + self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False + ) -> _IncompatibleKeys: + # don't override type in annotation since it's is declared as Mapping in nn.Module + state_dict = cast(dict[str, Any], state_dict) + # restore optimizer states + optimizers_state_dict = state_dict.pop(self._STATE_DICT_KEY_OPTIMIZERS) + for optim, optim_state in zip(self._optimizers, optimizers_state_dict, strict=True): + optim.load_state_dict(optim_state) + + return super().load_state_dict(state_dict, strict=strict, assign=assign) + + def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, @@ -458,32 +557,9 @@ def process_fn( """ return batch - @abstractmethod - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TTrainingStats: - """Update policy with a given batch of data. - - :return: A dataclass object, including the data needed to be logged (e.g., loss). - - .. note:: - - In order to distinguish the collecting state, updating state and - testing state, you can check the policy state by ``self.training`` - and ``self.updating``. Please refer to :ref:`policy_state` for more - detailed explanation. - - .. warning:: - - If you use ``torch.distributions.Normal`` and - ``torch.distributions.Categorical`` to calculate the log_prob, - please be careful about the shape: Categorical distribution gives - "[batch_size]" shape while Normal distribution gives "[batch_size, - 1]" shape. The auto-broadcasting of numerical operation with torch - tensors will amplify this error. - """ - - def post_process_fn( + def _postprocess_batch( self, - batch: BatchProtocol, + batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> None: @@ -505,53 +581,50 @@ def post_process_fn( "Prioritized replay is disabled for this batch.", ) - def update( + def _update( self, sample_size: int | None, buffer: ReplayBuffer | None, - **kwargs: Any, - ) -> TTrainingStats: - """Update the policy network and replay buffer. + update_with_batch_fn: Callable[[RolloutBatchProtocol], TrainingStats], + ) -> TrainingStats: + """Orchestrates an update step. - It includes 3 function steps: process_fn, learn, and post_process_fn. In - addition, this function will change the value of ``self.updating``: it will be - False before this function and will be True when executing :meth:`update`. - Please refer to :ref:`policy_state` for more detailed explanation. The return - value of learn is augmented with the training time within update, while smoothed - loss values are computed in the trainer. + An update involves three algorithm-specific sub-steps: + * pre-processing of the batch, + * performing the actual network update with the batch, and + * post-processing of the batch. + + The return value is that of the network update call, augmented with the + training time within update. :param sample_size: 0 means it will extract all the data from the buffer, otherwise it will sample a batch with given sample_size. None also means it will extract all the data from the buffer, but it will be shuffled - first. TODO: remove the option for 0? + first. :param buffer: the corresponding replay buffer. + :param update_with_batch_fn: the function to call for the actual update step, + which is algorithm-specific and thus provided by the subclass. - :return: A dataclass object containing the data needed to be logged (e.g., loss) from - ``policy.learn()``. + :return: A dataclass object containing data to be logged (e.g., loss) """ - # TODO: when does this happen? - # -> this happens never in practice as update is either called with a collector buffer or an assert before - - if not self.is_within_training_step: + if not self.policy.is_within_training_step: raise RuntimeError( - f"update() was called outside of a training step as signalled by {self.is_within_training_step=} " + f"update() was called outside of a training step as signalled by {self.policy.is_within_training_step=} " f"If you want to update the policy without a Trainer, you will have to manage the above-mentioned " f"flag yourself. You can to this e.g., by using the contextmanager {policy_within_training_step.__name__}.", ) if buffer is None: - return TrainingStats() # type: ignore[return-value] + return TrainingStats() start_time = time.time() batch, indices = buffer.sample(sample_size) TraceLogger.log(logger, lambda: f"Updating with batch: indices={pickle_hash(indices)}") - self.updating = True - batch = self.process_fn(batch, buffer, indices) + batch = self._preprocess_batch(batch, buffer, indices) with torch_train_mode(self): - training_stat = self.learn(batch, **kwargs) - self.post_process_fn(batch, buffer, indices) - if self.lr_scheduler is not None: - self.lr_scheduler.step() - self.updating = False + training_stat = update_with_batch_fn(batch) + self._postprocess_batch(batch, buffer, indices) + for lr_scheduler in self.lr_schedulers: + lr_scheduler.step() training_stat.train_time = time.time() - start_time return training_stat @@ -606,9 +679,24 @@ def compute_episodic_return( If None, it will be set to an array of 0. :param v_s: the value function of all current states :math:`V(s)`. If None, it is set based upon `v_s_` rolled by 1. - :param gamma: the discount factor, should be in [0, 1]. - :param gae_lambda: the parameter for Generalized Advantage Estimation, - should be in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + :param gae_lambda: the lambda parameter in [0, 1] for generalized advantage estimation (GAE). + Controls the bias-variance tradeoff in advantage estimates, acting as a + weighting factor for combining different n-step advantage estimators. Higher values + (closer to 1) reduce bias but increase variance by giving more weight to longer + trajectories, while lower values (closer to 0) reduce variance but increase bias + by relying more on the immediate TD error and value function estimates. At λ=0, + GAE becomes equivalent to the one-step TD error (high bias, low variance); at λ=1, + it becomes equivalent to Monte Carlo advantage estimation (low bias, high variance). + Intermediate values create a weighted average of n-step returns, with exponentially + decaying weights for longer-horizon returns. Typically set between 0.9 and 0.99 for + most policy gradient methods. :return: two numpy arrays (returns, advantage) with each shape (bsz, ). """ @@ -618,12 +706,12 @@ def compute_episodic_return( v_s_ = np.zeros_like(rew) else: v_s_ = to_numpy(v_s_.flatten()) - v_s_ = v_s_ * BasePolicy.value_mask(buffer, indices) + v_s_ = v_s_ * Algorithm.value_mask(buffer, indices) v_s = np.roll(v_s_, 1) if v_s is None else to_numpy(v_s.flatten()) end_flag = np.logical_or(batch.terminated, batch.truncated) end_flag[np.isin(indices, buffer.unfinished_index())] = True - advantage = _gae_return(v_s, v_s_, rew, end_flag, gamma, gae_lambda) + advantage = _gae(v_s, v_s_, rew, end_flag, gamma, gae_lambda) returns = advantage + v_s # normalization varies from each policy, so we don't do it here return returns, advantage @@ -636,9 +724,9 @@ def compute_nstep_return( target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor], gamma: float = 0.99, n_step: int = 1, - rew_norm: bool = False, ) -> BatchWithReturnsProtocol: - r"""Compute n-step return for Q-learning targets. + r""" + Computes the n-step return for Q-learning targets, adds it to the batch and returns the resulting batch. .. math:: G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i + @@ -650,17 +738,20 @@ def compute_nstep_return( :param batch: a data batch, which is equal to buffer[indices]. :param buffer: the data buffer. :param indices: tell batch's location in buffer - :param function target_q_fn: a function which compute target Q value - of "obs_next" given data buffer and wanted indices. - :param gamma: the discount factor, should be in [0, 1]. + :param target_q_fn: a function which computes the target Q value + of "obs_next" given data buffer and wanted indices (`n_step` steps ahead). + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param n_step: the number of estimation step, should be an int greater than 0. - :param rew_norm: normalize the reward to Normal(0, 1). - TODO: passing True is not supported and will cause an error! - :return: a Batch. The result will be stored in batch.returns as a + :return: a Batch. The result will be stored in `batch.returns` as a torch.Tensor with the same shape as target_q_fn's return tensor. """ - assert not rew_norm, "Reward normalization in computing n-step returns is unsupported now." if len(indices) != len(batch): raise ValueError(f"Batch size {len(batch)} and indices size {len(indices)} mismatch.") @@ -702,7 +793,7 @@ def compute_nstep_return( target_q_IA = to_numpy(target_q_torch_IA.reshape(I, -1)) """Represents the Q-values (one for each action) of the transition after N steps.""" - target_q_IA *= BasePolicy.value_mask(buffer, indices_after_n_steps_I).reshape(-1, 1) + target_q_IA *= Algorithm.value_mask(buffer, indices_after_n_steps_I).reshape(-1, 1) end_flag_B = buffer.done.copy() end_flag_B[buffer.unfinished_index()] = True n_step_return_IA = _nstep_return( @@ -717,24 +808,253 @@ def compute_nstep_return( batch.returns = to_torch_as(n_step_return_IA, target_q_torch_IA) - # TODO: this is simply casting to a certain type. Why is this necessary, and why is it happening here? + # TODO: this is simply converting to a certain type. Why is this necessary, and why is it happening here? if hasattr(batch, "weight"): batch.weight = to_torch_as(batch.weight, target_q_torch_IA) return cast(BatchWithReturnsProtocol, batch) - @staticmethod - def _compile() -> None: - f64 = np.array([0, 1], dtype=np.float64) - f32 = np.array([0, 1], dtype=np.float32) - b = np.array([False, True], dtype=np.bool_) - i64 = np.array([[0, 1]], dtype=np.int64) - _gae_return(f64, f64, f64, b, 0.1, 0.1) - _gae_return(f32, f32, f64, b, 0.1, 0.1) - _nstep_return(f64, b, f32.reshape(-1, 1), i64, 0.1, 1) + @abstractmethod + def create_trainer(self, params: TTrainerParams) -> "Trainer": + pass + + def run_training(self, params: TTrainerParams) -> "InfoStats": + trainer = self.create_trainer(params) + return trainer.run() + + +class OnPolicyAlgorithm( + Algorithm[TPolicy, "OnPolicyTrainerParams"], + Generic[TPolicy], + ABC, +): + """Base class for on-policy RL algorithms.""" + + def create_trainer(self, params: "OnPolicyTrainerParams") -> "OnPolicyTrainer": + from tianshou.trainer import OnPolicyTrainer + + return OnPolicyTrainer(self, params) + + @abstractmethod + def _update_with_batch( + self, batch: RolloutBatchProtocol, batch_size: int | None, repeat: int + ) -> TrainingStats: + """Performs an update step based on the given batch of data, updating the network + parameters. + + :param batch: the batch of data + :param batch_size: the minibatch size for gradient updates + :param repeat: the number of times to repeat the update over the whole batch + :return: a dataclas object containing statistics on the learning process, including + the data needed to be logged (e.g. loss values). + """ + + def update( + self, + buffer: ReplayBuffer, + batch_size: int | None, + repeat: int, + ) -> TrainingStats: + update_with_batch_fn = lambda batch: self._update_with_batch( + batch=batch, batch_size=batch_size, repeat=repeat + ) + return super()._update( + sample_size=0, buffer=buffer, update_with_batch_fn=update_with_batch_fn + ) + + +class OffPolicyAlgorithm( + Algorithm[TPolicy, "OffPolicyTrainerParams"], + Generic[TPolicy], + ABC, +): + """Base class for off-policy RL algorithms.""" + + def create_trainer(self, params: "OffPolicyTrainerParams") -> "OffPolicyTrainer": + from tianshou.trainer import OffPolicyTrainer + + return OffPolicyTrainer(self, params) + + @abstractmethod + def _update_with_batch( + self, + batch: RolloutBatchProtocol, + ) -> TrainingStats: + """Performs an update step based on the given batch of data, updating the network + parameters. + + :param batch: the batch of data + :return: a dataclas object containing statistics on the learning process, including + the data needed to be logged (e.g. loss values). + """ + + def update( + self, + buffer: ReplayBuffer, + sample_size: int | None, + ) -> TrainingStats: + update_with_batch_fn = lambda batch: self._update_with_batch(batch) + return super()._update( + sample_size=sample_size, buffer=buffer, update_with_batch_fn=update_with_batch_fn + ) + + +class OfflineAlgorithm( + Algorithm[TPolicy, "OfflineTrainerParams"], + Generic[TPolicy], + ABC, +): + """Base class for offline RL algorithms.""" + def process_buffer(self, buffer: TBuffer) -> TBuffer: + """Pre-process the replay buffer to prepare for offline learning, e.g. to add new keys.""" + return buffer + + def run_training(self, params: "OfflineTrainerParams") -> "InfoStats": + # NOTE: This override is required for correct typing when converting + # an algorithm to an offline algorithm using diamond inheritance + # (e.g. DiscreteCQL) in order to make it match first in the MRO + return super().run_training(params) + + def create_trainer(self, params: "OfflineTrainerParams") -> "OfflineTrainer": + from tianshou.trainer import OfflineTrainer + + return OfflineTrainer(self, params) + + @abstractmethod + def _update_with_batch( + self, + batch: RolloutBatchProtocol, + ) -> TrainingStats: + """Performs an update step based on the given batch of data, updating the network + parameters. + + :param batch: the batch of data + :return: a dataclas object containing statistics on the learning process, including + the data needed to be logged (e.g. loss values). + """ + + def update( + self, + buffer: ReplayBuffer, + sample_size: int | None, + ) -> TrainingStats: + update_with_batch_fn = lambda batch: self._update_with_batch(batch) + return super()._update( + sample_size=sample_size, buffer=buffer, update_with_batch_fn=update_with_batch_fn + ) + + +class OnPolicyWrapperAlgorithm( + OnPolicyAlgorithm[TPolicy], + Generic[TPolicy], + ABC, +): + """ + Base class for an on-policy algorithm that is a wrapper around another algorithm. + + It applies the wrapped algorithm's pre-processing and post-processing methods + and chains the update method of the wrapped algorithm with the wrapper's own update method. + """ + + def __init__( + self, + wrapped_algorithm: OnPolicyAlgorithm[TPolicy], + ): + super().__init__(policy=wrapped_algorithm.policy) + self.wrapped_algorithm = wrapped_algorithm + + def _preprocess_batch( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> RolloutBatchProtocol: + """Performs the pre-processing as defined by the wrapped algorithm.""" + return self.wrapped_algorithm._preprocess_batch(batch, buffer, indices) + + def _postprocess_batch( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> None: + """Performs the batch post-processing as defined by the wrapped algorithm.""" + self.wrapped_algorithm._postprocess_batch(batch, buffer, indices) + + def _update_with_batch( + self, batch: RolloutBatchProtocol, batch_size: int | None, repeat: int + ) -> TrainingStats: + """Performs the update as defined by the wrapped algorithm, followed by the wrapper's update.""" + original_stats = self.wrapped_algorithm._update_with_batch( + batch, batch_size=batch_size, repeat=repeat + ) + return self._wrapper_update_with_batch(batch, batch_size, repeat, original_stats) + + @abstractmethod + def _wrapper_update_with_batch( + self, + batch: RolloutBatchProtocol, + batch_size: int | None, + repeat: int, + original_stats: TrainingStats, + ) -> TrainingStats: + pass + + +class OffPolicyWrapperAlgorithm( + OffPolicyAlgorithm[TPolicy], + Generic[TPolicy], + ABC, +): + """ + Base class for an off-policy algorithm that is a wrapper around another algorithm. + + It applies the wrapped algorithm's pre-processing and post-processing methods + and chains the update method of the wrapped algorithm with the wrapper's own update method. + """ + + def __init__( + self, + wrapped_algorithm: OffPolicyAlgorithm[TPolicy], + ): + super().__init__(policy=wrapped_algorithm.policy) + self.wrapped_algorithm = wrapped_algorithm + + def _preprocess_batch( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> RolloutBatchProtocol: + """Performs the pre-processing as defined by the wrapped algorithm.""" + return self.wrapped_algorithm._preprocess_batch(batch, buffer, indices) + + def _postprocess_batch( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> None: + """Performs the batch post-processing as defined by the wrapped algorithm.""" + self.wrapped_algorithm._postprocess_batch(batch, buffer, indices) -class RandomActionPolicy(BasePolicy): + def _update_with_batch( + self, + batch: RolloutBatchProtocol, + ) -> TrainingStats: + """Performs the update as defined by the wrapped algorithm, followed by the wrapper's update .""" + original_stats = self.wrapped_algorithm._update_with_batch(batch) + return self._wrapper_update_with_batch(batch, original_stats) + + @abstractmethod + def _wrapper_update_with_batch( + self, batch: RolloutBatchProtocol, original_stats: TrainingStats + ) -> TrainingStats: + pass + + +class RandomActionPolicy(Policy): def __init__( self, action_space: gym.Space, @@ -755,13 +1075,9 @@ def forward( act, next_state = self.actor.compute_action_batch(batch.obs), state return cast(ActStateBatchProtocol, Batch(act=act, state=next_state)) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TrainingStats: - return TrainingStats() - -# TODO: rename? See docstring @njit -def _gae_return( +def _gae( v_s: np.ndarray, v_s_: np.ndarray, rew: np.ndarray, @@ -771,8 +1087,7 @@ def _gae_return( ) -> np.ndarray: r"""Computes advantages with GAE. - Note: doesn't compute returns but rather advantages. The return - is given by the output of this + v_s. Note that the advantages plus v_s + The return is given by the output of this + v_s. Note that the advantages plus v_s is exactly the same as the TD-lambda target, which is computed by the recursive formula: @@ -795,8 +1110,18 @@ def _gae_return( $V_{t+1}$ :param rew: rewards in an episode, i.e. $r_t$ :param end_flag: boolean array indicating whether the episode is done - :param gamma: discount factor - :param gae_lambda: lambda parameter for GAE, controlling the bias-variance tradeoff + :param gamma: the discount factor in [0, 1] for future rewards. + :param gae_lambda: the lambda parameter in [0, 1] for generalized advantage estimation (GAE). + Controls the bias-variance tradeoff in advantage estimates, acting as a + weighting factor for combining different n-step advantage estimators. Higher values + (closer to 1) reduce bias but increase variance by giving more weight to longer + trajectories, while lower values (closer to 0) reduce variance but increase bias + by relying more on the immediate TD error and value function estimates. At λ=0, + GAE becomes equivalent to the one-step TD error (high bias, low variance); at λ=1, + it becomes equivalent to Monte Carlo advantage estimation (low bias, high variance). + Intermediate values create a weighted average of n-step returns, with exponentially + decaying weights for longer-horizon returns. Typically set between 0.9 and 0.99 for + most policy gradient methods. :return: """ returns = np.zeros(rew.shape) diff --git a/tianshou/policy/imitation/__init__.py b/tianshou/algorithm/imitation/__init__.py similarity index 100% rename from tianshou/policy/imitation/__init__.py rename to tianshou/algorithm/imitation/__init__.py diff --git a/tianshou/algorithm/imitation/bcq.py b/tianshou/algorithm/imitation/bcq.py new file mode 100644 index 000000000..609621ae4 --- /dev/null +++ b/tianshou/algorithm/imitation/bcq.py @@ -0,0 +1,263 @@ +import copy +from dataclasses import dataclass +from typing import Any, Literal, TypeVar, cast + +import gymnasium as gym +import numpy as np +import torch +import torch.nn.functional as F + +from tianshou.algorithm.algorithm_base import ( + LaggedNetworkPolyakUpdateAlgorithmMixin, + OfflineAlgorithm, + Policy, + TrainingStats, +) +from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import Batch, to_torch +from tianshou.data.batch import BatchProtocol +from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol +from tianshou.utils.net.continuous import VAE + + +@dataclass(kw_only=True) +class BCQTrainingStats(TrainingStats): + actor_loss: float + critic1_loss: float + critic2_loss: float + vae_loss: float + + +TBCQTrainingStats = TypeVar("TBCQTrainingStats", bound=BCQTrainingStats) + + +class BCQPolicy(Policy): + def __init__( + self, + *, + actor_perturbation: torch.nn.Module, + action_space: gym.Space, + critic: torch.nn.Module, + vae: VAE, + forward_sampled_times: int = 100, + observation_space: gym.Space | None = None, + action_scaling: bool = False, + action_bound_method: Literal["clip", "tanh"] | None = "clip", + ) -> None: + """ + :param actor_perturbation: the actor perturbation. `(s, a -> perturbed a)` + :param critic: the first critic network. + :param vae: the VAE network, generating actions similar to those in batch. + :param forward_sampled_times: the number of sampled actions in forward function. + The policy samples many actions and takes the action with the max value. + :param observation_space: the environment's observation space + :param action_scaling: flag indicating whether, for continuous action spaces, actions + should be scaled from the standard neural network output range [-1, 1] to the + environment's action space range [action_space.low, action_space.high]. + This applies to continuous action spaces only (gym.spaces.Box) and has no effect + for discrete spaces. + When enabled, policy outputs are expected to be in the normalized range [-1, 1] + (after bounding), and are then linearly transformed to the actual required range. + This improves neural network training stability, allows the same algorithm to work + across environments with different action ranges, and standardizes exploration + strategies. + Should be disabled if the actor model already produces outputs in the correct range. + :param action_bound_method: the method used for bounding actions in continuous action spaces + to the range [-1, 1] before scaling them to the environment's action space (provided + that `action_scaling` is enabled). + This applies to continuous action spaces only (`gym.spaces.Box`) and should be set to None + for discrete spaces. + When set to "clip", actions exceeding the [-1, 1] range are simply clipped to this + range. When set to "tanh", a hyperbolic tangent function is applied, which smoothly + constrains outputs to [-1, 1] while preserving gradients. + The choice of bounding method affects both training dynamics and exploration behavior. + Clipping provides hard boundaries but may create plateau regions in the gradient + landscape, while tanh provides smoother transitions but can compress sensitivity + near the boundaries. + Should be set to None if the actor model inherently produces bounded outputs. + Typically used together with `action_scaling=True`. + """ + super().__init__( + action_space=action_space, + observation_space=observation_space, + action_scaling=action_scaling, + action_bound_method=action_bound_method, + ) + self.actor_perturbation = actor_perturbation + self.critic = critic + self.vae = vae + self.forward_sampled_times = forward_sampled_times + + def forward( + self, + batch: ObsBatchProtocol, + state: dict | BatchProtocol | np.ndarray | None = None, + **kwargs: Any, + ) -> ActBatchProtocol: + """Compute action over the given batch data.""" + # There is "obs" in the Batch + # obs_group: several groups. Each group has a state. + device = next(self.parameters()).device + obs_group: torch.Tensor = to_torch(batch.obs, device=device) + act_group = [] + for obs_orig in obs_group: + # now obs is (state_dim) + obs = (obs_orig.reshape(1, -1)).repeat(self.forward_sampled_times, 1) + # now obs is (forward_sampled_times, state_dim) + + # decode(obs) generates action and actor perturbs it + act = self.actor_perturbation(obs, self.vae.decode(obs)) + # now action is (forward_sampled_times, action_dim) + q1 = self.critic(obs, act) + # q1 is (forward_sampled_times, 1) + max_indice = q1.argmax(0) + act_group.append(act[max_indice].cpu().data.numpy().flatten()) + act_group = np.array(act_group) + return cast(ActBatchProtocol, Batch(act=act_group)) + + +class BCQ( + OfflineAlgorithm[BCQPolicy], + LaggedNetworkPolyakUpdateAlgorithmMixin, +): + """Implementation of Batch-Constrained Deep Q-learning (BCQ) algorithm. arXiv:1812.02900.""" + + def __init__( + self, + *, + policy: BCQPolicy, + actor_perturbation_optim: OptimizerFactory, + critic_optim: OptimizerFactory, + vae_optim: OptimizerFactory, + critic2: torch.nn.Module | None = None, + critic2_optim: OptimizerFactory | None = None, + gamma: float = 0.99, + tau: float = 0.005, + lmbda: float = 0.75, + num_sampled_action: int = 10, + ) -> None: + """ + :param policy: the policy + :param actor_perturbation_optim: the optimizer factory for the policy's actor perturbation network. + :param critic_optim: the optimizer factory for the policy's critic network. + :param critic2: the second critic network; if None, clone the critic from the policy + :param critic2_optim: the optimizer factory for the second critic network; if None, use optimizer factory of first critic + :param vae_optim: the optimizer factory for the VAE network. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + :param tau: the soft update coefficient for target networks, controlling the rate at which + target networks track the learned networks. + When the parameters of the target network are updated with the current (source) network's + parameters, a weighted average is used: target = tau * source + (1 - tau) * target. + Smaller values (closer to 0) create more stable but slower learning as target networks + change more gradually. Higher values (closer to 1) allow faster learning but may reduce + stability. + Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. + :param lmbda: param for Clipped Double Q-learning. + :param num_sampled_action: the number of sampled actions in calculating target Q. + The algorithm samples several actions using VAE, and perturbs each action to get the target Q. + """ + # actor is Perturbation! + super().__init__( + policy=policy, + ) + LaggedNetworkPolyakUpdateAlgorithmMixin.__init__(self, tau=tau) + self.actor_perturbation_target = self._add_lagged_network(self.policy.actor_perturbation) + self.actor_perturbation_optim = self._create_optimizer( + self.policy.actor_perturbation, actor_perturbation_optim + ) + + self.critic_target = self._add_lagged_network(self.policy.critic) + self.critic_optim = self._create_optimizer(self.policy.critic, critic_optim) + + self.critic2 = critic2 or copy.deepcopy(self.policy.critic) + self.critic2_target = self._add_lagged_network(self.critic2) + self.critic2_optim = self._create_optimizer(self.critic2, critic2_optim or critic_optim) + + self.vae_optim = self._create_optimizer(self.policy.vae, vae_optim) + + self.gamma = gamma + self.lmbda = lmbda + self.num_sampled_action = num_sampled_action + + def _update_with_batch( + self, + batch: RolloutBatchProtocol, + ) -> BCQTrainingStats: + # batch: obs, act, rew, done, obs_next. (numpy array) + # (batch_size, state_dim) + # TODO: This does not use policy.forward but computes things directly, which seems odd + + device = next(self.parameters()).device + batch: Batch = to_torch(batch, dtype=torch.float, device=device) + obs, act = batch.obs, batch.act + batch_size = obs.shape[0] + + # mean, std: (state.shape[0], latent_dim) + recon, mean, std = self.policy.vae(obs, act) + recon_loss = F.mse_loss(act, recon) + # (....) is D_KL( N(mu, sigma) || N(0,1) ) + KL_loss = (-torch.log(std) + (std.pow(2) + mean.pow(2) - 1) / 2).mean() + vae_loss = recon_loss + KL_loss / 2 + + self.vae_optim.step(vae_loss) + + # critic training: + with torch.no_grad(): + # repeat num_sampled_action times + obs_next = batch.obs_next.repeat_interleave(self.num_sampled_action, dim=0) + # now obs_next: (num_sampled_action * batch_size, state_dim) + + # perturbed action generated by VAE + act_next = self.policy.vae.decode(obs_next) + # now obs_next: (num_sampled_action * batch_size, action_dim) + target_Q1 = self.critic_target(obs_next, act_next) + target_Q2 = self.critic2_target(obs_next, act_next) + + # Clipped Double Q-learning + target_Q = self.lmbda * torch.min(target_Q1, target_Q2) + (1 - self.lmbda) * torch.max( + target_Q1, + target_Q2, + ) + # now target_Q: (num_sampled_action * batch_size, 1) + + # the max value of Q + target_Q = target_Q.reshape(batch_size, -1).max(dim=1)[0].reshape(-1, 1) + # now target_Q: (batch_size, 1) + + target_Q = ( + batch.rew.reshape(-1, 1) + + torch.logical_not(batch.done).reshape(-1, 1) * self.gamma * target_Q + ) + target_Q = target_Q.float() + + current_Q1 = self.policy.critic(obs, act) + current_Q2 = self.critic2(obs, act) + + critic1_loss = F.mse_loss(current_Q1, target_Q) + critic2_loss = F.mse_loss(current_Q2, target_Q) + self.critic_optim.step(critic1_loss) + self.critic2_optim.step(critic2_loss) + + sampled_act = self.policy.vae.decode(obs) + perturbed_act = self.policy.actor_perturbation(obs, sampled_act) + + # max + actor_loss = -self.policy.critic(obs, perturbed_act).mean() + + self.actor_perturbation_optim.step(actor_loss) + + # update target networks + self._update_lagged_network_weights() + + return BCQTrainingStats( + actor_loss=actor_loss.item(), + critic1_loss=critic1_loss.item(), + critic2_loss=critic2_loss.item(), + vae_loss=vae_loss.item(), + ) diff --git a/tianshou/algorithm/imitation/cql.py b/tianshou/algorithm/imitation/cql.py new file mode 100644 index 000000000..f37b03c2d --- /dev/null +++ b/tianshou/algorithm/imitation/cql.py @@ -0,0 +1,400 @@ +from copy import deepcopy +from dataclasses import dataclass +from typing import cast + +import numpy as np +import torch +import torch.nn.functional as F +from overrides import override + +from tianshou.algorithm.algorithm_base import ( + LaggedNetworkPolyakUpdateAlgorithmMixin, + OfflineAlgorithm, +) +from tianshou.algorithm.modelfree.sac import Alpha, SACPolicy, SACTrainingStats +from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import Batch, ReplayBuffer, to_torch +from tianshou.data.buffer.buffer_base import TBuffer +from tianshou.data.types import RolloutBatchProtocol +from tianshou.utils.conversion import to_optional_float +from tianshou.utils.torch_utils import torch_device + + +@dataclass(kw_only=True) +class CQLTrainingStats(SACTrainingStats): + """A data structure for storing loss statistics of the CQL learn step.""" + + cql_alpha: float | None = None + cql_alpha_loss: float | None = None + + +# TODO: Perhaps SACPolicy should get a more generic name +class CQL(OfflineAlgorithm[SACPolicy], LaggedNetworkPolyakUpdateAlgorithmMixin): + """Implementation of the conservative Q-learning (CQL) algorithm. arXiv:2006.04779.""" + + def __init__( + self, + *, + policy: SACPolicy, + policy_optim: OptimizerFactory, + critic: torch.nn.Module, + critic_optim: OptimizerFactory, + critic2: torch.nn.Module | None = None, + critic2_optim: OptimizerFactory | None = None, + cql_alpha_lr: float = 1e-4, + cql_weight: float = 1.0, + tau: float = 0.005, + gamma: float = 0.99, + alpha: float | Alpha = 0.2, + temperature: float = 1.0, + with_lagrange: bool = True, + lagrange_threshold: float = 10.0, + min_action: float = -1.0, + max_action: float = 1.0, + num_repeat_actions: int = 10, + alpha_min: float = 0.0, + alpha_max: float = 1e6, + max_grad_norm: float = 1.0, + calibrated: bool = True, + ) -> None: + """ + :param actor: the actor network following the rules (s -> a) + :param policy_optim: the optimizer factory for the policy/its actor network. + :param critic: the first critic network. + :param critic_optim: the optimizer factory for the first critic network. + :param action_space: the environment's action space. + :param critic2: the second critic network. (s, a -> Q(s, a)). + If None, use the same network as critic (via deepcopy). + :param critic2_optim: the optimizer factory for the second critic network. + If None, clone the first critic's optimizer factory. + :param cql_alpha_lr: the learning rate for the Lagrange multiplier optimization. + Controls how quickly the CQL regularization coefficient (alpha) adapts during training. + Higher values allow faster adaptation but may cause instability in the training process. + Lower values provide more stable but slower adaptation of the regularization strength. + Only relevant when with_lagrange=True. + :param cql_weight: the coefficient that scales the conservative regularization term in the Q-function loss. + Controls the strength of the conservative Q-learning component relative to standard TD learning. + Higher values enforce more conservative value estimates by penalizing overestimation more strongly. + Lower values allow the algorithm to behave more like standard Q-learning. + Increasing this weight typically improves performance in purely offline settings where + overestimation bias can lead to poor policy extraction. + :param tau: the soft update coefficient for target networks, controlling the rate at which + target networks track the learned networks. + When the parameters of the target network are updated with the current (source) network's + parameters, a weighted average is used: target = tau * source + (1 - tau) * target. + Smaller values (closer to 0) create more stable but slower learning as target networks + change more gradually. Higher values (closer to 1) allow faster learning but may reduce + stability. + Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + :param alpha: the entropy regularization coefficient alpha or an object + which can be used to automatically tune it (e.g. an instance of `AutoAlpha`). + :param temperature: the temperature parameter used in the LogSumExp calculation of the CQL loss. + Controls the sharpness of the softmax distribution when computing the expected Q-values. + Lower values make the LogSumExp operation more selective, focusing on the highest Q-values. + Higher values make the operation closer to an average, giving more weight to all Q-values. + The temperature affects how conservatively the algorithm penalizes out-of-distribution actions. + :param with_lagrange: a flag indicating whether to automatically tune the CQL regularization strength. + If True, uses Lagrangian dual gradient descent to dynamically adjust the CQL alpha parameter. + This formulation maintains the CQL regularization loss near the lagrange_threshold value. + Adaptive tuning helps balance conservative learning against excessive pessimism. + If False, the conservative loss is scaled by a fixed cql_weight throughout training. + The original CQL paper recommends setting this to True for most offline RL tasks. + :param lagrange_threshold: the target value for the CQL regularization loss when using Lagrangian optimization. + When with_lagrange=True, the algorithm dynamically adjusts the CQL alpha parameter to maintain + the regularization loss close to this threshold. + Lower values result in more conservative behavior by enforcing stronger penalties on + out-of-distribution actions. + Higher values allow more optimistic Q-value estimates similar to standard Q-learning. + This threshold effectively controls the level of conservatism in CQL's value estimation. + :param min_action: the lower bound for each dimension of the action space. + Used when sampling random actions for the CQL regularization term. + Should match the environment's action space minimum values. + These random actions help penalize Q-values for out-of-distribution actions. + Typically set to -1.0 for normalized continuous action spaces. + :param max_action: the upper bound for each dimension of the action space. + Used when sampling random actions for the CQL regularization term. + Should match the environment's action space maximum values. + These random actions help penalize Q-values for out-of-distribution actions. + Typically set to 1.0 for normalized continuous action spaces. + :param num_repeat_actions: the number of action samples generated per state when computing + the CQL regularization term. + Controls how many random and policy actions are sampled for each state in the batch when + estimating expected Q-values. + Higher values provide more accurate approximation of the expected Q-values but increase + computational cost. + Lower values reduce computation but may provide less stable or less accurate regularization. + The original CQL paper typically uses values around 10. + :param alpha_min: the minimum value allowed for the adaptive CQL regularization coefficient. + When using Lagrangian optimization (with_lagrange=True), constrains the automatically tuned + cql_alpha parameter to be at least this value. + Prevents the regularization strength from becoming too small during training. + Setting a positive value ensures the algorithm maintains at least some degree of conservatism. + Only relevant when with_lagrange=True. + :param alpha_max: the maximum value allowed for the adaptive CQL regularization coefficient. + When using Lagrangian optimization (with_lagrange=True), constrains the automatically tuned + cql_alpha parameter to be at most this value. + Prevents the regularization strength from becoming too large during training. + Setting an appropriate upper limit helps avoid overly conservative behavior that might hinder + learning useful value functions. + Only relevant when with_lagrange=True. + :param max_grad_norm: the maximum L2 norm threshold for gradient clipping when updating critic networks. + Gradients with norm exceeding this value will be rescaled to have norm equal to this value. + Helps stabilize training by preventing excessively large parameter updates from outlier samples. + Higher values allow larger updates but may lead to training instability. + Lower values enforce more conservative updates but may slow down learning. + Setting to a large value effectively disables gradient clipping. + :param calibrated: a flag indicating whether to use the calibrated version of CQL (CalQL). + If True, calibrates Q-values by taking the maximum of computed Q-values and Monte Carlo returns. + This modification helps address the excessive pessimism problem in standard CQL. + Particularly useful for offline pre-training followed by online fine-tuning scenarios. + Experimental results suggest this approach often achieves better performance than vanilla CQL. + Based on techniques from the CalQL paper (arXiv:2303.05479). + """ + super().__init__( + policy=policy, + ) + LaggedNetworkPolyakUpdateAlgorithmMixin.__init__(self, tau=tau) + + device = torch_device(policy) + + self.policy_optim = self._create_optimizer(self.policy, policy_optim) + self.critic = critic + self.critic_optim = self._create_optimizer( + self.critic, critic_optim, max_grad_norm=max_grad_norm + ) + self.critic2 = critic2 or deepcopy(critic) + self.critic2_optim = self._create_optimizer( + self.critic2, critic2_optim or critic_optim, max_grad_norm=max_grad_norm + ) + self.critic_old = self._add_lagged_network(self.critic) + self.critic2_old = self._add_lagged_network(self.critic2) + + self.gamma = gamma + self.alpha = Alpha.from_float_or_instance(alpha) + + self.temperature = temperature + self.with_lagrange = with_lagrange + self.lagrange_threshold = lagrange_threshold + + self.cql_weight = cql_weight + + self.cql_log_alpha = torch.tensor([0.0], requires_grad=True) + # TODO: Use an OptimizerFactory? + self.cql_alpha_optim = torch.optim.Adam([self.cql_log_alpha], lr=cql_alpha_lr) + self.cql_log_alpha = self.cql_log_alpha.to(device) + + self.min_action = min_action + self.max_action = max_action + + self.num_repeat_actions = num_repeat_actions + + self.alpha_min = alpha_min + self.alpha_max = alpha_max + + self.calibrated = calibrated + + def _policy_pred(self, obs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + batch = Batch(obs=obs, info=[None] * len(obs)) + obs_result = self.policy(batch) + return obs_result.act, obs_result.log_prob + + def _calc_policy_loss(self, obs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + act_pred, log_pi = self._policy_pred(obs) + q1 = self.critic(obs, act_pred) + q2 = self.critic2(obs, act_pred) + min_Q = torch.min(q1, q2) + # self.alpha: float | torch.Tensor + actor_loss = (self.alpha.value * log_pi - min_Q).mean() + # actor_loss.shape: (), log_pi.shape: (batch_size, 1) + return actor_loss, log_pi + + def _calc_pi_values( + self, + obs_pi: torch.Tensor, + obs_to_pred: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + act_pred, log_pi = self._policy_pred(obs_pi) + + q1 = self.critic(obs_to_pred, act_pred) + q2 = self.critic2(obs_to_pred, act_pred) + + return q1 - log_pi.detach(), q2 - log_pi.detach() + + def _calc_random_values( + self, + obs: torch.Tensor, + act: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + random_value1 = self.critic(obs, act) + random_log_prob1 = np.log(0.5 ** act.shape[-1]) + + random_value2 = self.critic2(obs, act) + random_log_prob2 = np.log(0.5 ** act.shape[-1]) + + return random_value1 - random_log_prob1, random_value2 - random_log_prob2 + + @override + def process_buffer(self, buffer: TBuffer) -> TBuffer: + """If `self.calibrated = True`, adds `calibration_returns` to buffer._meta. + + :param buffer: + :return: + """ + if self.calibrated: + # otherwise _meta hack cannot work + assert isinstance(buffer, ReplayBuffer) + batch, indices = buffer.sample(0) + returns, _ = self.compute_episodic_return( + batch=batch, + buffer=buffer, + indices=indices, + gamma=self.gamma, + gae_lambda=1.0, + ) + # TODO: don't access _meta directly + buffer._meta = cast( + RolloutBatchProtocol, + Batch(**buffer._meta.__dict__, calibration_returns=returns), + ) + return buffer + + def _update_with_batch(self, batch: RolloutBatchProtocol) -> CQLTrainingStats: + device = torch_device(self.policy) + batch: Batch = to_torch(batch, dtype=torch.float, device=device) + obs, act, rew, obs_next = batch.obs, batch.act, batch.rew, batch.obs_next + batch_size = obs.shape[0] + + # compute actor loss and update actor + actor_loss, log_pi = self._calc_policy_loss(obs) + self.policy_optim.step(actor_loss) + + entropy = -log_pi.detach() + alpha_loss = self.alpha.update(entropy) + + # compute target_Q + with torch.no_grad(): + act_next, new_log_pi = self._policy_pred(obs_next) + + target_Q1 = self.critic_old(obs_next, act_next) + target_Q2 = self.critic2_old(obs_next, act_next) + + target_Q = torch.min(target_Q1, target_Q2) - self.alpha.value * new_log_pi + + target_Q = rew + torch.logical_not(batch.done) * self.gamma * target_Q.flatten() + target_Q = target_Q.float() + # shape: (batch_size) + + # compute critic loss + current_Q1 = self.critic(obs, act).flatten() + current_Q2 = self.critic2(obs, act).flatten() + # shape: (batch_size) + + critic1_loss = F.mse_loss(current_Q1, target_Q) + critic2_loss = F.mse_loss(current_Q2, target_Q) + + # CQL + random_actions = ( + torch.FloatTensor(batch_size * self.num_repeat_actions, act.shape[-1]) + .uniform_(-self.min_action, self.max_action) + .to(device) + ) + + obs_len = len(obs.shape) + repeat_size = [1, self.num_repeat_actions] + [1] * (obs_len - 1) + view_size = [batch_size * self.num_repeat_actions, *list(obs.shape[1:])] + tmp_obs = obs.unsqueeze(1).repeat(*repeat_size).view(*view_size) + tmp_obs_next = obs_next.unsqueeze(1).repeat(*repeat_size).view(*view_size) + # tmp_obs & tmp_obs_next: (batch_size * num_repeat, state_dim) + + current_pi_value1, current_pi_value2 = self._calc_pi_values(tmp_obs, tmp_obs) + next_pi_value1, next_pi_value2 = self._calc_pi_values(tmp_obs_next, tmp_obs) + + random_value1, random_value2 = self._calc_random_values(tmp_obs, random_actions) + + for value in [ + current_pi_value1, + current_pi_value2, + next_pi_value1, + next_pi_value2, + random_value1, + random_value2, + ]: + value.reshape(batch_size, self.num_repeat_actions, 1) + + if self.calibrated: + returns = ( + batch.calibration_returns.unsqueeze(1) + .repeat( + (1, self.num_repeat_actions), + ) + .view(-1, 1) + ) + random_value1 = torch.max(random_value1, returns) + random_value2 = torch.max(random_value2, returns) + + current_pi_value1 = torch.max(current_pi_value1, returns) + current_pi_value2 = torch.max(current_pi_value2, returns) + + next_pi_value1 = torch.max(next_pi_value1, returns) + next_pi_value2 = torch.max(next_pi_value2, returns) + + # cat q values + cat_q1 = torch.cat([random_value1, current_pi_value1, next_pi_value1], 1) + cat_q2 = torch.cat([random_value2, current_pi_value2, next_pi_value2], 1) + # shape: (batch_size, 3 * num_repeat, 1) + + cql1_scaled_loss = ( + torch.logsumexp(cat_q1 / self.temperature, dim=1).mean() + * self.cql_weight + * self.temperature + - current_Q1.mean() * self.cql_weight + ) + cql2_scaled_loss = ( + torch.logsumexp(cat_q2 / self.temperature, dim=1).mean() + * self.cql_weight + * self.temperature + - current_Q2.mean() * self.cql_weight + ) + # shape: (1) + + cql_alpha_loss = None + cql_alpha = None + if self.with_lagrange: + cql_alpha = torch.clamp( + self.cql_log_alpha.exp(), + self.alpha_min, + self.alpha_max, + ) + cql1_scaled_loss = cql_alpha * (cql1_scaled_loss - self.lagrange_threshold) + cql2_scaled_loss = cql_alpha * (cql2_scaled_loss - self.lagrange_threshold) + + self.cql_alpha_optim.zero_grad() + cql_alpha_loss = -(cql1_scaled_loss + cql2_scaled_loss) * 0.5 + cql_alpha_loss.backward(retain_graph=True) + self.cql_alpha_optim.step() + + critic1_loss = critic1_loss + cql1_scaled_loss + critic2_loss = critic2_loss + cql2_scaled_loss + + # update critics + self.critic_optim.step(critic1_loss, retain_graph=True) + self.critic2_optim.step(critic2_loss) + + self._update_lagged_network_weights() + + return CQLTrainingStats( + actor_loss=to_optional_float(actor_loss), + critic1_loss=to_optional_float(critic1_loss), + critic2_loss=to_optional_float(critic2_loss), + alpha=to_optional_float(self.alpha.value), + alpha_loss=to_optional_float(alpha_loss), + cql_alpha_loss=to_optional_float(cql_alpha_loss), + cql_alpha=to_optional_float(cql_alpha), + ) diff --git a/tianshou/algorithm/imitation/discrete_bcq.py b/tianshou/algorithm/imitation/discrete_bcq.py new file mode 100644 index 000000000..4a8824e81 --- /dev/null +++ b/tianshou/algorithm/imitation/discrete_bcq.py @@ -0,0 +1,261 @@ +import math +from dataclasses import dataclass +from typing import Any, cast + +import gymnasium as gym +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from tianshou.algorithm.algorithm_base import ( + LaggedNetworkFullUpdateAlgorithmMixin, + OfflineAlgorithm, +) +from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy +from tianshou.algorithm.modelfree.reinforce import SimpleLossTrainingStats +from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import Batch, ReplayBuffer, to_torch +from tianshou.data.types import ( + BatchWithReturnsProtocol, + ImitationBatchProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) + +float_info = torch.finfo(torch.float32) +INF = float_info.max + + +@dataclass(kw_only=True) +class DiscreteBCQTrainingStats(SimpleLossTrainingStats): + q_loss: float + i_loss: float + reg_loss: float + + +class DiscreteBCQPolicy(DiscreteQLearningPolicy): + def __init__( + self, + *, + model: torch.nn.Module, + imitator: torch.nn.Module, + target_update_freq: int = 8000, + unlikely_action_threshold: float = 0.3, + action_space: gym.spaces.Discrete, + observation_space: gym.Space | None = None, + eps_inference: float = 0.0, + ) -> None: + """ + :param model: a model following the rules (s_B -> action_values_BA) + :param imitator: a model following the rules (s -> imitation_logits) + :param target_update_freq: the number of training iterations between each complete update of + the target network. + Controls how frequently the target Q-network parameters are updated with the current + Q-network values. + A value of 0 disables the target network entirely, using only a single network for both + action selection and bootstrap targets. + Higher values provide more stable learning targets but slow down the propagation of new + value estimates. Lower positive values allow faster learning but may lead to instability + due to rapidly changing targets. + Typically set between 100-10000 for DQN variants, with exact values depending on environment + complexity. + :param unlikely_action_threshold: the threshold (tau) for unlikely + actions, as shown in Equ. (17) in the paper. + :param target_update_freq: the number of training iterations between each complete update of + the target network. + Controls how frequently the target Q-network parameters are updated with the current + Q-network values. + A value of 0 disables the target network entirely, using only a single network for both + action selection and bootstrap targets. + Higher values provide more stable learning targets but slow down the propagation of new + value estimates. Lower positive values allow faster learning but may lead to instability + due to rapidly changing targets. + Typically set between 100-10000 for DQN variants, with exact values depending on environment + complexity. + :param action_space: the environment's action space. + :param observation_space: the environment's observation space. + :param eps_inference: the epsilon value for epsilon-greedy exploration during inference, + i.e. non-training cases (such as evaluation during test steps). + The epsilon value is the probability of choosing a random action instead of the action + chosen by the policy. + A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full + exploration (fully random). + """ + super().__init__( + model=model, + action_space=action_space, + observation_space=observation_space, + eps_training=0.0, # no training data collection (offline) + eps_inference=eps_inference, + ) + self.imitator = imitator + assert ( + target_update_freq > 0 + ), f"BCQ needs target_update_freq>0 but got: {target_update_freq}." + assert ( + 0.0 <= unlikely_action_threshold < 1.0 + ), f"unlikely_action_threshold should be in [0, 1) but got: {unlikely_action_threshold}" + if unlikely_action_threshold > 0: + self._log_tau = math.log(unlikely_action_threshold) + else: + self._log_tau = -np.inf + + def forward( + self, + batch: ObsBatchProtocol, + state: Any | None = None, + model: nn.Module | None = None, + ) -> ImitationBatchProtocol: + if model is None: + model = self.model + q_value, state = model(batch.obs, state=state, info=batch.info) + imitation_logits, _ = self.imitator(batch.obs, state=state, info=batch.info) + + # mask actions for argmax + ratio = imitation_logits - imitation_logits.max(dim=-1, keepdim=True).values + mask = (ratio < self._log_tau).float() + act = (q_value - INF * mask).argmax(dim=-1) + + result = Batch( + act=act, + state=state, + q_value=q_value, + imitation_logits=imitation_logits, + logits=imitation_logits, + ) + return cast(ImitationBatchProtocol, result) + + +class DiscreteBCQ( + OfflineAlgorithm[DiscreteBCQPolicy], + LaggedNetworkFullUpdateAlgorithmMixin, +): + """Implementation of the discrete batch-constrained deep Q-learning (BCQ) algorithm. arXiv:1910.01708.""" + + def __init__( + self, + *, + policy: DiscreteBCQPolicy, + optim: OptimizerFactory, + gamma: float = 0.99, + n_step_return_horizon: int = 1, + target_update_freq: int = 8000, + imitation_logits_penalty: float = 1e-2, + ) -> None: + """ + :param policy: the policy + :param optim: the optimizer factory for the policy's model. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal + difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: + higher values reduce bias (by relying less on potentially inaccurate value estimates) + but increase variance (by incorporating more environmental stochasticity and reducing + the averaging effect). A value of 1 corresponds to standard TD learning with immediate + bootstrapping, while very large values approach Monte Carlo-like estimation that uses + complete episode returns. + :param target_update_freq: the number of training iterations between each complete update of + the target network. + Controls how frequently the target Q-network parameters are updated with the current + Q-network values. + A value of 0 disables the target network entirely, using only a single network for both + action selection and bootstrap targets. + Higher values provide more stable learning targets but slow down the propagation of new + value estimates. Lower positive values allow faster learning but may lead to instability + due to rapidly changing targets. + Typically set between 100-10000 for DQN variants, with exact values depending on environment + complexity. + :param imitation_logits_penalty: regularization weight for imitation + logits. + :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal + difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: + higher values reduce bias (by relying less on potentially inaccurate value estimates) + but increase variance (by incorporating more environmental stochasticity and reducing + the averaging effect). A value of 1 corresponds to standard TD learning with immediate + bootstrapping, while very large values approach Monte Carlo-like estimation that uses + complete episode returns. + :param target_update_freq: the number of training iterations between each complete update of + the target network. + Controls how frequently the target Q-network parameters are updated with the current + Q-network values. + A value of 0 disables the target network entirely, using only a single network for both + action selection and bootstrap targets. + Higher values provide more stable learning targets but slow down the propagation of new + value estimates. Lower positive values allow faster learning but may lead to instability + due to rapidly changing targets. + Typically set between 100-10000 for DQN variants, with exact values depending on environment + complexity. + """ + super().__init__( + policy=policy, + ) + LaggedNetworkFullUpdateAlgorithmMixin.__init__(self) + self.optim = self._create_optimizer(self.policy, optim) + assert 0.0 <= gamma <= 1.0, f"discount factor should be in [0, 1] but got: {gamma}" + self.gamma = gamma + assert ( + n_step_return_horizon > 0 + ), f"n_step_return_horizon should be greater than 0 but got: {n_step_return_horizon}" + self.n_step = n_step_return_horizon + self._target = target_update_freq > 0 + self.freq = target_update_freq + self._iter = 0 + if self._target: + self.model_old = self._add_lagged_network(self.policy.model) + self._weight_reg = imitation_logits_penalty + + def _preprocess_batch( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> BatchWithReturnsProtocol: + return self.compute_nstep_return( + batch=batch, + buffer=buffer, + indices=indices, + target_q_fn=self._target_q, + gamma=self.gamma, + n_step=self.n_step, + ) + + def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: + batch = buffer[indices] # batch.obs_next: s_{t+n} + next_obs_batch = Batch(obs=batch.obs_next, info=[None] * len(batch)) + # target_Q = Q_old(s_, argmax(Q_new(s_, *))) + act = self.policy(next_obs_batch).act + target_q, _ = self.model_old(batch.obs_next) + return target_q[np.arange(len(act)), act] + + def _update_with_batch( # type: ignore[override] + self, + batch: BatchWithReturnsProtocol, + ) -> DiscreteBCQTrainingStats: + if self._iter % self.freq == 0: + self._update_lagged_network_weights() + self._iter += 1 + + target_q = batch.returns.flatten() + result = self.policy(batch) + imitation_logits = result.imitation_logits + current_q = result.q_value[np.arange(len(target_q)), batch.act] + act = to_torch(batch.act, dtype=torch.long, device=target_q.device) + q_loss = F.smooth_l1_loss(current_q, target_q) + i_loss = F.nll_loss(F.log_softmax(imitation_logits, dim=-1), act) + reg_loss = imitation_logits.pow(2).mean() + loss = q_loss + i_loss + self._weight_reg * reg_loss + + self.optim.step(loss) + + return DiscreteBCQTrainingStats( + loss=loss.item(), + q_loss=q_loss.item(), + i_loss=i_loss.item(), + reg_loss=reg_loss.item(), + ) diff --git a/tianshou/algorithm/imitation/discrete_cql.py b/tianshou/algorithm/imitation/discrete_cql.py new file mode 100644 index 000000000..8ca39440b --- /dev/null +++ b/tianshou/algorithm/imitation/discrete_cql.py @@ -0,0 +1,113 @@ +from dataclasses import dataclass + +import numpy as np +import torch +import torch.nn.functional as F + +from tianshou.algorithm import QRDQN +from tianshou.algorithm.algorithm_base import OfflineAlgorithm +from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy +from tianshou.algorithm.modelfree.reinforce import SimpleLossTrainingStats +from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import to_torch +from tianshou.data.types import RolloutBatchProtocol + + +@dataclass(kw_only=True) +class DiscreteCQLTrainingStats(SimpleLossTrainingStats): + cql_loss: float + qr_loss: float + + +# NOTE: This uses diamond inheritance to convert from off-policy to offline +class DiscreteCQL(OfflineAlgorithm[QRDQNPolicy], QRDQN[QRDQNPolicy]): # type: ignore[misc] + """Implementation of discrete Conservative Q-Learning algorithm. arXiv:2006.04779.""" + + def __init__( + self, + *, + policy: QRDQNPolicy, + optim: OptimizerFactory, + min_q_weight: float = 10.0, + gamma: float = 0.99, + num_quantiles: int = 200, + n_step_return_horizon: int = 1, + target_update_freq: int = 0, + ) -> None: + """ + :param policy: the policy + :param optim: the optimizer factory for the policy's model. + :param min_q_weight: the weight for the cql loss. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + :param num_quantiles: the number of quantile midpoints in the inverse + cumulative distribution function of the value. + :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal + difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: + higher values reduce bias (by relying less on potentially inaccurate value estimates) + but increase variance (by incorporating more environmental stochasticity and reducing + the averaging effect). A value of 1 corresponds to standard TD learning with immediate + bootstrapping, while very large values approach Monte Carlo-like estimation that uses + complete episode returns. + :param target_update_freq: the number of training iterations between each complete update of + the target network. + Controls how frequently the target Q-network parameters are updated with the current + Q-network values. + A value of 0 disables the target network entirely, using only a single network for both + action selection and bootstrap targets. + Higher values provide more stable learning targets but slow down the propagation of new + value estimates. Lower positive values allow faster learning but may lead to instability + due to rapidly changing targets. + Typically set between 100-10000 for DQN variants, with exact values depending on environment + complexity. + """ + QRDQN.__init__( + self, + policy=policy, + optim=optim, + gamma=gamma, + num_quantiles=num_quantiles, + n_step_return_horizon=n_step_return_horizon, + target_update_freq=target_update_freq, + ) + self.min_q_weight = min_q_weight + + def _update_with_batch( + self, + batch: RolloutBatchProtocol, + ) -> DiscreteCQLTrainingStats: + self._periodically_update_lagged_network_weights() + weight = batch.pop("weight", 1.0) + all_dist = self.policy(batch).logits + act = to_torch(batch.act, dtype=torch.long, device=all_dist.device) + curr_dist = all_dist[np.arange(len(act)), act, :].unsqueeze(2) + target_dist = batch.returns.unsqueeze(1) + # calculate each element's difference between curr_dist and target_dist + dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") + huber_loss = ( + (dist_diff * (self.tau_hat - (target_dist - curr_dist).detach().le(0.0).float()).abs()) + .sum(-1) + .mean(1) + ) + qr_loss = (huber_loss * weight).mean() + # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ + # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 + batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer + # add CQL loss + q = self.policy.compute_q_value(all_dist, None) + dataset_expec = q.gather(1, act.unsqueeze(1)).mean() + negative_sampling = q.logsumexp(1).mean() + min_q_loss = negative_sampling - dataset_expec + loss = qr_loss + min_q_loss * self.min_q_weight + self.optim.step(loss) + + return DiscreteCQLTrainingStats( + loss=loss.item(), + qr_loss=qr_loss.item(), + cql_loss=min_q_loss.item(), + ) diff --git a/tianshou/algorithm/imitation/discrete_crr.py b/tianshou/algorithm/imitation/discrete_crr.py new file mode 100644 index 000000000..1a344a65b --- /dev/null +++ b/tianshou/algorithm/imitation/discrete_crr.py @@ -0,0 +1,167 @@ +from dataclasses import dataclass +from typing import Literal + +import numpy as np +import torch +import torch.nn.functional as F +from torch.distributions import Categorical +from torch.nn import ModuleList + +from tianshou.algorithm.algorithm_base import ( + LaggedNetworkFullUpdateAlgorithmMixin, + OfflineAlgorithm, +) +from tianshou.algorithm.modelfree.reinforce import ( + DiscountedReturnComputation, + DiscreteActorPolicy, + SimpleLossTrainingStats, +) +from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import ReplayBuffer, to_torch, to_torch_as +from tianshou.data.types import BatchWithReturnsProtocol, RolloutBatchProtocol +from tianshou.utils.lagged_network import EvalModeModuleWrapper +from tianshou.utils.net.discrete import DiscreteCritic + + +@dataclass +class DiscreteCRRTrainingStats(SimpleLossTrainingStats): + actor_loss: float + critic_loss: float + cql_loss: float + + +class DiscreteCRR( + OfflineAlgorithm[DiscreteActorPolicy], + LaggedNetworkFullUpdateAlgorithmMixin, +): + r"""Implementation of discrete Critic Regularized Regression. arXiv:2006.15134.""" + + def __init__( + self, + *, + policy: DiscreteActorPolicy, + critic: torch.nn.Module | DiscreteCritic, + optim: OptimizerFactory, + gamma: float = 0.99, + policy_improvement_mode: Literal["exp", "binary", "all"] = "exp", + ratio_upper_bound: float = 20.0, + beta: float = 1.0, + min_q_weight: float = 10.0, + target_update_freq: int = 0, + return_standardization: bool = False, + ) -> None: + r""" + :param policy: the policy + :param critic: the action-value critic (i.e., Q function) + network. (s -> Q(s, \*)) + :param optim: the optimizer factory for the policy's actor network and the critic networks. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + :param str policy_improvement_mode: type of the weight function f. Possible + values: "binary"/"exp"/"all". + :param ratio_upper_bound: when policy_improvement_mode is "exp", the value + of the exp function is upper-bounded by this parameter. + :param beta: when policy_improvement_mode is "exp", this is the denominator + of the exp function. + :param min_q_weight: weight for CQL loss/regularizer. Default to 10. + :param target_update_freq: the number of training iterations between each complete update of + the target network. + Controls how frequently the target Q-network parameters are updated with the current + Q-network values. + A value of 0 disables the target network entirely, using only a single network for both + action selection and bootstrap targets. + Higher values provide more stable learning targets but slow down the propagation of new + value estimates. Lower positive values allow faster learning but may lead to instability + due to rapidly changing targets. + Typically set between 100-10000 for DQN variants, with exact values depending on environment + complexity. + :param return_standardization: whether to standardize episode returns + by subtracting the running mean and dividing by the running standard deviation. + Note that this is known to be detrimental to performance in many cases! + """ + super().__init__( + policy=policy, + ) + LaggedNetworkFullUpdateAlgorithmMixin.__init__(self) + self.discounted_return_computation = DiscountedReturnComputation( + gamma=gamma, + return_standardization=return_standardization, + ) + self.critic = critic + self.optim = self._create_optimizer(ModuleList([self.policy, self.critic]), optim) + self._target = target_update_freq > 0 + self._freq = target_update_freq + self._iter = 0 + self.actor_old: torch.nn.Module | EvalModeModuleWrapper + self.critic_old: torch.nn.Module | EvalModeModuleWrapper + if self._target: + self.actor_old = self._add_lagged_network(self.policy.actor) + self.critic_old = self._add_lagged_network(self.critic) + else: + self.actor_old = self.policy.actor + self.critic_old = self.critic + self._policy_improvement_mode = policy_improvement_mode + self._ratio_upper_bound = ratio_upper_bound + self._beta = beta + self._min_q_weight = min_q_weight + + def _preprocess_batch( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> BatchWithReturnsProtocol: + return self.discounted_return_computation.add_discounted_returns( + batch, + buffer, + indices, + ) + + def _update_with_batch( # type: ignore[override] + self, + batch: BatchWithReturnsProtocol, + ) -> DiscreteCRRTrainingStats: + if self._target and self._iter % self._freq == 0: + self._update_lagged_network_weights() + q_t = self.critic(batch.obs) + act = to_torch(batch.act, dtype=torch.long, device=q_t.device) + qa_t = q_t.gather(1, act.unsqueeze(1)) + # Critic loss + with torch.no_grad(): + target_a_t, _ = self.actor_old(batch.obs_next) + target_m = Categorical(logits=target_a_t) + q_t_target = self.critic_old(batch.obs_next) + rew = to_torch_as(batch.rew, q_t_target) + expected_target_q = (q_t_target * target_m.probs).sum(-1, keepdim=True) + expected_target_q[batch.done > 0] = 0.0 + target = rew.unsqueeze(1) + self.discounted_return_computation.gamma * expected_target_q + critic_loss = 0.5 * F.mse_loss(qa_t, target) + # Actor loss + act_target, _ = self.policy.actor(batch.obs) + dist = Categorical(logits=act_target) + expected_policy_q = (q_t * dist.probs).sum(-1, keepdim=True) + advantage = qa_t - expected_policy_q + if self._policy_improvement_mode == "binary": + actor_loss_coef = (advantage > 0).float() + elif self._policy_improvement_mode == "exp": + actor_loss_coef = (advantage / self._beta).exp().clamp(0, self._ratio_upper_bound) + else: + actor_loss_coef = 1.0 # effectively behavior cloning + actor_loss = (-dist.log_prob(act) * actor_loss_coef).mean() + # CQL loss/regularizer + min_q_loss = (q_t.logsumexp(1) - qa_t).mean() + loss = actor_loss + critic_loss + self._min_q_weight * min_q_loss + self.optim.step(loss) + self._iter += 1 + + return DiscreteCRRTrainingStats( + loss=loss.item(), + actor_loss=actor_loss.item(), + critic_loss=critic_loss.item(), + cql_loss=min_q_loss.item(), + ) diff --git a/tianshou/algorithm/imitation/gail.py b/tianshou/algorithm/imitation/gail.py new file mode 100644 index 000000000..648358d2c --- /dev/null +++ b/tianshou/algorithm/imitation/gail.py @@ -0,0 +1,248 @@ +from dataclasses import dataclass + +import numpy as np +import torch +import torch.nn.functional as F + +from tianshou.algorithm.modelfree.a2c import A2CTrainingStats +from tianshou.algorithm.modelfree.ppo import PPO +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy +from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import ( + ReplayBuffer, + SequenceSummaryStats, + to_numpy, + to_torch, +) +from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol +from tianshou.utils.net.common import ModuleWithVectorOutput +from tianshou.utils.net.continuous import ContinuousCritic +from tianshou.utils.net.discrete import DiscreteCritic +from tianshou.utils.torch_utils import torch_device + + +@dataclass(kw_only=True) +class GailTrainingStats(A2CTrainingStats): + disc_loss: SequenceSummaryStats + acc_pi: SequenceSummaryStats + acc_exp: SequenceSummaryStats + + +class GAIL(PPO): + """Implementation of Generative Adversarial Imitation Learning. arXiv:1606.03476.""" + + def __init__( + self, + *, + policy: ProbabilisticActorPolicy, + critic: torch.nn.Module | ContinuousCritic | DiscreteCritic, + optim: OptimizerFactory, + expert_buffer: ReplayBuffer, + disc_net: torch.nn.Module, + disc_optim: OptimizerFactory, + disc_update_num: int = 4, + eps_clip: float = 0.2, + dual_clip: float | None = None, + value_clip: bool = False, + advantage_normalization: bool = True, + recompute_advantage: bool = False, + vf_coef: float = 0.5, + ent_coef: float = 0.01, + max_grad_norm: float | None = None, + gae_lambda: float = 0.95, + max_batchsize: int = 256, + gamma: float = 0.99, + return_scaling: bool = False, + ) -> None: + """ + :param policy: the policy (which must use an actor with known output dimension, i.e. + any Tianshou `Actor` implementation or other subclass of `ModuleWithVectorOutput`). + :param critic: the critic network. (s -> V(s)) + :param optim: the optimizer factory for the actor and critic networks. + :param expert_buffer: the replay buffer containing expert experience. + :param disc_net: the discriminator neural network that distinguishes between expert and policy behaviors. + Takes concatenated state-action pairs [obs, act] as input and outputs an unbounded logit value. + The raw output is transformed in the algorithm using sigmoid functions: o(output) for expert + probability and -log(1-o(-output)) for policy rewards. + Positive output values indicate the discriminator believes the behavior is from an expert. + Negative output values indicate the discriminator believes the behavior is from the policy. + The network architecture should end with a linear layer of output size 1 without any + activation function, as sigmoid operations are applied separately. + :param disc_optim: the optimizer factory for the discriminator network. + :param disc_update_num: the number of discriminator update steps performed for each policy update step. + Controls the learning dynamics between the policy and the discriminator. + Higher values strengthen the discriminator relative to the policy, potentially improving + the quality of the reward signal but slowing down training. + Lower values allow faster policy updates but may result in a weaker discriminator that fails + to properly distinguish between expert and policy behaviors. + Typical values range from 1 to 10, with the original GAIL paper using multiple discriminator + updates per policy update. + :param eps_clip: determines the range of allowed change in the policy during a policy update: + The ratio of action probabilities indicated by the new and old policy is + constrained to stay in the interval [1 - eps_clip, 1 + eps_clip]. + Small values thus force the new policy to stay close to the old policy. + Typical values range between 0.1 and 0.3, the value of 0.2 is recommended + in the original PPO paper. + The optimal value depends on the environment; more stochastic environments may + need larger values. + :param dual_clip: a clipping parameter (denoted as c in the literature) that prevents + excessive pessimism in policy updates for negative-advantage actions. + Excessive pessimism occurs when the policy update too strongly reduces the probability + of selecting actions that led to negative advantages, potentially eliminating useful + actions based on limited negative experiences. + When enabled (c > 1), the objective for negative advantages becomes: + max(min(r(θ)A, clip(r(θ), 1-ε, 1+ε)A), c*A), where min(r(θ)A, clip(r(θ), 1-ε, 1+ε)A) + is the original single-clipping objective determined by `eps_clip`. + This creates a floor on negative policy gradients, maintaining some probability + of exploring actions despite initial negative outcomes. + Larger values (e.g., 2.0 to 5.0) maintain more exploration, while values closer + to 1.0 provide less protection against pessimistic updates. + Set to None to disable dual clipping. + :param value_clip: flag indicating whether to enable clipping for value function updates. + When enabled, restricts how much the value function estimate can change from its + previous prediction, using the same clipping range as the policy updates (eps_clip). + This stabilizes training by preventing large fluctuations in value estimates, + particularly useful in environments with high reward variance. + The clipped value loss uses a pessimistic approach, taking the maximum of the + original and clipped value errors: + max((returns - value)², (returns - v_clipped)²) + Setting to True often improves training stability but may slow convergence. + Implementation follows the approach mentioned in arXiv:1811.02553v3 Sec. 4.1. + :param advantage_normalization: whether to do per mini-batch advantage + normalization. + :param recompute_advantage: whether to recompute advantage every update + repeat according to https://arxiv.org/pdf/2006.05990.pdf Sec. 3.5. + :param vf_coef: coefficient that weights the value loss relative to the actor loss in + the overall loss function. + Higher values prioritize accurate value function estimation over policy improvement. + Controls the trade-off between policy optimization and value function fitting. + Typically set between 0.5 and 1.0 for most actor-critic implementations. + :param ent_coef: coefficient that weights the entropy bonus relative to the actor loss. + Controls the exploration-exploitation trade-off by encouraging policy entropy. + Higher values promote more exploration by encouraging a more uniform action distribution. + Lower values focus more on exploitation of the current policy's knowledge. + Typically set between 0.01 and 0.05 for most actor-critic implementations. + :param max_grad_norm: the maximum L2 norm threshold for gradient clipping. + When not None, gradients will be rescaled using to ensure their L2 norm does not + exceed this value. This prevents exploding gradients and stabilizes training by + limiting the size of parameter updates. + Set to None to disable gradient clipping. + :param gae_lambda: the lambda parameter in [0, 1] for generalized advantage estimation (GAE). + Controls the bias-variance tradeoff in advantage estimates, acting as a + weighting factor for combining different n-step advantage estimators. Higher values + (closer to 1) reduce bias but increase variance by giving more weight to longer + trajectories, while lower values (closer to 0) reduce variance but increase bias + by relying more on the immediate TD error and value function estimates. At λ=0, + GAE becomes equivalent to the one-step TD error (high bias, low variance); at λ=1, + it becomes equivalent to Monte Carlo advantage estimation (low bias, high variance). + Intermediate values create a weighted average of n-step returns, with exponentially + decaying weights for longer-horizon returns. Typically set between 0.9 and 0.99 for + most policy gradient methods. + :param max_batchsize: the maximum number of samples to process at once when computing + generalized advantage estimation (GAE) and value function predictions. + Controls memory usage by breaking large batches into smaller chunks processed sequentially. + Higher values may increase speed but require more GPU/CPU memory; lower values + reduce memory requirements but may increase computation time. Should be adjusted + based on available hardware resources and total batch size of your training data. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + :param return_scaling: flag indicating whether to enable scaling of estimated returns by + dividing them by their running standard deviation without centering the mean. + This reduces the magnitude variation of advantages across different episodes while + preserving their signs and relative ordering. + The use of running statistics (rather than batch-specific scaling) means that early + training experiences may be scaled differently than later ones as the statistics evolve. + When enabled, this improves training stability in environments with highly variable + reward scales and makes the algorithm less sensitive to learning rate settings. + However, it may reduce the algorithm's ability to distinguish between episodes with + different absolute return magnitudes. + Best used in environments where the relative ordering of actions is more important + than the absolute scale of returns. + """ + super().__init__( + policy=policy, + critic=critic, + optim=optim, + eps_clip=eps_clip, + dual_clip=dual_clip, + value_clip=value_clip, + advantage_normalization=advantage_normalization, + recompute_advantage=recompute_advantage, + vf_coef=vf_coef, + ent_coef=ent_coef, + max_grad_norm=max_grad_norm, + gae_lambda=gae_lambda, + max_batchsize=max_batchsize, + gamma=gamma, + return_scaling=return_scaling, + ) + self.disc_net = disc_net + self.disc_optim = self._create_optimizer(self.disc_net, disc_optim) + self.disc_update_num = disc_update_num + self.expert_buffer = expert_buffer + actor = self.policy.actor + if not isinstance(actor, ModuleWithVectorOutput): + raise TypeError("GAIL requires the policy to use an actor with known output dimension.") + self.action_dim = actor.get_output_dim() + + def _preprocess_batch( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> LogpOldProtocol: + """Pre-process the data from the provided replay buffer. + + Used in :meth:`update`. Check out :ref:`process_fn` for more information. + """ + # update reward + with torch.no_grad(): + batch.rew = to_numpy(-F.logsigmoid(-self.disc(batch)).flatten()) + return super()._preprocess_batch(batch, buffer, indices) + + def disc(self, batch: RolloutBatchProtocol) -> torch.Tensor: + device = torch_device(self.disc_net) + obs = to_torch(batch.obs, device=device) + act = to_torch(batch.act, device=device) + return self.disc_net(torch.cat([obs, act], dim=1)) + + def _update_with_batch( # type: ignore[override] + self, + batch: LogpOldProtocol, + batch_size: int | None, + repeat: int, + ) -> GailTrainingStats: + # update discriminator + losses = [] + acc_pis = [] + acc_exps = [] + bsz = len(batch) // self.disc_update_num + for b in batch.split(bsz, merge_last=True): + logits_pi = self.disc(b) + exp_b = self.expert_buffer.sample(bsz)[0] + logits_exp = self.disc(exp_b) + loss_pi = -F.logsigmoid(-logits_pi).mean() + loss_exp = -F.logsigmoid(logits_exp).mean() + loss_disc = loss_pi + loss_exp + self.disc_optim.step(loss_disc) + losses.append(loss_disc.item()) + acc_pis.append((logits_pi < 0).float().mean().item()) + acc_exps.append((logits_exp > 0).float().mean().item()) + # update policy + ppo_loss_stat = super()._update_with_batch(batch, batch_size, repeat) + + disc_losses_summary = SequenceSummaryStats.from_sequence(losses) + acc_pi_summary = SequenceSummaryStats.from_sequence(acc_pis) + acc_exps_summary = SequenceSummaryStats.from_sequence(acc_exps) + + return GailTrainingStats( + **ppo_loss_stat.__dict__, + disc_loss=disc_losses_summary, + acc_pi=acc_pi_summary, + acc_exp=acc_exps_summary, + ) diff --git a/tianshou/algorithm/imitation/imitation_base.py b/tianshou/algorithm/imitation/imitation_base.py new file mode 100644 index 000000000..b21bd3132 --- /dev/null +++ b/tianshou/algorithm/imitation/imitation_base.py @@ -0,0 +1,183 @@ +from dataclasses import dataclass +from typing import Any, Literal, cast + +import gymnasium as gym +import numpy as np +import torch +import torch.nn.functional as F + +from tianshou.algorithm import Algorithm +from tianshou.algorithm.algorithm_base import ( + OfflineAlgorithm, + OffPolicyAlgorithm, + Policy, + TrainingStats, +) +from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import Batch, to_torch +from tianshou.data.batch import BatchProtocol +from tianshou.data.types import ( + ModelOutputBatchProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) + +# Dimension Naming Convention +# B - Batch Size +# A - Action +# D - Dist input (usually 2, loc and scale) +# H - Dimension of hidden, can be None + + +@dataclass(kw_only=True) +class ImitationTrainingStats(TrainingStats): + loss: float = 0.0 + + +class ImitationPolicy(Policy): + def __init__( + self, + *, + actor: torch.nn.Module, + action_space: gym.Space, + observation_space: gym.Space | None = None, + action_scaling: bool = False, + action_bound_method: Literal["clip", "tanh"] | None = "clip", + ): + """ + :param actor: a model following the rules (s -> a) + :param action_space: the environment's action_space. + :param observation_space: the environment's observation space + :param action_scaling: flag indicating whether, for continuous action spaces, actions + should be scaled from the standard neural network output range [-1, 1] to the + environment's action space range [action_space.low, action_space.high]. + This applies to continuous action spaces only (gym.spaces.Box) and has no effect + for discrete spaces. + When enabled, policy outputs are expected to be in the normalized range [-1, 1] + (after bounding), and are then linearly transformed to the actual required range. + This improves neural network training stability, allows the same algorithm to work + across environments with different action ranges, and standardizes exploration + strategies. + Should be disabled if the actor model already produces outputs in the correct range. + :param action_bound_method: the method used for bounding actions in continuous action spaces + to the range [-1, 1] before scaling them to the environment's action space (provided + that `action_scaling` is enabled). + This applies to continuous action spaces only (`gym.spaces.Box`) and should be set to None + for discrete spaces. + When set to "clip", actions exceeding the [-1, 1] range are simply clipped to this + range. When set to "tanh", a hyperbolic tangent function is applied, which smoothly + constrains outputs to [-1, 1] while preserving gradients. + The choice of bounding method affects both training dynamics and exploration behavior. + Clipping provides hard boundaries but may create plateau regions in the gradient + landscape, while tanh provides smoother transitions but can compress sensitivity + near the boundaries. + Should be set to None if the actor model inherently produces bounded outputs. + Typically used together with `action_scaling=True`. + """ + super().__init__( + action_space=action_space, + observation_space=observation_space, + action_scaling=action_scaling, + action_bound_method=action_bound_method, + ) + self.actor = actor + + def forward( + self, + batch: ObsBatchProtocol, + state: dict | BatchProtocol | np.ndarray | None = None, + **kwargs: Any, + ) -> ModelOutputBatchProtocol: + # TODO - ALGO-REFACTORING: marked for refactoring when Algorithm abstraction is introduced + if self.action_type == "discrete": + # If it's discrete, the "actor" is usually a critic that maps obs to action_values + # which then could be turned into logits or a Categorigal + action_values_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) + act_B = action_values_BA.argmax(dim=1) + result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH) + elif self.action_type == "continuous": + # If it's continuous, the actor would usually deliver something like loc, scale determining a + # Gaussian dist + dist_input_BD, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) + result = Batch(logits=dist_input_BD, act=dist_input_BD, state=hidden_BH) + else: + raise RuntimeError(f"Unknown {self.action_type=}, this shouldn't have happened!") + return cast(ModelOutputBatchProtocol, result) + + +class ImitationLearningAlgorithmMixin: + def _imitation_update( + self, + batch: RolloutBatchProtocol, + policy: ImitationPolicy, + optim: Algorithm.Optimizer, + ) -> ImitationTrainingStats: + if policy.action_type == "continuous": # regression + act = policy(batch).act + act_target = to_torch(batch.act, dtype=torch.float32, device=act.device) + loss = F.mse_loss(act, act_target) + elif policy.action_type == "discrete": # classification + act = F.log_softmax(policy(batch).logits, dim=-1) + act_target = to_torch(batch.act, dtype=torch.long, device=act.device) + loss = F.nll_loss(act, act_target) + else: + raise ValueError(policy.action_type) + optim.step(loss) + + return ImitationTrainingStats(loss=loss.item()) + + +class OffPolicyImitationLearning( + OffPolicyAlgorithm[ImitationPolicy], + ImitationLearningAlgorithmMixin, +): + """Implementation of off-policy vanilla imitation learning.""" + + def __init__( + self, + *, + policy: ImitationPolicy, + optim: OptimizerFactory, + ) -> None: + """ + :param policy: the policy + :param optim: the optimizer factory + """ + super().__init__( + policy=policy, + ) + self.optim = self._create_optimizer(self.policy, optim) + + def _update_with_batch( + self, + batch: RolloutBatchProtocol, + ) -> ImitationTrainingStats: + return self._imitation_update(batch, self.policy, self.optim) + + +class OfflineImitationLearning( + OfflineAlgorithm[ImitationPolicy], + ImitationLearningAlgorithmMixin, +): + """Implementation of offline vanilla imitation learning.""" + + def __init__( + self, + *, + policy: ImitationPolicy, + optim: OptimizerFactory, + ) -> None: + """ + :param policy: the policy + :param optim: the optimizer factory + """ + super().__init__( + policy=policy, + ) + self.optim = self._create_optimizer(self.policy, optim) + + def _update_with_batch( + self, + batch: RolloutBatchProtocol, + ) -> ImitationTrainingStats: + return self._imitation_update(batch, self.policy, self.optim) diff --git a/tianshou/algorithm/imitation/td3_bc.py b/tianshou/algorithm/imitation/td3_bc.py new file mode 100644 index 000000000..5ccbbe0fb --- /dev/null +++ b/tianshou/algorithm/imitation/td3_bc.py @@ -0,0 +1,127 @@ +import torch +import torch.nn.functional as F + +from tianshou.algorithm import TD3 +from tianshou.algorithm.algorithm_base import OfflineAlgorithm +from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy +from tianshou.algorithm.modelfree.td3 import TD3TrainingStats +from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import to_torch_as +from tianshou.data.types import RolloutBatchProtocol + + +# NOTE: This uses diamond inheritance to convert from off-policy to offline +class TD3BC(OfflineAlgorithm[ContinuousDeterministicPolicy], TD3): # type: ignore + """Implementation of TD3+BC. arXiv:2106.06860.""" + + def __init__( + self, + *, + policy: ContinuousDeterministicPolicy, + policy_optim: OptimizerFactory, + critic: torch.nn.Module, + critic_optim: OptimizerFactory, + critic2: torch.nn.Module | None = None, + critic2_optim: OptimizerFactory | None = None, + tau: float = 0.005, + gamma: float = 0.99, + policy_noise: float = 0.2, + update_actor_freq: int = 2, + noise_clip: float = 0.5, + alpha: float = 2.5, + n_step_return_horizon: int = 1, + ) -> None: + """ + :param policy: the policy + :param policy_optim: the optimizer factory for the policy's model. + :param critic: the first critic network. (s, a -> Q(s, a)) + :param critic_optim: the optimizer factory for the first critic network. + :param critic2: the second critic network. (s, a -> Q(s, a)). + If None, copy the first critic (via deepcopy). + :param critic2_optim: the optimizer factory for the second critic network. + If None, use the first critic's factory. + :param tau: the soft update coefficient for target networks, controlling the rate at which + target networks track the learned networks. + When the parameters of the target network are updated with the current (source) network's + parameters, a weighted average is used: target = tau * source + (1 - tau) * target. + Smaller values (closer to 0) create more stable but slower learning as target networks + change more gradually. Higher values (closer to 1) allow faster learning but may reduce + stability. + Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + :param exploration_noise: add noise to action for exploration. + This is useful when solving "hard exploration" problems. + "default" is equivalent to GaussianNoise(sigma=0.1). + :param policy_noise: scaling factor for the Gaussian noise added to target policy actions. + This parameter implements target policy smoothing, a regularization technique described in the TD3 paper. + The noise is sampled from a normal distribution and multiplied by this value before being added to actions. + Higher values increase exploration in the target policy, helping to address function approximation error. + The added noise is optionally clipped to a range determined by the noise_clip parameter. + Typically set between 0.1 and 0.5 relative to the action scale of the environment. + :param update_actor_freq: the frequency of actor network updates relative to critic network updates + (the actor network is only updated once for every `update_actor_freq` critic updates). + This implements the "delayed" policy updates from the TD3 algorithm, where the actor is + updated less frequently than the critics. + Higher values (e.g., 2-5) help stabilize training by allowing the critic to become more + accurate before updating the policy. + The default value of 2 follows the original TD3 paper's recommendation of updating the + policy at half the rate of the Q-functions. + :param noise_clip: defines the maximum absolute value of the noise added to target policy actions, i.e. noise values + are clipped to the range [-noise_clip, noise_clip] (after generating and scaling the noise + via `policy_noise`). + This parameter implements bounded target policy smoothing as described in the TD3 paper. + It prevents extreme noise values from causing unrealistic target values during training. + Setting it 0.0 (or a negative value) disables clipping entirely. + It is typically set to about twice the `policy_noise` value (e.g. 0.5 when `policy_noise` is 0.2). + :param alpha: the value of alpha, which controls the weight for TD3 learning + relative to behavior cloning. + """ + TD3.__init__( + self, + policy=policy, + policy_optim=policy_optim, + critic=critic, + critic_optim=critic_optim, + critic2=critic2, + critic2_optim=critic2_optim, + tau=tau, + gamma=gamma, + policy_noise=policy_noise, + noise_clip=noise_clip, + update_actor_freq=update_actor_freq, + n_step_return_horizon=n_step_return_horizon, + ) + self.alpha = alpha + + def _update_with_batch(self, batch: RolloutBatchProtocol) -> TD3TrainingStats: + # critic 1&2 + td1, critic1_loss = self._minimize_critic_squared_loss( + batch, self.critic, self.critic_optim + ) + td2, critic2_loss = self._minimize_critic_squared_loss( + batch, self.critic2, self.critic2_optim + ) + batch.weight = (td1 + td2) / 2.0 # prio-buffer + + # actor + if self._cnt % self.update_actor_freq == 0: + act = self.policy(batch, eps=0.0).act + q_value = self.critic(batch.obs, act) + lmbda = self.alpha / q_value.abs().mean().detach() + actor_loss = -lmbda * q_value.mean() + F.mse_loss(act, to_torch_as(batch.act, act)) + self._last = actor_loss.item() + self.policy_optim.step(actor_loss) + self._update_lagged_network_weights() + self._cnt += 1 + + return TD3TrainingStats( + actor_loss=self._last, + critic1_loss=critic1_loss.item(), + critic2_loss=critic2_loss.item(), + ) diff --git a/tianshou/policy/modelbased/__init__.py b/tianshou/algorithm/modelbased/__init__.py similarity index 100% rename from tianshou/policy/modelbased/__init__.py rename to tianshou/algorithm/modelbased/__init__.py diff --git a/tianshou/algorithm/modelbased/icm.py b/tianshou/algorithm/modelbased/icm.py new file mode 100644 index 000000000..72a3beedf --- /dev/null +++ b/tianshou/algorithm/modelbased/icm.py @@ -0,0 +1,261 @@ +import numpy as np +import torch +import torch.nn.functional as F + +from tianshou.algorithm import Algorithm +from tianshou.algorithm.algorithm_base import ( + OffPolicyAlgorithm, + OffPolicyWrapperAlgorithm, + OnPolicyAlgorithm, + OnPolicyWrapperAlgorithm, + TPolicy, + TrainingStats, + TrainingStatsWrapper, +) +from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch +from tianshou.data.batch import BatchProtocol +from tianshou.data.types import RolloutBatchProtocol +from tianshou.utils.net.discrete import IntrinsicCuriosityModule + + +class ICMTrainingStats(TrainingStatsWrapper): + def __init__( + self, + wrapped_stats: TrainingStats, + *, + icm_loss: float, + icm_forward_loss: float, + icm_inverse_loss: float, + ) -> None: + self.icm_loss = icm_loss + self.icm_forward_loss = icm_forward_loss + self.icm_inverse_loss = icm_inverse_loss + super().__init__(wrapped_stats) + + +class _ICMMixin: + """Implementation of the Intrinsic Curiosity Module (ICM) algorithm. arXiv:1705.05363.""" + + def __init__( + self, + *, + model: IntrinsicCuriosityModule, + optim: Algorithm.Optimizer, + lr_scale: float, + reward_scale: float, + forward_loss_weight: float, + ) -> None: + """ + :param model: the ICM model. + :param optim: the optimizer factory. + :param lr_scale: a multiplier that effectively scales the learning rate for the ICM model updates. + Higher values increase the step size during optimization of the intrinsic curiosity module. + Lower values decrease the step size, leading to more gradual learning of the curiosity mechanism. + This parameter offers an alternative to directly adjusting the base learning rate in the optimizer. + :param reward_scale: a multiplier that controls the magnitude of intrinsic rewards (curiosity-driven + rewards generated by the agent itself) relative to extrinsic rewards (task-specific rewards provided + by the environment). + Scales the prediction error (curiosity signal) before adding it to the environment rewards. + Higher values increase the agent's motivation to explore novel states. + Lower values decrease the influence of curiosity relative to task-specific rewards. + Setting to zero effectively disables intrinsic motivation while still learning the ICM model. + :param forward_loss_weight: relative importance in [0, 1] of the forward model loss in relation to + the inverse model loss. + Controls the trade-off between state prediction and action prediction in the ICM algorithm. + Higher values (> 0.5) prioritize learning to predict next states given current states and actions. + Lower values (< 0.5) prioritize learning to predict actions given current and next states. + The total loss combines both components: + (1-forward_loss_weight)*inverse_loss + forward_loss_weight*forward_loss. + """ + self.model = model + self.optim = optim + self.lr_scale = lr_scale + self.reward_scale = reward_scale + self.forward_loss_weight = forward_loss_weight + + def _icm_preprocess_batch( + self, + batch: RolloutBatchProtocol, + ) -> None: + mse_loss, act_hat = self.model(batch.obs, batch.act, batch.obs_next) + batch.policy = Batch(orig_rew=batch.rew, act_hat=act_hat, mse_loss=mse_loss) + batch.rew += to_numpy(mse_loss * self.reward_scale) + + @staticmethod + def _icm_postprocess_batch(batch: BatchProtocol) -> None: + # restore original reward + batch.rew = batch.policy.orig_rew + + def _icm_update( + self, + batch: RolloutBatchProtocol, + original_stats: TrainingStats, + ) -> ICMTrainingStats: + act_hat = batch.policy.act_hat + act = to_torch(batch.act, dtype=torch.long, device=act_hat.device) + inverse_loss = F.cross_entropy(act_hat, act).mean() + forward_loss = batch.policy.mse_loss.mean() + loss = ( + (1 - self.forward_loss_weight) * inverse_loss + self.forward_loss_weight * forward_loss + ) * self.lr_scale + self.optim.step(loss) + + return ICMTrainingStats( + original_stats, + icm_loss=loss.item(), + icm_forward_loss=forward_loss.item(), + icm_inverse_loss=inverse_loss.item(), + ) + + +class ICMOffPolicyWrapper(OffPolicyWrapperAlgorithm[TPolicy], _ICMMixin): + """Implementation of the Intrinsic Curiosity Module (ICM) algorithm for off-policy learning. arXiv:1705.05363.""" + + def __init__( + self, + *, + wrapped_algorithm: OffPolicyAlgorithm[TPolicy], + model: IntrinsicCuriosityModule, + optim: OptimizerFactory, + lr_scale: float, + reward_scale: float, + forward_loss_weight: float, + ) -> None: + """ + :param wrapped_algorithm: the base algorithm to which we want to add the ICM. + :param model: the ICM model. + :param optim: the optimizer factory for the ICM model. + :param lr_scale: a multiplier that effectively scales the learning rate for the ICM model updates. + Higher values increase the step size during optimization of the intrinsic curiosity module. + Lower values decrease the step size, leading to more gradual learning of the curiosity mechanism. + This parameter offers an alternative to directly adjusting the base learning rate in the optimizer. + :param reward_scale: a multiplier that controls the magnitude of intrinsic rewards (curiosity-driven + rewards generated by the agent itself) relative to extrinsic rewards (task-specific rewards provided + by the environment). + Scales the prediction error (curiosity signal) before adding it to the environment rewards. + Higher values increase the agent's motivation to explore novel states. + Lower values decrease the influence of curiosity relative to task-specific rewards. + Setting to zero effectively disables intrinsic motivation while still learning the ICM model. + :param forward_loss_weight: relative importance in [0, 1] of the forward model loss in relation to + the inverse model loss. + Controls the trade-off between state prediction and action prediction in the ICM algorithm. + Higher values (> 0.5) prioritize learning to predict next states given current states and actions. + Lower values (< 0.5) prioritize learning to predict actions given current and next states. + The total loss combines both components: + (1-forward_loss_weight)*inverse_loss + forward_loss_weight*forward_loss. + """ + OffPolicyWrapperAlgorithm.__init__( + self, + wrapped_algorithm=wrapped_algorithm, + ) + _ICMMixin.__init__( + self, + model=model, + optim=self._create_optimizer(model, optim), + lr_scale=lr_scale, + reward_scale=reward_scale, + forward_loss_weight=forward_loss_weight, + ) + + def _preprocess_batch( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> RolloutBatchProtocol: + self._icm_preprocess_batch(batch) + return super()._preprocess_batch(batch, buffer, indices) + + def _postprocess_batch( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> None: + super()._postprocess_batch(batch, buffer, indices) + self._icm_postprocess_batch(batch) + + def _wrapper_update_with_batch( + self, + batch: RolloutBatchProtocol, + original_stats: TrainingStats, + ) -> ICMTrainingStats: + return self._icm_update(batch, original_stats) + + +class ICMOnPolicyWrapper(OnPolicyWrapperAlgorithm[TPolicy], _ICMMixin): + """Implementation of the Intrinsic Curiosity Module (ICM) algorithm for on-policy learning. arXiv:1705.05363.""" + + def __init__( + self, + *, + wrapped_algorithm: OnPolicyAlgorithm[TPolicy], + model: IntrinsicCuriosityModule, + optim: OptimizerFactory, + lr_scale: float, + reward_scale: float, + forward_loss_weight: float, + ) -> None: + """ + :param wrapped_algorithm: the base algorithm to which we want to add the ICM. + :param model: the ICM model. + :param optim: the optimizer factory for the ICM model. + :param lr_scale: a multiplier that effectively scales the learning rate for the ICM model updates. + Higher values increase the step size during optimization of the intrinsic curiosity module. + Lower values decrease the step size, leading to more gradual learning of the curiosity mechanism. + This parameter offers an alternative to directly adjusting the base learning rate in the optimizer. + :param reward_scale: a multiplier that controls the magnitude of intrinsic rewards (curiosity-driven + rewards generated by the agent itself) relative to extrinsic rewards (task-specific rewards provided + by the environment). + Scales the prediction error (curiosity signal) before adding it to the environment rewards. + Higher values increase the agent's motivation to explore novel states. + Lower values decrease the influence of curiosity relative to task-specific rewards. + Setting to zero effectively disables intrinsic motivation while still learning the ICM model. + :param forward_loss_weight: relative importance in [0, 1] of the forward model loss in relation to + the inverse model loss. + Controls the trade-off between state prediction and action prediction in the ICM algorithm. + Higher values (> 0.5) prioritize learning to predict next states given current states and actions. + Lower values (< 0.5) prioritize learning to predict actions given current and next states. + The total loss combines both components: + (1-forward_loss_weight)*inverse_loss + forward_loss_weight*forward_loss. + """ + OnPolicyWrapperAlgorithm.__init__( + self, + wrapped_algorithm=wrapped_algorithm, + ) + _ICMMixin.__init__( + self, + model=model, + optim=self._create_optimizer(model, optim), + lr_scale=lr_scale, + reward_scale=reward_scale, + forward_loss_weight=forward_loss_weight, + ) + + def _preprocess_batch( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> RolloutBatchProtocol: + self._icm_preprocess_batch(batch) + return super()._preprocess_batch(batch, buffer, indices) + + def _postprocess_batch( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> None: + super()._postprocess_batch(batch, buffer, indices) + self._icm_postprocess_batch(batch) + + def _wrapper_update_with_batch( + self, + batch: RolloutBatchProtocol, + batch_size: int | None, + repeat: int, + original_stats: TrainingStats, + ) -> ICMTrainingStats: + return self._icm_update(batch, original_stats) diff --git a/tianshou/policy/modelbased/psrl.py b/tianshou/algorithm/modelbased/psrl.py similarity index 59% rename from tianshou/policy/modelbased/psrl.py rename to tianshou/algorithm/modelbased/psrl.py index 95b9527d2..097d0dff9 100644 --- a/tianshou/policy/modelbased/psrl.py +++ b/tianshou/algorithm/modelbased/psrl.py @@ -1,15 +1,18 @@ from dataclasses import dataclass -from typing import Any, TypeVar, cast +from typing import Any, cast import gymnasium as gym import numpy as np import torch +from tianshou.algorithm.algorithm_base import ( + OnPolicyAlgorithm, + Policy, + TrainingStats, +) from tianshou.data import Batch from tianshou.data.batch import BatchProtocol from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol -from tianshou.policy import BasePolicy -from tianshou.policy.base import TLearningRateScheduler, TrainingStats @dataclass(kw_only=True) @@ -18,39 +21,40 @@ class PSRLTrainingStats(TrainingStats): psrl_rew_std: float = 0.0 -TPSRLTrainingStats = TypeVar("TPSRLTrainingStats", bound=PSRLTrainingStats) - - class PSRLModel: - """Implementation of Posterior Sampling Reinforcement Learning Model. - - :param trans_count_prior: dirichlet prior (alphas), with shape - (n_state, n_action, n_state). - :param rew_mean_prior: means of the normal priors of rewards, - with shape (n_state, n_action). - :param rew_std_prior: standard deviations of the normal priors - of rewards, with shape (n_state, n_action). - :param discount_factor: in [0, 1]. - :param epsilon: for precision control in value iteration. - :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in - optimizer in each policy.update(). Default to None (no lr_scheduler). - """ + """Implementation of Posterior Sampling Reinforcement Learning Model.""" def __init__( self, trans_count_prior: np.ndarray, rew_mean_prior: np.ndarray, rew_std_prior: np.ndarray, - discount_factor: float, + gamma: float, epsilon: float, ) -> None: + """ + :param trans_count_prior: dirichlet prior (alphas), with shape + (n_state, n_action, n_state). + :param rew_mean_prior: means of the normal priors of rewards, + with shape (n_state, n_action). + :param rew_std_prior: standard deviations of the normal priors + of rewards, with shape (n_state, n_action). + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + :param epsilon: for precision control in value iteration. + """ self.trans_count = trans_count_prior self.n_state, self.n_action = rew_mean_prior.shape self.rew_mean = rew_mean_prior self.rew_std = rew_std_prior self.rew_square_sum = np.zeros_like(rew_mean_prior) self.rew_std_prior = rew_std_prior - self.discount_factor = discount_factor + self.gamma = gamma self.rew_count = np.full(rew_mean_prior.shape, epsilon) # no weight self.eps = epsilon self.policy: np.ndarray @@ -104,7 +108,7 @@ def solve_policy(self) -> None: self.policy, self.value = self.value_iteration( self.sample_trans_prob(), self.sample_reward(), - self.discount_factor, + self.gamma, self.eps, self.value, ) @@ -113,7 +117,7 @@ def solve_policy(self) -> None: def value_iteration( trans_prob: np.ndarray, rew: np.ndarray, - discount_factor: float, + gamma: float, eps: float, value: np.ndarray, ) -> tuple[np.ndarray, np.ndarray]: @@ -123,17 +127,23 @@ def value_iteration( (n_state, n_action, n_state). :param rew: rewards, with shape (n_state, n_action). :param eps: for precision control. - :param discount_factor: in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param value: the initialize value of value array, with shape (n_state, ). :return: the optimal policy with shape (n_state, ). """ - Q = rew + discount_factor * trans_prob.dot(value) + Q = rew + gamma * trans_prob.dot(value) new_value = Q.max(axis=1) while not np.allclose(new_value, value, eps): value = new_value - Q = rew + discount_factor * trans_prob.dot(value) + Q = rew + gamma * trans_prob.dot(value) new_value = Q.max(axis=1) # this is to make sure if Q(s, a1) == Q(s, a2) -> choose a1/a2 randomly Q += eps * np.random.randn(*Q.shape) @@ -150,32 +160,7 @@ def __call__( return self.policy[obs] -class PSRLPolicy(BasePolicy[TPSRLTrainingStats]): - """Implementation of Posterior Sampling Reinforcement Learning. - - Reference: Strens M. A Bayesian framework for reinforcement learning [C] - //ICML. 2000, 2000: 943-950. - - :param trans_count_prior: dirichlet prior (alphas), with shape - (n_state, n_action, n_state). - :param rew_mean_prior: means of the normal priors of rewards, - with shape (n_state, n_action). - :param rew_std_prior: standard deviations of the normal priors - of rewards, with shape (n_state, n_action). - :param action_space: Env's action_space. - :param discount_factor: in [0, 1]. - :param epsilon: for precision control in value iteration. - :param add_done_loop: whether to add an extra self-loop for the - terminal state in MDP. Default to False. - :param observation_space: Env's observation space. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. - """ - +class PSRLPolicy(Policy): def __init__( self, *, @@ -185,18 +170,25 @@ def __init__( action_space: gym.spaces.Discrete, discount_factor: float = 0.99, epsilon: float = 0.01, - add_done_loop: bool = False, observation_space: gym.Space | None = None, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: + """ + :param trans_count_prior: dirichlet prior (alphas), with shape + (n_state, n_action, n_state). + :param rew_mean_prior: means of the normal priors of rewards, + with shape (n_state, n_action). + :param rew_std_prior: standard deviations of the normal priors + of rewards, with shape (n_state, n_action). + :param action_space: the environment's action_space. + :param epsilon: for precision control in value iteration. + :param observation_space: the environment's observation space + """ super().__init__( action_space=action_space, observation_space=observation_space, action_scaling=False, action_bound_method=None, - lr_scheduler=lr_scheduler, ) - assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]" self.model = PSRLModel( trans_count_prior, rew_mean_prior, @@ -204,7 +196,6 @@ def __init__( discount_factor, epsilon, ) - self._add_done_loop = add_done_loop def forward( self, @@ -216,19 +207,50 @@ def forward( :return: A :class:`~tianshou.data.Batch` with "act" key containing the action. - - .. seealso:: - - Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for - more detailed explanation. """ assert isinstance(batch.obs, np.ndarray), "only support np.ndarray observation" # TODO: shouldn't the model output a state as well if state is passed (i.e. RNNs are involved)? act = self.model(batch.obs, state=state, info=batch.info) return cast(ActBatchProtocol, Batch(act=act)) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TPSRLTrainingStats: - n_s, n_a = self.model.n_state, self.model.n_action + +class PSRL(OnPolicyAlgorithm[PSRLPolicy]): + """Implementation of Posterior Sampling Reinforcement Learning (PSRL). + + Reference: Strens M., A Bayesian Framework for Reinforcement Learning, ICML, 2000. + """ + + def __init__( + self, + *, + policy: PSRLPolicy, + add_done_loop: bool = False, + ) -> None: + """ + :param policy: the policy + :param add_done_loop: a flag indicating whether to add a self-loop transition for terminal states + in the MDP. + If True, whenever an episode terminates, an artificial transition from the terminal state + back to itself is added to the transition counts for all actions. + This modification can help stabilize learning for terminal states that have limited samples. + Setting to True can be beneficial in environments where episodes frequently terminate, + ensuring that terminal states receive sufficient updates to their value estimates. + Default is False, which preserves the standard MDP formulation without artificial self-loops. + """ + super().__init__( + policy=policy, + ) + self._add_done_loop = add_done_loop + + def _update_with_batch( + self, batch: RolloutBatchProtocol, batch_size: int | None, repeat: int + ) -> PSRLTrainingStats: + # NOTE: In contrast to other on-policy algorithms, this algorithm ignores + # the batch_size and repeat arguments. + # PSRL, being a Bayesian approach, updates its posterior distribution of + # the MDP parameters based on the collected transition data as a whole, + # rather than performing gradient-based updates that benefit from mini-batching. + n_s, n_a = self.policy.model.n_state, self.policy.model.n_action trans_count = np.zeros((n_s, n_a, n_s)) rew_sum = np.zeros((n_s, n_a)) rew_square_sum = np.zeros((n_s, n_a)) @@ -246,9 +268,9 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TPSRL # special operation for terminal states: add a self-loop trans_count[obs_next, :, obs_next] += 1 rew_count[obs_next, :] += 1 - self.model.observe(trans_count, rew_sum, rew_square_sum, rew_count) + self.policy.model.observe(trans_count, rew_sum, rew_square_sum, rew_count) - return PSRLTrainingStats( # type: ignore[return-value] - psrl_rew_mean=float(self.model.rew_mean.mean()), - psrl_rew_std=float(self.model.rew_std.mean()), + return PSRLTrainingStats( + psrl_rew_mean=float(self.policy.model.rew_mean.mean()), + psrl_rew_std=float(self.policy.model.rew_std.mean()), ) diff --git a/tianshou/policy/modelfree/__init__.py b/tianshou/algorithm/modelfree/__init__.py similarity index 100% rename from tianshou/policy/modelfree/__init__.py rename to tianshou/algorithm/modelfree/__init__.py diff --git a/tianshou/algorithm/modelfree/a2c.py b/tianshou/algorithm/modelfree/a2c.py new file mode 100644 index 000000000..f26414f7e --- /dev/null +++ b/tianshou/algorithm/modelfree/a2c.py @@ -0,0 +1,288 @@ +from abc import ABC +from dataclasses import dataclass +from typing import cast + +import numpy as np +import torch +import torch.nn.functional as F + +from tianshou.algorithm.algorithm_base import ( + OnPolicyAlgorithm, + TrainingStats, +) +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy +from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as +from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol +from tianshou.utils import RunningMeanStd +from tianshou.utils.net.common import ActorCritic +from tianshou.utils.net.continuous import ContinuousCritic +from tianshou.utils.net.discrete import DiscreteCritic + + +@dataclass(kw_only=True) +class A2CTrainingStats(TrainingStats): + loss: SequenceSummaryStats + actor_loss: SequenceSummaryStats + vf_loss: SequenceSummaryStats + ent_loss: SequenceSummaryStats + gradient_steps: int + + +class ActorCriticOnPolicyAlgorithm(OnPolicyAlgorithm[ProbabilisticActorPolicy], ABC): + """Abstract base class for actor-critic algorithms that use generalized advantage estimation (GAE).""" + + def __init__( + self, + *, + policy: ProbabilisticActorPolicy, + critic: torch.nn.Module | ContinuousCritic | DiscreteCritic, + optim: OptimizerFactory, + optim_include_actor: bool, + max_grad_norm: float | None = None, + gae_lambda: float = 0.95, + max_batchsize: int = 256, + gamma: float = 0.99, + return_scaling: bool = False, + ) -> None: + """ + :param critic: the critic network. (s -> V(s)) + :param optim: the optimizer factory. + :param optim_include_actor: whether the optimizer shall include the actor network's parameters. + Pass False for algorithms that shall update only the critic via the optimizer. + :param max_grad_norm: the maximum L2 norm threshold for gradient clipping. + When not None, gradients will be rescaled using to ensure their L2 norm does not + exceed this value. This prevents exploding gradients and stabilizes training by + limiting the magnitude of parameter updates. + Set to None to disable gradient clipping. + :param gae_lambda: the lambda parameter in [0, 1] for generalized advantage estimation (GAE). + Controls the bias-variance tradeoff in advantage estimates, acting as a + weighting factor for combining different n-step advantage estimators. Higher values + (closer to 1) reduce bias but increase variance by giving more weight to longer + trajectories, while lower values (closer to 0) reduce variance but increase bias + by relying more on the immediate TD error and value function estimates. At λ=0, + GAE becomes equivalent to the one-step TD error (high bias, low variance); at λ=1, + it becomes equivalent to Monte Carlo advantage estimation (low bias, high variance). + Intermediate values create a weighted average of n-step returns, with exponentially + decaying weights for longer-horizon returns. Typically set between 0.9 and 0.99 for + most policy gradient methods. + :param max_batchsize: the maximum number of samples to process at once when computing + generalized advantage estimation (GAE) and value function predictions. + Controls memory usage by breaking large batches into smaller chunks processed sequentially. + Higher values may increase speed but require more GPU/CPU memory; lower values + reduce memory requirements but may increase computation time. Should be adjusted + based on available hardware resources and total batch size of your training data. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + :param return_scaling: flag indicating whether to enable scaling of estimated returns by + dividing them by their running standard deviation without centering the mean. + This reduces the magnitude variation of advantages across different episodes while + preserving their signs and relative ordering. + The use of running statistics (rather than batch-specific scaling) means that early + training experiences may be scaled differently than later ones as the statistics evolve. + When enabled, this improves training stability in environments with highly variable + reward scales and makes the algorithm less sensitive to learning rate settings. + However, it may reduce the algorithm's ability to distinguish between episodes with + different absolute return magnitudes. + Best used in environments where the relative ordering of actions is more important + than the absolute scale of returns. + """ + super().__init__( + policy=policy, + ) + self.critic = critic + assert 0.0 <= gae_lambda <= 1.0, f"GAE lambda should be in [0, 1] but got: {gae_lambda}" + self.gae_lambda = gae_lambda + self.max_batchsize = max_batchsize + if optim_include_actor: + self.optim = self._create_optimizer( + ActorCritic(self.policy.actor, self.critic), optim, max_grad_norm=max_grad_norm + ) + else: + self.optim = self._create_optimizer(self.critic, optim, max_grad_norm=max_grad_norm) + self.gamma = gamma + self.return_scaling = return_scaling + self.ret_rms = RunningMeanStd() + self._eps = 1e-8 + + def _add_returns_and_advantages( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> BatchWithAdvantagesProtocol: + """Adds the returns and advantages to the given batch.""" + v_s, v_s_ = [], [] + with torch.no_grad(): + for minibatch in batch.split(self.max_batchsize, shuffle=False, merge_last=True): + v_s.append(self.critic(minibatch.obs)) + v_s_.append(self.critic(minibatch.obs_next)) + batch.v_s = torch.cat(v_s, dim=0).flatten() # old value + v_s = batch.v_s.cpu().numpy() + v_s_ = torch.cat(v_s_, dim=0).flatten().cpu().numpy() + # when normalizing values, we do not minus self.ret_rms.mean to be numerically + # consistent with OPENAI baselines' value normalization pipeline. Empirical + # study also shows that "minus mean" will harm performances a tiny little bit + # due to unknown reasons (on Mujoco envs, not confident, though). + if self.return_scaling: # unnormalize v_s & v_s_ + v_s = v_s * np.sqrt(self.ret_rms.var + self._eps) + v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) + unnormalized_returns, advantages = self.compute_episodic_return( + batch, + buffer, + indices, + v_s_, + v_s, + gamma=self.gamma, + gae_lambda=self.gae_lambda, + ) + if self.return_scaling: + batch.returns = unnormalized_returns / np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.update(unnormalized_returns) + else: + batch.returns = unnormalized_returns + batch.returns = to_torch_as(batch.returns, batch.v_s) + batch.adv = to_torch_as(advantages, batch.v_s) + return cast(BatchWithAdvantagesProtocol, batch) + + +class A2C(ActorCriticOnPolicyAlgorithm): + """Implementation of (synchronous) Advantage Actor-Critic (A2C). arXiv:1602.01783.""" + + def __init__( + self, + *, + policy: ProbabilisticActorPolicy, + critic: torch.nn.Module | ContinuousCritic | DiscreteCritic, + optim: OptimizerFactory, + vf_coef: float = 0.5, + ent_coef: float = 0.01, + max_grad_norm: float | None = None, + gae_lambda: float = 0.95, + max_batchsize: int = 256, + gamma: float = 0.99, + return_scaling: bool = False, + ) -> None: + """ + :param policy: the policy containing the actor network. + :param critic: the critic network. (s -> V(s)) + :param optim: the optimizer factory. + :param vf_coef: coefficient that weights the value loss relative to the actor loss in + the overall loss function. + Higher values prioritize accurate value function estimation over policy improvement. + Controls the trade-off between policy optimization and value function fitting. + Typically set between 0.5 and 1.0 for most actor-critic implementations. + :param ent_coef: coefficient that weights the entropy bonus relative to the actor loss. + Controls the exploration-exploitation trade-off by encouraging policy entropy. + Higher values promote more exploration by encouraging a more uniform action distribution. + Lower values focus more on exploitation of the current policy's knowledge. + Typically set between 0.01 and 0.05 for most actor-critic implementations. + :param max_grad_norm: the maximum L2 norm threshold for gradient clipping. + When not None, gradients will be rescaled using to ensure their L2 norm does not + exceed this value. This prevents exploding gradients and stabilizes training by + limiting the magnitude of parameter updates. + Set to None to disable gradient clipping. + :param gae_lambda: the lambda parameter in [0, 1] for generalized advantage estimation (GAE). + Controls the bias-variance tradeoff in advantage estimates, acting as a + weighting factor for combining different n-step advantage estimators. Higher values + (closer to 1) reduce bias but increase variance by giving more weight to longer + trajectories, while lower values (closer to 0) reduce variance but increase bias + by relying more on the immediate TD error and value function estimates. At λ=0, + GAE becomes equivalent to the one-step TD error (high bias, low variance); at λ=1, + it becomes equivalent to Monte Carlo advantage estimation (low bias, high variance). + Intermediate values create a weighted average of n-step returns, with exponentially + decaying weights for longer-horizon returns. Typically set between 0.9 and 0.99 for + most policy gradient methods. + :param max_batchsize: the maximum size of the batch when computing GAE. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + :param return_scaling: flag indicating whether to enable scaling of estimated returns by + dividing them by their running standard deviation without centering the mean. + This reduces the magnitude variation of advantages across different episodes while + preserving their signs and relative ordering. + The use of running statistics (rather than batch-specific scaling) means that early + training experiences may be scaled differently than later ones as the statistics evolve. + When enabled, this improves training stability in environments with highly variable + reward scales and makes the algorithm less sensitive to learning rate settings. + However, it may reduce the algorithm's ability to distinguish between episodes with + different absolute return magnitudes. + Best used in environments where the relative ordering of actions is more important + than the absolute scale of returns. + """ + super().__init__( + policy=policy, + critic=critic, + optim=optim, + optim_include_actor=True, + max_grad_norm=max_grad_norm, + gae_lambda=gae_lambda, + max_batchsize=max_batchsize, + gamma=gamma, + return_scaling=return_scaling, + ) + self.vf_coef = vf_coef + self.ent_coef = ent_coef + self.max_grad_norm = max_grad_norm + + def _preprocess_batch( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> BatchWithAdvantagesProtocol: + batch = self._add_returns_and_advantages(batch, buffer, indices) + batch.act = to_torch_as(batch.act, batch.v_s) + return batch + + def _update_with_batch( # type: ignore[override] + self, + batch: BatchWithAdvantagesProtocol, + batch_size: int | None, + repeat: int, + ) -> A2CTrainingStats: + losses, actor_losses, vf_losses, ent_losses = [], [], [], [] + split_batch_size = batch_size or -1 + gradient_steps = 0 + for _ in range(repeat): + for minibatch in batch.split(split_batch_size, merge_last=True): + gradient_steps += 1 + + # calculate loss for actor + dist = self.policy(minibatch).dist + log_prob = dist.log_prob(minibatch.act) + log_prob = log_prob.reshape(len(minibatch.adv), -1).transpose(0, 1) + actor_loss = -(log_prob * minibatch.adv).mean() + # calculate loss for critic + value = self.critic(minibatch.obs).flatten() + vf_loss = F.mse_loss(minibatch.returns, value) + # calculate regularization and overall loss + ent_loss = dist.entropy().mean() + loss = actor_loss + self.vf_coef * vf_loss - self.ent_coef * ent_loss + self.optim.step(loss) + actor_losses.append(actor_loss.item()) + vf_losses.append(vf_loss.item()) + ent_losses.append(ent_loss.item()) + losses.append(loss.item()) + + loss_summary_stat = SequenceSummaryStats.from_sequence(losses) + actor_loss_summary_stat = SequenceSummaryStats.from_sequence(actor_losses) + vf_loss_summary_stat = SequenceSummaryStats.from_sequence(vf_losses) + ent_loss_summary_stat = SequenceSummaryStats.from_sequence(ent_losses) + + return A2CTrainingStats( + loss=loss_summary_stat, + actor_loss=actor_loss_summary_stat, + vf_loss=vf_loss_summary_stat, + ent_loss=ent_loss_summary_stat, + gradient_steps=gradient_steps, + ) diff --git a/tianshou/algorithm/modelfree/bdqn.py b/tianshou/algorithm/modelfree/bdqn.py new file mode 100644 index 000000000..9ab9b34c3 --- /dev/null +++ b/tianshou/algorithm/modelfree/bdqn.py @@ -0,0 +1,224 @@ +from typing import cast + +import gymnasium as gym +import numpy as np +import torch +from sensai.util.helper import mark_used + +from tianshou.algorithm.algorithm_base import TArrOrActBatch +from tianshou.algorithm.modelfree.dqn import ( + DiscreteQLearningPolicy, + QLearningOffPolicyAlgorithm, +) +from tianshou.algorithm.modelfree.reinforce import SimpleLossTrainingStats +from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch, to_torch_as +from tianshou.data.batch import BatchProtocol +from tianshou.data.types import ( + ActBatchProtocol, + BatchWithReturnsProtocol, + ModelOutputBatchProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) +from tianshou.utils.net.common import BranchingNet + +mark_used(ActBatchProtocol) + + +class BDQNPolicy(DiscreteQLearningPolicy[BranchingNet]): + def __init__( + self, + *, + model: BranchingNet, + action_space: gym.spaces.Discrete, + observation_space: gym.Space | None = None, + eps_training: float = 0.0, + eps_inference: float = 0.0, + ): + """ + :param model: BranchingNet mapping (obs, state, info) -> action_values_BA. + :param action_space: the environment's action space + :param observation_space: the environment's observation space. + :param eps_training: the epsilon value for epsilon-greedy exploration during training. + When collecting data for training, this is the probability of choosing a random action + instead of the action chosen by the policy. + A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full + exploration (fully random). + :param eps_inference: the epsilon value for epsilon-greedy exploration during inference, + i.e. non-training cases (such as evaluation during test steps). + The epsilon value is the probability of choosing a random action instead of the action + chosen by the policy. + A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full + exploration (fully random). + """ + super().__init__( + model=model, + action_space=action_space, + observation_space=observation_space, + eps_training=eps_training, + eps_inference=eps_inference, + ) + + def forward( + self, + batch: ObsBatchProtocol, + state: dict | BatchProtocol | np.ndarray | None = None, + model: torch.nn.Module | None = None, + ) -> ModelOutputBatchProtocol: + if model is None: + model = self.model + assert model is not None + obs = batch.obs + # TODO: this is very contrived, see also iqn.py + obs_next_BO = obs.obs if hasattr(obs, "obs") else obs + action_values_BA, hidden_BH = model(obs_next_BO, state=state, info=batch.info) + act_B = to_numpy(action_values_BA.argmax(dim=-1)) + result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH) + return cast(ModelOutputBatchProtocol, result) + + def add_exploration_noise( + self, + act: TArrOrActBatch, + batch: ObsBatchProtocol, + ) -> TArrOrActBatch: + eps = self.eps_training if self.is_within_training_step else self.eps_inference + if np.isclose(eps, 0.0): + return act + if isinstance(act, np.ndarray): + bsz = len(act) + rand_mask = np.random.rand(bsz) < eps + rand_act = np.random.randint( + low=0, + high=self.model.action_per_branch, + size=(bsz, act.shape[-1]), + ) + if hasattr(batch.obs, "mask"): + rand_act += batch.obs.mask + act[rand_mask] = rand_act[rand_mask] + return act # type: ignore[return-value] + else: + raise NotImplementedError( + f"Currently only numpy arrays are supported, got {type(act)=}." + ) + + +class BDQN(QLearningOffPolicyAlgorithm[BDQNPolicy]): + """Implementation of the Branching Dueling Q-Network (BDQN) algorithm arXiv:1711.08946.""" + + def __init__( + self, + *, + policy: BDQNPolicy, + optim: OptimizerFactory, + gamma: float = 0.99, + target_update_freq: int = 0, + is_double: bool = True, + ) -> None: + """ + :param policy: policy + :param optim: the optimizer factory for the policy's model. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + :param target_update_freq: the number of training iterations between each complete update of + the target network. + Controls how frequently the target Q-network parameters are updated with the current + Q-network values. + A value of 0 disables the target network entirely, using only a single network for both + action selection and bootstrap targets. + Higher values provide more stable learning targets but slow down the propagation of new + value estimates. Lower positive values allow faster learning but may lead to instability + due to rapidly changing targets. + Typically set between 100-10000 for DQN variants, with exact values depending on environment + complexity. + :param is_double: flag indicating whether to use Double Q-learning for target value calculation. + If True, the algorithm uses the online network to select actions and the target network to evaluate their Q-values. + This decoupling helps reduce the overestimation bias that standard Q-learning is prone to. + If False, the algorithm selects actions by directly taking the maximum Q-value from the target network. + Note: This parameter is most effective when used with a target network (target_update_freq > 0). + """ + super().__init__( + policy=policy, + optim=optim, + gamma=gamma, + # BDQN implements its own returns computation (below), which supports only 1-step returns + n_step_return_horizon=1, + target_update_freq=target_update_freq, + ) + self.is_double = is_double + + def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: + obs_next_batch = Batch( + obs=buffer[indices].obs_next, + info=[None] * len(indices), + ) # obs_next: s_{t+n} + result = self.policy(obs_next_batch) + if self.use_target_network: + # target_Q = Q_old(s_, argmax(Q_new(s_, *))) + target_q = self.policy(obs_next_batch, model=self.model_old).logits + else: + target_q = result.logits + if self.is_double: + act = np.expand_dims(self.policy(obs_next_batch).act, -1) + act = to_torch(act, dtype=torch.long, device=target_q.device) + else: + act = target_q.max(-1).indices.unsqueeze(-1) + return torch.gather(target_q, -1, act).squeeze() + + def _compute_return( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indice: np.ndarray, + gamma: float = 0.99, + ) -> BatchWithReturnsProtocol: + rew = batch.rew + with torch.no_grad(): + target_q_torch = self._target_q(buffer, indice) # (bsz, ?) + target_q = to_numpy(target_q_torch) + end_flag = buffer.done.copy() + end_flag[buffer.unfinished_index()] = True + end_flag = end_flag[indice] + mean_target_q = np.mean(target_q, -1) if len(target_q.shape) > 1 else target_q + _target_q = rew + gamma * mean_target_q * (1 - end_flag) + target_q = np.repeat(_target_q[..., None], self.policy.model.num_branches, axis=-1) + target_q = np.repeat(target_q[..., None], self.policy.model.action_per_branch, axis=-1) + + batch.returns = to_torch_as(target_q, target_q_torch) + if hasattr(batch, "weight"): # prio buffer update + batch.weight = to_torch_as(batch.weight, target_q_torch) + return cast(BatchWithReturnsProtocol, batch) + + def _preprocess_batch( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> BatchWithReturnsProtocol: + """Compute the 1-step return for BDQ targets.""" + return self._compute_return(batch, buffer, indices) + + def _update_with_batch( # type: ignore[override] + self, + batch: BatchWithReturnsProtocol, + ) -> SimpleLossTrainingStats: + self._periodically_update_lagged_network_weights() + weight = batch.pop("weight", 1.0) + act = to_torch(batch.act, dtype=torch.long, device=batch.returns.device) + q = self.policy(batch).logits + act_mask = torch.zeros_like(q) + act_mask = act_mask.scatter_(-1, act.unsqueeze(-1), 1) + act_q = q * act_mask + returns = batch.returns + returns = returns * act_mask + td_error = returns - act_q + loss = (td_error.pow(2).sum(-1).mean(-1) * weight).mean() + batch.weight = td_error.sum(-1).sum(-1) # prio-buffer + self.optim.step(loss) + + return SimpleLossTrainingStats(loss=loss.item()) diff --git a/tianshou/algorithm/modelfree/c51.py b/tianshou/algorithm/modelfree/c51.py new file mode 100644 index 000000000..8ca11f37d --- /dev/null +++ b/tianshou/algorithm/modelfree/c51.py @@ -0,0 +1,160 @@ +import gymnasium as gym +import numpy as np +import torch + +from tianshou.algorithm.modelfree.dqn import ( + DiscreteQLearningPolicy, + QLearningOffPolicyAlgorithm, +) +from tianshou.algorithm.modelfree.reinforce import LossSequenceTrainingStats +from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import Batch, ReplayBuffer +from tianshou.data.types import RolloutBatchProtocol +from tianshou.utils.net.common import Net + + +class C51Policy(DiscreteQLearningPolicy): + def __init__( + self, + model: torch.nn.Module | Net, + action_space: gym.spaces.Space, + observation_space: gym.Space | None = None, + num_atoms: int = 51, + v_min: float = -10.0, + v_max: float = 10.0, + eps_training: float = 0.0, + eps_inference: float = 0.0, + ): + """ + :param model: a model following the rules (s_B -> action_values_BA) + :param num_atoms: the number of atoms in the support set of the + value distribution. Default to 51. + :param v_min: the value of the smallest atom in the support set. + Default to -10.0. + :param v_max: the value of the largest atom in the support set. + Default to 10.0. + :param eps_training: the epsilon value for epsilon-greedy exploration during training. + When collecting data for training, this is the probability of choosing a random action + instead of the action chosen by the policy. + A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full + exploration (fully random). + :param eps_inference: the epsilon value for epsilon-greedy exploration during inference, + i.e. non-training cases (such as evaluation during test steps). + The epsilon value is the probability of choosing a random action instead of the action + chosen by the policy. + A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full + exploration (fully random). + """ + assert isinstance(action_space, gym.spaces.Discrete) + super().__init__( + model=model, + action_space=action_space, + observation_space=observation_space, + eps_training=eps_training, + eps_inference=eps_inference, + ) + assert num_atoms > 1, f"num_atoms should be greater than 1 but got: {num_atoms}" + assert v_min < v_max, f"v_max should be larger than v_min, but got {v_min=} and {v_max=}" + self.num_atoms = num_atoms + self.v_min = v_min + self.v_max = v_max + self.support = torch.nn.Parameter( + torch.linspace(self.v_min, self.v_max, self.num_atoms), + requires_grad=False, + ) + + def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torch.Tensor: + return super().compute_q_value((logits * self.support).sum(2), mask) + + +class C51(QLearningOffPolicyAlgorithm[C51Policy]): + """Implementation of Categorical Deep Q-Network. arXiv:1707.06887.""" + + def __init__( + self, + *, + policy: C51Policy, + optim: OptimizerFactory, + gamma: float = 0.99, + n_step_return_horizon: int = 1, + target_update_freq: int = 0, + ) -> None: + """ + :param policy: a policy following the rules (s -> action_values_BA) + :param optim: the optimizer factory for the policy's model. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal + difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: + higher values reduce bias (by relying less on potentially inaccurate value estimates) + but increase variance (by incorporating more environmental stochasticity and reducing + the averaging effect). A value of 1 corresponds to standard TD learning with immediate + bootstrapping, while very large values approach Monte Carlo-like estimation that uses + complete episode returns. + :param target_update_freq: the number of training iterations between each complete update of + the target network. + Controls how frequently the target Q-network parameters are updated with the current + Q-network values. + A value of 0 disables the target network entirely, using only a single network for both + action selection and bootstrap targets. + Higher values provide more stable learning targets but slow down the propagation of new + value estimates. Lower positive values allow faster learning but may lead to instability + due to rapidly changing targets. + Typically set between 100-10000 for DQN variants, with exact values depending on environment + complexity. + """ + super().__init__( + policy=policy, + optim=optim, + gamma=gamma, + n_step_return_horizon=n_step_return_horizon, + target_update_freq=target_update_freq, + ) + self.delta_z = (policy.v_max - policy.v_min) / (policy.num_atoms - 1) + + def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: + return self.policy.support.repeat(len(indices), 1) # shape: [bsz, num_atoms] + + def _target_dist(self, batch: RolloutBatchProtocol) -> torch.Tensor: + obs_next_batch = Batch(obs=batch.obs_next, info=[None] * len(batch)) + if self.use_target_network: + act = self.policy(obs_next_batch).act + next_dist = self.policy(obs_next_batch, model=self.model_old).logits + else: + next_batch = self.policy(obs_next_batch) + act = next_batch.act + next_dist = next_batch.logits + next_dist = next_dist[np.arange(len(act)), act, :] + target_support = batch.returns.clamp(self.policy.v_min, self.policy.v_max) + # An amazing trick for calculating the projection gracefully. + # ref: https://github.com/ShangtongZhang/DeepRL + target_dist = ( + 1 + - (target_support.unsqueeze(1) - self.policy.support.view(1, -1, 1)).abs() + / self.delta_z + ).clamp(0, 1) * next_dist.unsqueeze(1) + return target_dist.sum(-1) + + def _update_with_batch( + self, + batch: RolloutBatchProtocol, + ) -> LossSequenceTrainingStats: + self._periodically_update_lagged_network_weights() + with torch.no_grad(): + target_dist = self._target_dist(batch) + weight = batch.pop("weight", 1.0) + curr_dist = self.policy(batch).logits + act = batch.act + curr_dist = curr_dist[np.arange(len(act)), act, :] + cross_entropy = -(target_dist * torch.log(curr_dist + 1e-8)).sum(1) + loss = (cross_entropy * weight).mean() + # ref: https://github.com/Kaixhin/Rainbow/blob/master/agent.py L94-100 + batch.weight = cross_entropy.detach() # prio-buffer + self.optim.step(loss) + + return LossSequenceTrainingStats(loss=loss.item()) diff --git a/tianshou/algorithm/modelfree/ddpg.py b/tianshou/algorithm/modelfree/ddpg.py new file mode 100644 index 000000000..2520fd8d1 --- /dev/null +++ b/tianshou/algorithm/modelfree/ddpg.py @@ -0,0 +1,410 @@ +import warnings +from abc import ABC +from dataclasses import dataclass +from typing import Any, Generic, Literal, TypeVar, cast + +import gymnasium as gym +import numpy as np +import torch +from sensai.util.helper import mark_used + +from tianshou.algorithm import Algorithm +from tianshou.algorithm.algorithm_base import ( + LaggedNetworkPolyakUpdateAlgorithmMixin, + OffPolicyAlgorithm, + Policy, + TArrOrActBatch, + TPolicy, + TrainingStats, +) +from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import Batch, ReplayBuffer +from tianshou.data.batch import BatchProtocol +from tianshou.data.types import ( + ActBatchProtocol, + ActStateBatchProtocol, + BatchWithReturnsProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) +from tianshou.exploration import BaseNoise, GaussianNoise +from tianshou.utils.net.continuous import ( + AbstractContinuousActorDeterministic, + ContinuousCritic, +) + +mark_used(ActBatchProtocol) + + +@dataclass(kw_only=True) +class DDPGTrainingStats(TrainingStats): + actor_loss: float + critic_loss: float + + +class ContinuousPolicyWithExplorationNoise(Policy, ABC): + def __init__( + self, + *, + exploration_noise: BaseNoise | Literal["default"] | None = None, + action_space: gym.Space, + observation_space: gym.Space | None = None, + action_scaling: bool = True, + action_bound_method: Literal["clip"] | None = "clip", + ): + """ + :param exploration_noise: noise model for adding noise to continuous actions + for exploration. This is useful when solving "hard exploration" problems. + "default" is equivalent to GaussianNoise(sigma=0.1). + :param action_space: the environment's action_space. + :param observation_space: the environment's observation space + :param action_scaling: flag indicating whether, for continuous action spaces, actions + should be scaled from the standard neural network output range [-1, 1] to the + environment's action space range [action_space.low, action_space.high]. + This applies to continuous action spaces only (gym.spaces.Box) and has no effect + for discrete spaces. + When enabled, policy outputs are expected to be in the normalized range [-1, 1] + (after bounding), and are then linearly transformed to the actual required range. + This improves neural network training stability, allows the same algorithm to work + across environments with different action ranges, and standardizes exploration + strategies. + Should be disabled if the actor model already produces outputs in the correct range. + :param action_bound_method: the method used for bounding actions in continuous action spaces + to the range [-1, 1] before scaling them to the environment's action space (provided + that `action_scaling` is enabled). + This applies to continuous action spaces only (`gym.spaces.Box`) and should be set to None + for discrete spaces. + When set to "clip", actions exceeding the [-1, 1] range are simply clipped to this + range. When set to "tanh", a hyperbolic tangent function is applied, which smoothly + constrains outputs to [-1, 1] while preserving gradients. + The choice of bounding method affects both training dynamics and exploration behavior. + Clipping provides hard boundaries but may create plateau regions in the gradient + landscape, while tanh provides smoother transitions but can compress sensitivity + near the boundaries. + Should be set to None if the actor model inherently produces bounded outputs. + Typically used together with `action_scaling=True`. + """ + super().__init__( + action_space=action_space, + observation_space=observation_space, + action_scaling=action_scaling, + action_bound_method=action_bound_method, + ) + if exploration_noise == "default": + exploration_noise = GaussianNoise(sigma=0.1) + self.exploration_noise = exploration_noise + + def set_exploration_noise(self, noise: BaseNoise | None) -> None: + """Set the exploration noise.""" + self.exploration_noise = noise + + def add_exploration_noise( + self, + act: TArrOrActBatch, + batch: ObsBatchProtocol, + ) -> TArrOrActBatch: + if self.exploration_noise is None: + return act + if isinstance(act, np.ndarray): + return act + self.exploration_noise(act.shape) + warnings.warn("Cannot add exploration noise to non-numpy_array action.") + return act + + +class ContinuousDeterministicPolicy(ContinuousPolicyWithExplorationNoise): + """A policy for continuous action spaces that uses an actor which directly maps states to actions.""" + + def __init__( + self, + *, + actor: AbstractContinuousActorDeterministic, + exploration_noise: BaseNoise | Literal["default"] | None = None, + action_space: gym.Space, + observation_space: gym.Space | None = None, + action_scaling: bool = True, + action_bound_method: Literal["clip"] | None = "clip", + ): + """ + :param actor: The actor network following the rules (s -> actions) + :param exploration_noise: add noise to continuous actions for exploration; + set to None for discrete action spaces. + This is useful when solving "hard exploration" problems. + "default" is equivalent to GaussianNoise(sigma=0.1). + :param action_space: the environment's action space. + :param tau: the soft update coefficient for target networks, controlling the rate at which + target networks track the learned networks. + When the parameters of the target network are updated with the current (source) network's + parameters, a weighted average is used: target = tau * source + (1 - tau) * target. + Smaller values (closer to 0) create more stable but slower learning as target networks + change more gradually. Higher values (closer to 1) allow faster learning but may reduce + stability. + Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. + :param observation_space: the environment's observation space. + :param action_scaling: flag indicating whether, for continuous action spaces, actions + should be scaled from the standard neural network output range [-1, 1] to the + environment's action space range [action_space.low, action_space.high]. + This applies to continuous action spaces only (gym.spaces.Box) and has no effect + for discrete spaces. + When enabled, policy outputs are expected to be in the normalized range [-1, 1] + (after bounding), and are then linearly transformed to the actual required range. + This improves neural network training stability, allows the same algorithm to work + across environments with different action ranges, and standardizes exploration + strategies. + Should be disabled if the actor model already produces outputs in the correct range. + :param action_bound_method: method to bound action to range [-1, 1]. + """ + if action_scaling and not np.isclose(actor.max_action, 1.0): + warnings.warn( + "action_scaling and action_bound_method are only intended to deal" + "with unbounded model action space, but find actor model bound" + f"action space with max_action={actor.max_action}." + "Consider using unbounded=True option of the actor model," + "or set action_scaling to False and action_bound_method to None.", + ) + super().__init__( + exploration_noise=exploration_noise, + action_space=action_space, + observation_space=observation_space, + action_scaling=action_scaling, + action_bound_method=action_bound_method, + ) + self.actor = actor + + def forward( + self, + batch: ObsBatchProtocol, + state: dict | BatchProtocol | np.ndarray | None = None, + model: torch.nn.Module | None = None, + **kwargs: Any, + ) -> ActStateBatchProtocol: + """Compute action over the given batch data. + + :return: A :class:`~tianshou.data.Batch` which has 2 keys: + + * ``act`` the action. + * ``state`` the hidden state. + """ + if model is None: + model = self.actor + actions, hidden = model(batch.obs, state=state, info=batch.info) + return cast(ActStateBatchProtocol, Batch(act=actions, state=hidden)) + + +TActBatchProtocol = TypeVar("TActBatchProtocol", bound=ActBatchProtocol) + + +class ActorCriticOffPolicyAlgorithm( + OffPolicyAlgorithm[TPolicy], + LaggedNetworkPolyakUpdateAlgorithmMixin, + Generic[TPolicy, TActBatchProtocol], + ABC, +): + """Base class for actor-critic off-policy algorithms that use a lagged critic + as a target network. + + Its implementation of `process_fn` adds the n-step return to the batch, using the + Q-values computed by the target network (lagged critic, `critic_old`) in order to compute the + reward-to-go. + + Specializations can override the action computation (`_target_q_compute_action`) or the + Q-value computation based on these actions (`_target_q_compute_value`) to customize the + target Q-value computation. + The default implementation assumes a continuous action space where a critic receives a + state/observation and an action; for discrete action spaces, where the critic receives only + a state/observation, the method `_target_q_compute_value` must be overridden. + """ + + def __init__( + self, + *, + policy: TPolicy, + policy_optim: OptimizerFactory, + critic: torch.nn.Module, + critic_optim: OptimizerFactory, + tau: float = 0.005, + gamma: float = 0.99, + n_step_return_horizon: int = 1, + ) -> None: + """ + :param policy: the policy + :param policy_optim: the optimizer factory for the policy's model. + :param critic: the critic network. + For continuous action spaces: (s, a -> Q(s, a)). + For discrete action spaces: (s -> ). + **NOTE**: The default implementation of `_target_q_compute_value` assumes + a continuous action space; override this method if using discrete actions. + :param critic_optim: the optimizer factory for the critic network. + :param tau: the soft update coefficient for target networks, controlling the rate at which + target networks track the learned networks. + When the parameters of the target network are updated with the current (source) network's + parameters, a weighted average is used: target = tau * source + (1 - tau) * target. + Smaller values (closer to 0) create more stable but slower learning as target networks + change more gradually. Higher values (closer to 1) allow faster learning but may reduce + stability. + Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + """ + assert 0.0 <= tau <= 1.0, f"tau should be in [0, 1] but got: {tau}" + assert 0.0 <= gamma <= 1.0, f"gamma should be in [0, 1] but got: {gamma}" + super().__init__( + policy=policy, + ) + LaggedNetworkPolyakUpdateAlgorithmMixin.__init__(self, tau=tau) + self.policy_optim = self._create_optimizer(policy, policy_optim) + self.critic = critic + self.critic_old = self._add_lagged_network(self.critic) + self.critic_optim = self._create_optimizer(self.critic, critic_optim) + self.gamma = gamma + self.n_step_return_horizon = n_step_return_horizon + + @staticmethod + def _minimize_critic_squared_loss( + batch: RolloutBatchProtocol, + critic: torch.nn.Module, + optimizer: Algorithm.Optimizer, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Takes an optimizer step to minimize the squared loss of the critic given a batch of data. + + :param batch: the batch containing the observations, actions, returns, and (optionally) weights. + :param critic: the critic network to minimize the loss for. + :param optimizer: the optimizer for the critic's parameters. + :return: a pair (`td`, `loss`), where `td` is the tensor of errors (current - target) and `loss` is the MSE loss. + """ + weight = getattr(batch, "weight", 1.0) + current_q = critic(batch.obs, batch.act).flatten() + target_q = batch.returns.flatten() + td = current_q - target_q + critic_loss = (td.pow(2) * weight).mean() + optimizer.step(critic_loss) + return td, critic_loss + + def _preprocess_batch( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> RolloutBatchProtocol | BatchWithReturnsProtocol: + # add the n-step return to the batch, which the critic (Q-functions) seeks to match, + # based the Q-values computed by the target network (lagged critic) + return self.compute_nstep_return( + batch=batch, + buffer=buffer, + indices=indices, + target_q_fn=self._target_q, + gamma=self.gamma, + n_step=self.n_step_return_horizon, + ) + + def _target_q_compute_action(self, obs_batch: Batch) -> TActBatchProtocol: + """ + Computes the action to be taken for the given batch (containing the observations) + within the context of Q-value target computation. + + :param obs_batch: the batch containing the observations. + :return: batch containing the actions to be taken. + """ + return self.policy(obs_batch) + + def _target_q_compute_value( + self, obs_batch: Batch, act_batch: TActBatchProtocol + ) -> torch.Tensor: + """ + Computes the target Q-value given a batch with observations and actions taken. + + :param obs_batch: the batch containing the observations. + :param act_batch: the batch containing the actions taken. + :return: a tensor containing the target Q-values. + """ + # compute the target Q-value using the lagged critic network (target network) + return self.critic_old(obs_batch.obs, act_batch.act) + + def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: + """ + Computes the target Q-value for the given buffer and indices. + + :param buffer: the replay buffer + :param indices: the indices within the buffer to compute the target Q-value for + """ + obs_next_batch = Batch( + obs=buffer[indices].obs_next, + info=[None] * len(indices), + ) # obs_next: s_{t+n} + act_batch = self._target_q_compute_action(obs_next_batch) + return self._target_q_compute_value(obs_next_batch, act_batch) + + +class DDPG( + ActorCriticOffPolicyAlgorithm[ContinuousDeterministicPolicy, ActBatchProtocol], +): + """Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971.""" + + def __init__( + self, + *, + policy: ContinuousDeterministicPolicy, + policy_optim: OptimizerFactory, + critic: torch.nn.Module | ContinuousCritic, + critic_optim: OptimizerFactory, + tau: float = 0.005, + gamma: float = 0.99, + n_step_return_horizon: int = 1, + ) -> None: + """ + :param policy: the policy + :param policy_optim: the optimizer factory for the policy's model. + :param critic: the critic network. (s, a -> Q(s, a)) + :param critic_optim: the optimizer factory for the critic network. + :param tau: the soft update coefficient for target networks, controlling the rate at which + target networks track the learned networks. + When the parameters of the target network are updated with the current (source) network's + parameters, a weighted average is used: target = tau * source + (1 - tau) * target. + Smaller values (closer to 0) create more stable but slower learning as target networks + change more gradually. Higher values (closer to 1) allow faster learning but may reduce + stability. + Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal + difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: + higher values reduce bias (by relying less on potentially inaccurate value estimates) + but increase variance (by incorporating more environmental stochasticity and reducing + the averaging effect). A value of 1 corresponds to standard TD learning with immediate + bootstrapping, while very large values approach Monte Carlo-like estimation that uses + complete episode returns. + """ + super().__init__( + policy=policy, + policy_optim=policy_optim, + critic=critic, + critic_optim=critic_optim, + tau=tau, + gamma=gamma, + n_step_return_horizon=n_step_return_horizon, + ) + self.actor_old = self._add_lagged_network(self.policy.actor) + + def _target_q_compute_action(self, obs_batch: Batch) -> ActBatchProtocol: + # compute the action using the lagged actor network + return self.policy(obs_batch, model=self.actor_old) + + def _update_with_batch(self, batch: RolloutBatchProtocol) -> DDPGTrainingStats: + # critic + td, critic_loss = self._minimize_critic_squared_loss(batch, self.critic, self.critic_optim) + batch.weight = td # prio-buffer + # actor + actor_loss = -self.critic(batch.obs, self.policy(batch).act).mean() + self.policy_optim.step(actor_loss) + self._update_lagged_network_weights() + + return DDPGTrainingStats(actor_loss=actor_loss.item(), critic_loss=critic_loss.item()) diff --git a/tianshou/algorithm/modelfree/discrete_sac.py b/tianshou/algorithm/modelfree/discrete_sac.py new file mode 100644 index 000000000..99cf2c01c --- /dev/null +++ b/tianshou/algorithm/modelfree/discrete_sac.py @@ -0,0 +1,196 @@ +from dataclasses import dataclass +from typing import Any, TypeVar, cast + +import gymnasium as gym +import numpy as np +import torch +from torch.distributions import Categorical + +from tianshou.algorithm.algorithm_base import Policy +from tianshou.algorithm.modelfree.sac import Alpha, SACTrainingStats +from tianshou.algorithm.modelfree.td3 import ActorDualCriticsOffPolicyAlgorithm +from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import Batch, to_torch +from tianshou.data.batch import BatchProtocol +from tianshou.data.types import ( + DistBatchProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) +from tianshou.utils.net.discrete import DiscreteCritic + + +@dataclass +class DiscreteSACTrainingStats(SACTrainingStats): + pass + + +TDiscreteSACTrainingStats = TypeVar("TDiscreteSACTrainingStats", bound=DiscreteSACTrainingStats) + + +class DiscreteSACPolicy(Policy): + def __init__( + self, + *, + actor: torch.nn.Module, + deterministic_eval: bool = True, + action_space: gym.Space, + observation_space: gym.Space | None = None, + ): + """ + :param actor: the actor network following the rules (s -> dist_input_BD), + where the distribution input is for a `Categorical` distribution. + :param deterministic_eval: flag indicating whether the policy should use deterministic + actions (using the mode of the action distribution) instead of stochastic ones + (using random sampling) during evaluation. + When enabled, the policy will always select the most probable action according to + the learned distribution during evaluation phases, while still using stochastic + sampling during training. This creates a clear distinction between exploration + (training) and exploitation (evaluation) behaviors. + Deterministic actions are generally preferred for final deployment and reproducible + evaluation as they provide consistent behavior, reduce variance in performance + metrics, and are more interpretable for human observers. + Note that this parameter only affects behavior when the policy is not within a + training step. When collecting rollouts for training, actions remain stochastic + regardless of this setting to maintain proper exploration behaviour. + :param action_space: the environment's action_space. + :param observation_space: the environment's observation space + """ + assert isinstance(action_space, gym.spaces.Discrete) + super().__init__( + action_space=action_space, + observation_space=observation_space, + ) + self.actor = actor + self.deterministic_eval = deterministic_eval + + def forward( + self, + batch: ObsBatchProtocol, + state: dict | BatchProtocol | np.ndarray | None = None, + **kwargs: Any, + ) -> Batch: + logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) + dist = Categorical(logits=logits_BA) + act_B = ( + dist.mode + if self.deterministic_eval and not self.is_within_training_step + else dist.sample() + ) + return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist) + + +class DiscreteSAC(ActorDualCriticsOffPolicyAlgorithm[DiscreteSACPolicy, DistBatchProtocol]): + """Implementation of SAC for Discrete Action Settings. arXiv:1910.07207.""" + + def __init__( + self, + *, + policy: DiscreteSACPolicy, + policy_optim: OptimizerFactory, + critic: torch.nn.Module | DiscreteCritic, + critic_optim: OptimizerFactory, + critic2: torch.nn.Module | DiscreteCritic | None = None, + critic2_optim: OptimizerFactory | None = None, + tau: float = 0.005, + gamma: float = 0.99, + alpha: float | Alpha = 0.2, + n_step_return_horizon: int = 1, + ) -> None: + """ + :param policy: the policy + :param policy_optim: the optimizer factory for the policy's model. + :param critic: the first critic network. (s -> ). + :param critic_optim: the optimizer factory for the first critic network. + :param critic2: the second critic network. (s -> ). + If None, copy the first critic (via deepcopy). + :param critic2_optim: the optimizer factory for the second critic network. + If None, use the first critic's factory. + :param tau: the soft update coefficient for target networks, controlling the rate at which + target networks track the learned networks. + When the parameters of the target network are updated with the current (source) network's + parameters, a weighted average is used: target = tau * source + (1 - tau) * target. + Smaller values (closer to 0) create more stable but slower learning as target networks + change more gradually. Higher values (closer to 1) allow faster learning but may reduce + stability. + Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + :param alpha: the entropy regularization coefficient alpha or an object + which can be used to automatically tune it (e.g. an instance of `AutoAlpha`). + :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal + difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: + higher values reduce bias (by relying less on potentially inaccurate value estimates) + but increase variance (by incorporating more environmental stochasticity and reducing + the averaging effect). A value of 1 corresponds to standard TD learning with immediate + bootstrapping, while very large values approach Monte Carlo-like estimation that uses + complete episode returns. + """ + super().__init__( + policy=policy, + policy_optim=policy_optim, + critic=critic, + critic_optim=critic_optim, + critic2=critic2, + critic2_optim=critic2_optim, + tau=tau, + gamma=gamma, + n_step_return_horizon=n_step_return_horizon, + ) + self.alpha = Alpha.from_float_or_instance(alpha) + + def _target_q_compute_value( + self, obs_batch: Batch, act_batch: DistBatchProtocol + ) -> torch.Tensor: + dist = cast(Categorical, act_batch.dist) + target_q = dist.probs * torch.min( + self.critic_old(obs_batch.obs), + self.critic2_old(obs_batch.obs), + ) + return target_q.sum(dim=-1) + self.alpha.value * dist.entropy() + + def _update_with_batch(self, batch: RolloutBatchProtocol) -> TDiscreteSACTrainingStats: # type: ignore + weight = batch.pop("weight", 1.0) + target_q = batch.returns.flatten() + act = to_torch(batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long) + + # critic 1 + current_q1 = self.critic(batch.obs).gather(1, act).flatten() + td1 = current_q1 - target_q + critic1_loss = (td1.pow(2) * weight).mean() + self.critic_optim.step(critic1_loss) + + # critic 2 + current_q2 = self.critic2(batch.obs).gather(1, act).flatten() + td2 = current_q2 - target_q + critic2_loss = (td2.pow(2) * weight).mean() + self.critic2_optim.step(critic2_loss) + + batch.weight = (td1 + td2) / 2.0 # prio-buffer + + # actor + dist = self.policy(batch).dist + entropy = dist.entropy() + with torch.no_grad(): + current_q1a = self.critic(batch.obs) + current_q2a = self.critic2(batch.obs) + q = torch.min(current_q1a, current_q2a) + actor_loss = -(self.alpha.value * entropy + (dist.probs * q).sum(dim=-1)).mean() + self.policy_optim.step(actor_loss) + + alpha_loss = self.alpha.update(entropy.detach()) + + self._update_lagged_network_weights() + + return DiscreteSACTrainingStats( # type: ignore[return-value] + actor_loss=actor_loss.item(), + critic1_loss=critic1_loss.item(), + critic2_loss=critic2_loss.item(), + alpha=self.alpha.value, + alpha_loss=alpha_loss, + ) diff --git a/tianshou/algorithm/modelfree/dqn.py b/tianshou/algorithm/modelfree/dqn.py new file mode 100644 index 000000000..530de6ef0 --- /dev/null +++ b/tianshou/algorithm/modelfree/dqn.py @@ -0,0 +1,404 @@ +import logging +from abc import ABC, abstractmethod +from typing import Any, Generic, TypeVar, cast + +import gymnasium as gym +import numpy as np +import torch +from gymnasium.spaces.discrete import Discrete +from sensai.util.helper import mark_used + +from tianshou.algorithm import Algorithm +from tianshou.algorithm.algorithm_base import ( + LaggedNetworkFullUpdateAlgorithmMixin, + OffPolicyAlgorithm, + Policy, + TArrOrActBatch, +) +from tianshou.algorithm.modelfree.reinforce import ( + SimpleLossTrainingStats, +) +from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as +from tianshou.data.types import ( + ActBatchProtocol, + BatchWithReturnsProtocol, + ModelOutputBatchProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) +from tianshou.utils.lagged_network import EvalModeModuleWrapper +from tianshou.utils.net.common import Net + +mark_used(ActBatchProtocol) + +TModel = TypeVar("TModel", bound=torch.nn.Module | Net) +log = logging.getLogger(__name__) + + +class DiscreteQLearningPolicy(Policy, Generic[TModel]): + def __init__( + self, + *, + model: TModel, + action_space: gym.spaces.Space, + observation_space: gym.Space | None = None, + eps_training: float = 0.0, + eps_inference: float = 0.0, + ) -> None: + """ + :param model: a model mapping (obs, state, info) to action_values_BA. + :param action_space: the environment's action space + :param observation_space: the environment's observation space. + :param eps_training: the epsilon value for epsilon-greedy exploration during training. + When collecting data for training, this is the probability of choosing a random action + instead of the action chosen by the policy. + A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full + exploration (fully random). + :param eps_inference: the epsilon value for epsilon-greedy exploration during inference, + i.e. non-training cases (such as evaluation during test steps). + The epsilon value is the probability of choosing a random action instead of the action + chosen by the policy. + A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full + exploration (fully random). + """ + super().__init__( + action_space=action_space, + observation_space=observation_space, + action_scaling=False, + action_bound_method=None, + ) + self.action_space = cast(Discrete, self.action_space) + self.model = model + self.eps_training = eps_training + self.eps_inference = eps_inference + + def set_eps_training(self, eps: float) -> None: + """ + Sets the epsilon value for epsilon-greedy exploration during training. + + :param eps: the epsilon value for epsilon-greedy exploration during training. + When collecting data for training, this is the probability of choosing a random action + instead of the action chosen by the policy. + A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full + exploration (fully random). + """ + self.eps_training = eps + + def set_eps_inference(self, eps: float) -> None: + """ + Sets the epsilon value for epsilon-greedy exploration during inference. + + :param eps: the epsilon value for epsilon-greedy exploration during inference, + i.e. non-training cases (such as evaluation during test steps). + The epsilon value is the probability of choosing a random action instead of the action + chosen by the policy. + A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full + exploration (fully random). + """ + self.eps_inference = eps + + def forward( + self, + batch: ObsBatchProtocol, + state: Any | None = None, + model: torch.nn.Module | None = None, + ) -> ModelOutputBatchProtocol: + """Compute action over the given batch data. + + If you need to mask the action, please add a "mask" into batch.obs, for + example, if we have an environment that has "0/1/2" three actions: + :: + + batch == Batch( + obs=Batch( + obs="original obs, with batch_size=1 for demonstration", + mask=np.array([[False, True, False]]), + # action 1 is available + # action 0 and 2 are unavailable + ), + ... + ) + + :param batch: + :param state: optional hidden state (for RNNs) + :param model: if not passed will use `self.model`. Typically used to pass + the lagged target network instead of using the current model. + :return: A :class:`~tianshou.data.Batch` which has 3 keys: + + * ``act`` the action. + * ``logits`` the network's raw output. + * ``state`` the hidden state. + """ + if model is None: + model = self.model + obs = batch.obs + mask = getattr(obs, "mask", None) + # TODO: this is convoluted! See also other places where this is done. + obs_arr = obs.obs if hasattr(obs, "obs") else obs + action_values_BA, hidden_BH = model(obs_arr, state=state, info=batch.info) + q = self.compute_q_value(action_values_BA, mask) + act_B = to_numpy(q.argmax(dim=1)) + result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH) + return cast(ModelOutputBatchProtocol, result) + + def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torch.Tensor: + """Compute the q value based on the network's raw output and action mask.""" + if mask is not None: + # the masked q value should be smaller than logits.min() + min_value = logits.min() - logits.max() - 1.0 + logits = logits + to_torch_as(1 - mask, logits) * min_value + return logits + + def add_exploration_noise( + self, + act: TArrOrActBatch, + batch: ObsBatchProtocol, + ) -> TArrOrActBatch: + eps = self.eps_training if self.is_within_training_step else self.eps_inference + if np.isclose(eps, 0.0): + return act + if isinstance(act, np.ndarray): + batch_size = len(act) + rand_mask = np.random.rand(batch_size) < eps + self.action_space = cast(Discrete, self.action_space) # for mypy + action_num = int(self.action_space.n) + q = np.random.rand(batch_size, action_num) # [0, 1] + if hasattr(batch.obs, "mask"): + q += batch.obs.mask + rand_act = q.argmax(axis=1) + act[rand_mask] = rand_act[rand_mask] + return act # type: ignore[return-value] + raise NotImplementedError( + f"Currently only numpy array is supported for action, but got {type(act)}" + ) + + +TDQNPolicy = TypeVar("TDQNPolicy", bound=DiscreteQLearningPolicy) + + +class QLearningOffPolicyAlgorithm( + OffPolicyAlgorithm[TDQNPolicy], LaggedNetworkFullUpdateAlgorithmMixin, ABC +): + """ + Base class for Q-learning off-policy algorithms that use a Q-function to compute the + n-step return. + It optionally uses a lagged model, which is used as a target network and which is + fully updated periodically. + """ + + def __init__( + self, + *, + policy: TDQNPolicy, + optim: OptimizerFactory, + gamma: float = 0.99, + n_step_return_horizon: int = 1, + target_update_freq: int = 0, + ) -> None: + """ + :param policy: the policy + :param optim: the optimizer factory for the policy's model. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal + difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: + higher values reduce bias (by relying less on potentially inaccurate value estimates) + but increase variance (by incorporating more environmental stochasticity and reducing + the averaging effect). A value of 1 corresponds to standard TD learning with immediate + bootstrapping, while very large values approach Monte Carlo-like estimation that uses + complete episode returns. + :param target_update_freq: the number of training iterations between each complete update of + the target network. + Controls how frequently the target Q-network parameters are updated with the current + Q-network values. + A value of 0 disables the target network entirely, using only a single network for both + action selection and bootstrap targets. + Higher values provide more stable learning targets but slow down the propagation of new + value estimates. Lower positive values allow faster learning but may lead to instability + due to rapidly changing targets. + Typically set between 100-10000 for DQN variants, with exact values depending on environment + complexity. + """ + super().__init__( + policy=policy, + ) + self.optim = self._create_policy_optimizer(optim) + LaggedNetworkFullUpdateAlgorithmMixin.__init__(self) + assert 0.0 <= gamma <= 1.0, f"discount factor should be in [0, 1] but got: {gamma}" + self.gamma = gamma + assert ( + n_step_return_horizon > 0 + ), f"n_step_return_horizon should be greater than 0 but got: {n_step_return_horizon}" + self.n_step = n_step_return_horizon + self.target_update_freq = target_update_freq + # TODO: 1 would be a more reasonable initialization given how it is incremented + self._iter = 0 + self.model_old: EvalModeModuleWrapper | None = ( + self._add_lagged_network(self.policy.model) if self.use_target_network else None + ) + + def _create_policy_optimizer(self, optim: OptimizerFactory) -> Algorithm.Optimizer: + return self._create_optimizer(self.policy, optim) + + @property + def use_target_network(self) -> bool: + return self.target_update_freq > 0 + + @abstractmethod + def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: + pass + + def _preprocess_batch( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> BatchWithReturnsProtocol: + """Compute the n-step return for Q-learning targets. + + More details can be found at + :meth:`~tianshou.policy.BasePolicy.compute_nstep_return`. + """ + return self.compute_nstep_return( + batch=batch, + buffer=buffer, + indices=indices, + target_q_fn=self._target_q, + gamma=self.gamma, + n_step=self.n_step, + ) + + def _periodically_update_lagged_network_weights(self) -> None: + """ + Periodically updates the parameters of the lagged target network (if any), i.e. + every n-th call (where n=`target_update_freq`), the target network's parameters + are fully updated with the model's parameters. + """ + if self.use_target_network and self._iter % self.target_update_freq == 0: + self._update_lagged_network_weights() + self._iter += 1 + + +class DQN( + QLearningOffPolicyAlgorithm[TDQNPolicy], + Generic[TDQNPolicy], +): + """Implementation of Deep Q Network. arXiv:1312.5602. + + Implementation of Double Q-Learning. arXiv:1509.06461. + + Implementation of Dueling DQN. arXiv:1511.06581 (the dueling DQN is + implemented in the network side, not here). + """ + + def __init__( + self, + *, + policy: TDQNPolicy, + optim: OptimizerFactory, + gamma: float = 0.99, + n_step_return_horizon: int = 1, + target_update_freq: int = 0, + is_double: bool = True, + huber_loss_delta: float | None = None, + ) -> None: + """ + :param policy: the policy + :param optim: the optimizer factory for the policy's model. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal + difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: + higher values reduce bias (by relying less on potentially inaccurate value estimates) + but increase variance (by incorporating more environmental stochasticity and reducing + the averaging effect). A value of 1 corresponds to standard TD learning with immediate + bootstrapping, while very large values approach Monte Carlo-like estimation that uses + complete episode returns. + :param target_update_freq: the number of training iterations between each complete update of + the target network. + Controls how frequently the target Q-network parameters are updated with the current + Q-network values. + A value of 0 disables the target network entirely, using only a single network for both + action selection and bootstrap targets. + Higher values provide more stable learning targets but slow down the propagation of new + value estimates. Lower positive values allow faster learning but may lead to instability + due to rapidly changing targets. + Typically set between 100-10000 for DQN variants, with exact values depending on environment + complexity. + :param is_double: flag indicating whether to use the Double DQN algorithm for target value computation. + If True, the algorithm uses the online network to select actions and the target network to + evaluate their Q-values. This approach helps reduce the overestimation bias in Q-learning + by decoupling action selection from action evaluation. + If False, the algorithm follows the vanilla DQN method that directly takes the maximum Q-value + from the target network. + Note: Double Q-learning will only be effective when a target network is used (target_update_freq > 0). + :param huber_loss_delta: controls whether to use the Huber loss instead of the MSE loss for the TD error + and the threshold for the Huber loss. + If None, the MSE loss is used. + If not None, uses the Huber loss as described in the Nature DQN paper (nature14236) with the given delta, + which limits the influence of outliers. + Unlike the MSE loss where the gradients grow linearly with the error magnitude, the Huber + loss causes the gradients to plateau at a constant value for large errors, providing more stable training. + NOTE: The magnitude of delta should depend on the scale of the returns obtained in the environment. + """ + super().__init__( + policy=policy, + optim=optim, + gamma=gamma, + n_step_return_horizon=n_step_return_horizon, + target_update_freq=target_update_freq, + ) + self.is_double = is_double + self.huber_loss_delta = huber_loss_delta + + def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: + obs_next_batch = Batch( + obs=buffer[indices].obs_next, + info=[None] * len(indices), + ) # obs_next: s_{t+n} + result = self.policy(obs_next_batch) + if self.use_target_network: + # target_Q = Q_old(s_, argmax(Q_new(s_, *))) + target_q = self.policy(obs_next_batch, model=self.model_old).logits + else: + target_q = result.logits + if self.is_double: + return target_q[np.arange(len(result.act)), result.act] + # Nature DQN, over estimate + return target_q.max(dim=1)[0] + + def _update_with_batch( + self, + batch: RolloutBatchProtocol, + ) -> SimpleLossTrainingStats: + self._periodically_update_lagged_network_weights() + weight = batch.pop("weight", 1.0) + q = self.policy(batch).logits + q = q[np.arange(len(q)), batch.act] + returns = to_torch_as(batch.returns.flatten(), q) + td_error = returns - q + + if self.huber_loss_delta is not None: + y = q.reshape(-1, 1) + t = returns.reshape(-1, 1) + loss = torch.nn.functional.huber_loss( + y, t, delta=self.huber_loss_delta, reduction="mean" + ) + else: + loss = (td_error.pow(2) * weight).mean() + + batch.weight = td_error # prio-buffer + self.optim.step(loss) + + return SimpleLossTrainingStats(loss=loss.item()) diff --git a/tianshou/algorithm/modelfree/fqf.py b/tianshou/algorithm/modelfree/fqf.py new file mode 100644 index 000000000..8ebc6fa93 --- /dev/null +++ b/tianshou/algorithm/modelfree/fqf.py @@ -0,0 +1,256 @@ +from dataclasses import dataclass +from typing import Any, cast + +import gymnasium as gym +import numpy as np +import torch +import torch.nn.functional as F +from overrides import override + +from tianshou.algorithm import QRDQN, Algorithm +from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy +from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy +from tianshou.algorithm.modelfree.reinforce import SimpleLossTrainingStats +from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import Batch, ReplayBuffer, to_numpy +from tianshou.data.types import FQFBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol +from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction + + +@dataclass(kw_only=True) +class FQFTrainingStats(SimpleLossTrainingStats): + quantile_loss: float + fraction_loss: float + entropy_loss: float + + +class FQFPolicy(QRDQNPolicy): + def __init__( + self, + *, + model: FullQuantileFunction, + fraction_model: FractionProposalNetwork, + action_space: gym.spaces.Space, + observation_space: gym.Space | None = None, + eps_training: float = 0.0, + eps_inference: float = 0.0, + ): + """ + :param model: a model following the rules (s_B -> action_values_BA) + :param fraction_model: a FractionProposalNetwork for + proposing fractions/quantiles given state. + :param action_space: the environment's action space + :param observation_space: the environment's observation space. + :param eps_training: the epsilon value for epsilon-greedy exploration during training. + When collecting data for training, this is the probability of choosing a random action + instead of the action chosen by the policy. + A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full + exploration (fully random). + :param eps_inference: the epsilon value for epsilon-greedy exploration during inference, + i.e. non-training cases (such as evaluation during test steps). + The epsilon value is the probability of choosing a random action instead of the action + chosen by the policy. + A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full + exploration (fully random). + """ + assert isinstance(action_space, gym.spaces.Discrete) + super().__init__( + model=model, + action_space=action_space, + observation_space=observation_space, + eps_training=eps_training, + eps_inference=eps_inference, + ) + self.fraction_model = fraction_model + + def forward( # type: ignore + self, + batch: ObsBatchProtocol, + state: dict | Batch | np.ndarray | None = None, + model: FullQuantileFunction | None = None, + fractions: Batch | None = None, + **kwargs: Any, + ) -> FQFBatchProtocol: + if model is None: + model = self.model + obs = batch.obs + # TODO: this is convoluted! See also other places where this is done + obs_next = obs.obs if hasattr(obs, "obs") else obs + if fractions is None: + (logits, fractions, quantiles_tau), hidden = model( + obs_next, + propose_model=self.fraction_model, + state=state, + info=batch.info, + ) + else: + (logits, _, quantiles_tau), hidden = model( + obs_next, + propose_model=self.fraction_model, + fractions=fractions, + state=state, + info=batch.info, + ) + weighted_logits = (fractions.taus[:, 1:] - fractions.taus[:, :-1]).unsqueeze(1) * logits + q = DiscreteQLearningPolicy.compute_q_value( + self, weighted_logits.sum(2), getattr(obs, "mask", None) + ) + act = to_numpy(q.max(dim=1)[1]) + result = Batch( + logits=logits, + act=act, + state=hidden, + fractions=fractions, + quantiles_tau=quantiles_tau, + ) + return cast(FQFBatchProtocol, result) + + +class FQF(QRDQN[FQFPolicy]): + """Implementation of Fully Parameterized Quantile Function for Distributional Reinforcement Learning. arXiv:1911.02140.""" + + def __init__( + self, + *, + policy: FQFPolicy, + optim: OptimizerFactory, + fraction_optim: OptimizerFactory, + gamma: float = 0.99, + num_fractions: int = 32, + ent_coef: float = 0.0, + n_step_return_horizon: int = 1, + target_update_freq: int = 0, + ) -> None: + """ + :param policy: the policy + :param optim: the optimizer factory for the policy's main Q-function model + :param fraction_optim: the optimizer factory for the policy's fraction model + :param action_space: the environment's action space. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + :param num_fractions: the number of fractions to use. + :param ent_coef: coefficient that weights the entropy bonus relative to the actor loss. + Controls the exploration-exploitation trade-off by encouraging policy entropy. + Higher values promote more exploration by encouraging a more uniform action distribution. + Lower values focus more on exploitation of the current policy's knowledge. + Typically set between 0.01 and 0.05 for most actor-critic implementations. + :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal + difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: + higher values reduce bias (by relying less on potentially inaccurate value estimates) + but increase variance (by incorporating more environmental stochasticity and reducing + the averaging effect). A value of 1 corresponds to standard TD learning with immediate + bootstrapping, while very large values approach Monte Carlo-like estimation that uses + complete episode returns. + :param target_update_freq: the number of training iterations between each complete update of + the target network. + Controls how frequently the target Q-network parameters are updated with the current + Q-network values. + A value of 0 disables the target network entirely, using only a single network for both + action selection and bootstrap targets. + Higher values provide more stable learning targets but slow down the propagation of new + value estimates. Lower positive values allow faster learning but may lead to instability + due to rapidly changing targets. + Typically set between 100-10000 for DQN variants, with exact values depending on environment + complexity. + """ + super().__init__( + policy=policy, + optim=optim, + gamma=gamma, + num_quantiles=num_fractions, + n_step_return_horizon=n_step_return_horizon, + target_update_freq=target_update_freq, + ) + self.ent_coef = ent_coef + self.fraction_optim = self._create_optimizer(self.policy.fraction_model, fraction_optim) + + @override + def _create_policy_optimizer(self, optim: OptimizerFactory) -> Algorithm.Optimizer: + # Override to leave out the fraction model (use main model only), as we want + # to use a separate optimizer for the fraction model + return self._create_optimizer(self.policy.model, optim) + + def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: + obs_next_batch = Batch( + obs=buffer[indices].obs_next, + info=[None] * len(indices), + ) # obs_next: s_{t+n} + if self.use_target_network: + result = self.policy(obs_next_batch) + act, fractions = result.act, result.fractions + next_dist = self.policy( + obs_next_batch, model=self.model_old, fractions=fractions + ).logits + else: + next_batch = self.policy(obs_next_batch) + act = next_batch.act + next_dist = next_batch.logits + return next_dist[np.arange(len(act)), act, :] + + def _update_with_batch( + self, + batch: RolloutBatchProtocol, + ) -> FQFTrainingStats: + self._periodically_update_lagged_network_weights() + weight = batch.pop("weight", 1.0) + out = self.policy(batch) + curr_dist_orig = out.logits + taus, tau_hats = out.fractions.taus, out.fractions.tau_hats + act = batch.act + curr_dist = curr_dist_orig[np.arange(len(act)), act, :].unsqueeze(2) + target_dist = batch.returns.unsqueeze(1) + # calculate each element's difference between curr_dist and target_dist + dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") + huber_loss = ( + ( + dist_diff + * (tau_hats.unsqueeze(2) - (target_dist - curr_dist).detach().le(0.0).float()).abs() + ) + .sum(-1) + .mean(1) + ) + quantile_loss = (huber_loss * weight).mean() + # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ + # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 + batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer + # calculate fraction loss + with torch.no_grad(): + sa_quantile_hats = curr_dist_orig[np.arange(len(act)), act, :] + sa_quantiles = out.quantiles_tau[np.arange(len(act)), act, :] + # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ + # blob/master/fqf_iqn_qrdqn/agent/fqf_agent.py L169 + values_1 = sa_quantiles - sa_quantile_hats[:, :-1] + signs_1 = sa_quantiles > torch.cat( + [sa_quantile_hats[:, :1], sa_quantiles[:, :-1]], + dim=1, + ) + + values_2 = sa_quantiles - sa_quantile_hats[:, 1:] + signs_2 = sa_quantiles < torch.cat( + [sa_quantiles[:, 1:], sa_quantile_hats[:, -1:]], + dim=1, + ) + + gradient_of_taus = torch.where(signs_1, values_1, -values_1) + torch.where( + signs_2, + values_2, + -values_2, + ) + fraction_loss = (gradient_of_taus * taus[:, 1:-1]).sum(1).mean() + # calculate entropy loss + entropy_loss = out.fractions.entropies.mean() + fraction_entropy_loss = fraction_loss - self.ent_coef * entropy_loss + self.fraction_optim.step(fraction_entropy_loss, retain_graph=True) + self.optim.step(quantile_loss) + + return FQFTrainingStats( + loss=quantile_loss.item() + fraction_entropy_loss.item(), + quantile_loss=quantile_loss.item(), + fraction_loss=fraction_loss.item(), + entropy_loss=entropy_loss.item(), + ) diff --git a/tianshou/algorithm/modelfree/iqn.py b/tianshou/algorithm/modelfree/iqn.py new file mode 100644 index 000000000..4cd69b80e --- /dev/null +++ b/tianshou/algorithm/modelfree/iqn.py @@ -0,0 +1,183 @@ +from typing import Any, cast + +import gymnasium as gym +import numpy as np +import torch +import torch.nn.functional as F + +from tianshou.algorithm import QRDQN +from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy +from tianshou.algorithm.modelfree.reinforce import SimpleLossTrainingStats +from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import Batch, to_numpy +from tianshou.data.batch import BatchProtocol +from tianshou.data.types import ( + ObsBatchProtocol, + QuantileRegressionBatchProtocol, + RolloutBatchProtocol, +) + + +class IQNPolicy(QRDQNPolicy): + def __init__( + self, + *, + model: torch.nn.Module, + action_space: gym.spaces.Space, + sample_size: int = 32, + online_sample_size: int = 8, + target_sample_size: int = 8, + observation_space: gym.Space | None = None, + eps_training: float = 0.0, + eps_inference: float = 0.0, + ) -> None: + """ + :param model: + :param action_space: the environment's action space + :param sample_size: + :param online_sample_size: + :param target_sample_size: + :param observation_space: the environment's observation space + :param eps_training: the epsilon value for epsilon-greedy exploration during training. + When collecting data for training, this is the probability of choosing a random action + instead of the action chosen by the policy. + A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full + exploration (fully random). + :param eps_inference: the epsilon value for epsilon-greedy exploration during inference, + i.e. non-training cases (such as evaluation during test steps). + The epsilon value is the probability of choosing a random action instead of the action + chosen by the policy. + A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full + exploration (fully random). + """ + assert isinstance(action_space, gym.spaces.Discrete) + assert sample_size > 1, f"sample_size should be greater than 1 but got: {sample_size}" + assert ( + online_sample_size > 1 + ), f"online_sample_size should be greater than 1 but got: {online_sample_size}" + assert ( + target_sample_size > 1 + ), f"target_sample_size should be greater than 1 but got: {target_sample_size}" + super().__init__( + model=model, + action_space=action_space, + observation_space=observation_space, + eps_training=eps_training, + eps_inference=eps_inference, + ) + self.sample_size = sample_size + self.online_sample_size = online_sample_size + self.target_sample_size = target_sample_size + + def forward( + self, + batch: ObsBatchProtocol, + state: dict | BatchProtocol | np.ndarray | None = None, + model: torch.nn.Module | None = None, + **kwargs: Any, + ) -> QuantileRegressionBatchProtocol: + is_model_old = model is not None + if is_model_old: + sample_size = self.target_sample_size + elif self.training: + sample_size = self.online_sample_size + else: + sample_size = self.sample_size + if model is None: + model = self.model + obs = batch.obs + # TODO: this seems very contrived! + obs_next = obs.obs if hasattr(obs, "obs") else obs + (logits, taus), hidden = model( + obs_next, + sample_size=sample_size, + state=state, + info=batch.info, + ) + q = self.compute_q_value(logits, getattr(obs, "mask", None)) + act = to_numpy(q.max(dim=1)[1]) + result = Batch(logits=logits, act=act, state=hidden, taus=taus) + return cast(QuantileRegressionBatchProtocol, result) + + +class IQN(QRDQN[IQNPolicy]): + """Implementation of Implicit Quantile Network. arXiv:1806.06923.""" + + def __init__( + self, + *, + policy: IQNPolicy, + optim: OptimizerFactory, + gamma: float = 0.99, + num_quantiles: int = 200, + n_step_return_horizon: int = 1, + target_update_freq: int = 0, + ) -> None: + """ + :param policy: the policy + :param optim: the optimizer factory for the policy's model + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + :param num_quantiles: the number of quantile midpoints in the inverse + cumulative distribution function of the value. + :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal + difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: + higher values reduce bias (by relying less on potentially inaccurate value estimates) + but increase variance (by incorporating more environmental stochasticity and reducing + the averaging effect). A value of 1 corresponds to standard TD learning with immediate + bootstrapping, while very large values approach Monte Carlo-like estimation that uses + complete episode returns. + :param target_update_freq: the number of training iterations between each complete update of + the target network. + Controls how frequently the target Q-network parameters are updated with the current + Q-network values. + A value of 0 disables the target network entirely, using only a single network for both + action selection and bootstrap targets. + Higher values provide more stable learning targets but slow down the propagation of new + value estimates. Lower positive values allow faster learning but may lead to instability + due to rapidly changing targets. + Typically set between 100-10000 for DQN variants, with exact values depending on environment + complexity. + """ + super().__init__( + policy=policy, + optim=optim, + gamma=gamma, + num_quantiles=num_quantiles, + n_step_return_horizon=n_step_return_horizon, + target_update_freq=target_update_freq, + ) + + def _update_with_batch( + self, + batch: RolloutBatchProtocol, + ) -> SimpleLossTrainingStats: + self._periodically_update_lagged_network_weights() + weight = batch.pop("weight", 1.0) + action_batch = self.policy(batch) + curr_dist, taus = action_batch.logits, action_batch.taus + act = batch.act + curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2) + target_dist = batch.returns.unsqueeze(1) + # calculate each element's difference between curr_dist and target_dist + dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") + huber_loss = ( + ( + dist_diff + * (taus.unsqueeze(2) - (target_dist - curr_dist).detach().le(0.0).float()).abs() + ) + .sum(-1) + .mean(1) + ) + loss = (huber_loss * weight).mean() + # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ + # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 + batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer + self.optim.step(loss) + + return SimpleLossTrainingStats(loss=loss.item()) diff --git a/tianshou/algorithm/modelfree/npg.py b/tianshou/algorithm/modelfree/npg.py new file mode 100644 index 000000000..637850031 --- /dev/null +++ b/tianshou/algorithm/modelfree/npg.py @@ -0,0 +1,236 @@ +from dataclasses import dataclass +from typing import Any + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.distributions import kl_divergence + +from tianshou.algorithm.algorithm_base import TrainingStats +from tianshou.algorithm.modelfree.a2c import ActorCriticOnPolicyAlgorithm +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy +from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as +from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol +from tianshou.utils.net.continuous import ContinuousCritic +from tianshou.utils.net.discrete import DiscreteCritic + + +@dataclass(kw_only=True) +class NPGTrainingStats(TrainingStats): + actor_loss: SequenceSummaryStats + vf_loss: SequenceSummaryStats + kl: SequenceSummaryStats + + +class NPG(ActorCriticOnPolicyAlgorithm): + """Implementation of Natural Policy Gradient. + + https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf + """ + + def __init__( + self, + *, + policy: ProbabilisticActorPolicy, + critic: torch.nn.Module | ContinuousCritic | DiscreteCritic, + optim: OptimizerFactory, + optim_critic_iters: int = 5, + trust_region_size: float = 0.5, + advantage_normalization: bool = True, + gae_lambda: float = 0.95, + max_batchsize: int = 256, + gamma: float = 0.99, + return_scaling: bool = False, + ) -> None: + """ + :param policy: the policy containing the actor network. + :param critic: the critic network. (s -> V(s)) + :param optim: the optimizer factory for the critic network. + :param optim_critic_iters: the number of optimization steps performed on the critic network + for each policy (actor) update. + Controls the learning rate balance between critic and actor. + Higher values prioritize critic accuracy by training the value function more + extensively before each policy update, which can improve stability but slow down + training. Lower values maintain a more even learning pace between policy and value + function but may lead to less reliable advantage estimates. + Typically set between 1 and 10, depending on the complexity of the value function. + :param trust_region_size: the parameter delta - a scalar multiplier for policy updates in the natural gradient direction. + The mathematical meaning is the trust region size, which is the maximum KL divergence + allowed between the old and new policy distributions. + Controls how far the policy parameters move in the calculated direction + during each update. Higher values allow for faster learning but may cause instability + or policy deterioration; lower values provide more stable but slower learning. Unlike + regular policy gradients, natural gradients already account for the local geometry of + the parameter space, making this step size more robust to different parameterizations. + Typically set between 0.1 and 1.0 for most reinforcement learning tasks. + :param advantage_normalization: whether to do per mini-batch advantage + normalization. + :param gae_lambda: the lambda parameter in [0, 1] for generalized advantage estimation (GAE). + Controls the bias-variance tradeoff in advantage estimates, acting as a + weighting factor for combining different n-step advantage estimators. Higher values + (closer to 1) reduce bias but increase variance by giving more weight to longer + trajectories, while lower values (closer to 0) reduce variance but increase bias + by relying more on the immediate TD error and value function estimates. At λ=0, + GAE becomes equivalent to the one-step TD error (high bias, low variance); at λ=1, + it becomes equivalent to Monte Carlo advantage estimation (low bias, high variance). + Intermediate values create a weighted average of n-step returns, with exponentially + decaying weights for longer-horizon returns. Typically set between 0.9 and 0.99 for + most policy gradient methods. + :param max_batchsize: the maximum number of samples to process at once when computing + generalized advantage estimation (GAE) and value function predictions. + Controls memory usage by breaking large batches into smaller chunks processed sequentially. + Higher values may increase speed but require more GPU/CPU memory; lower values + reduce memory requirements but may increase computation time. Should be adjusted + based on available hardware resources and total batch size of your training data. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + :param return_scaling: flag indicating whether to enable scaling of estimated returns by + dividing them by their running standard deviation without centering the mean. + This reduces the magnitude variation of advantages across different episodes while + preserving their signs and relative ordering. + The use of running statistics (rather than batch-specific scaling) means that early + training experiences may be scaled differently than later ones as the statistics evolve. + When enabled, this improves training stability in environments with highly variable + reward scales and makes the algorithm less sensitive to learning rate settings. + However, it may reduce the algorithm's ability to distinguish between episodes with + different absolute return magnitudes. + Best used in environments where the relative ordering of actions is more important + than the absolute scale of returns. + """ + super().__init__( + policy=policy, + critic=critic, + optim=optim, + optim_include_actor=False, + gae_lambda=gae_lambda, + max_batchsize=max_batchsize, + gamma=gamma, + return_scaling=return_scaling, + ) + self.advantage_normalization = advantage_normalization + self.optim_critic_iters = optim_critic_iters + self.trust_region_size = trust_region_size + # adjusts Hessian-vector product calculation for numerical stability + self._damping = 0.1 + + def _preprocess_batch( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> BatchWithAdvantagesProtocol: + batch = self._add_returns_and_advantages(batch, buffer, indices) + batch.act = to_torch_as(batch.act, batch.v_s) + old_log_prob = [] + with torch.no_grad(): + for minibatch in batch.split(self.max_batchsize, shuffle=False, merge_last=True): + old_log_prob.append(self.policy(minibatch).dist.log_prob(minibatch.act)) + batch.logp_old = torch.cat(old_log_prob, dim=0) + if self.advantage_normalization: + batch.adv = (batch.adv - batch.adv.mean()) / batch.adv.std() + return batch + + def _update_with_batch( # type: ignore[override] + self, + batch: BatchWithAdvantagesProtocol, + batch_size: int | None, + repeat: int, + ) -> NPGTrainingStats: + actor_losses, vf_losses, kls = [], [], [] + split_batch_size = batch_size or -1 + for _ in range(repeat): + for minibatch in batch.split(split_batch_size, merge_last=True): + # optimize actor + # direction: calculate villia gradient + dist = self.policy(minibatch).dist + log_prob = dist.log_prob(minibatch.act) + log_prob = log_prob.reshape(log_prob.size(0), -1).transpose(0, 1) + actor_loss = -(log_prob * minibatch.adv).mean() + flat_grads = self._get_flat_grad( + actor_loss, self.policy.actor, retain_graph=True + ).detach() + + # direction: calculate natural gradient + with torch.no_grad(): + old_dist = self.policy(minibatch).dist + + kl = kl_divergence(old_dist, dist).mean() + # calculate first order gradient of kl with respect to theta + flat_kl_grad = self._get_flat_grad(kl, self.policy.actor, create_graph=True) + search_direction = -self._conjugate_gradients(flat_grads, flat_kl_grad, nsteps=10) + + # step + with torch.no_grad(): + flat_params = torch.cat( + [param.data.view(-1) for param in self.policy.actor.parameters()], + ) + new_flat_params = flat_params + self.trust_region_size * search_direction + self._set_from_flat_params(self.policy.actor, new_flat_params) + new_dist = self.policy(minibatch).dist + kl = kl_divergence(old_dist, new_dist).mean() + + # optimize critic + for _ in range(self.optim_critic_iters): + value = self.critic(minibatch.obs).flatten() + vf_loss = F.mse_loss(minibatch.returns, value) + self.optim.step(vf_loss) + + actor_losses.append(actor_loss.item()) + vf_losses.append(vf_loss.item()) + kls.append(kl.item()) + + return NPGTrainingStats( + actor_loss=SequenceSummaryStats.from_sequence(actor_losses), + vf_loss=SequenceSummaryStats.from_sequence(vf_losses), + kl=SequenceSummaryStats.from_sequence(kls), + ) + + def _MVP(self, v: torch.Tensor, flat_kl_grad: torch.Tensor) -> torch.Tensor: + """Matrix vector product.""" + # caculate second order gradient of kl with respect to theta + kl_v = (flat_kl_grad * v).sum() + flat_kl_grad_grad = self._get_flat_grad(kl_v, self.policy.actor, retain_graph=True).detach() + return flat_kl_grad_grad + v * self._damping + + def _conjugate_gradients( + self, + minibatch: torch.Tensor, + flat_kl_grad: torch.Tensor, + nsteps: int = 10, + residual_tol: float = 1e-10, + ) -> torch.Tensor: + x = torch.zeros_like(minibatch) + r, p = minibatch.clone(), minibatch.clone() + # Note: should be 'r, p = minibatch - MVP(x)', but for x=0, MVP(x)=0. + # Change if doing warm start. + rdotr = r.dot(r) + for _ in range(nsteps): + z = self._MVP(p, flat_kl_grad) + alpha = rdotr / p.dot(z) + x += alpha * p + r -= alpha * z + new_rdotr = r.dot(r) + if new_rdotr < residual_tol: + break + p = r + new_rdotr / rdotr * p + rdotr = new_rdotr + return x + + def _get_flat_grad(self, y: torch.Tensor, model: nn.Module, **kwargs: Any) -> torch.Tensor: + grads = torch.autograd.grad(y, model.parameters(), **kwargs) # type: ignore + return torch.cat([grad.reshape(-1) for grad in grads]) + + def _set_from_flat_params(self, model: nn.Module, flat_params: torch.Tensor) -> nn.Module: + prev_ind = 0 + for param in model.parameters(): + flat_size = int(np.prod(list(param.size()))) + param.data.copy_(flat_params[prev_ind : prev_ind + flat_size].view(param.size())) + prev_ind += flat_size + return model diff --git a/tianshou/algorithm/modelfree/ppo.py b/tianshou/algorithm/modelfree/ppo.py new file mode 100644 index 000000000..ede1d3418 --- /dev/null +++ b/tianshou/algorithm/modelfree/ppo.py @@ -0,0 +1,224 @@ +from typing import cast + +import numpy as np +import torch + +from tianshou.algorithm import A2C +from tianshou.algorithm.modelfree.a2c import A2CTrainingStats +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy +from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as +from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol +from tianshou.utils.net.continuous import ContinuousCritic +from tianshou.utils.net.discrete import DiscreteCritic + + +class PPO(A2C): + """Implementation of Proximal Policy Optimization. arXiv:1707.06347.""" + + def __init__( + self, + *, + policy: ProbabilisticActorPolicy, + critic: torch.nn.Module | ContinuousCritic | DiscreteCritic, + optim: OptimizerFactory, + eps_clip: float = 0.2, + dual_clip: float | None = None, + value_clip: bool = False, + advantage_normalization: bool = True, + recompute_advantage: bool = False, + vf_coef: float = 0.5, + ent_coef: float = 0.01, + max_grad_norm: float | None = None, + gae_lambda: float = 0.95, + max_batchsize: int = 256, + gamma: float = 0.99, + return_scaling: bool = False, + ) -> None: + r""" + :param policy: the policy containing the actor network. + :param critic: the critic network. (s -> V(s)) + :param optim: the optimizer factory for the policy's actor network and the critic networks. + :param eps_clip: determines the range of allowed change in the policy during a policy update: + The ratio of action probabilities indicated by the new and old policy is + constrained to stay in the interval [1 - eps_clip, 1 + eps_clip]. + Small values thus force the new policy to stay close to the old policy. + Typical values range between 0.1 and 0.3, the value of 0.2 is recommended + in the original PPO paper. + The optimal value depends on the environment; more stochastic environments may + need larger values. + :param dual_clip: a clipping parameter (denoted as c in the literature) that prevents + excessive pessimism in policy updates for negative-advantage actions. + Excessive pessimism occurs when the policy update too strongly reduces the probability + of selecting actions that led to negative advantages, potentially eliminating useful + actions based on limited negative experiences. + When enabled (c > 1), the objective for negative advantages becomes: + max(min(r(θ)A, clip(r(θ), 1-ε, 1+ε)A), c*A), where min(r(θ)A, clip(r(θ), 1-ε, 1+ε)A) + is the original single-clipping objective determined by `eps_clip`. + This creates a floor on negative policy gradients, maintaining some probability + of exploring actions despite initial negative outcomes. + Larger values (e.g., 2.0 to 5.0) maintain more exploration, while values closer + to 1.0 provide less protection against pessimistic updates. + Set to None to disable dual clipping. + :param value_clip: flag indicating whether to enable clipping for value function updates. + When enabled, restricts how much the value function estimate can change from its + previous prediction, using the same clipping range as the policy updates (eps_clip). + This stabilizes training by preventing large fluctuations in value estimates, + particularly useful in environments with high reward variance. + The clipped value loss uses a pessimistic approach, taking the maximum of the + original and clipped value errors: + max((returns - value)², (returns - v_clipped)²) + Setting to True often improves training stability but may slow convergence. + Implementation follows the approach mentioned in arXiv:1811.02553v3 Sec. 4.1. + :param advantage_normalization: whether to do per mini-batch advantage + normalization. + :param recompute_advantage: whether to recompute advantage every update + repeat according to https://arxiv.org/pdf/2006.05990.pdf Sec. 3.5. + :param vf_coef: coefficient that weights the value loss relative to the actor loss in + the overall loss function. + Higher values prioritize accurate value function estimation over policy improvement. + Controls the trade-off between policy optimization and value function fitting. + Typically set between 0.5 and 1.0 for most actor-critic implementations. + :param ent_coef: coefficient that weights the entropy bonus relative to the actor loss. + Controls the exploration-exploitation trade-off by encouraging policy entropy. + Higher values promote more exploration by encouraging a more uniform action distribution. + Lower values focus more on exploitation of the current policy's knowledge. + Typically set between 0.01 and 0.05 for most actor-critic implementations. + :param max_grad_norm: the maximum L2 norm threshold for gradient clipping. + When not None, gradients will be rescaled using to ensure their L2 norm does not + exceed this value. This prevents exploding gradients and stabilizes training by + limiting the magnitude of parameter updates. + Set to None to disable gradient clipping. + :param gae_lambda: the lambda parameter in [0, 1] for generalized advantage estimation (GAE). + Controls the bias-variance tradeoff in advantage estimates, acting as a + weighting factor for combining different n-step advantage estimators. Higher values + (closer to 1) reduce bias but increase variance by giving more weight to longer + trajectories, while lower values (closer to 0) reduce variance but increase bias + by relying more on the immediate TD error and value function estimates. At λ=0, + GAE becomes equivalent to the one-step TD error (high bias, low variance); at λ=1, + it becomes equivalent to Monte Carlo advantage estimation (low bias, high variance). + Intermediate values create a weighted average of n-step returns, with exponentially + decaying weights for longer-horizon returns. Typically set between 0.9 and 0.99 for + most policy gradient methods. + :param max_batchsize: the maximum size of the batch when computing GAE. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + :param return_scaling: flag indicating whether to enable scaling of estimated returns by + dividing them by their running standard deviation without centering the mean. + This reduces the magnitude variation of advantages across different episodes while + preserving their signs and relative ordering. + The use of running statistics (rather than batch-specific scaling) means that early + training experiences may be scaled differently than later ones as the statistics evolve. + When enabled, this improves training stability in environments with highly variable + reward scales and makes the algorithm less sensitive to learning rate settings. + However, it may reduce the algorithm's ability to distinguish between episodes with + different absolute return magnitudes. + Best used in environments where the relative ordering of actions is more important + than the absolute scale of returns. + """ + assert ( + dual_clip is None or dual_clip > 1.0 + ), f"Dual-clip PPO parameter should greater than 1.0 but got {dual_clip}" + + super().__init__( + policy=policy, + critic=critic, + optim=optim, + vf_coef=vf_coef, + ent_coef=ent_coef, + max_grad_norm=max_grad_norm, + gae_lambda=gae_lambda, + max_batchsize=max_batchsize, + gamma=gamma, + return_scaling=return_scaling, + ) + self.eps_clip = eps_clip + self.dual_clip = dual_clip + self.value_clip = value_clip + self.advantage_normalization = advantage_normalization + self.recompute_adv = recompute_advantage + + def _preprocess_batch( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> LogpOldProtocol: + if self.recompute_adv: + # buffer input `buffer` and `indices` to be used in `_update_with_batch()`. + self._buffer, self._indices = buffer, indices + batch = self._add_returns_and_advantages(batch, buffer, indices) + batch.act = to_torch_as(batch.act, batch.v_s) + logp_old = [] + with torch.no_grad(): + for minibatch in batch.split(self.max_batchsize, shuffle=False, merge_last=True): + logp_old.append(self.policy(minibatch).dist.log_prob(minibatch.act)) + batch.logp_old = torch.cat(logp_old, dim=0).flatten() + return cast(LogpOldProtocol, batch) + + def _update_with_batch( # type: ignore[override] + self, + batch: LogpOldProtocol, + batch_size: int | None, + repeat: int, + ) -> A2CTrainingStats: + losses, clip_losses, vf_losses, ent_losses = [], [], [], [] + gradient_steps = 0 + split_batch_size = batch_size or -1 + for step in range(repeat): + if self.recompute_adv and step > 0: + batch = cast( + LogpOldProtocol, + self._add_returns_and_advantages(batch, self._buffer, self._indices), + ) + for minibatch in batch.split(split_batch_size, merge_last=True): + gradient_steps += 1 + # calculate loss for actor + advantages = minibatch.adv + dist = self.policy(minibatch).dist + if self.advantage_normalization: + mean, std = advantages.mean(), advantages.std() + advantages = (advantages - mean) / (std + self._eps) # per-batch norm + ratios = (dist.log_prob(minibatch.act) - minibatch.logp_old).exp().float() + ratios = ratios.reshape(ratios.size(0), -1).transpose(0, 1) + surr1 = ratios * advantages + surr2 = ratios.clamp(1.0 - self.eps_clip, 1.0 + self.eps_clip) * advantages + if self.dual_clip: + clip1 = torch.min(surr1, surr2) + clip2 = torch.max(clip1, self.dual_clip * advantages) + clip_loss = -torch.where(advantages < 0, clip2, clip1).mean() + else: + clip_loss = -torch.min(surr1, surr2).mean() + # calculate loss for critic + value = self.critic(minibatch.obs).flatten() + if self.value_clip: + v_clip = minibatch.v_s + (value - minibatch.v_s).clamp( + -self.eps_clip, + self.eps_clip, + ) + vf1 = (minibatch.returns - value).pow(2) + vf2 = (minibatch.returns - v_clip).pow(2) + vf_loss = torch.max(vf1, vf2).mean() + else: + vf_loss = (minibatch.returns - value).pow(2).mean() + # calculate regularization and overall loss + ent_loss = dist.entropy().mean() + loss = clip_loss + self.vf_coef * vf_loss - self.ent_coef * ent_loss + self.optim.step(loss) + clip_losses.append(clip_loss.item()) + vf_losses.append(vf_loss.item()) + ent_losses.append(ent_loss.item()) + losses.append(loss.item()) + + return A2CTrainingStats( + loss=SequenceSummaryStats.from_sequence(losses), + actor_loss=SequenceSummaryStats.from_sequence(clip_losses), + vf_loss=SequenceSummaryStats.from_sequence(vf_losses), + ent_loss=SequenceSummaryStats.from_sequence(ent_losses), + gradient_steps=gradient_steps, + ) diff --git a/tianshou/algorithm/modelfree/qrdqn.py b/tianshou/algorithm/modelfree/qrdqn.py new file mode 100644 index 000000000..af1cb416a --- /dev/null +++ b/tianshou/algorithm/modelfree/qrdqn.py @@ -0,0 +1,131 @@ +import warnings +from typing import Generic, TypeVar + +import numpy as np +import torch +import torch.nn.functional as F + +from tianshou.algorithm.modelfree.dqn import ( + DiscreteQLearningPolicy, + QLearningOffPolicyAlgorithm, +) +from tianshou.algorithm.modelfree.reinforce import SimpleLossTrainingStats +from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import Batch, ReplayBuffer +from tianshou.data.types import RolloutBatchProtocol + + +class QRDQNPolicy(DiscreteQLearningPolicy): + def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torch.Tensor: + return super().compute_q_value(logits.mean(2), mask) + + +TQRDQNPolicy = TypeVar("TQRDQNPolicy", bound=QRDQNPolicy) + + +class QRDQN( + QLearningOffPolicyAlgorithm[TQRDQNPolicy], + Generic[TQRDQNPolicy], +): + """Implementation of Quantile Regression Deep Q-Network. arXiv:1710.10044.""" + + def __init__( + self, + *, + policy: TQRDQNPolicy, + optim: OptimizerFactory, + gamma: float = 0.99, + num_quantiles: int = 200, + n_step_return_horizon: int = 1, + target_update_freq: int = 0, + ) -> None: + """ + :param policy: the policy + :param optim: the optimizer factory for the policy's model. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + :param num_quantiles: the number of quantiles used to represent the return distribution for each action. + Determines the granularity of the approximated distribution function. + Higher values provide a more fine-grained approximation of the true return distribution but + increase computational and memory requirements. + Lower values reduce computational cost but may not capture the distribution accurately enough. + The original QRDQN paper used 200 quantiles for Atari environments. + Must be greater than 1, as at least two quantiles are needed to represent a distribution. + :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal + difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: + higher values reduce bias (by relying less on potentially inaccurate value estimates) + but increase variance (by incorporating more environmental stochasticity and reducing + the averaging effect). A value of 1 corresponds to standard TD learning with immediate + bootstrapping, while very large values approach Monte Carlo-like estimation that uses + complete episode returns. + :param target_update_freq: the number of training iterations between each complete update of + the target network. + Controls how frequently the target Q-network parameters are updated with the current + Q-network values. + A value of 0 disables the target network entirely, using only a single network for both + action selection and bootstrap targets. + Higher values provide more stable learning targets but slow down the propagation of new + value estimates. Lower positive values allow faster learning but may lead to instability + due to rapidly changing targets. + Typically set between 100-10000 for DQN variants, with exact values depending on environment + complexity. + """ + assert num_quantiles > 1, f"num_quantiles should be greater than 1 but got: {num_quantiles}" + super().__init__( + policy=policy, + optim=optim, + gamma=gamma, + n_step_return_horizon=n_step_return_horizon, + target_update_freq=target_update_freq, + ) + self.num_quantiles = num_quantiles + tau = torch.linspace(0, 1, self.num_quantiles + 1) + self.tau_hat = torch.nn.Parameter( + ((tau[:-1] + tau[1:]) / 2).view(1, -1, 1), + requires_grad=False, + ) + warnings.filterwarnings("ignore", message="Using a target size") + + def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: + obs_next_batch = Batch( + obs=buffer[indices].obs_next, + info=[None] * len(indices), + ) # obs_next: s_{t+n} + if self.use_target_network: + act = self.policy(obs_next_batch).act + next_dist = self.policy(obs_next_batch, model=self.model_old).logits + else: + next_batch = self.policy(obs_next_batch) + act = next_batch.act + next_dist = next_batch.logits + return next_dist[np.arange(len(act)), act, :] + + def _update_with_batch( + self, + batch: RolloutBatchProtocol, + ) -> SimpleLossTrainingStats: + self._periodically_update_lagged_network_weights() + weight = batch.pop("weight", 1.0) + curr_dist = self.policy(batch).logits + act = batch.act + curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2) + target_dist = batch.returns.unsqueeze(1) + # calculate each element's difference between curr_dist and target_dist + dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") + huber_loss = ( + (dist_diff * (self.tau_hat - (target_dist - curr_dist).detach().le(0.0).float()).abs()) + .sum(-1) + .mean(1) + ) + loss = (huber_loss * weight).mean() + # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ + # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 + batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer + self.optim.step(loss) + + return SimpleLossTrainingStats(loss=loss.item()) diff --git a/tianshou/algorithm/modelfree/rainbow.py b/tianshou/algorithm/modelfree/rainbow.py new file mode 100644 index 000000000..e04c6876e --- /dev/null +++ b/tianshou/algorithm/modelfree/rainbow.py @@ -0,0 +1,101 @@ +from dataclasses import dataclass + +from torch import nn + +from tianshou.algorithm.modelfree.c51 import C51, C51Policy +from tianshou.algorithm.modelfree.reinforce import LossSequenceTrainingStats +from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data.types import RolloutBatchProtocol +from tianshou.utils.lagged_network import EvalModeModuleWrapper +from tianshou.utils.net.discrete import NoisyLinear + + +@dataclass(kw_only=True) +class RainbowTrainingStats: + loss: float + + +class RainbowDQN(C51): + """Implementation of Rainbow DQN. arXiv:1710.02298.""" + + def __init__( + self, + *, + policy: C51Policy, + optim: OptimizerFactory, + gamma: float = 0.99, + n_step_return_horizon: int = 1, + target_update_freq: int = 0, + ) -> None: + """ + :param policy: a policy following the rules (s -> action_values_BA) + :param optim: the optimizer factory for the policy's model. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal + difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: + higher values reduce bias (by relying less on potentially inaccurate value estimates) + but increase variance (by incorporating more environmental stochasticity and reducing + the averaging effect). A value of 1 corresponds to standard TD learning with immediate + bootstrapping, while very large values approach Monte Carlo-like estimation that uses + complete episode returns. + :param target_update_freq: the number of training iterations between each complete update of + the target network. + Controls how frequently the target Q-network parameters are updated with the current + Q-network values. + A value of 0 disables the target network entirely, using only a single network for both + action selection and bootstrap targets. + Higher values provide more stable learning targets but slow down the propagation of new + value estimates. Lower positive values allow faster learning but may lead to instability + due to rapidly changing targets. + Typically set between 100-10000 for DQN variants, with exact values depending on environment + complexity. + """ + super().__init__( + policy=policy, + optim=optim, + gamma=gamma, + n_step_return_horizon=n_step_return_horizon, + target_update_freq=target_update_freq, + ) + + self.model_old: nn.Module | None # type: ignore[assignment] + # Remove the wrapper that forces eval mode for the target network, + # because Rainbow requires it to be set to train mode for sampling noise + # in NoisyLinear layers to take effect. + # (minor violation of Liskov Substitution Principle) + if self.use_target_network: + assert isinstance(self.model_old, EvalModeModuleWrapper) + self.model_old = self.model_old.module + + @staticmethod + def _sample_noise(model: nn.Module) -> bool: + """Sample the random noises of NoisyLinear modules in the model. + + Returns True if at least one NoisyLinear submodule was found. + + :param model: a PyTorch module which may have NoisyLinear submodules. + :returns: True if model has at least one NoisyLinear submodule; + otherwise, False. + """ + sampled_any_noise = False + for m in model.modules(): + if isinstance(m, NoisyLinear): + m.sample() + sampled_any_noise = True + return sampled_any_noise + + def _update_with_batch( + self, + batch: RolloutBatchProtocol, + ) -> LossSequenceTrainingStats: + self._sample_noise(self.policy.model) + if self.use_target_network: + assert self.model_old is not None + self._sample_noise(self.model_old) + return super()._update_with_batch(batch) diff --git a/tianshou/algorithm/modelfree/redq.py b/tianshou/algorithm/modelfree/redq.py new file mode 100644 index 000000000..786db84d9 --- /dev/null +++ b/tianshou/algorithm/modelfree/redq.py @@ -0,0 +1,304 @@ +from dataclasses import dataclass +from typing import Any, Literal, TypeVar, cast + +import gymnasium as gym +import numpy as np +import torch +from torch.distributions import Independent, Normal + +from tianshou.algorithm.modelfree.ddpg import ( + ActorCriticOffPolicyAlgorithm, + ContinuousPolicyWithExplorationNoise, + DDPGTrainingStats, +) +from tianshou.algorithm.modelfree.sac import Alpha +from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import Batch +from tianshou.data.types import ( + DistLogProbBatchProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) +from tianshou.exploration import BaseNoise +from tianshou.utils.net.continuous import ContinuousActorProbabilistic + + +@dataclass +class REDQTrainingStats(DDPGTrainingStats): + """A data structure for storing loss statistics of the REDQ learn step.""" + + alpha: float | None = None + alpha_loss: float | None = None + + +TREDQTrainingStats = TypeVar("TREDQTrainingStats", bound=REDQTrainingStats) + + +class REDQPolicy(ContinuousPolicyWithExplorationNoise): + def __init__( + self, + *, + actor: torch.nn.Module | ContinuousActorProbabilistic, + exploration_noise: BaseNoise | Literal["default"] | None = None, + action_space: gym.spaces.Space, + deterministic_eval: bool = True, + action_scaling: bool = True, + action_bound_method: Literal["clip"] | None = "clip", + observation_space: gym.Space | None = None, + ): + """ + :param actor: The actor network following the rules (s -> model_output) + :param action_space: the environment's action_space. + :param deterministic_eval: flag indicating whether the policy should use deterministic + actions (using the mode of the action distribution) instead of stochastic ones + (using random sampling) during evaluation. + When enabled, the policy will always select the most probable action according to + the learned distribution during evaluation phases, while still using stochastic + sampling during training. This creates a clear distinction between exploration + (training) and exploitation (evaluation) behaviors. + Deterministic actions are generally preferred for final deployment and reproducible + evaluation as they provide consistent behavior, reduce variance in performance + metrics, and are more interpretable for human observers. + Note that this parameter only affects behavior when the policy is not within a + training step. When collecting rollouts for training, actions remain stochastic + regardless of this setting to maintain proper exploration behaviour. + :param observation_space: the environment's observation space + :param action_scaling: flag indicating whether, for continuous action spaces, actions + should be scaled from the standard neural network output range [-1, 1] to the + environment's action space range [action_space.low, action_space.high]. + This applies to continuous action spaces only (gym.spaces.Box) and has no effect + for discrete spaces. + When enabled, policy outputs are expected to be in the normalized range [-1, 1] + (after bounding), and are then linearly transformed to the actual required range. + This improves neural network training stability, allows the same algorithm to work + across environments with different action ranges, and standardizes exploration + strategies. + Should be disabled if the actor model already produces outputs in the correct range. + :param action_bound_method: the method used for bounding actions in continuous action spaces + to the range [-1, 1] before scaling them to the environment's action space (provided + that `action_scaling` is enabled). + This applies to continuous action spaces only (`gym.spaces.Box`) and should be set to None + for discrete spaces. + When set to "clip", actions exceeding the [-1, 1] range are simply clipped to this + range. When set to "tanh", a hyperbolic tangent function is applied, which smoothly + constrains outputs to [-1, 1] while preserving gradients. + The choice of bounding method affects both training dynamics and exploration behavior. + Clipping provides hard boundaries but may create plateau regions in the gradient + landscape, while tanh provides smoother transitions but can compress sensitivity + near the boundaries. + Should be set to None if the actor model inherently produces bounded outputs. + Typically used together with `action_scaling=True`. + """ + super().__init__( + exploration_noise=exploration_noise, + action_space=action_space, + observation_space=observation_space, + action_scaling=action_scaling, + action_bound_method=action_bound_method, + ) + self.actor = actor + self.deterministic_eval = deterministic_eval + self._eps = np.finfo(np.float32).eps.item() + + def forward( # type: ignore + self, + batch: ObsBatchProtocol, + state: dict | Batch | np.ndarray | None = None, + **kwargs: Any, + ) -> DistLogProbBatchProtocol: + (loc_B, scale_B), h_BH = self.actor(batch.obs, state=state, info=batch.info) + dist = Independent(Normal(loc_B, scale_B), 1) + if self.deterministic_eval and not self.is_within_training_step: + act_B = dist.mode + else: + act_B = dist.rsample() + log_prob = dist.log_prob(act_B).unsqueeze(-1) + # apply correction for Tanh squashing when computing logprob from Gaussian + # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. + # in appendix C to get some understanding of this equation. + squashed_action = torch.tanh(act_B) + log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self._eps).sum( + -1, + keepdim=True, + ) + result = Batch( + logits=(loc_B, scale_B), + act=squashed_action, + state=h_BH, + dist=dist, + log_prob=log_prob, + ) + return cast(DistLogProbBatchProtocol, result) + + +class REDQ(ActorCriticOffPolicyAlgorithm[REDQPolicy, DistLogProbBatchProtocol]): + """Implementation of REDQ. arXiv:2101.05982.""" + + def __init__( + self, + *, + policy: REDQPolicy, + policy_optim: OptimizerFactory, + critic: torch.nn.Module, + critic_optim: OptimizerFactory, + ensemble_size: int = 10, + subset_size: int = 2, + tau: float = 0.005, + gamma: float = 0.99, + alpha: float | Alpha = 0.2, + n_step_return_horizon: int = 1, + actor_delay: int = 20, + deterministic_eval: bool = True, + target_mode: Literal["mean", "min"] = "min", + ) -> None: + """ + :param policy: the policy + :param policy_optim: the optimizer factory for the policy's model. + :param critic: the critic network. (s, a -> Q(s, a)) + :param critic_optim: the optimizer factory for the critic network. + :param ensemble_size: the total number of critic networks in the ensemble. + This parameter implements the randomized ensemble approach described in REDQ. + The algorithm maintains `ensemble_size` different critic networks that all share the same + architecture. During target value computation, a random subset of these networks (determined + by `subset_size`) is used. + Larger values increase the diversity of the ensemble but require more memory and computation. + The original paper recommends a value of 10 for most tasks, balancing performance and + computational efficiency. + :param subset_size: the number of critic networks randomly selected from the ensemble for + computing target Q-values. + During each update, the algorithm samples `subset_size` networks from the ensemble of + `ensemble_size` networks without replacement. + The target Q-value is then calculated as either the minimum or mean (based on `target_mode`) + of the predictions from this subset. + Smaller values increase randomization and sample efficiency but may introduce more variance. + Larger values provide more stable estimates but reduce the benefits of randomization. + The REDQ paper recommends a value of 2 for optimal sample efficiency. + Must satisfy 0 < subset_size <= ensemble_size. + :param tau: the soft update coefficient for target networks, controlling the rate at which + target networks track the learned networks. + When the parameters of the target network are updated with the current (source) network's + parameters, a weighted average is used: target = tau * source + (1 - tau) * target. + Smaller values (closer to 0) create more stable but slower learning as target networks + change more gradually. Higher values (closer to 1) allow faster learning but may reduce + stability. + Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + :param alpha: the entropy regularization coefficient, which balances exploration and exploitation. + This coefficient controls how much the agent values randomness in its policy versus + pursuing higher rewards. + Higher values (e.g., 0.5-1.0) strongly encourage exploration by rewarding the agent + for maintaining diverse action choices, even if this means selecting some lower-value actions. + Lower values (e.g., 0.01-0.1) prioritize exploitation, allowing the policy to become + more focused on the highest-value actions. + A value of 0 would completely remove entropy regularization, potentially leading to + premature convergence to suboptimal deterministic policies. + Can be provided as a fixed float (0.2 is a reasonable default) or as an instance of, + in particular, class `AutoAlpha` for automatic tuning during training. + :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal + difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: + higher values reduce bias (by relying less on potentially inaccurate value estimates) + but increase variance (by incorporating more environmental stochasticity and reducing + the averaging effect). A value of 1 corresponds to standard TD learning with immediate + bootstrapping, while very large values approach Monte Carlo-like estimation that uses + complete episode returns. + :param actor_delay: the number of critic updates performed before each actor update. + The actor network is only updated once for every actor_delay critic updates, implementing + a delayed policy update strategy similar to TD3. + Larger values stabilize training by allowing critics to become more accurate before policy updates. + Smaller values allow the policy to adapt more quickly but may lead to less stable learning. + The REDQ paper recommends a value of 20 for most tasks. + :param target_mode: the method used to aggregate Q-values from the subset of critic networks. + Can be either "min" or "mean". + If "min", uses the minimum Q-value across the selected subset of critics for each state-action pair. + If "mean", uses the average Q-value across the selected subset of critics. + Using "min" helps prevent overestimation bias but may lead to more conservative value estimates. + Using "mean" provides more optimistic value estimates but may suffer from overestimation bias. + Default is "min" following the conservative value estimation approach common in recent Q-learning + algorithms. + """ + if target_mode not in ("min", "mean"): + raise ValueError(f"Unsupported target_mode: {target_mode}") + if not 0 < subset_size <= ensemble_size: + raise ValueError( + f"Invalid choice of ensemble size or subset size. " + f"Should be 0 < {subset_size=} <= {ensemble_size=}", + ) + super().__init__( + policy=policy, + policy_optim=policy_optim, + critic=critic, + critic_optim=critic_optim, + tau=tau, + gamma=gamma, + n_step_return_horizon=n_step_return_horizon, + ) + self.ensemble_size = ensemble_size + self.subset_size = subset_size + + self.target_mode = target_mode + self.critic_gradient_step = 0 + self.actor_delay = actor_delay + self.deterministic_eval = deterministic_eval + self.__eps = np.finfo(np.float32).eps.item() + + self._last_actor_loss = 0.0 # only for logging purposes + + self.alpha = Alpha.from_float_or_instance(alpha) + + def _target_q_compute_value( + self, obs_batch: Batch, act_batch: DistLogProbBatchProtocol + ) -> torch.Tensor: + a_ = act_batch.act + sample_ensemble_idx = np.random.choice(self.ensemble_size, self.subset_size, replace=False) + qs = self.critic_old(obs_batch.obs, a_)[sample_ensemble_idx, ...] + if self.target_mode == "min": + target_q, _ = torch.min(qs, dim=0) + elif self.target_mode == "mean": + target_q = torch.mean(qs, dim=0) + else: + raise ValueError(f"Invalid target_mode: {self.target_mode}") + + target_q -= self.alpha.value * act_batch.log_prob + + return target_q + + def _update_with_batch(self, batch: RolloutBatchProtocol) -> REDQTrainingStats: # type: ignore + # critic ensemble + weight = getattr(batch, "weight", 1.0) + current_qs = self.critic(batch.obs, batch.act).flatten(1) + target_q = batch.returns.flatten() + td = current_qs - target_q + critic_loss = (td.pow(2) * weight).mean() + self.critic_optim.step(critic_loss) + batch.weight = torch.mean(td, dim=0) # prio-buffer + self.critic_gradient_step += 1 + + alpha_loss = None + # actor + if self.critic_gradient_step % self.actor_delay == 0: + obs_result = self.policy(batch) + a = obs_result.act + current_qa = self.critic(batch.obs, a).mean(dim=0).flatten() + actor_loss = (self.alpha.value * obs_result.log_prob.flatten() - current_qa).mean() + self.policy_optim.step(actor_loss) + + # The entropy of a Gaussian policy can be expressed as -log_prob + a constant (which we ignore) + entropy = -obs_result.log_prob.detach() + alpha_loss = self.alpha.update(entropy) + + self._last_actor_loss = actor_loss.item() + + self._update_lagged_network_weights() + + return REDQTrainingStats( + actor_loss=self._last_actor_loss, + critic_loss=critic_loss.item(), + alpha=self.alpha.value, + alpha_loss=alpha_loss, + ) diff --git a/tianshou/algorithm/modelfree/reinforce.py b/tianshou/algorithm/modelfree/reinforce.py new file mode 100644 index 000000000..60fc91cae --- /dev/null +++ b/tianshou/algorithm/modelfree/reinforce.py @@ -0,0 +1,382 @@ +import logging +import warnings +from collections.abc import Callable +from dataclasses import dataclass +from typing import Literal, TypeVar, cast + +import gymnasium as gym +import numpy as np +import torch + +from tianshou.algorithm import Algorithm +from tianshou.algorithm.algorithm_base import ( + OnPolicyAlgorithm, + Policy, + TrainingStats, +) +from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import ( + Batch, + ReplayBuffer, + SequenceSummaryStats, + to_torch, + to_torch_as, +) +from tianshou.data.batch import BatchProtocol +from tianshou.data.types import ( + BatchWithReturnsProtocol, + DistBatchProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) +from tianshou.utils import RunningMeanStd +from tianshou.utils.net.common import ( + AbstractContinuousActorProbabilistic, + AbstractDiscreteActor, + ActionReprNet, +) +from tianshou.utils.net.discrete import dist_fn_categorical_from_logits + +log = logging.getLogger(__name__) + + +# Dimension Naming Convention +# B - Batch Size +# A - Action +# D - Dist input (usually 2, loc and scale) +# H - Dimension of hidden, can be None + +TDistFnContinuous = Callable[ + [tuple[torch.Tensor, torch.Tensor]], + torch.distributions.Distribution, +] +TDistFnDiscrete = Callable[[torch.Tensor], torch.distributions.Distribution] + +TDistFnDiscrOrCont = TDistFnContinuous | TDistFnDiscrete + + +@dataclass(kw_only=True) +class LossSequenceTrainingStats(TrainingStats): + loss: SequenceSummaryStats + + +@dataclass(kw_only=True) +class SimpleLossTrainingStats(TrainingStats): + loss: float + + +class ProbabilisticActorPolicy(Policy): + """ + A policy that outputs (representations of) probability distributions from which + actions can be sampled. + """ + + def __init__( + self, + *, + actor: AbstractContinuousActorProbabilistic | AbstractDiscreteActor | ActionReprNet, + dist_fn: TDistFnDiscrOrCont, + deterministic_eval: bool = False, + action_space: gym.Space, + observation_space: gym.Space | None = None, + action_scaling: bool = True, + action_bound_method: Literal["clip", "tanh"] | None = "clip", + ) -> None: + """ + :param actor: the actor network following the rules: + If `self.action_type == "discrete"`: (`s_B` -> `action_values_BA`). + If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`). + :param dist_fn: the function/type which creates a distribution from the actor output, + i.e. it maps the tensor(s) generated by the actor to a torch distribution. + For continuous action spaces, the output is typically a pair of tensors + (mean, std) and the distribution is a Gaussian distribution. + For discrete action spaces, the output is typically a tensor of unnormalized + log probabilities ("logits" in PyTorch terminology) or a tensor of probabilities + which can serve as the parameters of a Categorical distribution. + Note that if the actor uses softmax activation in its final layer, it will produce + probabilities, whereas if it uses no activation, it can be considered as producing + "logits". + As a user, you are responsible for ensuring that the distribution + is compatible with the output of the actor model and the action space. + :param deterministic_eval: flag indicating whether the policy should use deterministic + actions (using the mode of the action distribution) instead of stochastic ones + (using random sampling) during evaluation. + When enabled, the policy will always select the most probable action according to + the learned distribution during evaluation phases, while still using stochastic + sampling during training. This creates a clear distinction between exploration + (training) and exploitation (evaluation) behaviors. + Deterministic actions are generally preferred for final deployment and reproducible + evaluation as they provide consistent behavior, reduce variance in performance + metrics, and are more interpretable for human observers. + Note that this parameter only affects behavior when the policy is not within a + training step. When collecting rollouts for training, actions remain stochastic + regardless of this setting to maintain proper exploration behaviour. + :param action_space: the environment's action space. + :param observation_space: the environment's observation space. + :param action_scaling: flag indicating whether, for continuous action spaces, actions + should be scaled from the standard neural network output range [-1, 1] to the + environment's action space range [action_space.low, action_space.high]. + This applies to continuous action spaces only (gym.spaces.Box) and has no effect + for discrete spaces. + When enabled, policy outputs are expected to be in the normalized range [-1, 1] + (after bounding), and are then linearly transformed to the actual required range. + This improves neural network training stability, allows the same algorithm to work + across environments with different action ranges, and standardizes exploration + strategies. + Should be disabled if the actor model already produces outputs in the correct range. + :param action_bound_method: the method used for bounding actions in continuous action spaces + to the range [-1, 1] before scaling them to the environment's action space (provided + that `action_scaling` is enabled). + This applies to continuous action spaces only (`gym.spaces.Box`) and should be set to None + for discrete spaces. + When set to "clip", actions exceeding the [-1, 1] range are simply clipped to this + range. When set to "tanh", a hyperbolic tangent function is applied, which smoothly + constrains outputs to [-1, 1] while preserving gradients. + The choice of bounding method affects both training dynamics and exploration behavior. + Clipping provides hard boundaries but may create plateau regions in the gradient + landscape, while tanh provides smoother transitions but can compress sensitivity + near the boundaries. + Should be set to None if the actor model inherently produces bounded outputs. + Typically used together with `action_scaling=True`. + """ + super().__init__( + action_space=action_space, + observation_space=observation_space, + action_scaling=action_scaling, + action_bound_method=action_bound_method, + ) + if action_scaling: + try: + max_action = float(actor.max_action) + if np.isclose(max_action, 1.0): + warnings.warn( + "action_scaling and action_bound_method are only intended " + "to deal with unbounded model action space, but found actor model " + f"bound action space with max_action={actor.max_action}. " + "Consider using unbounded=True option of the actor model, " + "or set action_scaling to False and action_bound_method to None.", + ) + except BaseException: + pass + + self.actor = actor + self.dist_fn = dist_fn + self._eps = 1e-8 + self.deterministic_eval = deterministic_eval + + def forward( + self, + batch: ObsBatchProtocol, + state: dict | BatchProtocol | np.ndarray | None = None, + ) -> DistBatchProtocol: + """Compute action over the given batch data by applying the actor. + + Will sample from the dist_fn, if appropriate. + Returns a new object representing the processed batch data + (contrary to other methods that modify the input batch inplace). + """ + action_dist_input_BD, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) + # in the case that self.action_type == "discrete", the dist should always be Categorical, and D=A + # therefore action_dist_input_BD is equivalent to logits_BA + # If discrete, dist_fn will typically map loc, scale to a distribution (usually a Gaussian) + # the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked + dist = self.dist_fn(action_dist_input_BD) + + act_B = ( + dist.mode + if self.deterministic_eval and not self.is_within_training_step + else dist.sample() + ) + # act is of dimension BA in continuous case and of dimension B in discrete + result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist) + return cast(DistBatchProtocol, result) + + +class DiscreteActorPolicy(ProbabilisticActorPolicy): + def __init__( + self, + *, + actor: AbstractDiscreteActor | ActionReprNet, + dist_fn: TDistFnDiscrete = dist_fn_categorical_from_logits, + deterministic_eval: bool = False, + action_space: gym.Space, + observation_space: gym.Space | None = None, + ) -> None: + """ + :param actor: the actor network following the rules: (`s_B` -> `dist_input_BD`). + :param dist_fn: the function/type which creates a distribution from the actor output, + i.e. it maps the tensor(s) generated by the actor to a torch distribution. + For discrete action spaces, the output is typically a tensor of unnormalized + log probabilities ("logits" in PyTorch terminology) or a tensor of probabilities + which serve as the parameters of a Categorical distribution. + Note that if the actor uses softmax activation in its final layer, it will produce + probabilities, whereas if it uses no activation, it can be considered as producing + "logits". + As a user, you are responsible for ensuring that the distribution + is compatible with the output of the actor model and the action space. + :param deterministic_eval: flag indicating whether the policy should use deterministic + actions (using the mode of the action distribution) instead of stochastic ones + (using random sampling) during evaluation. + When enabled, the policy will always select the most probable action according to + the learned distribution during evaluation phases, while still using stochastic + sampling during training. This creates a clear distinction between exploration + (training) and exploitation (evaluation) behaviors. + Deterministic actions are generally preferred for final deployment and reproducible + evaluation as they provide consistent behavior, reduce variance in performance + metrics, and are more interpretable for human observers. + Note that this parameter only affects behavior when the policy is not within a + training step. When collecting rollouts for training, actions remain stochastic + regardless of this setting to maintain proper exploration behaviour. + :param action_space: the environment's (discrete) action space. + :param observation_space: the environment's observation space. + """ + if not isinstance(action_space, gym.spaces.Discrete): + raise ValueError(f"Action space must be an instance of Discrete; got {action_space}") + super().__init__( + actor=actor, + dist_fn=dist_fn, + deterministic_eval=deterministic_eval, + action_space=action_space, + observation_space=observation_space, + action_scaling=False, + action_bound_method=None, + ) + + +TActorPolicy = TypeVar("TActorPolicy", bound=ProbabilisticActorPolicy) + + +class DiscountedReturnComputation: + def __init__( + self, + gamma: float = 0.99, + return_standardization: bool = False, + ): + """ + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + :param return_standardization: whether to standardize episode returns + by subtracting the running mean and dividing by the running standard deviation. + Note that this is known to be detrimental to performance in many cases! + """ + assert 0.0 <= gamma <= 1.0, "discount factor gamma should be in [0, 1]" + self.gamma = gamma + self.return_standardization = return_standardization + self.ret_rms = RunningMeanStd() + self.eps = 1e-8 + + def add_discounted_returns( + self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray + ) -> BatchWithReturnsProtocol: + r"""Compute the discounted returns (Monte Carlo estimates) for each transition. + + They are added to the batch under the field `returns`. + Note: this function will modify the input batch! + + .. math:: + G_t = \sum_{i=t}^T \gamma^{i-t}r_i + + where :math:`T` is the terminal time step, :math:`\gamma` is the + discount factor, :math:`\gamma \in [0, 1]`. + + :param batch: a data batch which contains several episodes of data in + sequential order. Mind that the end of each finished episode of batch + should be marked by done flag, unfinished (or collecting) episodes will be + recognized by buffer.unfinished_index(). + :param buffer: the corresponding replay buffer. + :param indices: tell batch's location in buffer, batch is equal + to buffer[indices]. + """ + v_s_ = np.full(indices.shape, self.ret_rms.mean) + # gae_lambda = 1.0 means we use Monte Carlo estimate + unnormalized_returns, _ = Algorithm.compute_episodic_return( + batch, + buffer, + indices, + v_s_=v_s_, + gamma=self.gamma, + gae_lambda=1.0, + ) + if self.return_standardization: + batch.returns = (unnormalized_returns - self.ret_rms.mean) / np.sqrt( + self.ret_rms.var + self.eps, + ) + self.ret_rms.update(unnormalized_returns) + else: + batch.returns = unnormalized_returns + return cast(BatchWithReturnsProtocol, batch) + + +class Reinforce(OnPolicyAlgorithm[ProbabilisticActorPolicy]): + """Implementation of the REINFORCE (a.k.a. vanilla policy gradient) algorithm.""" + + def __init__( + self, + *, + policy: ProbabilisticActorPolicy, + gamma: float = 0.99, + return_standardization: bool = False, + optim: OptimizerFactory, + ) -> None: + """ + :param policy: the policy + :param optim: the optimizer factory for the policy's model. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + :param return_standardization: if True, will scale/standardize returns + by subtracting the running mean and dividing by the running standard deviation. + Can be detrimental to performance! + """ + super().__init__( + policy=policy, + ) + self.discounted_return_computation = DiscountedReturnComputation( + gamma=gamma, + return_standardization=return_standardization, + ) + self.optim = self._create_optimizer(self.policy, optim) + + def _preprocess_batch( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> BatchWithReturnsProtocol: + return self.discounted_return_computation.add_discounted_returns( + batch, + buffer, + indices, + ) + + # Needs BatchWithReturnsProtocol, which violates the substitution principle. But not a problem since it's a private method and + # the remainder of the class was adjusted to provide the correct batch + def _update_with_batch( # type: ignore[override] + self, + batch: BatchWithReturnsProtocol, + batch_size: int | None, + repeat: int, + ) -> LossSequenceTrainingStats: + losses = [] + split_batch_size = batch_size or -1 + for _ in range(repeat): + for minibatch in batch.split(split_batch_size, merge_last=True): + result = self.policy(minibatch) + dist = result.dist + act = to_torch_as(minibatch.act, result.act) + ret = to_torch(minibatch.returns, torch.float, result.act.device) + log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1) + loss = -(log_prob * ret).mean() + self.optim.step(loss) + losses.append(loss.item()) + + return LossSequenceTrainingStats(loss=SequenceSummaryStats.from_sequence(losses)) diff --git a/tianshou/algorithm/modelfree/sac.py b/tianshou/algorithm/modelfree/sac.py new file mode 100644 index 000000000..a63583445 --- /dev/null +++ b/tianshou/algorithm/modelfree/sac.py @@ -0,0 +1,336 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Generic, Literal, TypeVar, Union, cast + +import gymnasium as gym +import numpy as np +import torch +from torch.distributions import Independent, Normal + +from tianshou.algorithm.algorithm_base import TrainingStats +from tianshou.algorithm.modelfree.ddpg import ContinuousPolicyWithExplorationNoise +from tianshou.algorithm.modelfree.td3 import ActorDualCriticsOffPolicyAlgorithm +from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import Batch +from tianshou.data.types import ( + DistLogProbBatchProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) +from tianshou.exploration import BaseNoise +from tianshou.utils.conversion import to_optional_float +from tianshou.utils.net.continuous import ContinuousActorProbabilistic + + +def correct_log_prob_gaussian_tanh( + log_prob: torch.Tensor, + tanh_squashed_action: torch.Tensor, + eps: float = np.finfo(np.float32).eps.item(), +) -> torch.Tensor: + """Apply correction for Tanh squashing when computing `log_prob` from Gaussian. + + See equation 21 in the original `SAC paper `_. + + :param log_prob: log probability of the action + :param tanh_squashed_action: action squashed to values in (-1, 1) range by tanh + :param eps: epsilon for numerical stability + """ + log_prob_correction = torch.log(1 - tanh_squashed_action.pow(2) + eps).sum(-1, keepdim=True) + return log_prob - log_prob_correction + + +@dataclass(kw_only=True) +class SACTrainingStats(TrainingStats): + actor_loss: float + critic1_loss: float + critic2_loss: float + alpha: float | None = None + alpha_loss: float | None = None + + +TSACTrainingStats = TypeVar("TSACTrainingStats", bound=SACTrainingStats) + + +class SACPolicy(ContinuousPolicyWithExplorationNoise): + def __init__( + self, + *, + actor: torch.nn.Module | ContinuousActorProbabilistic, + exploration_noise: BaseNoise | Literal["default"] | None = None, + deterministic_eval: bool = True, + action_scaling: bool = True, + action_space: gym.Space, + observation_space: gym.Space | None = None, + ): + """ + :param actor: the actor network following the rules (s -> dist_input_BD) + :param exploration_noise: add noise to action for exploration. + This is useful when solving "hard exploration" problems. + "default" is equivalent to GaussianNoise(sigma=0.1). + :param deterministic_eval: flag indicating whether the policy should use deterministic + actions (using the mode of the action distribution) instead of stochastic ones + (using random sampling) during evaluation. + When enabled, the policy will always select the most probable action according to + the learned distribution during evaluation phases, while still using stochastic + sampling during training. This creates a clear distinction between exploration + (training) and exploitation (evaluation) behaviors. + Deterministic actions are generally preferred for final deployment and reproducible + evaluation as they provide consistent behavior, reduce variance in performance + metrics, and are more interpretable for human observers. + Note that this parameter only affects behavior when the policy is not within a + training step. When collecting rollouts for training, actions remain stochastic + regardless of this setting to maintain proper exploration behaviour. + :param action_scaling: flag indicating whether, for continuous action spaces, actions + should be scaled from the standard neural network output range [-1, 1] to the + environment's action space range [action_space.low, action_space.high]. + This applies to continuous action spaces only (gym.spaces.Box) and has no effect + for discrete spaces. + When enabled, policy outputs are expected to be in the normalized range [-1, 1] + (after bounding), and are then linearly transformed to the actual required range. + This improves neural network training stability, allows the same algorithm to work + across environments with different action ranges, and standardizes exploration + strategies. + Should be disabled if the actor model already produces outputs in the correct range. + :param action_space: the environment's action_space. + :param observation_space: the environment's observation space + """ + super().__init__( + exploration_noise=exploration_noise, + action_space=action_space, + observation_space=observation_space, + action_scaling=action_scaling, + # actions already squashed by tanh + action_bound_method=None, + ) + self.actor = actor + self.deterministic_eval = deterministic_eval + + def forward( # type: ignore + self, + batch: ObsBatchProtocol, + state: dict | Batch | np.ndarray | None = None, + **kwargs: Any, + ) -> DistLogProbBatchProtocol: + (loc_B, scale_B), hidden_BH = self.actor(batch.obs, state=state, info=batch.info) + dist = Independent(Normal(loc=loc_B, scale=scale_B), 1) + if self.deterministic_eval and not self.is_within_training_step: + act_B = dist.mode + else: + act_B = dist.rsample() + log_prob = dist.log_prob(act_B).unsqueeze(-1) + + squashed_action = torch.tanh(act_B) + log_prob = correct_log_prob_gaussian_tanh(log_prob, squashed_action) + result = Batch( + logits=(loc_B, scale_B), + act=squashed_action, + state=hidden_BH, + dist=dist, + log_prob=log_prob, + ) + return cast(DistLogProbBatchProtocol, result) + + +class Alpha(ABC): + """Defines the interface for the entropy regularization coefficient alpha.""" + + @staticmethod + def from_float_or_instance(alpha: Union[float, "Alpha"]) -> "Alpha": + if isinstance(alpha, float): + return FixedAlpha(alpha) + elif isinstance(alpha, Alpha): + return alpha + else: + raise ValueError(f"Expected float or Alpha instance, but got {alpha=}") + + @property + @abstractmethod + def value(self) -> float: + """Retrieves the current value of alpha.""" + + @abstractmethod + def update(self, entropy: torch.Tensor) -> float | None: + """ + Updates the alpha value based on the entropy. + + :param entropy: the entropy of the policy. + :return: the loss value if alpha is auto-tuned, otherwise None. + """ + return None + + +class FixedAlpha(Alpha): + """Represents a fixed entropy regularization coefficient alpha.""" + + def __init__(self, alpha: float): + self._value = alpha + + @property + def value(self) -> float: + return self._value + + def update(self, entropy: torch.Tensor) -> float | None: + return None + + +class AutoAlpha(torch.nn.Module, Alpha): + """Represents an entropy regularization coefficient alpha that is automatically tuned.""" + + def __init__(self, target_entropy: float, log_alpha: float, optim: OptimizerFactory): + """ + :param target_entropy: the target entropy value. + For discrete action spaces, it is usually `-log(|A|)` for a balance between stochasticity + and determinism or `-log(1/|A|)=log(|A|)` for maximum stochasticity or, more generally, + `lambda*log(|A|)`, e.g. with `lambda` close to 1 (e.g. 0.98) for pronounced stochasticity. + For continuous action spaces, it is usually `-dim(A)` for a balance between stochasticity + and determinism, with similar generalizations as for discrete action spaces. + :param log_alpha: the (initial) value of the log of the entropy regularization coefficient alpha. + :param optim: the factory with which to create the optimizer for `log_alpha`. + """ + super().__init__() + self._target_entropy = target_entropy + self._log_alpha = torch.nn.Parameter(torch.tensor(log_alpha)) + self._optim, lr_scheduler = optim.create_instances(self) + if lr_scheduler is not None: + raise ValueError( + f"Learning rate schedulers are not supported by {self.__class__.__name__}" + ) + + @property + def value(self) -> float: + return self._log_alpha.detach().exp().item() + + def update(self, entropy: torch.Tensor) -> float: + entropy_deficit = self._target_entropy - entropy + alpha_loss = -(self._log_alpha * entropy_deficit).mean() + self._optim.zero_grad() + alpha_loss.backward() + self._optim.step() + return alpha_loss.item() + + +class SAC( + ActorDualCriticsOffPolicyAlgorithm[SACPolicy, DistLogProbBatchProtocol], + Generic[TSACTrainingStats], +): + """Implementation of Soft Actor-Critic. arXiv:1812.05905.""" + + def __init__( + self, + *, + policy: SACPolicy, + policy_optim: OptimizerFactory, + critic: torch.nn.Module, + critic_optim: OptimizerFactory, + critic2: torch.nn.Module | None = None, + critic2_optim: OptimizerFactory | None = None, + tau: float = 0.005, + gamma: float = 0.99, + alpha: float | Alpha = 0.2, + n_step_return_horizon: int = 1, + deterministic_eval: bool = True, + ) -> None: + """ + :param policy: the policy + :param policy_optim: the optimizer factory for the policy's model. + :param critic: the first critic network. (s, a -> Q(s, a)) + :param critic_optim: the optimizer factory for the first critic network. + :param critic2: the second critic network. (s, a -> Q(s, a)). + If None, copy the first critic (via deepcopy). + :param critic2_optim: the optimizer factory for the second critic network. + If None, use the first critic's factory. + :param tau: the soft update coefficient for target networks, controlling the rate at which + target networks track the learned networks. + When the parameters of the target network are updated with the current (source) network's + parameters, a weighted average is used: target = tau * source + (1 - tau) * target. + Smaller values (closer to 0) create more stable but slower learning as target networks + change more gradually. Higher values (closer to 1) allow faster learning but may reduce + stability. + Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + :param alpha: the entropy regularization coefficient, which balances exploration and exploitation. + This coefficient controls how much the agent values randomness in its policy versus + pursuing higher rewards. + Higher values (e.g., 0.5-1.0) strongly encourage exploration by rewarding the agent + for maintaining diverse action choices, even if this means selecting some lower-value actions. + Lower values (e.g., 0.01-0.1) prioritize exploitation, allowing the policy to become + more focused on the highest-value actions. + A value of 0 would completely remove entropy regularization, potentially leading to + premature convergence to suboptimal deterministic policies. + Can be provided as a fixed float (0.2 is a reasonable default) or as an instance of, + in particular, class `AutoAlpha` for automatic tuning during training. + :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal + difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: + higher values reduce bias (by relying less on potentially inaccurate value estimates) + but increase variance (by incorporating more environmental stochasticity and reducing + the averaging effect). A value of 1 corresponds to standard TD learning with immediate + bootstrapping, while very large values approach Monte Carlo-like estimation that uses + complete episode returns. + """ + super().__init__( + policy=policy, + policy_optim=policy_optim, + critic=critic, + critic_optim=critic_optim, + critic2=critic2, + critic2_optim=critic2_optim, + tau=tau, + gamma=gamma, + n_step_return_horizon=n_step_return_horizon, + ) + self.deterministic_eval = deterministic_eval + self.alpha = Alpha.from_float_or_instance(alpha) + self._check_field_validity() + + def _check_field_validity(self) -> None: + if not isinstance(self.policy.action_space, gym.spaces.Box): + raise ValueError( + f"SACPolicy only supports gym.spaces.Box, but got {self.action_space=}." + f"Please use DiscreteSACPolicy for discrete action spaces.", + ) + + def _target_q_compute_value( + self, obs_batch: Batch, act_batch: DistLogProbBatchProtocol + ) -> torch.Tensor: + min_q_value = super()._target_q_compute_value(obs_batch, act_batch) + return min_q_value - self.alpha.value * act_batch.log_prob + + def _update_with_batch(self, batch: RolloutBatchProtocol) -> TSACTrainingStats: # type: ignore + # critic 1&2 + td1, critic1_loss = self._minimize_critic_squared_loss( + batch, self.critic, self.critic_optim + ) + td2, critic2_loss = self._minimize_critic_squared_loss( + batch, self.critic2, self.critic2_optim + ) + batch.weight = (td1 + td2) / 2.0 # prio-buffer + + # actor + obs_result = self.policy(batch) + act = obs_result.act + current_q1a = self.critic(batch.obs, act).flatten() + current_q2a = self.critic2(batch.obs, act).flatten() + actor_loss = ( + self.alpha.value * obs_result.log_prob.flatten() - torch.min(current_q1a, current_q2a) + ).mean() + self.policy_optim.step(actor_loss) + + # The entropy of a Gaussian policy can be expressed as -log_prob + a constant (which we ignore) + entropy = -obs_result.log_prob.detach() + alpha_loss = self.alpha.update(entropy) + + self._update_lagged_network_weights() + + return SACTrainingStats( # type: ignore[return-value] + actor_loss=actor_loss.item(), + critic1_loss=critic1_loss.item(), + critic2_loss=critic2_loss.item(), + alpha=to_optional_float(self.alpha.value), + alpha_loss=to_optional_float(alpha_loss), + ) diff --git a/tianshou/algorithm/modelfree/td3.py b/tianshou/algorithm/modelfree/td3.py new file mode 100644 index 000000000..4c273cc63 --- /dev/null +++ b/tianshou/algorithm/modelfree/td3.py @@ -0,0 +1,226 @@ +from abc import ABC +from copy import deepcopy +from dataclasses import dataclass +from typing import Any + +import torch + +from tianshou.algorithm.algorithm_base import ( + TPolicy, + TrainingStats, +) +from tianshou.algorithm.modelfree.ddpg import ( + ActorCriticOffPolicyAlgorithm, + ContinuousDeterministicPolicy, + TActBatchProtocol, +) +from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import Batch +from tianshou.data.types import ( + ActStateBatchProtocol, + RolloutBatchProtocol, +) + + +@dataclass(kw_only=True) +class TD3TrainingStats(TrainingStats): + actor_loss: float + critic1_loss: float + critic2_loss: float + + +class ActorDualCriticsOffPolicyAlgorithm( + ActorCriticOffPolicyAlgorithm[TPolicy, TActBatchProtocol], + ABC, +): + """A base class for off-policy algorithms with two critics, where the target Q-value is computed as the minimum + of the two lagged critics' values. + """ + + def __init__( + self, + *, + policy: Any, + policy_optim: OptimizerFactory, + critic: torch.nn.Module, + critic_optim: OptimizerFactory, + critic2: torch.nn.Module | None = None, + critic2_optim: OptimizerFactory | None = None, + tau: float = 0.005, + gamma: float = 0.99, + n_step_return_horizon: int = 1, + ) -> None: + """ + :param policy: the policy + :param policy_optim: the optimizer factory for the policy's model. + :param critic: the first critic network. + For continuous action spaces: (s, a -> Q(s, a)). + **NOTE**: The default implementation of `_target_q_compute_value` assumes + a continuous action space; override this method if using discrete actions. + :param critic_optim: the optimizer factory for the first critic network. + :param critic2: the second critic network (analogous functionality to the first). + If None, copy the first critic (via deepcopy). + :param critic2_optim: the optimizer factory for the second critic network. + If None, use the first critic's factory. + :param tau: the soft update coefficient for target networks, controlling the rate at which + target networks track the learned networks. + When the parameters of the target network are updated with the current (source) network's + parameters, a weighted average is used: target = tau * source + (1 - tau) * target. + Smaller values (closer to 0) create more stable but slower learning as target networks + change more gradually. Higher values (closer to 1) allow faster learning but may reduce + stability. + Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + """ + super().__init__( + policy=policy, + policy_optim=policy_optim, + critic=critic, + critic_optim=critic_optim, + tau=tau, + gamma=gamma, + n_step_return_horizon=n_step_return_horizon, + ) + self.critic2 = critic2 or deepcopy(critic) + self.critic2_old = self._add_lagged_network(self.critic2) + self.critic2_optim = self._create_optimizer(self.critic2, critic2_optim or critic_optim) + + def _target_q_compute_value( + self, obs_batch: Batch, act_batch: TActBatchProtocol + ) -> torch.Tensor: + # compute the Q-value as the minimum of the two lagged critics + act = act_batch.act + return torch.min( + self.critic_old(obs_batch.obs, act), + self.critic2_old(obs_batch.obs, act), + ) + + +class TD3( + ActorDualCriticsOffPolicyAlgorithm[ContinuousDeterministicPolicy, ActStateBatchProtocol], +): + """Implementation of TD3, arXiv:1802.09477.""" + + def __init__( + self, + *, + policy: ContinuousDeterministicPolicy, + policy_optim: OptimizerFactory, + critic: torch.nn.Module, + critic_optim: OptimizerFactory, + critic2: torch.nn.Module | None = None, + critic2_optim: OptimizerFactory | None = None, + tau: float = 0.005, + gamma: float = 0.99, + policy_noise: float = 0.2, + update_actor_freq: int = 2, + noise_clip: float = 0.5, + n_step_return_horizon: int = 1, + ) -> None: + """ + :param policy: the policy + :param policy_optim: the optimizer factory for the policy's model. + :param critic: the first critic network. (s, a -> Q(s, a)) + :param critic_optim: the optimizer factory for the first critic network. + :param critic2: the second critic network. (s, a -> Q(s, a)). + If None, copy the first critic (via deepcopy). + :param critic2_optim: the optimizer factory for the second critic network. + If None, use the first critic's factory. + :param tau: the soft update coefficient for target networks, controlling the rate at which + target networks track the learned networks. + When the parameters of the target network are updated with the current (source) network's + parameters, a weighted average is used: target = tau * source + (1 - tau) * target. + Smaller values (closer to 0) create more stable but slower learning as target networks + change more gradually. Higher values (closer to 1) allow faster learning but may reduce + stability. + Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + :param policy_noise: scaling factor for the Gaussian noise added to target policy actions. + This parameter implements target policy smoothing, a regularization technique described in the TD3 paper. + The noise is sampled from a normal distribution and multiplied by this value before being added to actions. + Higher values increase exploration in the target policy, helping to address function approximation error. + The added noise is optionally clipped to a range determined by the noise_clip parameter. + Typically set between 0.1 and 0.5 relative to the action scale of the environment. + :param update_actor_freq: the frequency of actor network updates relative to critic network updates + (the actor network is only updated once for every `update_actor_freq` critic updates). + This implements the "delayed" policy updates from the TD3 algorithm, where the actor is + updated less frequently than the critics. + Higher values (e.g., 2-5) help stabilize training by allowing the critic to become more + accurate before updating the policy. + The default value of 2 follows the original TD3 paper's recommendation of updating the + policy at half the rate of the Q-functions. + :param noise_clip: defines the maximum absolute value of the noise added to target policy actions, i.e. noise values + are clipped to the range [-noise_clip, noise_clip] (after generating and scaling the noise + via `policy_noise`). + This parameter implements bounded target policy smoothing as described in the TD3 paper. + It prevents extreme noise values from causing unrealistic target values during training. + Setting it 0.0 (or a negative value) disables clipping entirely. + It is typically set to about twice the `policy_noise` value (e.g. 0.5 when `policy_noise` is 0.2). + """ + super().__init__( + policy=policy, + policy_optim=policy_optim, + critic=critic, + critic_optim=critic_optim, + critic2=critic2, + critic2_optim=critic2_optim, + tau=tau, + gamma=gamma, + n_step_return_horizon=n_step_return_horizon, + ) + self.actor_old = self._add_lagged_network(self.policy.actor) + self.policy_noise = policy_noise + self.update_actor_freq = update_actor_freq + self.noise_clip = noise_clip + self._cnt = 0 + self._last = 0 + + def _target_q_compute_action(self, obs_batch: Batch) -> ActStateBatchProtocol: + # compute action using lagged actor + act_batch = self.policy(obs_batch, model=self.actor_old) + act_ = act_batch.act + + # add noise + noise = torch.randn(size=act_.shape, device=act_.device) * self.policy_noise + if self.noise_clip > 0.0: + noise = noise.clamp(-self.noise_clip, self.noise_clip) + act_ += noise + + act_batch.act = act_ + return act_batch + + def _update_with_batch(self, batch: RolloutBatchProtocol) -> TD3TrainingStats: + # critic 1&2 + td1, critic1_loss = self._minimize_critic_squared_loss( + batch, self.critic, self.critic_optim + ) + td2, critic2_loss = self._minimize_critic_squared_loss( + batch, self.critic2, self.critic2_optim + ) + batch.weight = (td1 + td2) / 2.0 # prio-buffer + + # actor + if self._cnt % self.update_actor_freq == 0: + actor_loss = -self.critic(batch.obs, self.policy(batch, eps=0.0).act).mean() + self._last = actor_loss.item() + self.policy_optim.step(actor_loss) + self._update_lagged_network_weights() + self._cnt += 1 + + return TD3TrainingStats( + actor_loss=self._last, + critic1_loss=critic1_loss.item(), + critic2_loss=critic2_loss.item(), + ) diff --git a/tianshou/algorithm/modelfree/trpo.py b/tianshou/algorithm/modelfree/trpo.py new file mode 100644 index 000000000..450fdde54 --- /dev/null +++ b/tianshou/algorithm/modelfree/trpo.py @@ -0,0 +1,214 @@ +import warnings +from dataclasses import dataclass + +import torch +import torch.nn.functional as F +from torch.distributions import kl_divergence + +from tianshou.algorithm import NPG +from tianshou.algorithm.modelfree.npg import NPGTrainingStats +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy +from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import SequenceSummaryStats +from tianshou.data.types import BatchWithAdvantagesProtocol +from tianshou.utils.net.continuous import ContinuousCritic +from tianshou.utils.net.discrete import DiscreteCritic + + +@dataclass(kw_only=True) +class TRPOTrainingStats(NPGTrainingStats): + step_size: SequenceSummaryStats + + +class TRPO(NPG): + """Implementation of Trust Region Policy Optimization. arXiv:1502.05477.""" + + def __init__( + self, + *, + policy: ProbabilisticActorPolicy, + critic: torch.nn.Module | ContinuousCritic | DiscreteCritic, + optim: OptimizerFactory, + max_kl: float = 0.01, + backtrack_coeff: float = 0.8, + max_backtracks: int = 10, + optim_critic_iters: int = 5, + trust_region_size: float = 0.5, + advantage_normalization: bool = True, + gae_lambda: float = 0.95, + max_batchsize: int = 256, + gamma: float = 0.99, + return_scaling: bool = False, + ) -> None: + """ + :param policy: the policy + :param critic: the critic network. (s -> V(s)) + :param optim: the optimizer factory for the critic network. + :param max_kl: max kl-divergence used to constrain each actor network update. + :param backtrack_coeff: Coefficient to be multiplied by step size when + constraints are not met. + :param max_backtracks: Max number of backtracking times in linesearch. + :param optim_critic_iters: the number of optimization steps performed on the critic network + for each policy (actor) update. + Controls the learning rate balance between critic and actor. + Higher values prioritize critic accuracy by training the value function more + extensively before each policy update, which can improve stability but slow down + training. Lower values maintain a more even learning pace between policy and value + function but may lead to less reliable advantage estimates. + Typically set between 1 and 10, depending on the complexity of the value function. + :param trust_region_size: the parameter delta - a scalar multiplier for policy updates in the natural gradient direction. + The mathematical meaning is the trust region size, which is the maximum KL divergence + allowed between the old and new policy distributions. + Controls how far the policy parameters move in the calculated direction + during each update. Higher values allow for faster learning but may cause instability + or policy deterioration; lower values provide more stable but slower learning. Unlike + regular policy gradients, natural gradients already account for the local geometry of + the parameter space, making this step size more robust to different parameterizations. + Typically set between 0.1 and 1.0 for most reinforcement learning tasks. + :param advantage_normalization: whether to do per mini-batch advantage + normalization. + :param gae_lambda: the lambda parameter in [0, 1] for generalized advantage estimation (GAE). + Controls the bias-variance tradeoff in advantage estimates, acting as a + weighting factor for combining different n-step advantage estimators. Higher values + (closer to 1) reduce bias but increase variance by giving more weight to longer + trajectories, while lower values (closer to 0) reduce variance but increase bias + by relying more on the immediate TD error and value function estimates. At λ=0, + GAE becomes equivalent to the one-step TD error (high bias, low variance); at λ=1, + it becomes equivalent to Monte Carlo advantage estimation (low bias, high variance). + Intermediate values create a weighted average of n-step returns, with exponentially + decaying weights for longer-horizon returns. Typically set between 0.9 and 0.99 for + most policy gradient methods. + :param max_batchsize: the maximum number of samples to process at once when computing + generalized advantage estimation (GAE) and value function predictions. + Controls memory usage by breaking large batches into smaller chunks processed sequentially. + Higher values may increase speed but require more GPU/CPU memory; lower values + reduce memory requirements but may increase computation time. Should be adjusted + based on available hardware resources and total batch size of your training data. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + :param return_scaling: flag indicating whether to enable scaling of estimated returns by + dividing them by their running standard deviation without centering the mean. + This reduces the magnitude variation of advantages across different episodes while + preserving their signs and relative ordering. + The use of running statistics (rather than batch-specific scaling) means that early + training experiences may be scaled differently than later ones as the statistics evolve. + When enabled, this improves training stability in environments with highly variable + reward scales and makes the algorithm less sensitive to learning rate settings. + However, it may reduce the algorithm's ability to distinguish between episodes with + different absolute return magnitudes. + Best used in environments where the relative ordering of actions is more important + than the absolute scale of returns. + """ + super().__init__( + policy=policy, + critic=critic, + optim=optim, + optim_critic_iters=optim_critic_iters, + trust_region_size=trust_region_size, + advantage_normalization=advantage_normalization, + gae_lambda=gae_lambda, + max_batchsize=max_batchsize, + gamma=gamma, + return_scaling=return_scaling, + ) + self.max_backtracks = max_backtracks + self.max_kl = max_kl + self.backtrack_coeff = backtrack_coeff + + def _update_with_batch( # type: ignore[override] + self, + batch: BatchWithAdvantagesProtocol, + batch_size: int | None, + repeat: int, + ) -> TRPOTrainingStats: + actor_losses, vf_losses, step_sizes, kls = [], [], [], [] + split_batch_size = batch_size or -1 + for _ in range(repeat): + for minibatch in batch.split(split_batch_size, merge_last=True): + # optimize actor + # direction: calculate villia gradient + dist = self.policy(minibatch).dist + ratio = (dist.log_prob(minibatch.act) - minibatch.logp_old).exp().float() + ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) + actor_loss = -(ratio * minibatch.adv).mean() + flat_grads = self._get_flat_grad( + actor_loss, self.policy.actor, retain_graph=True + ).detach() + + # direction: calculate natural gradient + with torch.no_grad(): + old_dist = self.policy(minibatch).dist + + kl = kl_divergence(old_dist, dist).mean() + # calculate first order gradient of kl with respect to theta + flat_kl_grad = self._get_flat_grad(kl, self.policy.actor, create_graph=True) + search_direction = -self._conjugate_gradients(flat_grads, flat_kl_grad, nsteps=10) + + # stepsize: calculate max stepsize constrained by kl bound + step_size = torch.sqrt( + 2 + * self.max_kl + / (search_direction * self._MVP(search_direction, flat_kl_grad)).sum( + 0, + keepdim=True, + ), + ) + + # stepsize: linesearch stepsize + with torch.no_grad(): + flat_params = torch.cat( + [param.data.view(-1) for param in self.policy.actor.parameters()], + ) + for i in range(self.max_backtracks): + new_flat_params = flat_params + step_size * search_direction + self._set_from_flat_params(self.policy.actor, new_flat_params) + # calculate kl and if in bound, loss actually down + new_dist = self.policy(minibatch).dist + new_dratio = ( + (new_dist.log_prob(minibatch.act) - minibatch.logp_old).exp().float() + ) + new_dratio = new_dratio.reshape(new_dratio.size(0), -1).transpose(0, 1) + new_actor_loss = -(new_dratio * minibatch.adv).mean() + kl = kl_divergence(old_dist, new_dist).mean() + + if kl < self.max_kl and new_actor_loss < actor_loss: + if i > 0: + warnings.warn(f"Backtracking to step {i}.") + break + if i < self.max_backtracks - 1: + step_size = step_size * self.backtrack_coeff + else: + self._set_from_flat_params(self.policy.actor, new_flat_params) + step_size = torch.tensor([0.0]) + warnings.warn( + "Line search failed! It seems hyperparamters" + " are poor and need to be changed.", + ) + + # optimize critic + for _ in range(self.optim_critic_iters): + value = self.critic(minibatch.obs).flatten() + vf_loss = F.mse_loss(minibatch.returns, value) + self.optim.step(vf_loss) + + actor_losses.append(actor_loss.item()) + vf_losses.append(vf_loss.item()) + step_sizes.append(step_size.item()) + kls.append(kl.item()) + + actor_loss_summary_stat = SequenceSummaryStats.from_sequence(actor_losses) + vf_loss_summary_stat = SequenceSummaryStats.from_sequence(vf_losses) + kl_summary_stat = SequenceSummaryStats.from_sequence(kls) + step_size_stat = SequenceSummaryStats.from_sequence(step_sizes) + + return TRPOTrainingStats( + actor_loss=actor_loss_summary_stat, + vf_loss=vf_loss_summary_stat, + kl=kl_summary_stat, + step_size=step_size_stat, + ) diff --git a/tianshou/policy/multiagent/__init__.py b/tianshou/algorithm/multiagent/__init__.py similarity index 100% rename from tianshou/policy/multiagent/__init__.py rename to tianshou/algorithm/multiagent/__init__.py diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/algorithm/multiagent/marl.py similarity index 61% rename from tianshou/policy/multiagent/mapolicy.py rename to tianshou/algorithm/multiagent/marl.py index 05cc8db8f..1c30a1cbc 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/algorithm/multiagent/marl.py @@ -1,13 +1,21 @@ -from typing import Any, Literal, Protocol, Self, TypeVar, cast, overload +from collections.abc import Callable +from typing import Any, Generic, Literal, Protocol, Self, TypeVar, cast, overload import numpy as np from overrides import override +from sensai.util.helper import mark_used +from torch.nn import ModuleList +from tianshou.algorithm import Algorithm +from tianshou.algorithm.algorithm_base import ( + OffPolicyAlgorithm, + OnPolicyAlgorithm, + Policy, + TrainingStats, +) from tianshou.data import Batch, ReplayBuffer from tianshou.data.batch import BatchProtocol, IndexType from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol -from tianshou.policy import BasePolicy -from tianshou.policy.base import TLearningRateScheduler, TrainingStats try: from tianshou.env.pettingzoo_env import PettingZooEnv @@ -15,6 +23,9 @@ PettingZooEnv = None # type: ignore +mark_used(ActBatchProtocol) + + class MapTrainingStats(TrainingStats): def __init__( self, @@ -63,106 +74,21 @@ def __getitem__(self, index: str | IndexType) -> Any: ... -class MultiAgentPolicyManager(BasePolicy): - """Multi-agent policy manager for MARL. - - This multi-agent policy manager accepts a list of - :class:`~tianshou.policy.BasePolicy`. It dispatches the batch data to each - of these policies when the "forward" is called. The same as "process_fn" - and "learn": it splits the data and feeds them to each policy. A figure in - :ref:`marl_example` can help you better understand this procedure. - - :param policies: a list of policies. - :param env: a PettingZooEnv. - :param action_scaling: if True, scale the action from [-1, 1] to the range - of action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - Only used if the action_space is continuous. - :param lr_scheduler: if not None, will be called in `policy.update()`. - """ - - def __init__( - self, - *, - policies: list[BasePolicy], - # TODO: 1 why restrict to PettingZooEnv? - # TODO: 2 This is the only policy that takes an env in init, is it really needed? - env: PettingZooEnv, - action_scaling: bool = False, - action_bound_method: Literal["clip", "tanh"] | None = "clip", - lr_scheduler: TLearningRateScheduler | None = None, - ) -> None: +class MultiAgentPolicy(Policy): + def __init__(self, policies: dict[str | int, Policy]): + p0 = next(iter(policies.values())) super().__init__( - action_space=env.action_space, - observation_space=env.observation_space, - action_scaling=action_scaling, - action_bound_method=action_bound_method, - lr_scheduler=lr_scheduler, + action_space=p0.action_space, + observation_space=p0.observation_space, + action_scaling=False, + action_bound_method=None, ) - assert len(policies) == len(env.agents), "One policy must be assigned for each agent." - - self.agent_idx = env.agent_idx - for i, policy in enumerate(policies): - # agent_id 0 is reserved for the environment proxy - # (this MultiAgentPolicyManager) - policy.set_agent_id(env.agents[i]) - - self.policies: dict[str | int, BasePolicy] = dict(zip(env.agents, policies, strict=True)) - """Maps agent_id to policy.""" - - # TODO: unused - remove it? - def replace_policy(self, policy: BasePolicy, agent_id: int) -> None: - """Replace the "agent_id"th policy in this manager.""" - policy.set_agent_id(agent_id) - self.policies[agent_id] = policy - - # TODO: violates Liskov substitution principle - def process_fn( # type: ignore - self, - batch: MAPRolloutBatchProtocol, - buffer: ReplayBuffer, - indice: np.ndarray, - ) -> MAPRolloutBatchProtocol: - """Dispatch batch data from `obs.agent_id` to every policy's process_fn. - - Save original multi-dimensional rew in "save_rew", set rew to the - reward of each agent during their "process_fn", and restore the - original reward afterwards. - """ - # TODO: maybe only str is actually allowed as agent_id? See MAPRolloutBatchProtocol - results: dict[str | int, RolloutBatchProtocol] = {} - assert isinstance( - batch.obs, - BatchProtocol, - ), f"here only observations of type Batch are permitted, but got {type(batch.obs)}" - # reward can be empty Batch (after initial reset) or nparray. - has_rew = isinstance(buffer.rew, np.ndarray) - if has_rew: # save the original reward in save_rew - # Since we do not override buffer.__setattr__, here we use _meta to - # change buffer.rew, otherwise buffer.rew = Batch() has no effect. - save_rew, buffer._meta.rew = buffer.rew, Batch() # type: ignore - for agent, policy in self.policies.items(): - agent_index = np.nonzero(batch.obs.agent_id == agent)[0] - if len(agent_index) == 0: - results[agent] = cast(RolloutBatchProtocol, Batch()) - continue - tmp_batch, tmp_indice = batch[agent_index], indice[agent_index] - if has_rew: - tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent]] - buffer._meta.rew = save_rew[:, self.agent_idx[agent]] - if not hasattr(tmp_batch.obs, "mask"): - if hasattr(tmp_batch.obs, "obs"): - tmp_batch.obs = tmp_batch.obs.obs - if hasattr(tmp_batch.obs_next, "obs"): - tmp_batch.obs_next = tmp_batch.obs_next.obs - results[agent] = policy.process_fn(tmp_batch, buffer, tmp_indice) - if has_rew: # restore from save_rew - buffer._meta.rew = save_rew - return cast(MAPRolloutBatchProtocol, Batch(results)) + self.policies = policies + self._submodules = ModuleList(policies.values()) _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") - def exploration_noise( + def add_exploration_noise( self, act: _TArrOrActBatch, batch: ObsBatchProtocol, @@ -176,7 +102,7 @@ def exploration_noise( agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0] if len(agent_index) == 0: continue - act[agent_index] = policy.exploration_noise(act[agent_index], batch[agent_index]) + act[agent_index] = policy.add_exploration_noise(act[agent_index], batch[agent_index]) return act def forward( # type: ignore @@ -258,29 +184,170 @@ def forward( # type: ignore holder["state"] = state_dict return holder - # Violates Liskov substitution principle - def learn( # type: ignore + +TAlgorithm = TypeVar("TAlgorithm", bound=Algorithm) + + +class MARLDispatcher(Generic[TAlgorithm]): + """ + Supports multi-agent learning by dispatching calls to the corresponding + algorithm for each agent. + """ + + def __init__(self, algorithms: list[TAlgorithm], env: PettingZooEnv): + agent_ids = env.agents + assert len(algorithms) == len(agent_ids), "One policy must be assigned for each agent." + self.algorithms: dict[str | int, TAlgorithm] = dict(zip(agent_ids, algorithms, strict=True)) + """maps agent_id to the corresponding algorithm.""" + self.agent_idx = env.agent_idx + """maps agent_id to 0-based index.""" + + def create_policy(self) -> MultiAgentPolicy: + return MultiAgentPolicy({agent_id: a.policy for agent_id, a in self.algorithms.items()}) + + def dispatch_process_fn( self, batch: MAPRolloutBatchProtocol, - *args: Any, - **kwargs: Any, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> MAPRolloutBatchProtocol: + """Dispatch batch data from `obs.agent_id` to every algorithm's processing function. + + Save original multi-dimensional rew in "save_rew", set rew to the + reward of each agent during their "process_fn", and restore the + original reward afterwards. + """ + # TODO: maybe only str is actually allowed as agent_id? See MAPRolloutBatchProtocol + results: dict[str | int, RolloutBatchProtocol] = {} + assert isinstance( + batch.obs, + BatchProtocol, + ), f"here only observations of type Batch are permitted, but got {type(batch.obs)}" + # reward can be empty Batch (after initial reset) or nparray. + has_rew = isinstance(buffer.rew, np.ndarray) + if has_rew: # save the original reward in save_rew + # Since we do not override buffer.__setattr__, here we use _meta to + # change buffer.rew, otherwise buffer.rew = Batch() has no effect. + save_rew, buffer._meta.rew = buffer.rew, Batch() # type: ignore + for agent, algorithm in self.algorithms.items(): + agent_index = np.nonzero(batch.obs.agent_id == agent)[0] + if len(agent_index) == 0: + results[agent] = cast(RolloutBatchProtocol, Batch()) + continue + tmp_batch, tmp_indice = batch[agent_index], indices[agent_index] + if has_rew: + tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent]] + buffer._meta.rew = save_rew[:, self.agent_idx[agent]] + if not hasattr(tmp_batch.obs, "mask"): + if hasattr(tmp_batch.obs, "obs"): + tmp_batch.obs = tmp_batch.obs.obs + if hasattr(tmp_batch.obs_next, "obs"): + tmp_batch.obs_next = tmp_batch.obs_next.obs + results[agent] = algorithm._preprocess_batch(tmp_batch, buffer, tmp_indice) + if has_rew: # restore from save_rew + buffer._meta.rew = save_rew + return cast(MAPRolloutBatchProtocol, Batch(results)) + + def dispatch_update_with_batch( + self, + batch: MAPRolloutBatchProtocol, + algorithm_update_with_batch_fn: Callable[[TAlgorithm, RolloutBatchProtocol], TrainingStats], ) -> MapTrainingStats: - """Dispatch the data to all policies for learning. + """Dispatch the respective subset of the batch data to each algorithm. :param batch: must map agent_ids to rollout batches + :param algorithm_update_with_batch_fn: a function that performs the algorithm-specific + update with the given agent-specific batch data """ agent_id_to_stats = {} - for agent_id, policy in self.policies.items(): + for agent_id, algorithm in self.algorithms.items(): data = batch[agent_id] if len(data.get_keys()) != 0: - train_stats = policy.learn(batch=data, **kwargs) + train_stats = algorithm_update_with_batch_fn(algorithm, data) agent_id_to_stats[agent_id] = train_stats return MapTrainingStats(agent_id_to_stats) - # Need a train method that set all sub-policies to train mode. - # No need for a similar eval function, as eval internally uses the train function. - def train(self, mode: bool = True) -> Self: - """Set each internal policy in training mode.""" - for policy in self.policies.values(): - policy.train(mode) - return self + +class MultiAgentOffPolicyAlgorithm(OffPolicyAlgorithm[MultiAgentPolicy]): + """Multi-agent reinforcement learning where each agent uses off-policy learning.""" + + def __init__( + self, + *, + algorithms: list[OffPolicyAlgorithm], + env: PettingZooEnv, + ) -> None: + """ + :param algorithms: a list of off-policy algorithms. + :param env: the multi-agent RL environment + """ + self._dispatcher: MARLDispatcher[OffPolicyAlgorithm] = MARLDispatcher(algorithms, env) + super().__init__( + policy=self._dispatcher.create_policy(), + ) + self._submodules = ModuleList(algorithms) + + def get_algorithm(self, agent_id: str | int) -> OffPolicyAlgorithm: + return self._dispatcher.algorithms[agent_id] + + def _preprocess_batch( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> RolloutBatchProtocol: + batch = cast(MAPRolloutBatchProtocol, batch) + return self._dispatcher.dispatch_process_fn(batch, buffer, indices) + + def _update_with_batch( + self, + batch: RolloutBatchProtocol, + ) -> MapTrainingStats: + batch = cast(MAPRolloutBatchProtocol, batch) + + def update(algorithm: OffPolicyAlgorithm, data: RolloutBatchProtocol) -> TrainingStats: + return algorithm._update_with_batch(data) + + return self._dispatcher.dispatch_update_with_batch(batch, update) + + +class MultiAgentOnPolicyAlgorithm(OnPolicyAlgorithm[MultiAgentPolicy]): + """Multi-agent reinforcement learning where each agent uses on-policy learning.""" + + def __init__( + self, + *, + algorithms: list[OnPolicyAlgorithm], + env: PettingZooEnv, + ) -> None: + """ + :param algorithms: a list of off-policy algorithms. + :param env: the multi-agent RL environment + """ + self._dispatcher: MARLDispatcher[OnPolicyAlgorithm] = MARLDispatcher(algorithms, env) + super().__init__( + policy=self._dispatcher.create_policy(), + ) + self._submodules = ModuleList(algorithms) + + def get_algorithm(self, agent_id: str | int) -> OnPolicyAlgorithm: + return self._dispatcher.algorithms[agent_id] + + def _preprocess_batch( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> RolloutBatchProtocol: + batch = cast(MAPRolloutBatchProtocol, batch) + return self._dispatcher.dispatch_process_fn(batch, buffer, indices) + + def _update_with_batch( + self, batch: RolloutBatchProtocol, batch_size: int | None, repeat: int + ) -> MapTrainingStats: + batch = cast(MAPRolloutBatchProtocol, batch) + + def update(algorithm: OnPolicyAlgorithm, data: RolloutBatchProtocol) -> TrainingStats: + return algorithm._update_with_batch(data, batch_size, repeat) + + return self._dispatcher.dispatch_update_with_batch(batch, update) diff --git a/tianshou/algorithm/optim.py b/tianshou/algorithm/optim.py new file mode 100644 index 000000000..c802f95d4 --- /dev/null +++ b/tianshou/algorithm/optim.py @@ -0,0 +1,140 @@ +from abc import ABC, abstractmethod +from collections.abc import Callable, Iterable +from typing import Any, Self, TypeAlias + +import numpy as np +import torch +from sensai.util.string import ToStringMixin +from torch.optim import Adam, RMSprop +from torch.optim.lr_scheduler import LambdaLR, LRScheduler + +ParamsType: TypeAlias = Iterable[torch.Tensor] | Iterable[dict[str, Any]] + + +class LRSchedulerFactory(ToStringMixin, ABC): + """Factory for the creation of a learning rate scheduler.""" + + @abstractmethod + def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: + pass + + +class LRSchedulerFactoryLinear(LRSchedulerFactory): + """ + Factory for a learning rate scheduler where the learning rate linearly decays towards + zero for the given trainer parameters. + """ + + def __init__(self, max_epochs: int, epoch_num_steps: int, collection_step_num_env_steps: int): + self.num_epochs = max_epochs + self.epoch_num_steps = epoch_num_steps + self.collection_step_num_env_steps = collection_step_num_env_steps + + def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: + return LambdaLR(optim, lr_lambda=self._LRLambda(self).compute) + + class _LRLambda: + def __init__(self, parent: "LRSchedulerFactoryLinear"): + self.max_update_num = ( + np.ceil(parent.epoch_num_steps / parent.collection_step_num_env_steps) + * parent.num_epochs + ) + + def compute(self, epoch: int) -> float: + return 1.0 - epoch / self.max_update_num + + +class OptimizerFactory(ABC, ToStringMixin): + def __init__(self) -> None: + self.lr_scheduler_factory: LRSchedulerFactory | None = None + + def with_lr_scheduler_factory(self, lr_scheduler_factory: LRSchedulerFactory) -> Self: + self.lr_scheduler_factory = lr_scheduler_factory + return self + + def create_instances( + self, + module: torch.nn.Module, + ) -> tuple[torch.optim.Optimizer, LRScheduler | None]: + optimizer = self._create_optimizer_for_params(module.parameters()) + lr_scheduler = None + if self.lr_scheduler_factory is not None: + lr_scheduler = self.lr_scheduler_factory.create_scheduler(optimizer) + return optimizer, lr_scheduler + + @abstractmethod + def _create_optimizer_for_params(self, params: ParamsType) -> torch.optim.Optimizer: + pass + + +class TorchOptimizerFactory(OptimizerFactory): + """General factory for arbitrary torch optimizers.""" + + def __init__(self, optim_class: Callable[..., torch.optim.Optimizer], **kwargs: Any): + """ + + :param optim_class: the optimizer class (e.g. subclass of `torch.optim.Optimizer`), + which will be passed the module parameters, the learning rate as `lr` and the + kwargs provided. + :param kwargs: keyword arguments to provide at optimizer construction + """ + super().__init__() + self.optim_class = optim_class + self.kwargs = kwargs + + def _create_optimizer_for_params(self, params: ParamsType) -> torch.optim.Optimizer: + return self.optim_class(params, **self.kwargs) + + +class AdamOptimizerFactory(OptimizerFactory): + def __init__( + self, + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-08, + weight_decay: float = 0, + ): + super().__init__() + self.lr = lr + self.weight_decay = weight_decay + self.eps = eps + self.betas = betas + + def _create_optimizer_for_params(self, params: ParamsType) -> torch.optim.Optimizer: + return Adam( + params, + lr=self.lr, + betas=self.betas, + eps=self.eps, + weight_decay=self.weight_decay, + ) + + +class RMSpropOptimizerFactory(OptimizerFactory): + def __init__( + self, + lr: float = 1e-2, + alpha: float = 0.99, + eps: float = 1e-08, + weight_decay: float = 0, + momentum: float = 0, + centered: bool = False, + ): + super().__init__() + self.lr = lr + self.alpha = alpha + self.momentum = momentum + self.centered = centered + self.weight_decay = weight_decay + self.eps = eps + + def _create_optimizer_for_params(self, params: ParamsType) -> torch.optim.Optimizer: + return RMSprop( + params, + lr=self.lr, + alpha=self.alpha, + eps=self.eps, + weight_decay=self.weight_decay, + momentum=self.momentum, + centered=self.centered, + ) diff --git a/tianshou/algorithm/random.py b/tianshou/algorithm/random.py new file mode 100644 index 000000000..2d66040fa --- /dev/null +++ b/tianshou/algorithm/random.py @@ -0,0 +1,60 @@ +from typing import cast + +import gymnasium as gym +import numpy as np + +from tianshou.algorithm.algorithm_base import OffPolicyAlgorithm, TrainingStats +from tianshou.algorithm.algorithm_base import Policy as BasePolicy +from tianshou.data import Batch +from tianshou.data.batch import BatchProtocol +from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol + + +class MARLRandomTrainingStats(TrainingStats): + pass + + +class MARLRandomDiscreteMaskedOffPolicyAlgorithm(OffPolicyAlgorithm): + """A random agent used in multi-agent learning. + + It randomly chooses an action from the legal actions (according to the given mask). + """ + + class Policy(BasePolicy): + """A random agent used in multi-agent learning. + + It randomly chooses an action from the legal actions. + """ + + def __init__(self, action_space: gym.spaces.Space) -> None: + super().__init__(action_space=action_space) + + def forward( + self, + batch: ObsBatchProtocol, + state: dict | BatchProtocol | np.ndarray | None = None, + **kwargs: dict, + ) -> ActBatchProtocol: + """Compute the random action over the given batch data. + + The input should contain a mask in batch.obs, with "True" to be + available and "False" to be unavailable. For example, + ``batch.obs.mask == np.array([[False, True, False]])`` means with batch + size 1, action "1" is available but action "0" and "2" are unavailable. + + :return: A :class:`~tianshou.data.Batch` with "act" key, containing + the random action. + """ + mask = batch.obs.mask # type: ignore + logits = np.random.rand(*mask.shape) + logits[~mask] = -np.inf + result = Batch(act=logits.argmax(axis=-1)) + return cast(ActBatchProtocol, result) + + def __init__(self, action_space: gym.spaces.Space) -> None: + """:param action_space: the environment's action space.""" + super().__init__(policy=self.Policy(action_space)) + + def _update_with_batch(self, batch: RolloutBatchProtocol) -> MARLRandomTrainingStats: # type: ignore + """Since a random agent learns nothing, it returns an empty dict.""" + return MARLRandomTrainingStats() diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index c84c2ec7d..7e1d5298d 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -4,7 +4,7 @@ from tianshou.data.batch import Batch from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as from tianshou.data.utils.segtree import SegmentTree -from tianshou.data.buffer.base import ReplayBuffer +from tianshou.data.buffer.buffer_base import ReplayBuffer from tianshou.data.buffer.prio import PrioritizedReplayBuffer from tianshou.data.buffer.her import HERReplayBuffer from tianshou.data.buffer.manager import ( diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index c7ee7505b..8eb88e939 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -73,6 +73,7 @@ TDistribution = TypeVar("TDistribution", bound=Distribution) T = TypeVar("T") TArr = torch.Tensor | np.ndarray +TObsArr = torch.Tensor | np.ndarray log = logging.getLogger(__name__) diff --git a/tianshou/data/buffer/__init__.py b/tianshou/data/buffer/__init__.py index e69de29bb..0609fddf8 100644 --- a/tianshou/data/buffer/__init__.py +++ b/tianshou/data/buffer/__init__.py @@ -0,0 +1,10 @@ +def _backward_compatibility() -> None: + import sys + + from . import buffer_base + + # backward compatibility with persisted buffers from v1 for determinism tests + sys.modules["tianshou.data.buffer.base"] = buffer_base + + +_backward_compatibility() diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/buffer_base.py similarity index 100% rename from tianshou/data/buffer/base.py rename to tianshou/data/buffer/buffer_base.py diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 3c6d75d4d..640025e25 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -12,6 +12,8 @@ from overrides import override from torch.distributions import Categorical, Distribution +from tianshou.algorithm import Algorithm +from tianshou.algorithm.algorithm_base import Policy, episode_mc_return_to_go from tianshou.config import ENABLE_VALIDATION from tianshou.data import ( Batch, @@ -22,7 +24,7 @@ VectorReplayBuffer, to_numpy, ) -from tianshou.data.buffer.base import MalformedBufferError +from tianshou.data.buffer.buffer_base import MalformedBufferError from tianshou.data.stats import compute_dim_to_summary_stats from tianshou.data.types import ( ActBatchProtocol, @@ -31,8 +33,6 @@ RolloutBatchProtocol, ) from tianshou.env import BaseVectorEnv, DummyVectorEnv -from tianshou.policy import BasePolicy -from tianshou.policy.base import episode_mc_return_to_go from tianshou.utils.determinism import TraceLogger from tianshou.utils.print import DataclassPPrintMixin from tianshou.utils.torch_utils import torch_train_mode @@ -313,7 +313,7 @@ class BaseCollector(Generic[TCollectStats], ABC): def __init__( self, - policy: BasePolicy, + policy: Policy | Algorithm, env: BaseVectorEnv | gym.Env, buffer: ReplayBuffer | None = None, exploration_noise: bool = False, @@ -355,7 +355,7 @@ def __init__( self.buffer: ReplayBuffer | ReplayBufferManager = buffer self.raise_on_nan_in_buffer = raise_on_nan_in_buffer - self.policy = policy + self.policy = policy.policy if isinstance(policy, Algorithm) else policy self.env = cast(BaseVectorEnv, env) self.exploration_noise = exploration_noise self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0 @@ -576,7 +576,7 @@ class Collector(BaseCollector[TCollectStats], Generic[TCollectStats]): # def __init__( self, - policy: BasePolicy, + policy: Policy | Algorithm, env: gym.Env | BaseVectorEnv, buffer: ReplayBuffer | None = None, exploration_noise: bool = False, @@ -586,8 +586,7 @@ def __init__( collect_stats_class: type[TCollectStats] = CollectStats, # type: ignore[assignment] ) -> None: """ - :param policy: a tianshou policy, each :class:`BasePolicy` is capable of computing a batch - of actions from a batch of observations. + :param policy: a tianshou policy or algorithm :param env: a ``gymnasium.Env`` environment or a vectorized instance of the :class:`~tianshou.env.BaseVectorEnv` class. The latter is strongly recommended, as with a gymnasium env the collection will not happen in parallel (a `DummyVectorEnv` @@ -736,7 +735,7 @@ def _compute_action_policy_hidden( act_RA = to_numpy(act_batch_RA.act) if self.exploration_noise: - act_RA = self.policy.exploration_noise(act_RA, obs_batch_R) + act_RA = self.policy.add_exploration_noise(act_RA, obs_batch_R) act_normalized_RA = self.policy.map_action(act_RA) # TODO: cleanup the whole policy in batch thing @@ -1122,7 +1121,7 @@ class AsyncCollector(Collector[CollectStats]): def __init__( self, - policy: BasePolicy, + policy: Policy | Algorithm, env: BaseVectorEnv, buffer: ReplayBuffer | None = None, exploration_noise: bool = False, diff --git a/tianshou/data/stats.py b/tianshou/data/stats.py index 11d64c017..4deda628f 100644 --- a/tianshou/data/stats.py +++ b/tianshou/data/stats.py @@ -8,8 +8,8 @@ from tianshou.utils.print import DataclassPPrintMixin if TYPE_CHECKING: + from tianshou.algorithm.algorithm_base import TrainingStats from tianshou.data import CollectStats, CollectStatsBase - from tianshou.policy.base import TrainingStats log = logging.getLogger(__name__) @@ -42,6 +42,10 @@ def from_sequence(cls, sequence: Sequence[float | int] | np.ndarray) -> "Sequenc min=float(np.min(sequence)), ) + @classmethod + def from_single_value(cls, value: float | int) -> "SequenceSummaryStats": + return cls(mean=value, std=0.0, max=value, min=value) + def compute_dim_to_summary_stats( arr: Sequence[Sequence[float]] | np.ndarray, @@ -79,8 +83,8 @@ class TimingStats(DataclassPPrintMixin): class InfoStats(DataclassPPrintMixin): """A data structure for storing information about the learning process.""" - gradient_step: int - """The total gradient step.""" + update_step: int + """The total number of update steps that have been taken.""" best_score: float """The best score over the test results. The one with the highest score will be considered the best model.""" best_reward: float @@ -107,7 +111,7 @@ class EpochStats(DataclassPPrintMixin): epoch: int """The current epoch.""" - train_collect_stat: "CollectStatsBase" + train_collect_stat: Optional["CollectStatsBase"] """The statistics of the last call to the training collector.""" test_collect_stat: Optional["CollectStats"] """The statistics of the last call to the test collector.""" diff --git a/tianshou/data/types.py b/tianshou/data/types.py index fd2f6d287..b87984ea2 100644 --- a/tianshou/data/types.py +++ b/tianshou/data/types.py @@ -4,7 +4,9 @@ import torch from tianshou.data import Batch -from tianshou.data.batch import BatchProtocol, TArr +from tianshou.data.batch import BatchProtocol, TArr, TObsArr + +TObs = TObsArr | BatchProtocol TNestedDictValue = np.ndarray | dict[str, "TNestedDictValue"] @@ -15,14 +17,19 @@ class ObsBatchProtocol(BatchProtocol, Protocol): Typically used inside a policy's forward """ - obs: TArr | BatchProtocol - info: TArr | BatchProtocol + obs: TObs + """the observations as generated by the environment in `step`. + If it is a dict env, this will be an instance of Batch, otherwise it will be an array (or tensor if your env returns tensors)""" + info: TArr + """array of info dicts generated by the environment in `step`""" class RolloutBatchProtocol(ObsBatchProtocol, Protocol): """Typically, the outcome of sampling from a replay buffer.""" - obs_next: TArr | BatchProtocol + obs_next: TObs + """the observations after obs as generated by the environment in `step`. + If it is a dict env, this will be an instance of Batch, otherwise it will be an array (or tensor if your env returns tensors)""" act: TArr rew: np.ndarray terminated: TArr @@ -32,13 +39,14 @@ class RolloutBatchProtocol(ObsBatchProtocol, Protocol): class BatchWithReturnsProtocol(RolloutBatchProtocol, Protocol): """With added returns, usually computed with GAE.""" - returns: TArr + returns: torch.Tensor class PrioBatchProtocol(RolloutBatchProtocol, Protocol): """Contains weights that can be used for prioritized replay.""" weight: np.ndarray | torch.Tensor + """can be used for prioritized replay.""" class RecurrentStateBatch(BatchProtocol, Protocol): @@ -118,7 +126,7 @@ class QuantileRegressionBatchProtocol(ModelOutputBatchProtocol, Protocol): taus: torch.Tensor -class ImitationBatchProtocol(ActBatchProtocol, Protocol): +class ImitationBatchProtocol(ModelOutputBatchProtocol, Protocol): """Similar to other batches, but contains `imitation_logits` and `q_value` fields.""" state: dict | Batch | np.ndarray | None diff --git a/examples/atari/atari_network.py b/tianshou/env/atari/atari_network.py similarity index 73% rename from examples/atari/atari_network.py rename to tianshou/env/atari/atari_network.py index 87797f760..79b862c4c 100644 --- a/examples/atari/atari_network.py +++ b/tianshou/env/atari/atari_network.py @@ -1,10 +1,13 @@ from collections.abc import Callable, Sequence -from typing import Any +from typing import Any, TypeVar import numpy as np import torch from torch import nn +from tianshou.algorithm.modelfree.reinforce import TDistFnDiscrOrCont +from tianshou.data import Batch +from tianshou.data.types import TObs from tianshou.highlevel.env import Environments from tianshou.highlevel.module.actor import ActorFactory from tianshou.highlevel.module.core import ( @@ -15,41 +18,46 @@ IntermediateModuleFactory, ) from tianshou.highlevel.params.dist_fn import DistributionFunctionFactoryCategorical -from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont -from tianshou.utils.net.common import NetBase -from tianshou.utils.net.discrete import Actor, NoisyLinear +from tianshou.utils.net.common import ( + ActionReprNetWithVectorOutput, +) +from tianshou.utils.net.discrete import DiscreteActor, NoisyLinear +from tianshou.utils.torch_utils import torch_device def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0.0) -> nn.Module: + """TODO.""" torch.nn.init.orthogonal_(layer.weight, std) torch.nn.init.constant_(layer.bias, bias_const) return layer -class ScaledObsInputModule(torch.nn.Module): - def __init__(self, module: NetBase, denom: float = 255.0) -> None: - super().__init__() +T = TypeVar("T") + + +class ScaledObsInputActionReprNet(ActionReprNetWithVectorOutput): + def __init__(self, module: ActionReprNetWithVectorOutput, denom: float = 255.0) -> None: + super().__init__(module.get_output_dim()) self.module = module self.denom = denom - # This is required such that the value can be retrieved by downstream modules (see usages of get_output_dim) - self.output_dim = module.output_dim def forward( self, - obs: np.ndarray | torch.Tensor, - state: Any | None = None, - info: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, Any]: + obs: TObs, + state: T | None = None, + info: dict[str, T] | None = None, + ) -> tuple[torch.Tensor | Sequence[torch.Tensor], T | None]: if info is None: info = {} - return self.module.forward(obs / self.denom, state, info) - - -def scale_obs(module: NetBase, denom: float = 255.0) -> ScaledObsInputModule: - return ScaledObsInputModule(module, denom=denom) + scaler = lambda arr: arr / self.denom + if isinstance(obs, Batch): + scaled_obs = obs.apply_values_transform(scaler) + else: + scaled_obs = scaler(obs) + return self.module.forward(scaled_obs, state, info) -class DQN(NetBase[Any]): +class DQNet(ActionReprNetWithVectorOutput[Any]): """Reference: Human-level control through deep reinforcement learning. For advanced usage (how to customize the network), please refer to @@ -62,7 +70,6 @@ def __init__( h: int, w: int, action_shape: Sequence[int] | int, - device: str | int | torch.device = "cpu", features_only: bool = False, output_dim_added_layer: int | None = None, layer_init: Callable[[nn.Module], nn.Module] = lambda x: x, @@ -72,9 +79,7 @@ def __init__( raise ValueError( "Should not provide explicit output dimension using `output_dim_added_layer` when `features_only` is true.", ) - super().__init__() - self.device = device - self.net = nn.Sequential( + net = nn.Sequential( layer_init(nn.Conv2d(c, 32, kernel_size=8, stride=4)), nn.ReLU(inplace=True), layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)), @@ -84,39 +89,44 @@ def __init__( nn.Flatten(), ) with torch.no_grad(): - base_cnn_output_dim = int(np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:])) + base_cnn_output_dim = int(np.prod(net(torch.zeros(1, c, h, w)).shape[1:])) if not features_only: action_dim = int(np.prod(action_shape)) - self.net = nn.Sequential( - self.net, + net = nn.Sequential( + net, layer_init(nn.Linear(base_cnn_output_dim, 512)), nn.ReLU(inplace=True), layer_init(nn.Linear(512, action_dim)), ) - self.output_dim = action_dim + output_dim = action_dim elif output_dim_added_layer is not None: - self.net = nn.Sequential( - self.net, + net = nn.Sequential( + net, layer_init(nn.Linear(base_cnn_output_dim, output_dim_added_layer)), nn.ReLU(inplace=True), ) - self.output_dim = output_dim_added_layer + output_dim = output_dim_added_layer else: - self.output_dim = base_cnn_output_dim + output_dim = base_cnn_output_dim + super().__init__(output_dim) + self.net = net def forward( self, - obs: np.ndarray | torch.Tensor, - state: Any | None = None, - info: dict[str, Any] | None = None, - **kwargs: Any, - ) -> tuple[torch.Tensor, Any]: - r"""Mapping: s -> Q(s, \*).""" - obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32) + obs: TObs, + state: T | None = None, + info: dict[str, T] | None = None, + ) -> tuple[torch.Tensor, T | None]: + r"""Mapping: s -> Q(s, \*). + + For more info, see docstring of parent. + """ + device = torch_device(self) + obs = torch.as_tensor(obs, device=device, dtype=torch.float32) return self.net(obs), state -class C51(DQN): +class C51Net(DQNet): """Reference: A distributional perspective on reinforcement learning. For advanced usage (how to customize the network), please refer to @@ -125,24 +135,23 @@ class C51(DQN): def __init__( self, + *, c: int, h: int, w: int, action_shape: Sequence[int], num_atoms: int = 51, - device: str | int | torch.device = "cpu", ) -> None: self.action_num = int(np.prod(action_shape)) - super().__init__(c, h, w, [self.action_num * num_atoms], device) + super().__init__(c=c, h=h, w=w, action_shape=[self.action_num * num_atoms]) self.num_atoms = num_atoms def forward( self, - obs: np.ndarray | torch.Tensor, - state: Any | None = None, - info: dict[str, Any] | None = None, - **kwargs: Any, - ) -> tuple[torch.Tensor, Any]: + obs: TObs, + state: T | None = None, + info: dict[str, T] | None = None, + ) -> tuple[torch.Tensor, T | None]: r"""Mapping: x -> Z(x, \*).""" obs, state = super().forward(obs) obs = obs.view(-1, self.num_atoms).softmax(dim=-1) @@ -150,7 +159,7 @@ def forward( return obs, state -class Rainbow(DQN): +class RainbowNet(DQNet): """Reference: Rainbow: Combining Improvements in Deep Reinforcement Learning. For advanced usage (how to customize the network), please refer to @@ -159,17 +168,17 @@ class Rainbow(DQN): def __init__( self, + *, c: int, h: int, w: int, action_shape: Sequence[int], num_atoms: int = 51, noisy_std: float = 0.5, - device: str | int | torch.device = "cpu", is_dueling: bool = True, is_noisy: bool = True, ) -> None: - super().__init__(c, h, w, action_shape, device, features_only=True) + super().__init__(c=c, h=h, w=w, action_shape=action_shape, features_only=True) self.action_num = int(np.prod(action_shape)) self.num_atoms = num_atoms @@ -194,12 +203,10 @@ def linear(x: int, y: int) -> NoisyLinear | nn.Linear: def forward( self, - obs: np.ndarray | torch.Tensor, - state: Any | None = None, + obs: TObs, + state: T | None = None, info: dict[str, Any] | None = None, - **kwargs: Any, - ) -> tuple[torch.Tensor, Any]: - r"""Mapping: x -> Z(x, \*).""" + ) -> tuple[torch.Tensor, T | None]: obs, state = super().forward(obs) q = self.Q(obs) q = q.view(-1, self.action_num, self.num_atoms) @@ -213,7 +220,7 @@ def forward( return probs, state -class QRDQN(DQN): +class QRDQNet(DQNet): """Reference: Distributional Reinforcement Learning with Quantile Regression. For advanced usage (how to customize the network), please refer to @@ -228,20 +235,17 @@ def __init__( w: int, action_shape: Sequence[int] | int, num_quantiles: int = 200, - device: str | int | torch.device = "cpu", ) -> None: self.action_num = int(np.prod(action_shape)) - super().__init__(c, h, w, [self.action_num * num_quantiles], device) + super().__init__(c=c, h=h, w=w, action_shape=[self.action_num * num_quantiles]) self.num_quantiles = num_quantiles def forward( self, - obs: np.ndarray | torch.Tensor, - state: Any | None = None, + obs: TObs, + state: T | None = None, info: dict[str, Any] | None = None, - **kwargs: Any, - ) -> tuple[torch.Tensor, Any]: - r"""Mapping: x -> Z(x, \*).""" + ) -> tuple[torch.Tensor, T | None]: obs, state = super().forward(obs) obs = obs.view(-1, self.action_num, self.num_quantiles) return obs, state @@ -260,28 +264,26 @@ def __init__( self.scale_obs = scale_obs self.features_only = features_only - def create_module(self, envs: Environments, device: TDevice) -> Actor: + def create_module(self, envs: Environments, device: TDevice) -> DiscreteActor: c, h, w = envs.get_observation_shape() # type: ignore # only right shape is a sequence of length 3 action_shape = envs.get_action_shape() if isinstance(action_shape, np.int64): action_shape = int(action_shape) - net: DQN | ScaledObsInputModule - net = DQN( + net: DQNet | ScaledObsInputActionReprNet + net = DQNet( c=c, h=h, w=w, action_shape=action_shape, - device=device, features_only=self.features_only, output_dim_added_layer=self.output_dim_added_layer, layer_init=layer_init, ) if self.scale_obs: - net = scale_obs(net) - return Actor( - net, - envs.get_action_shape(), - device=device, + net = ScaledObsInputActionReprNet(net) + return DiscreteActor( + preprocess_net=net, + action_shape=envs.get_action_shape(), softmax_output=self.USE_SOFTMAX_OUTPUT, ).to(device) @@ -305,12 +307,11 @@ def create_intermediate_module(self, envs: Environments, device: TDevice) -> Int action_shape = envs.get_action_shape() if isinstance(action_shape, np.int64): action_shape = int(action_shape) - dqn = DQN( + dqn = DQNet( c=c, h=h, w=w, action_shape=action_shape, - device=device, features_only=self.features_only, ).to(device) module = dqn.net if self.net_only else dqn diff --git a/examples/atari/atari_wrapper.py b/tianshou/env/atari/atari_wrapper.py similarity index 98% rename from examples/atari/atari_wrapper.py rename to tianshou/env/atari/atari_wrapper.py index 25a3b09f2..7a8d3effb 100644 --- a/examples/atari/atari_wrapper.py +++ b/tianshou/env/atari/atari_wrapper.py @@ -40,6 +40,7 @@ def _parse_reset_result(reset_result: tuple) -> tuple[tuple, dict, bool]: def get_space_dtype(obs_space: gym.spaces.Box) -> type[np.floating] | type[np.integer]: + """TODO.""" obs_space_dtype: type[np.integer] | type[np.floating] if np.issubdtype(obs_space.dtype, np.integer): obs_space_dtype = np.integer @@ -351,9 +352,15 @@ def wrap_deepmind( env = MaxAndSkipEnv(env, skip=4) assert hasattr(env.unwrapped, "get_action_meanings") # for mypy - wrapped_env: MaxAndSkipEnv | EpisodicLifeEnv | FireResetEnv | WarpFrame | ScaledFloatFrame | ClipRewardEnv | FrameStack = ( - env - ) + wrapped_env: ( + MaxAndSkipEnv + | EpisodicLifeEnv + | FireResetEnv + | WarpFrame + | ScaledFloatFrame + | ClipRewardEnv + | FrameStack + ) = env if episode_life: wrapped_env = EpisodicLifeEnv(wrapped_env) if "FIRE" in env.unwrapped.get_action_meanings(): @@ -372,8 +379,8 @@ def wrap_deepmind( def make_atari_env( task: str, seed: int, - training_num: int, - test_num: int, + num_train_envs: int, + num_test_envs: int, scale: int | bool = False, frame_stack: int = 4, ) -> tuple[Env, BaseVectorEnv, BaseVectorEnv]: @@ -384,7 +391,7 @@ def make_atari_env( :return: a tuple of (single env, training envs, test envs). """ env_factory = AtariEnvFactory(task, frame_stack, scale=bool(scale)) - envs = env_factory.create_envs(training_num, test_num, seed=seed) + envs = env_factory.create_envs(num_train_envs, num_test_envs, seed=seed) return envs.env, envs.train_envs, envs.test_envs diff --git a/tianshou/env/pettingzoo_env.py b/tianshou/env/pettingzoo_env.py index 196741453..2e8287ba9 100644 --- a/tianshou/env/pettingzoo_env.py +++ b/tianshou/env/pettingzoo_env.py @@ -18,7 +18,7 @@ class PettingZooEnv(AECEnv, ABC): - """The interface for petting zoo environments. + """The interface for petting zoo environments which support multi-agent RL. Multi-agent environments must be wrapped as :class:`~tianshou.env.PettingZooEnv`. Here is the usage: diff --git a/tianshou/env/worker/__init__.py b/tianshou/env/worker/__init__.py index 1b1f37510..5e3d21235 100644 --- a/tianshou/env/worker/__init__.py +++ b/tianshou/env/worker/__init__.py @@ -1,4 +1,6 @@ -from tianshou.env.worker.base import EnvWorker +# isort:skip_file +# NOTE: Import order is important to avoid circular import errors! +from tianshou.env.worker.worker_base import EnvWorker from tianshou.env.worker.dummy import DummyEnvWorker from tianshou.env.worker.ray import RayEnvWorker from tianshou.env.worker.subproc import SubprocEnvWorker diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/worker_base.py similarity index 100% rename from tianshou/env/worker/base.py rename to tianshou/env/worker/worker_base.py diff --git a/tianshou/evaluation/rliable_evaluation_hl.py b/tianshou/evaluation/rliable_evaluation_hl.py index dc6205cc6..45fc450f2 100644 --- a/tianshou/evaluation/rliable_evaluation_hl.py +++ b/tianshou/evaluation/rliable_evaluation_hl.py @@ -15,7 +15,7 @@ from tianshou.highlevel.experiment import Experiment from tianshou.utils import TensorboardLogger -from tianshou.utils.logger.base import DataScope +from tianshou.utils.logger.logger_base import DataScope log = logging.getLogger(__name__) diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py deleted file mode 100644 index a023f0190..000000000 --- a/tianshou/highlevel/agent.py +++ /dev/null @@ -1,646 +0,0 @@ -import logging -import typing -from abc import ABC, abstractmethod -from typing import Any, Generic, TypeVar, cast - -import gymnasium -from sensai.util.string import ToStringMixin - -from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer -from tianshou.data.collector import BaseCollector, CollectStats -from tianshou.highlevel.config import SamplingConfig -from tianshou.highlevel.env import Environments -from tianshou.highlevel.module.actor import ( - ActorFactory, -) -from tianshou.highlevel.module.core import ( - ModuleFactory, - TDevice, -) -from tianshou.highlevel.module.critic import CriticEnsembleFactory, CriticFactory -from tianshou.highlevel.module.module_opt import ( - ActorCriticOpt, -) -from tianshou.highlevel.optim import OptimizerFactory -from tianshou.highlevel.params.policy_params import ( - A2CParams, - DDPGParams, - DiscreteSACParams, - DQNParams, - IQNParams, - NPGParams, - Params, - ParamsMixinActorAndDualCritics, - ParamsMixinLearningRateWithScheduler, - ParamTransformerData, - PGParams, - PPOParams, - REDQParams, - SACParams, - TD3Params, - TRPOParams, -) -from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory -from tianshou.highlevel.persistence import PolicyPersistence -from tianshou.highlevel.trainer import TrainerCallbacks, TrainingContext -from tianshou.highlevel.world import World -from tianshou.policy import ( - A2CPolicy, - BasePolicy, - DDPGPolicy, - DiscreteSACPolicy, - DQNPolicy, - IQNPolicy, - NPGPolicy, - PGPolicy, - PPOPolicy, - REDQPolicy, - SACPolicy, - TD3Policy, - TRPOPolicy, -) -from tianshou.policy.base import RandomActionPolicy -from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer -from tianshou.utils.net.common import ActorCritic - -CHECKPOINT_DICT_KEY_MODEL = "model" -CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms" -TParams = TypeVar("TParams", bound=Params) -TActorCriticParams = TypeVar( - "TActorCriticParams", - bound=Params | ParamsMixinLearningRateWithScheduler, -) -TActorDualCriticsParams = TypeVar( - "TActorDualCriticsParams", - bound=Params | ParamsMixinActorAndDualCritics, -) -TDiscreteCriticOnlyParams = TypeVar( - "TDiscreteCriticOnlyParams", - bound=Params | ParamsMixinLearningRateWithScheduler, -) -TPolicy = TypeVar("TPolicy", bound=BasePolicy) -log = logging.getLogger(__name__) - - -class AgentFactory(ABC, ToStringMixin): - """Factory for the creation of an agent's policy, its trainer as well as collectors.""" - - def __init__(self, sampling_config: SamplingConfig, optim_factory: OptimizerFactory): - self.sampling_config = sampling_config - self.optim_factory = optim_factory - self.policy_wrapper_factory: PolicyWrapperFactory | None = None - self.trainer_callbacks: TrainerCallbacks = TrainerCallbacks() - - def create_train_test_collector( - self, - policy: BasePolicy, - envs: Environments, - reset_collectors: bool = True, - ) -> tuple[BaseCollector, BaseCollector]: - """:param policy: - :param envs: - :param reset_collectors: Whether to reset the collectors before returning them. - Setting to True means that the envs will be reset as well. - :return: - """ - buffer_size = self.sampling_config.buffer_size - train_envs = envs.train_envs - buffer: ReplayBuffer - if len(train_envs) > 1: - buffer = VectorReplayBuffer( - buffer_size, - len(train_envs), - stack_num=self.sampling_config.replay_buffer_stack_num, - save_only_last_obs=self.sampling_config.replay_buffer_save_only_last_obs, - ignore_obs_next=self.sampling_config.replay_buffer_ignore_obs_next, - ) - else: - buffer = ReplayBuffer( - buffer_size, - stack_num=self.sampling_config.replay_buffer_stack_num, - save_only_last_obs=self.sampling_config.replay_buffer_save_only_last_obs, - ignore_obs_next=self.sampling_config.replay_buffer_ignore_obs_next, - ) - train_collector = Collector[CollectStats]( - policy, - train_envs, - buffer, - exploration_noise=True, - ) - test_collector = Collector[CollectStats](policy, envs.test_envs) - if reset_collectors: - train_collector.reset() - test_collector.reset() - return train_collector, test_collector - - def set_policy_wrapper_factory( - self, - policy_wrapper_factory: PolicyWrapperFactory | None, - ) -> None: - self.policy_wrapper_factory = policy_wrapper_factory - - def set_trainer_callbacks(self, callbacks: TrainerCallbacks) -> None: - self.trainer_callbacks = callbacks - - @abstractmethod - def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: - pass - - def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: - policy = self._create_policy(envs, device) - if self.policy_wrapper_factory is not None: - policy = self.policy_wrapper_factory.create_wrapped_policy( - policy, - envs, - self.optim_factory, - device, - ) - return policy - - @abstractmethod - def create_trainer(self, world: World, policy_persistence: PolicyPersistence) -> BaseTrainer: - pass - - -class OnPolicyAgentFactory(AgentFactory, ABC): - def create_trainer( - self, - world: World, - policy_persistence: PolicyPersistence, - ) -> OnpolicyTrainer: - sampling_config = self.sampling_config - callbacks = self.trainer_callbacks - context = TrainingContext(world.policy, world.envs, world.logger) - train_fn = ( - callbacks.epoch_train_callback.get_trainer_fn(context) - if callbacks.epoch_train_callback - else None - ) - test_fn = ( - callbacks.epoch_test_callback.get_trainer_fn(context) - if callbacks.epoch_test_callback - else None - ) - stop_fn = ( - callbacks.epoch_stop_callback.get_trainer_fn(context) - if callbacks.epoch_stop_callback - else None - ) - return OnpolicyTrainer( - policy=world.policy, - train_collector=world.train_collector, - test_collector=world.test_collector, - max_epoch=sampling_config.num_epochs, - step_per_epoch=sampling_config.step_per_epoch, - repeat_per_collect=sampling_config.repeat_per_collect, - episode_per_test=sampling_config.num_test_episodes, - batch_size=sampling_config.batch_size, - step_per_collect=sampling_config.step_per_collect, - save_best_fn=policy_persistence.get_save_best_fn(world), - save_checkpoint_fn=policy_persistence.get_save_checkpoint_fn(world), - logger=world.logger, - test_in_train=False, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - verbose=False, - ) - - -class OffPolicyAgentFactory(AgentFactory, ABC): - def create_trainer( - self, - world: World, - policy_persistence: PolicyPersistence, - ) -> OffpolicyTrainer: - sampling_config = self.sampling_config - callbacks = self.trainer_callbacks - context = TrainingContext(world.policy, world.envs, world.logger) - train_fn = ( - callbacks.epoch_train_callback.get_trainer_fn(context) - if callbacks.epoch_train_callback - else None - ) - test_fn = ( - callbacks.epoch_test_callback.get_trainer_fn(context) - if callbacks.epoch_test_callback - else None - ) - stop_fn = ( - callbacks.epoch_stop_callback.get_trainer_fn(context) - if callbacks.epoch_stop_callback - else None - ) - return OffpolicyTrainer( - policy=world.policy, - train_collector=world.train_collector, - test_collector=world.test_collector, - max_epoch=sampling_config.num_epochs, - step_per_epoch=sampling_config.step_per_epoch, - step_per_collect=sampling_config.step_per_collect, - episode_per_test=sampling_config.num_test_episodes, - batch_size=sampling_config.batch_size, - save_best_fn=policy_persistence.get_save_best_fn(world), - logger=world.logger, - update_per_step=sampling_config.update_per_step, - test_in_train=False, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - verbose=False, - ) - - -class RandomActionAgentFactory(OnPolicyAgentFactory): - def _create_policy(self, envs: Environments, device: TDevice) -> RandomActionPolicy: - return RandomActionPolicy(envs.get_action_space()) - - -class PGAgentFactory(OnPolicyAgentFactory): - def __init__( - self, - params: PGParams, - sampling_config: SamplingConfig, - actor_factory: ActorFactory, - optim_factory: OptimizerFactory, - ): - super().__init__(sampling_config, optim_factory) - self.params = params - self.actor_factory = actor_factory - self.optim_factory = optim_factory - - def _create_policy(self, envs: Environments, device: TDevice) -> PGPolicy: - actor = self.actor_factory.create_module_opt( - envs, - device, - self.optim_factory, - self.params.lr, - ) - kwargs = self.params.create_kwargs( - ParamTransformerData( - envs=envs, - device=device, - optim=actor.optim, - optim_factory=self.optim_factory, - ), - ) - dist_fn = self.actor_factory.create_dist_fn(envs) - assert dist_fn is not None - return PGPolicy( - actor=actor.module, - optim=actor.optim, - action_space=envs.get_action_space(), - observation_space=envs.get_observation_space(), - dist_fn=dist_fn, - **kwargs, - ) - - -class ActorCriticAgentFactory( - Generic[TActorCriticParams, TPolicy], - OnPolicyAgentFactory, - ABC, -): - def __init__( - self, - params: TActorCriticParams, - sampling_config: SamplingConfig, - actor_factory: ActorFactory, - critic_factory: CriticFactory, - optimizer_factory: OptimizerFactory, - ): - super().__init__(sampling_config, optim_factory=optimizer_factory) - self.params = params - self.actor_factory = actor_factory - self.critic_factory = critic_factory - self.optim_factory = optimizer_factory - self.critic_use_action = False - - @abstractmethod - def _get_policy_class(self) -> type[TPolicy]: - pass - - @abstractmethod - def _include_actor_in_optim(self) -> bool: - pass - - def create_actor_critic_module_opt( - self, - envs: Environments, - device: TDevice, - lr: float, - ) -> ActorCriticOpt: - actor = self.actor_factory.create_module(envs, device) - critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action) - actor_critic = ActorCritic(actor, critic) - if self._include_actor_in_optim(): - optim = self.optim_factory.create_optimizer(actor_critic, lr) - else: - optim = self.optim_factory.create_optimizer(critic, lr) - return ActorCriticOpt(actor_critic, optim) - - @typing.no_type_check - def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]: - actor_critic = self.create_actor_critic_module_opt(envs, device, self.params.lr) - kwargs = self.params.create_kwargs( - ParamTransformerData( - envs=envs, - device=device, - optim_factory=self.optim_factory, - optim=actor_critic.optim, - ), - ) - kwargs["actor"] = actor_critic.actor - kwargs["critic"] = actor_critic.critic - kwargs["optim"] = actor_critic.optim - kwargs["action_space"] = envs.get_action_space() - kwargs["dist_fn"] = self.actor_factory.create_dist_fn(envs) - return kwargs - - def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy: - policy_class = self._get_policy_class() - return policy_class(**self._create_kwargs(envs, device)) - - -class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]): - def _include_actor_in_optim(self) -> bool: - return True - - def _get_policy_class(self) -> type[A2CPolicy]: - return A2CPolicy - - -class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]): - def _include_actor_in_optim(self) -> bool: - return True - - def _get_policy_class(self) -> type[PPOPolicy]: - return PPOPolicy - - -class NPGAgentFactory(ActorCriticAgentFactory[NPGParams, NPGPolicy]): - def _include_actor_in_optim(self) -> bool: - return False - - def _get_policy_class(self) -> type[NPGPolicy]: - return NPGPolicy - - -class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]): - def _include_actor_in_optim(self) -> bool: - return False - - def _get_policy_class(self) -> type[TRPOPolicy]: - return TRPOPolicy - - -class DiscreteCriticOnlyAgentFactory( - OffPolicyAgentFactory, - Generic[TDiscreteCriticOnlyParams, TPolicy], -): - def __init__( - self, - params: TDiscreteCriticOnlyParams, - sampling_config: SamplingConfig, - model_factory: ModuleFactory, - optim_factory: OptimizerFactory, - ): - super().__init__(sampling_config, optim_factory) - self.params = params - self.model_factory = model_factory - self.optim_factory = optim_factory - - @abstractmethod - def _get_policy_class(self) -> type[TPolicy]: - pass - - @typing.no_type_check - def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy: - model = self.model_factory.create_module(envs, device) - optim = self.optim_factory.create_optimizer(model, self.params.lr) - kwargs = self.params.create_kwargs( - ParamTransformerData( - envs=envs, - device=device, - optim=optim, - optim_factory=self.optim_factory, - ), - ) - envs.get_type().assert_discrete(self) - action_space = cast(gymnasium.spaces.Discrete, envs.get_action_space()) - policy_class = self._get_policy_class() - return policy_class( - model=model, - optim=optim, - action_space=action_space, - observation_space=envs.get_observation_space(), - **kwargs, - ) - - -class DQNAgentFactory(DiscreteCriticOnlyAgentFactory[DQNParams, DQNPolicy]): - def _get_policy_class(self) -> type[DQNPolicy]: - return DQNPolicy - - -class IQNAgentFactory(DiscreteCriticOnlyAgentFactory[IQNParams, IQNPolicy]): - def _get_policy_class(self) -> type[IQNPolicy]: - return IQNPolicy - - -class DDPGAgentFactory(OffPolicyAgentFactory): - def __init__( - self, - params: DDPGParams, - sampling_config: SamplingConfig, - actor_factory: ActorFactory, - critic_factory: CriticFactory, - optim_factory: OptimizerFactory, - ): - super().__init__(sampling_config, optim_factory) - self.critic_factory = critic_factory - self.actor_factory = actor_factory - self.params = params - self.optim_factory = optim_factory - - def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: - actor = self.actor_factory.create_module_opt( - envs, - device, - self.optim_factory, - self.params.actor_lr, - ) - critic = self.critic_factory.create_module_opt( - envs, - device, - True, - self.optim_factory, - self.params.critic_lr, - ) - kwargs = self.params.create_kwargs( - ParamTransformerData( - envs=envs, - device=device, - optim_factory=self.optim_factory, - actor=actor, - critic1=critic, - ), - ) - return DDPGPolicy( - actor=actor.module, - actor_optim=actor.optim, - critic=critic.module, - critic_optim=critic.optim, - action_space=envs.get_action_space(), - observation_space=envs.get_observation_space(), - **kwargs, - ) - - -class REDQAgentFactory(OffPolicyAgentFactory): - def __init__( - self, - params: REDQParams, - sampling_config: SamplingConfig, - actor_factory: ActorFactory, - critic_ensemble_factory: CriticEnsembleFactory, - optim_factory: OptimizerFactory, - ): - super().__init__(sampling_config, optim_factory) - self.critic_ensemble_factory = critic_ensemble_factory - self.actor_factory = actor_factory - self.params = params - self.optim_factory = optim_factory - - def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: - envs.get_type().assert_continuous(self) - actor = self.actor_factory.create_module_opt( - envs, - device, - self.optim_factory, - self.params.actor_lr, - ) - critic_ensemble = self.critic_ensemble_factory.create_module_opt( - envs, - device, - self.params.ensemble_size, - True, - self.optim_factory, - self.params.critic_lr, - ) - kwargs = self.params.create_kwargs( - ParamTransformerData( - envs=envs, - device=device, - optim_factory=self.optim_factory, - actor=actor, - critic1=critic_ensemble, - ), - ) - action_space = cast(gymnasium.spaces.Box, envs.get_action_space()) - return REDQPolicy( - actor=actor.module, - actor_optim=actor.optim, - critic=critic_ensemble.module, - critic_optim=critic_ensemble.optim, - action_space=action_space, - observation_space=envs.get_observation_space(), - **kwargs, - ) - - -class ActorDualCriticsAgentFactory( - OffPolicyAgentFactory, - Generic[TActorDualCriticsParams, TPolicy], - ABC, -): - def __init__( - self, - params: TActorDualCriticsParams, - sampling_config: SamplingConfig, - actor_factory: ActorFactory, - critic1_factory: CriticFactory, - critic2_factory: CriticFactory, - optim_factory: OptimizerFactory, - ): - super().__init__(sampling_config, optim_factory) - self.params = params - self.actor_factory = actor_factory - self.critic1_factory = critic1_factory - self.critic2_factory = critic2_factory - self.optim_factory = optim_factory - - @abstractmethod - def _get_policy_class(self) -> type[TPolicy]: - pass - - def _get_discrete_last_size_use_action_shape(self) -> bool: - return True - - @staticmethod - def _get_critic_use_action(envs: Environments) -> bool: - return envs.get_type().is_continuous() - - @typing.no_type_check - def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy: - actor = self.actor_factory.create_module_opt( - envs, - device, - self.optim_factory, - self.params.actor_lr, - ) - use_action_shape = self._get_discrete_last_size_use_action_shape() - critic_use_action = self._get_critic_use_action(envs) - critic1 = self.critic1_factory.create_module_opt( - envs, - device, - critic_use_action, - self.optim_factory, - self.params.critic1_lr, - discrete_last_size_use_action_shape=use_action_shape, - ) - critic2 = self.critic2_factory.create_module_opt( - envs, - device, - critic_use_action, - self.optim_factory, - self.params.critic2_lr, - discrete_last_size_use_action_shape=use_action_shape, - ) - kwargs = self.params.create_kwargs( - ParamTransformerData( - envs=envs, - device=device, - optim_factory=self.optim_factory, - actor=actor, - critic1=critic1, - critic2=critic2, - ), - ) - policy_class = self._get_policy_class() - return policy_class( - actor=actor.module, - actor_optim=actor.optim, - critic=critic1.module, - critic_optim=critic1.optim, - critic2=critic2.module, - critic2_optim=critic2.optim, - action_space=envs.get_action_space(), - observation_space=envs.get_observation_space(), - **kwargs, - ) - - -class SACAgentFactory(ActorDualCriticsAgentFactory[SACParams, SACPolicy]): - def _get_policy_class(self) -> type[SACPolicy]: - return SACPolicy - - -class DiscreteSACAgentFactory(ActorDualCriticsAgentFactory[DiscreteSACParams, DiscreteSACPolicy]): - def _get_policy_class(self) -> type[DiscreteSACPolicy]: - return DiscreteSACPolicy - - -class TD3AgentFactory(ActorDualCriticsAgentFactory[TD3Params, TD3Policy]): - def _get_policy_class(self) -> type[TD3Policy]: - return TD3Policy diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py new file mode 100644 index 000000000..deaa2cd35 --- /dev/null +++ b/tianshou/highlevel/algorithm.py @@ -0,0 +1,723 @@ +import logging +import typing +from abc import ABC, abstractmethod +from typing import Any, Generic, TypeVar, cast + +import gymnasium +import torch +from sensai.util.string import ToStringMixin + +from tianshou.algorithm import ( + A2C, + DDPG, + DQN, + IQN, + NPG, + PPO, + REDQ, + SAC, + TD3, + TRPO, + Algorithm, + DiscreteSAC, + Reinforce, +) +from tianshou.algorithm.algorithm_base import ( + OffPolicyAlgorithm, + OnPolicyAlgorithm, + Policy, +) +from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy +from tianshou.algorithm.modelfree.discrete_sac import DiscreteSACPolicy +from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy +from tianshou.algorithm.modelfree.iqn import IQNPolicy +from tianshou.algorithm.modelfree.redq import REDQPolicy +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy +from tianshou.algorithm.modelfree.sac import SACPolicy +from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.data.collector import BaseCollector, CollectStats +from tianshou.highlevel.config import ( + OffPolicyTrainingConfig, + OnPolicyTrainingConfig, + TrainingConfig, +) +from tianshou.highlevel.env import Environments +from tianshou.highlevel.module.actor import ( + ActorFactory, +) +from tianshou.highlevel.module.core import ( + ModuleFactory, + TDevice, +) +from tianshou.highlevel.module.critic import CriticEnsembleFactory, CriticFactory +from tianshou.highlevel.params.algorithm_params import ( + A2CParams, + DDPGParams, + DiscreteSACParams, + DQNParams, + IQNParams, + NPGParams, + Params, + ParamsMixinActorAndDualCritics, + ParamsMixinSingleModel, + ParamTransformerData, + PPOParams, + REDQParams, + ReinforceParams, + SACParams, + TD3Params, + TRPOParams, +) +from tianshou.highlevel.params.algorithm_wrapper import AlgorithmWrapperFactory +from tianshou.highlevel.params.optim import OptimizerFactoryFactory +from tianshou.highlevel.persistence import PolicyPersistence +from tianshou.highlevel.trainer import TrainerCallbacks, TrainingContext +from tianshou.highlevel.world import World +from tianshou.trainer import ( + OffPolicyTrainer, + OffPolicyTrainerParams, + OnPolicyTrainer, + OnPolicyTrainerParams, + Trainer, +) +from tianshou.utils.net.discrete import DiscreteActor + +CHECKPOINT_DICT_KEY_MODEL = "model" +CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms" +TParams = TypeVar("TParams", bound=Params) +TActorCriticParams = TypeVar( + "TActorCriticParams", + bound=Params | ParamsMixinSingleModel, +) +TActorDualCriticsParams = TypeVar( + "TActorDualCriticsParams", + bound=Params | ParamsMixinActorAndDualCritics, +) +TDiscreteCriticOnlyParams = TypeVar( + "TDiscreteCriticOnlyParams", + bound=Params | ParamsMixinSingleModel, +) +TAlgorithm = TypeVar("TAlgorithm", bound=Algorithm) +TPolicy = TypeVar("TPolicy", bound=Policy) +TTrainingConfig = TypeVar("TTrainingConfig", bound=TrainingConfig) +log = logging.getLogger(__name__) + + +class AlgorithmFactory(ABC, ToStringMixin, Generic[TTrainingConfig]): + """Factory for the creation of an :class:`Algorithm` instance, its policy, trainer as well as collectors.""" + + def __init__(self, training_config: TTrainingConfig, optim_factory: OptimizerFactoryFactory): + self.training_config = training_config + self.optim_factory = optim_factory + self.algorithm_wrapper_factory: AlgorithmWrapperFactory | None = None + self.trainer_callbacks: TrainerCallbacks = TrainerCallbacks() + + def create_train_test_collector( + self, + policy: Algorithm, + envs: Environments, + reset_collectors: bool = True, + ) -> tuple[BaseCollector, BaseCollector]: + """:param policy: + :param envs: + :param reset_collectors: Whether to reset the collectors before returning them. + Setting to True means that the envs will be reset as well. + :return: + """ + buffer_size = self.training_config.buffer_size + train_envs = envs.train_envs + buffer: ReplayBuffer + if len(train_envs) > 1: + buffer = VectorReplayBuffer( + buffer_size, + len(train_envs), + stack_num=self.training_config.replay_buffer_stack_num, + save_only_last_obs=self.training_config.replay_buffer_save_only_last_obs, + ignore_obs_next=self.training_config.replay_buffer_ignore_obs_next, + ) + else: + buffer = ReplayBuffer( + buffer_size, + stack_num=self.training_config.replay_buffer_stack_num, + save_only_last_obs=self.training_config.replay_buffer_save_only_last_obs, + ignore_obs_next=self.training_config.replay_buffer_ignore_obs_next, + ) + train_collector = Collector[CollectStats]( + policy, + train_envs, + buffer, + exploration_noise=True, + ) + test_collector = Collector[CollectStats](policy, envs.test_envs) + if reset_collectors: + train_collector.reset() + test_collector.reset() + return train_collector, test_collector + + def set_policy_wrapper_factory( + self, + policy_wrapper_factory: AlgorithmWrapperFactory | None, + ) -> None: + self.algorithm_wrapper_factory = policy_wrapper_factory + + def set_trainer_callbacks(self, callbacks: TrainerCallbacks) -> None: + self.trainer_callbacks = callbacks + + @staticmethod + def _create_policy_from_args( + constructor: type[TPolicy], params_dict: dict, policy_params: list[str], **kwargs: Any + ) -> TPolicy: + params = {p: params_dict.pop(p) for p in policy_params} + return constructor(**params, **kwargs) + + @abstractmethod + def _create_algorithm(self, envs: Environments, device: TDevice) -> Algorithm: + pass + + def create_algorithm(self, envs: Environments, device: TDevice) -> Algorithm: + algorithm = self._create_algorithm(envs, device) + if self.algorithm_wrapper_factory is not None: + algorithm = self.algorithm_wrapper_factory.create_wrapped_algorithm( + algorithm, + envs, + self.optim_factory, + device, + ) + return algorithm + + @abstractmethod + def create_trainer(self, world: World, policy_persistence: PolicyPersistence) -> Trainer: + pass + + +class OnPolicyAlgorithmFactory(AlgorithmFactory[OnPolicyTrainingConfig], ABC): + def create_trainer( + self, + world: World, + policy_persistence: PolicyPersistence, + ) -> OnPolicyTrainer: + training_config = self.training_config + callbacks = self.trainer_callbacks + context = TrainingContext(world.algorithm, world.envs, world.logger) + train_fn = ( + callbacks.epoch_train_callback.get_trainer_fn(context) + if callbacks.epoch_train_callback + else None + ) + test_fn = ( + callbacks.epoch_test_callback.get_trainer_fn(context) + if callbacks.epoch_test_callback + else None + ) + stop_fn = ( + callbacks.epoch_stop_callback.get_trainer_fn(context) + if callbacks.epoch_stop_callback + else None + ) + algorithm = cast(OnPolicyAlgorithm, world.algorithm) + assert world.train_collector is not None + return algorithm.create_trainer( + OnPolicyTrainerParams( + train_collector=world.train_collector, + test_collector=world.test_collector, + max_epochs=training_config.max_epochs, + epoch_num_steps=training_config.epoch_num_steps, + update_step_num_repetitions=training_config.update_step_num_repetitions, + test_step_num_episodes=training_config.test_step_num_episodes, + batch_size=training_config.batch_size, + collection_step_num_env_steps=training_config.collection_step_num_env_steps, + save_best_fn=policy_persistence.get_save_best_fn(world), + save_checkpoint_fn=policy_persistence.get_save_checkpoint_fn(world), + logger=world.logger, + test_in_train=training_config.test_in_train, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + verbose=False, + ) + ) + + +class OffPolicyAlgorithmFactory(AlgorithmFactory[OffPolicyTrainingConfig], ABC): + def create_trainer( + self, + world: World, + policy_persistence: PolicyPersistence, + ) -> OffPolicyTrainer: + training_config = self.training_config + callbacks = self.trainer_callbacks + context = TrainingContext(world.algorithm, world.envs, world.logger) + train_fn = ( + callbacks.epoch_train_callback.get_trainer_fn(context) + if callbacks.epoch_train_callback + else None + ) + test_fn = ( + callbacks.epoch_test_callback.get_trainer_fn(context) + if callbacks.epoch_test_callback + else None + ) + stop_fn = ( + callbacks.epoch_stop_callback.get_trainer_fn(context) + if callbacks.epoch_stop_callback + else None + ) + algorithm = cast(OffPolicyAlgorithm, world.algorithm) + assert world.train_collector is not None + return algorithm.create_trainer( + OffPolicyTrainerParams( + train_collector=world.train_collector, + test_collector=world.test_collector, + max_epochs=training_config.max_epochs, + epoch_num_steps=training_config.epoch_num_steps, + collection_step_num_env_steps=training_config.collection_step_num_env_steps, + test_step_num_episodes=training_config.test_step_num_episodes, + batch_size=training_config.batch_size, + save_best_fn=policy_persistence.get_save_best_fn(world), + logger=world.logger, + update_step_num_gradient_steps_per_sample=training_config.update_step_num_gradient_steps_per_sample, + test_in_train=training_config.test_in_train, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + verbose=False, + ) + ) + + +class ReinforceAlgorithmFactory(OnPolicyAlgorithmFactory): + def __init__( + self, + params: ReinforceParams, + training_config: OnPolicyTrainingConfig, + actor_factory: ActorFactory, + optim_factory: OptimizerFactoryFactory, + ): + super().__init__(training_config, optim_factory) + self.params = params + self.actor_factory = actor_factory + self.optim_factory = optim_factory + + def _create_algorithm(self, envs: Environments, device: TDevice) -> Reinforce: + actor = self.actor_factory.create_module(envs, device) + kwargs = self.params.create_kwargs( + ParamTransformerData( + envs=envs, + device=device, + optim_factory_default=self.optim_factory, + ), + ) + dist_fn = self.actor_factory.create_dist_fn(envs) + assert dist_fn is not None + policy = self._create_policy_from_args( + ProbabilisticActorPolicy, + kwargs, + ["action_scaling", "action_bound_method", "deterministic_eval"], + actor=actor, + dist_fn=dist_fn, + action_space=envs.get_action_space(), + observation_space=envs.get_observation_space(), + ) + return Reinforce( + policy=policy, + **kwargs, + ) + + +class ActorCriticOnPolicyAlgorithmFactory( + OnPolicyAlgorithmFactory, + Generic[TActorCriticParams, TAlgorithm], +): + def __init__( + self, + params: TActorCriticParams, + training_config: OnPolicyTrainingConfig, + actor_factory: ActorFactory, + critic_factory: CriticFactory, + optimizer_factory: OptimizerFactoryFactory, + ): + super().__init__(training_config, optim_factory=optimizer_factory) + self.params = params + self.actor_factory = actor_factory + self.critic_factory = critic_factory + self.optim_factory = optimizer_factory + self.critic_use_action = False + + @abstractmethod + def _get_algorithm_class(self) -> type[TAlgorithm]: + pass + + @typing.no_type_check + def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]: + actor = self.actor_factory.create_module(envs, device) + critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action) + kwargs = self.params.create_kwargs( + ParamTransformerData( + envs=envs, + device=device, + optim_factory_default=self.optim_factory, + ), + ) + kwargs["actor"] = actor + kwargs["critic"] = critic + kwargs["action_space"] = envs.get_action_space() + kwargs["observation_space"] = envs.get_observation_space() + kwargs["dist_fn"] = self.actor_factory.create_dist_fn(envs) + return kwargs + + def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: + params = self._create_kwargs(envs, device) + policy = self._create_policy_from_args( + ProbabilisticActorPolicy, + params, + [ + "actor", + "dist_fn", + "action_space", + "deterministic_eval", + "observation_space", + "action_scaling", + "action_bound_method", + ], + ) + algorithm_class = self._get_algorithm_class() + return algorithm_class(policy=policy, **params) + + +class A2CAlgorithmFactory(ActorCriticOnPolicyAlgorithmFactory[A2CParams, A2C]): + def _get_algorithm_class(self) -> type[A2C]: + return A2C + + +class PPOAlgorithmFactory(ActorCriticOnPolicyAlgorithmFactory[PPOParams, PPO]): + def _get_algorithm_class(self) -> type[PPO]: + return PPO + + +class NPGAlgorithmFactory(ActorCriticOnPolicyAlgorithmFactory[NPGParams, NPG]): + def _get_algorithm_class(self) -> type[NPG]: + return NPG + + +class TRPOAlgorithmFactory(ActorCriticOnPolicyAlgorithmFactory[TRPOParams, TRPO]): + def _get_algorithm_class(self) -> type[TRPO]: + return TRPO + + +class DiscreteCriticOnlyOffPolicyAlgorithmFactory( + OffPolicyAlgorithmFactory, + Generic[TDiscreteCriticOnlyParams, TAlgorithm], +): + def __init__( + self, + params: TDiscreteCriticOnlyParams, + training_config: OffPolicyTrainingConfig, + model_factory: ModuleFactory, + optim_factory: OptimizerFactoryFactory, + ): + super().__init__(training_config, optim_factory) + self.params = params + self.model_factory = model_factory + self.optim_factory = optim_factory + + @abstractmethod + def _get_algorithm_class(self) -> type[TAlgorithm]: + pass + + @abstractmethod + def _create_policy( + self, + model: torch.nn.Module, + params: dict, + action_space: gymnasium.spaces.Discrete, + observation_space: gymnasium.spaces.Space, + ) -> Policy: + pass + + @typing.no_type_check + def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: + model = self.model_factory.create_module(envs, device) + params_dict = self.params.create_kwargs( + ParamTransformerData( + envs=envs, + device=device, + optim_factory_default=self.optim_factory, + ), + ) + envs.get_type().assert_discrete(self) + action_space = cast(gymnasium.spaces.Discrete, envs.get_action_space()) + policy = self._create_policy(model, params_dict, action_space, envs.get_observation_space()) + algorithm_class = self._get_algorithm_class() + return algorithm_class( + policy=policy, + **params_dict, + ) + + +class DQNAlgorithmFactory(DiscreteCriticOnlyOffPolicyAlgorithmFactory[DQNParams, DQN]): + def _create_policy( + self, + model: torch.nn.Module, + params: dict, + action_space: gymnasium.spaces.Discrete, + observation_space: gymnasium.spaces.Space, + ) -> Policy: + return self._create_policy_from_args( + constructor=DiscreteQLearningPolicy, + params_dict=params, + policy_params=["eps_training", "eps_inference"], + model=model, + action_space=action_space, + observation_space=observation_space, + ) + + def _get_algorithm_class(self) -> type[DQN]: + return DQN + + +class IQNAlgorithmFactory(DiscreteCriticOnlyOffPolicyAlgorithmFactory[IQNParams, IQN]): + def _create_policy( + self, + model: torch.nn.Module, + params: dict, + action_space: gymnasium.spaces.Discrete, + observation_space: gymnasium.spaces.Space, + ) -> Policy: + pass + return self._create_policy_from_args( + IQNPolicy, + params, + [ + "sample_size", + "online_sample_size", + "target_sample_size", + "eps_training", + "eps_inference", + ], + model=model, + action_space=action_space, + observation_space=observation_space, + ) + + def _get_algorithm_class(self) -> type[IQN]: + return IQN + + +class DDPGAlgorithmFactory(OffPolicyAlgorithmFactory): + def __init__( + self, + params: DDPGParams, + training_config: OffPolicyTrainingConfig, + actor_factory: ActorFactory, + critic_factory: CriticFactory, + optim_factory: OptimizerFactoryFactory, + ): + super().__init__(training_config, optim_factory) + self.critic_factory = critic_factory + self.actor_factory = actor_factory + self.params = params + self.optim_factory = optim_factory + + def _create_algorithm(self, envs: Environments, device: TDevice) -> Algorithm: + actor = self.actor_factory.create_module(envs, device) + critic = self.critic_factory.create_module( + envs, + device, + True, + ) + kwargs = self.params.create_kwargs( + ParamTransformerData( + envs=envs, + device=device, + optim_factory_default=self.optim_factory, + ), + ) + policy = self._create_policy_from_args( + ContinuousDeterministicPolicy, + kwargs, + ["exploration_noise", "action_scaling", "action_bound_method"], + actor=actor, + action_space=envs.get_action_space(), + observation_space=envs.get_observation_space(), + ) + return DDPG( + policy=policy, + critic=critic, + **kwargs, + ) + + +class REDQAlgorithmFactory(OffPolicyAlgorithmFactory): + def __init__( + self, + params: REDQParams, + training_config: OffPolicyTrainingConfig, + actor_factory: ActorFactory, + critic_ensemble_factory: CriticEnsembleFactory, + optim_factory: OptimizerFactoryFactory, + ): + super().__init__(training_config, optim_factory) + self.critic_ensemble_factory = critic_ensemble_factory + self.actor_factory = actor_factory + self.params = params + self.optim_factory = optim_factory + + def _create_algorithm(self, envs: Environments, device: TDevice) -> Algorithm: + envs.get_type().assert_continuous(self) + actor = self.actor_factory.create_module( + envs, + device, + ) + critic_ensemble = self.critic_ensemble_factory.create_module( + envs, + device, + self.params.ensemble_size, + True, + ) + kwargs = self.params.create_kwargs( + ParamTransformerData( + envs=envs, + device=device, + optim_factory_default=self.optim_factory, + ), + ) + action_space = cast(gymnasium.spaces.Box, envs.get_action_space()) + policy = self._create_policy_from_args( + REDQPolicy, + kwargs, + ["exploration_noise", "deterministic_eval", "action_scaling", "action_bound_method"], + actor=actor, + action_space=action_space, + observation_space=envs.get_observation_space(), + ) + return REDQ( + policy=policy, + critic=critic_ensemble, + **kwargs, + ) + + +class ActorDualCriticsOffPolicyAlgorithmFactory( + OffPolicyAlgorithmFactory, + Generic[TActorDualCriticsParams, TAlgorithm, TPolicy], +): + def __init__( + self, + params: TActorDualCriticsParams, + training_config: OffPolicyTrainingConfig, + actor_factory: ActorFactory, + critic1_factory: CriticFactory, + critic2_factory: CriticFactory, + optim_factory: OptimizerFactoryFactory, + ): + super().__init__(training_config, optim_factory) + self.params = params + self.actor_factory = actor_factory + self.critic1_factory = critic1_factory + self.critic2_factory = critic2_factory + self.optim_factory = optim_factory + + @abstractmethod + def _get_algorithm_class(self) -> type[TAlgorithm]: + pass + + def _get_discrete_last_size_use_action_shape(self) -> bool: + return True + + @staticmethod + def _get_critic_use_action(envs: Environments) -> bool: + return envs.get_type().is_continuous() + + @abstractmethod + def _create_policy( + self, actor: torch.nn.Module | DiscreteActor, envs: Environments, params: dict + ) -> TPolicy: + pass + + @typing.no_type_check + def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: + actor = self.actor_factory.create_module(envs, device) + use_action_shape = self._get_discrete_last_size_use_action_shape() + critic_use_action = self._get_critic_use_action(envs) + critic1 = self.critic1_factory.create_module( + envs, + device, + use_action=critic_use_action, + discrete_last_size_use_action_shape=use_action_shape, + ) + critic2 = self.critic2_factory.create_module( + envs, + device, + use_action=critic_use_action, + discrete_last_size_use_action_shape=use_action_shape, + ) + kwargs = self.params.create_kwargs( + ParamTransformerData( + envs=envs, + device=device, + optim_factory_default=self.optim_factory, + ), + ) + policy = self._create_policy(actor, envs, kwargs) + algorithm_class = self._get_algorithm_class() + return algorithm_class( + policy=policy, + critic=critic1, + critic2=critic2, + **kwargs, + ) + + +class SACAlgorithmFactory(ActorDualCriticsOffPolicyAlgorithmFactory[SACParams, SAC, SACPolicy]): + def _create_policy( + self, actor: torch.nn.Module | DiscreteActor, envs: Environments, params: dict + ) -> SACPolicy: + return self._create_policy_from_args( + SACPolicy, + params, + ["exploration_noise", "deterministic_eval", "action_scaling"], + actor=actor, + action_space=envs.get_action_space(), + observation_space=envs.get_observation_space(), + ) + + def _get_algorithm_class(self) -> type[SAC]: + return SAC + + +class DiscreteSACAlgorithmFactory( + ActorDualCriticsOffPolicyAlgorithmFactory[DiscreteSACParams, DiscreteSAC, DiscreteSACPolicy] +): + def _create_policy( + self, actor: torch.nn.Module | DiscreteActor, envs: Environments, params: dict + ) -> DiscreteSACPolicy: + return self._create_policy_from_args( + DiscreteSACPolicy, + params, + ["deterministic_eval"], + actor=actor, + action_space=envs.get_action_space(), + observation_space=envs.get_observation_space(), + ) + + def _get_algorithm_class(self) -> type[DiscreteSAC]: + return DiscreteSAC + + +class TD3AlgorithmFactory( + ActorDualCriticsOffPolicyAlgorithmFactory[TD3Params, TD3, ContinuousDeterministicPolicy] +): + def _create_policy( + self, actor: torch.nn.Module | DiscreteActor, envs: Environments, params: dict + ) -> ContinuousDeterministicPolicy: + return self._create_policy_from_args( + ContinuousDeterministicPolicy, + params, + ["exploration_noise", "action_scaling", "action_bound_method"], + actor=actor, + action_space=envs.get_action_space(), + observation_space=envs.get_observation_space(), + ) + + def _get_algorithm_class(self) -> type[TD3]: + return TD3 diff --git a/tianshou/highlevel/config.py b/tianshou/highlevel/config.py index fb58c8a58..caa012a9a 100644 --- a/tianshou/highlevel/config.py +++ b/tianshou/highlevel/config.py @@ -7,50 +7,42 @@ log = logging.getLogger(__name__) -@dataclass -class SamplingConfig(ToStringMixin): - """Configuration of sampling, epochs, parallelization, buffers, collectors, and batching.""" +@dataclass(kw_only=True) +class TrainingConfig(ToStringMixin): + """Training configuration.""" - num_epochs: int = 100 + max_epochs: int = 100 """ - the number of epochs to run training for. An epoch is the outermost iteration level and each - epoch consists of a number of training steps and a test step, where each training step + the (maximum) number of epochs to run training for. An **epoch** is the outermost iteration level and each + epoch consists of a number of training steps and one test step, where each training step - * collects environment steps/transitions (collection step), adding them to the (replay) - buffer (see :attr:`step_per_collect`) - * performs one or more gradient updates (see :attr:`update_per_step`), + * [for the online case] collects environment steps/transitions (**collection step**), + adding them to the (replay) buffer (see :attr:`collection_step_num_env_steps` and :attr:`collection_step_num_episodes`) + * performs an **update step** via the RL algorithm being used, which can involve + one or more actual gradient updates, depending on the algorithm and the test step collects :attr:`num_episodes_per_test` test episodes in order to evaluate agent performance. - The number of training steps in each epoch is indirectly determined by - :attr:`step_per_epoch`: As many training steps will be performed as are required in - order to reach :attr:`step_per_epoch` total steps in the training environments. + Training may be stopped early if the stop criterion is met (see :attr:`stop_fn`). + + For online training, the number of training steps in each epoch is indirectly determined by + :attr:`epoch_num_steps`: As many training steps will be performed as are required in + order to reach :attr:`epoch_num_steps` total steps in the training environments. Specifically, if the number of transitions collected per step is `c` (see - :attr:`step_per_collect`) and :attr:`step_per_epoch` is set to `s`, then the number + :attr:`collection_step_num_env_steps`) and :attr:`epoch_num_steps` is set to `s`, then the number of training steps per epoch is `ceil(s / c)`. - - Therefore, if `num_epochs = e`, the total number of environment steps taken during training + Therefore, if `max_epochs = e`, the total number of environment steps taken during training can be computed as `e * ceil(s / c) * c`. - """ - step_per_epoch: int = 30000 - """ - the total number of environment steps to be made per epoch. See :attr:`num_epochs` for - an explanation of epoch semantics. + For offline training, the number of training steps per epoch is equal to :attr:`epoch_num_steps`. """ - batch_size: int | None = 64 - """for off-policy algorithms, this is the number of environment steps/transitions to sample - from the buffer for a gradient update; for on-policy algorithms, its use is algorithm-specific. - On-policy algorithms use the full buffer that was collected in the preceding collection step - but they may use this parameter to perform the gradient update using mini-batches of this size - (causing the gradient to be less accurate, a form of regularization). - - ``batch_size=None`` means that the full buffer is used for the gradient update. This doesn't - make much sense for off-policy algorithms and is not recommended then. For on-policy or offline algorithms, - this means that the full buffer is used for the gradient update (no mini-batching), and - may make sense in some cases. + epoch_num_steps: int = 30000 + """ + For an online algorithm, this is the total number of environment steps to be collected per epoch, and, + for an offline algorithm, it is the total number of training steps to take per epoch. + See :attr:`max_epochs` for an explanation of epoch semantics. """ num_train_envs: int = -1 @@ -59,7 +51,7 @@ class SamplingConfig(ToStringMixin): num_test_envs: int = 1 """the number of test environments to use""" - num_test_episodes: int = 1 + test_step_num_episodes: int = 1 """the total number of episodes to collect in each test step (across all test environments). """ @@ -67,12 +59,12 @@ class SamplingConfig(ToStringMixin): """the total size of the sample/replay buffer, in which environment steps (transitions) are stored""" - step_per_collect: int | None = 2048 + collection_step_num_env_steps: int | None = 2048 """ the number of environment steps/transitions to collect in each collection step before the network update within each training step. - This is mutually exclusive with :attr:`episode_per_collect`, and one of the two must be set. + This is mutually exclusive with :attr:`collection_step_num_episodes`, and one of the two must be set. Note that the exact number can be reached only if this is a multiple of the number of training environments being used, as each training environment will produce the same @@ -80,40 +72,17 @@ class SamplingConfig(ToStringMixin): Specifically, if this is set to `n` and `m` training environments are used, then the total number of transitions collected per collection step is `ceil(n / m) * m =: c`. - See :attr:`num_epochs` for information on the total number of environment steps being + See :attr:`max_epochs` for information on the total number of environment steps being collected during training. """ - episode_per_collect: int | None = None + collection_step_num_episodes: int | None = None """ the number of episodes to collect in each collection step before the network update within each training step. If this is set, the number of environment steps collected in each collection step is the sum of the lengths of the episodes collected. - This is mutually exclusive with :attr:`step_per_collect`, and one of the two must be set. - """ - - repeat_per_collect: int | None = 1 - """ - controls, within one gradient update step of an on-policy algorithm, the number of times an - actual gradient update is applied using the full collected dataset, i.e. if the parameter is - 5, then the collected data shall be used five times to update the policy within the same - training step. - - The parameter is ignored and may be set to None for off-policy and offline algorithms. - """ - - update_per_step: float = 1.0 - """ - for off-policy algorithms only: the number of gradient steps to perform per sample - collected (see :attr:`step_per_collect`). - Specifically, if this is set to `u` and the number of samples collected in the preceding - collection step is `n`, then `round(u * n)` gradient steps will be performed. - - Note that for on-policy algorithms, only a single gradient update is usually performed, - because thereafter, the samples no longer reflect the behavior of the updated policy. - To change the number of gradient updates for an on-policy algorithm, use parameter - :attr:`repeat_per_collect` instead. + This is mutually exclusive with :attr:`collection_step_num_env_steps`, and one of the two must be set. """ start_timesteps: int = 0 @@ -166,20 +135,106 @@ def __post_init__(self) -> None: if self.num_train_envs == -1: self.num_train_envs = multiprocessing.cpu_count() - if self.num_test_episodes == 0 and self.num_test_envs != 0: + if self.test_step_num_episodes == 0 and self.num_test_envs != 0: log.warning( f"Number of test episodes is set to 0, " f"but number of test environments is ({self.num_test_envs}). " f"This can cause unnecessary memory usage.", ) - if self.num_test_episodes != 0 and self.num_test_episodes % self.num_test_envs != 0: + if ( + self.test_step_num_episodes != 0 + and self.test_step_num_episodes % self.num_test_envs != 0 + ): log.warning( - f"Number of test episodes ({self.num_test_episodes} " + f"Number of test episodes ({self.test_step_num_episodes} " f"is not divisible by the number of test environments ({self.num_test_envs}). " f"This can cause unnecessary memory usage, it is recommended to adjust this.", ) assert ( - sum([self.step_per_collect is not None, self.episode_per_collect is not None]) == 1 - ), ("Only one of `step_per_collect` and `episode_per_collect` can be set.",) + sum( + [ + self.collection_step_num_env_steps is not None, + self.collection_step_num_episodes is not None, + ] + ) + == 1 + ), ( + "Only one of `collection_step_num_env_steps` and `collection_step_num_episodes` can be set.", + ) + + +@dataclass(kw_only=True) +class OnlineTrainingConfig(TrainingConfig): + collection_step_num_env_steps: int | None = 2048 + """ + the number of environment steps/transitions to collect in each collection step before the + network update within each training step. + + This is mutually exclusive with :attr:`collection_step_num_episodes`, and one of the two must be set. + + Note that the exact number can be reached only if this is a multiple of the number of + training environments being used, as each training environment will produce the same + (non-zero) number of transitions. + Specifically, if this is set to `n` and `m` training environments are used, then the total + number of transitions collected per collection step is `ceil(n / m) * m =: c`. + + See :attr:`max_epochs` for information on the total number of environment steps being + collected during training. + """ + + collection_step_num_episodes: int | None = None + """ + the number of episodes to collect in each collection step before the network update within + each training step. If this is set, the number of environment steps collected in each + collection step is the sum of the lengths of the episodes collected. + + This is mutually exclusive with :attr:`collection_step_num_env_steps`, and one of the two must be set. + """ + + test_in_train: bool = False + """ + Whether to apply a test step within a training step depending on the early stopping criterion + (see :meth:`~tianshou.highlevel.Experiment.with_epoch_stop_callback`) being satisfied based + on the data collected within the training step. + Specifically, after each collect step, we check whether the early stopping criterion + would be satisfied by data we collected (provided that at least one episode was indeed completed, such + that we can evaluate returns, etc.). If the criterion is satisfied, we perform a full test step + (collecting :attr:`test_step_num_episodes` episodes in order to evaluate performance), and if the early + stopping criterion is also satisfied based on the test data, we stop training early. + """ + + +@dataclass(kw_only=True) +class OnPolicyTrainingConfig(OnlineTrainingConfig): + batch_size: int | None = 64 + """ + Use mini-batches of this size for gradient updates (causing the gradient to be less accurate, + a form of regularization). + Set ``batch_size=None`` for the full buffer that was collected within the training step to be + used for the gradient update (no mini-batching). + """ + + update_step_num_repetitions: int = 1 + """ + controls, within one update step of an on-policy algorithm, the number of times + the full collected data is applied for gradient updates, i.e. if the parameter is + 5, then the collected data shall be used five times to update the policy within the same + update step. + """ + + +@dataclass(kw_only=True) +class OffPolicyTrainingConfig(OnlineTrainingConfig): + batch_size: int = 64 + """ + the the number of environment steps/transitions to sample from the buffer for a gradient update. + """ + + update_step_num_gradient_steps_per_sample: float = 1.0 + """ + the number of gradient steps to perform per sample collected (see :attr:`collection_step_num_env_steps`). + Specifically, if this is set to `u` and the number of samples collected in the preceding + collection step is `n`, then `round(u * n)` gradient steps will be performed. + """ diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 4df648fa9..2828f818b 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -25,7 +25,7 @@ from copy import deepcopy from dataclasses import asdict, dataclass from pprint import pformat -from typing import TYPE_CHECKING, Any, Self, Union, cast +from typing import TYPE_CHECKING, Any, Generic, Self, Union, cast if TYPE_CHECKING: from tianshou.evaluation.launcher import ExpLauncher, RegisteredExpLauncher @@ -36,25 +36,30 @@ from sensai.util.logging import datetime_tag from sensai.util.string import ToStringMixin +from tianshou.algorithm import Algorithm from tianshou.data import BaseCollector, Collector, CollectStats, InfoStats from tianshou.env import BaseVectorEnv -from tianshou.highlevel.agent import ( - A2CAgentFactory, - AgentFactory, - DDPGAgentFactory, - DiscreteSACAgentFactory, - DQNAgentFactory, - IQNAgentFactory, - NPGAgentFactory, - PGAgentFactory, - PPOAgentFactory, - RandomActionAgentFactory, - REDQAgentFactory, - SACAgentFactory, - TD3AgentFactory, - TRPOAgentFactory, +from tianshou.highlevel.algorithm import ( + A2CAlgorithmFactory, + AlgorithmFactory, + DDPGAlgorithmFactory, + DiscreteSACAlgorithmFactory, + DQNAlgorithmFactory, + IQNAlgorithmFactory, + NPGAlgorithmFactory, + PPOAlgorithmFactory, + REDQAlgorithmFactory, + ReinforceAlgorithmFactory, + SACAlgorithmFactory, + TD3AlgorithmFactory, + TRPOAlgorithmFactory, + TTrainingConfig, +) +from tianshou.highlevel.config import ( + OffPolicyTrainingConfig, + OnPolicyTrainingConfig, + TrainingConfig, ) -from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.env import EnvFactory from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger from tianshou.highlevel.module.actor import ( @@ -78,25 +83,25 @@ ) from tianshou.highlevel.module.intermediate import IntermediateModuleFactory from tianshou.highlevel.module.special import ImplicitQuantileNetworkFactory -from tianshou.highlevel.optim import ( - OptimizerFactory, - OptimizerFactoryAdam, -) -from tianshou.highlevel.params.policy_params import ( +from tianshou.highlevel.params.algorithm_params import ( A2CParams, DDPGParams, DiscreteSACParams, DQNParams, IQNParams, NPGParams, - PGParams, PPOParams, REDQParams, + ReinforceParams, SACParams, TD3Params, TRPOParams, ) -from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory +from tianshou.highlevel.params.algorithm_wrapper import AlgorithmWrapperFactory +from tianshou.highlevel.params.optim import ( + OptimizerFactoryFactory, + OptimizerFactoryFactoryAdam, +) from tianshou.highlevel.persistence import ( PersistenceGroup, PolicyPersistence, @@ -108,7 +113,6 @@ TrainerCallbacks, ) from tianshou.highlevel.world import World -from tianshou.policy import BasePolicy from tianshou.utils import LazyLogger from tianshou.utils.net.common import ModuleType from tianshou.utils.print import DataclassPPrintMixin @@ -185,17 +189,17 @@ def __init__( self, config: ExperimentConfig, env_factory: EnvFactory, - agent_factory: AgentFactory, - sampling_config: SamplingConfig, + algorithm_factory: AlgorithmFactory, + training_config: TrainingConfig, name: str, logger_factory: LoggerFactory | None = None, ): if logger_factory is None: logger_factory = LoggerFactoryDefault() self.config = config - self.sampling_config = sampling_config + self.training_config = training_config self.env_factory = env_factory - self.agent_factory = agent_factory + self.algorithm_factory = algorithm_factory self.logger_factory = logger_factory self.name = name @@ -289,8 +293,8 @@ def create_experiment_world( # create environments envs = self.env_factory.create_envs( - self.sampling_config.num_train_envs, - self.sampling_config.num_test_envs, + self.training_config.num_train_envs, + self.training_config.num_test_envs, create_watch_env=self.config.watch, seed=self.config.seed, ) @@ -311,9 +315,9 @@ def create_experiment_world( full_config = self._build_config_dict() full_config.update(envs.info()) full_config["experiment_config"] = asdict(self.config) - full_config["sampling_config"] = asdict(self.sampling_config) + full_config["training_config_config"] = asdict(self.training_config) with suppress(AttributeError): - full_config["policy_params"] = asdict(self.agent_factory.params) + full_config["policy_params"] = asdict(self.algorithm_factory.params) logger: TLogger if use_persistence: @@ -328,13 +332,16 @@ def create_experiment_world( # create policy and collectors log.info("Creating policy") - policy = self.agent_factory.create_policy(envs, self.config.device) + policy = self.algorithm_factory.create_algorithm(envs, self.config.device) log.info("Creating collectors") train_collector: BaseCollector | None = None test_collector: BaseCollector | None = None if self.config.train: - train_collector, test_collector = self.agent_factory.create_train_test_collector( + ( + train_collector, + test_collector, + ) = self.algorithm_factory.create_train_test_collector( policy, envs, reset_collectors=reset_collectors, @@ -343,7 +350,7 @@ def create_experiment_world( # create context object with all relevant instances (except trainer; added later) world = World( envs=envs, - policy=policy, + algorithm=policy, train_collector=train_collector, test_collector=test_collector, logger=logger, @@ -360,7 +367,7 @@ def create_experiment_world( ) if self.config.train: - trainer = self.agent_factory.create_trainer(world, policy_persistence) + trainer = self.algorithm_factory.create_trainer(world, policy_persistence) world.trainer = trainer return world @@ -419,14 +426,14 @@ def run( assert world.test_collector is not None # prefilling buffers with either random or current agent's actions - if self.sampling_config.start_timesteps > 0: + if self.training_config.start_timesteps > 0: log.info( - f"Collecting {self.sampling_config.start_timesteps} initial environment " - f"steps before training (random={self.sampling_config.start_timesteps_random})", + f"Collecting {self.training_config.start_timesteps} initial environment " + f"steps before training (random={self.training_config.start_timesteps_random})", ) world.train_collector.collect( - n_step=self.sampling_config.start_timesteps, - random=self.sampling_config.start_timesteps_random, + n_step=self.training_config.start_timesteps, + random=self.training_config.start_timesteps_random, ) log.info("Starting training") @@ -441,7 +448,7 @@ def run( log.info("Watching agent performance") self._watch_agent( self.config.watch_num_episodes, - world.policy, + world.algorithm, world.envs.watch_env, self.config.watch_render, ) @@ -451,7 +458,7 @@ def run( @staticmethod def _watch_agent( num_episodes: int, - policy: BasePolicy, + policy: Algorithm, env: BaseVectorEnv, render: float, ) -> None: @@ -482,7 +489,7 @@ def run( return launcher.launch(experiments=self.experiments) -class ExperimentBuilder(ABC): +class ExperimentBuilder(ABC, Generic[TTrainingConfig]): """A helper class (following the builder pattern) for creating experiments. It contains a lot of defaults for the setup which can be adjusted using the @@ -495,28 +502,31 @@ def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, - sampling_config: SamplingConfig | None = None, + training_config: TTrainingConfig | None = None, ): """:param env_factory: controls how environments are to be created. :param experiment_config: the configuration for the experiment. If None, will use the default values of `ExperimentConfig`. - :param sampling_config: the sampling configuration to use. If None, will use the default values - of `SamplingConfig`. + :param training_config: the training configuration to use. If None, use default values (not recommended). """ if experiment_config is None: experiment_config = ExperimentConfig() - if sampling_config is None: - sampling_config = SamplingConfig() + if training_config is None: + training_config = self._create_training_config() self._config = experiment_config self._env_factory = env_factory - self._sampling_config = sampling_config + self._training_config = training_config self._logger_factory: LoggerFactory | None = None - self._optim_factory: OptimizerFactory | None = None - self._policy_wrapper_factory: PolicyWrapperFactory | None = None + self._optim_factory: OptimizerFactoryFactory | None = None + self._algorithm_wrapper_factory: AlgorithmWrapperFactory | None = None self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks() self._name: str = self.__class__.__name__.replace("Builder", "") + "_" + datetime_tag() + @abstractmethod + def _create_training_config(self) -> TTrainingConfig: + pass + def copy(self) -> Self: return deepcopy(self) @@ -529,12 +539,12 @@ def experiment_config(self, experiment_config: ExperimentConfig) -> None: self._config = experiment_config @property - def sampling_config(self) -> SamplingConfig: - return self._sampling_config + def training_config(self) -> TrainingConfig: + return self._training_config - @sampling_config.setter - def sampling_config(self, sampling_config: SamplingConfig) -> None: - self._sampling_config = sampling_config + @training_config.setter + def training_config(self, config: TrainingConfig) -> None: + self._training_config = config def with_logger_factory(self, logger_factory: LoggerFactory) -> Self: """Allows to customize the logger factory to use. @@ -547,19 +557,24 @@ def with_logger_factory(self, logger_factory: LoggerFactory) -> Self: self._logger_factory = logger_factory return self - def with_policy_wrapper_factory(self, policy_wrapper_factory: PolicyWrapperFactory) -> Self: - """Allows to define a wrapper around the policy that is created, extending the original policy. + def with_algorithm_wrapper_factory( + self, algorithm_wrapper_factory: AlgorithmWrapperFactory + ) -> Self: + """Allows to define a wrapper around the algorithm that is created, extending the original algorithm. - :param policy_wrapper_factory: the factory for the wrapper + :param algorithm_wrapper_factory: the factory for the wrapper :return: the builder """ - self._policy_wrapper_factory = policy_wrapper_factory + self._algorithm_wrapper_factory = algorithm_wrapper_factory return self - def with_optim_factory(self, optim_factory: OptimizerFactory) -> Self: - """Allows to customize the gradient-based optimizer to use. + def with_optim_default(self, optim_factory: OptimizerFactoryFactory) -> Self: + """Allows to customize the default optimizer to use. - By default, :class:`OptimizerFactoryAdam` will be used with default parameters. + The default optimizer applies when optimizer factory factories are set to None + in algorithm parameter objects. + + By default, :class:`OptimizerFactoryFactoryAdam` will be used with default parameters. :param optim_factory: the optimizer factory :return: the builder @@ -567,23 +582,6 @@ def with_optim_factory(self, optim_factory: OptimizerFactory) -> Self: self._optim_factory = optim_factory return self - def with_optim_factory_default( - self, - # Keep values in sync with default values in OptimizerFactoryAdam - betas: tuple[float, float] = (0.9, 0.999), - eps: float = 1e-08, - weight_decay: float = 0, - ) -> Self: - """Configures the use of the default optimizer, Adam, with the given parameters. - - :param betas: coefficients used for computing running averages of gradient and its square - :param eps: term added to the denominator to improve numerical stability - :param weight_decay: weight decay (L2 penalty) - :return: the builder - """ - self._optim_factory = OptimizerFactoryAdam(betas=betas, eps=eps, weight_decay=weight_decay) - return self - def with_epoch_train_callback(self, callback: EpochTrainCallback) -> Self: """Allows to define a callback function which is called at the beginning of every epoch during training. @@ -627,13 +625,12 @@ def with_name( return self @abstractmethod - def _create_agent_factory(self) -> AgentFactory: + def _create_algorithm_factory(self) -> AlgorithmFactory: pass - def _get_optim_factory(self) -> OptimizerFactory: + def _get_optim_factory(self) -> OptimizerFactoryFactory: if self._optim_factory is None: - # same mechanism as in `with_optim_factory_default` - return OptimizerFactoryAdam() + return OptimizerFactoryFactoryAdam() else: return self._optim_factory @@ -642,15 +639,15 @@ def build(self) -> Experiment: :return: the experiment """ - agent_factory = self._create_agent_factory() - agent_factory.set_trainer_callbacks(self._trainer_callbacks) - if self._policy_wrapper_factory: - agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory) + algorithm_factory = self._create_algorithm_factory() + algorithm_factory.set_trainer_callbacks(self._trainer_callbacks) + if self._algorithm_wrapper_factory: + algorithm_factory.set_policy_wrapper_factory(self._algorithm_wrapper_factory) experiment: Experiment = Experiment( config=self._config, env_factory=self._env_factory, - agent_factory=agent_factory, - sampling_config=self._sampling_config, + algorithm_factory=algorithm_factory, + training_config=self._training_config, name=self._name, logger_factory=self._logger_factory, ) @@ -677,12 +674,42 @@ def build_seeded_collection(self, num_experiments: int) -> ExperimentCollection: return ExperimentCollection(seeded_experiments) -class RandomActionExperimentBuilder(ExperimentBuilder): - def _create_agent_factory(self) -> RandomActionAgentFactory: - return RandomActionAgentFactory( - sampling_config=self.sampling_config, - optim_factory=self._get_optim_factory(), - ) +class OnPolicyExperimentBuilder(ExperimentBuilder[OnPolicyTrainingConfig], ABC): + def __init__( + self, + env_factory: EnvFactory, + experiment_config: ExperimentConfig | None = None, + training_config: OnPolicyTrainingConfig | None = None, + ): + """ + :param env_factory: controls how environments are to be created. + :param experiment_config: the configuration for the experiment. If None, will use the default values + of :class:`ExperimentConfig`. + :param training_config: the training configuration to use. If None, use default values (not recommended). + """ + super().__init__(env_factory, experiment_config, training_config) + + def _create_training_config(self) -> OnPolicyTrainingConfig: + return OnPolicyTrainingConfig() + + +class OffPolicyExperimentBuilder(ExperimentBuilder[OffPolicyTrainingConfig], ABC): + def __init__( + self, + env_factory: EnvFactory, + experiment_config: ExperimentConfig | None = None, + training_config: OffPolicyTrainingConfig | None = None, + ): + """ + :param env_factory: controls how environments are to be created. + :param experiment_config: the configuration for the experiment. If None, will use the default values + of :class:`ExperimentConfig`. + :param training_config: the training configuration to use. If None, use default values (not recommended). + """ + super().__init__(env_factory, experiment_config, training_config) + + def _create_training_config(self) -> OffPolicyTrainingConfig: + return OffPolicyTrainingConfig() class _BuilderMixinActorFactory(ActorFutureProviderProtocol): @@ -1011,36 +1038,36 @@ def _get_critic_ensemble_factory(self) -> CriticEnsembleFactory: return self.critic_ensemble_factory -class PGExperimentBuilder( - ExperimentBuilder, +class ReinforceExperimentBuilder( + OnPolicyExperimentBuilder, _BuilderMixinActorFactory_ContinuousGaussian, ): def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, - sampling_config: SamplingConfig | None = None, + training_config: OnPolicyTrainingConfig | None = None, ): - super().__init__(env_factory, experiment_config, sampling_config) + super().__init__(env_factory, experiment_config, training_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) - self._params: PGParams = PGParams() + self._params: ReinforceParams = ReinforceParams() self._env_config = None - def with_pg_params(self, params: PGParams) -> Self: + def with_reinforce_params(self, params: ReinforceParams) -> Self: self._params = params return self - def _create_agent_factory(self) -> AgentFactory: - return PGAgentFactory( + def _create_algorithm_factory(self) -> AlgorithmFactory: + return ReinforceAlgorithmFactory( self._params, - self._sampling_config, + self._training_config, self._get_actor_factory(), self._get_optim_factory(), ) class A2CExperimentBuilder( - ExperimentBuilder, + OnPolicyExperimentBuilder, _BuilderMixinActorFactory_ContinuousGaussian, _BuilderMixinSingleCriticCanUseActorFactory, ): @@ -1048,9 +1075,9 @@ def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, - sampling_config: SamplingConfig | None = None, + training_config: OnPolicyTrainingConfig | None = None, ): - super().__init__(env_factory, experiment_config, sampling_config) + super().__init__(env_factory, experiment_config, training_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self) self._params: A2CParams = A2CParams() @@ -1060,10 +1087,10 @@ def with_a2c_params(self, params: A2CParams) -> Self: self._params = params return self - def _create_agent_factory(self) -> AgentFactory: - return A2CAgentFactory( + def _create_algorithm_factory(self) -> AlgorithmFactory: + return A2CAlgorithmFactory( self._params, - self._sampling_config, + self._training_config, self._get_actor_factory(), self._get_critic_factory(0), self._get_optim_factory(), @@ -1071,7 +1098,7 @@ def _create_agent_factory(self) -> AgentFactory: class PPOExperimentBuilder( - ExperimentBuilder, + OnPolicyExperimentBuilder, _BuilderMixinActorFactory_ContinuousGaussian, _BuilderMixinSingleCriticCanUseActorFactory, ): @@ -1079,9 +1106,9 @@ def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, - sampling_config: SamplingConfig | None = None, + training_config: OnPolicyTrainingConfig | None = None, ): - super().__init__(env_factory, experiment_config, sampling_config) + super().__init__(env_factory, experiment_config, training_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self) self._params: PPOParams = PPOParams() @@ -1090,10 +1117,10 @@ def with_ppo_params(self, params: PPOParams) -> Self: self._params = params return self - def _create_agent_factory(self) -> AgentFactory: - return PPOAgentFactory( + def _create_algorithm_factory(self) -> AlgorithmFactory: + return PPOAlgorithmFactory( self._params, - self._sampling_config, + self._training_config, self._get_actor_factory(), self._get_critic_factory(0), self._get_optim_factory(), @@ -1101,7 +1128,7 @@ def _create_agent_factory(self) -> AgentFactory: class NPGExperimentBuilder( - ExperimentBuilder, + OnPolicyExperimentBuilder, _BuilderMixinActorFactory_ContinuousGaussian, _BuilderMixinSingleCriticCanUseActorFactory, ): @@ -1109,9 +1136,9 @@ def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, - sampling_config: SamplingConfig | None = None, + training_config: OnPolicyTrainingConfig | None = None, ): - super().__init__(env_factory, experiment_config, sampling_config) + super().__init__(env_factory, experiment_config, training_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self) self._params: NPGParams = NPGParams() @@ -1120,10 +1147,10 @@ def with_npg_params(self, params: NPGParams) -> Self: self._params = params return self - def _create_agent_factory(self) -> AgentFactory: - return NPGAgentFactory( + def _create_algorithm_factory(self) -> AlgorithmFactory: + return NPGAlgorithmFactory( self._params, - self._sampling_config, + self._training_config, self._get_actor_factory(), self._get_critic_factory(0), self._get_optim_factory(), @@ -1131,7 +1158,7 @@ def _create_agent_factory(self) -> AgentFactory: class TRPOExperimentBuilder( - ExperimentBuilder, + OnPolicyExperimentBuilder, _BuilderMixinActorFactory_ContinuousGaussian, _BuilderMixinSingleCriticCanUseActorFactory, ): @@ -1139,9 +1166,9 @@ def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, - sampling_config: SamplingConfig | None = None, + training_config: OnPolicyTrainingConfig | None = None, ): - super().__init__(env_factory, experiment_config, sampling_config) + super().__init__(env_factory, experiment_config, training_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self) self._params: TRPOParams = TRPOParams() @@ -1150,10 +1177,10 @@ def with_trpo_params(self, params: TRPOParams) -> Self: self._params = params return self - def _create_agent_factory(self) -> AgentFactory: - return TRPOAgentFactory( + def _create_algorithm_factory(self) -> AlgorithmFactory: + return TRPOAlgorithmFactory( self._params, - self._sampling_config, + self._training_config, self._get_actor_factory(), self._get_critic_factory(0), self._get_optim_factory(), @@ -1161,15 +1188,15 @@ def _create_agent_factory(self) -> AgentFactory: class DQNExperimentBuilder( - ExperimentBuilder, + OffPolicyExperimentBuilder, ): def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, - sampling_config: SamplingConfig | None = None, + training_config: OffPolicyTrainingConfig | None = None, ): - super().__init__(env_factory, experiment_config, sampling_config) + super().__init__(env_factory, experiment_config, training_config) self._params: DQNParams = DQNParams() self._model_factory: IntermediateModuleFactory = IntermediateModuleFactoryFromActorFactory( ActorFactoryDefault(ContinuousActorType.UNSUPPORTED, discrete_softmax=False), @@ -1208,23 +1235,23 @@ def with_model_factory_default( ) return self - def _create_agent_factory(self) -> AgentFactory: - return DQNAgentFactory( + def _create_algorithm_factory(self) -> AlgorithmFactory: + return DQNAlgorithmFactory( self._params, - self._sampling_config, + self._training_config, self._model_factory, self._get_optim_factory(), ) -class IQNExperimentBuilder(ExperimentBuilder): +class IQNExperimentBuilder(OffPolicyExperimentBuilder): def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, - sampling_config: SamplingConfig | None = None, + training_config: OffPolicyTrainingConfig | None = None, ): - super().__init__(env_factory, experiment_config, sampling_config) + super().__init__(env_factory, experiment_config, training_config) self._params: IQNParams = IQNParams() self._preprocess_network_factory: IntermediateModuleFactory = ( IntermediateModuleFactoryFromActorFactory( @@ -1240,22 +1267,22 @@ def with_preprocess_network_factory(self, module_factory: IntermediateModuleFact self._preprocess_network_factory = module_factory return self - def _create_agent_factory(self) -> AgentFactory: + def _create_algorithm_factory(self) -> AlgorithmFactory: model_factory = ImplicitQuantileNetworkFactory( self._preprocess_network_factory, hidden_sizes=self._params.hidden_sizes, num_cosines=self._params.num_cosines, ) - return IQNAgentFactory( + return IQNAlgorithmFactory( self._params, - self._sampling_config, + self._training_config, model_factory, self._get_optim_factory(), ) class DDPGExperimentBuilder( - ExperimentBuilder, + OffPolicyExperimentBuilder, _BuilderMixinActorFactory_ContinuousDeterministic, _BuilderMixinSingleCriticCanUseActorFactory, ): @@ -1263,9 +1290,9 @@ def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, - sampling_config: SamplingConfig | None = None, + training_config: OffPolicyTrainingConfig | None = None, ): - super().__init__(env_factory, experiment_config, sampling_config) + super().__init__(env_factory, experiment_config, training_config) _BuilderMixinActorFactory_ContinuousDeterministic.__init__(self) _BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self) self._params: DDPGParams = DDPGParams() @@ -1274,10 +1301,10 @@ def with_ddpg_params(self, params: DDPGParams) -> Self: self._params = params return self - def _create_agent_factory(self) -> AgentFactory: - return DDPGAgentFactory( + def _create_algorithm_factory(self) -> AlgorithmFactory: + return DDPGAlgorithmFactory( self._params, - self._sampling_config, + self._training_config, self._get_actor_factory(), self._get_critic_factory(0), self._get_optim_factory(), @@ -1285,7 +1312,7 @@ def _create_agent_factory(self) -> AgentFactory: class REDQExperimentBuilder( - ExperimentBuilder, + OffPolicyExperimentBuilder, _BuilderMixinActorFactory_ContinuousGaussian, _BuilderMixinCriticEnsembleFactory, ): @@ -1293,9 +1320,9 @@ def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, - sampling_config: SamplingConfig | None = None, + training_config: OffPolicyTrainingConfig | None = None, ): - super().__init__(env_factory, experiment_config, sampling_config) + super().__init__(env_factory, experiment_config, training_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinCriticEnsembleFactory.__init__(self) self._params: REDQParams = REDQParams() @@ -1304,10 +1331,10 @@ def with_redq_params(self, params: REDQParams) -> Self: self._params = params return self - def _create_agent_factory(self) -> AgentFactory: - return REDQAgentFactory( + def _create_algorithm_factory(self) -> AlgorithmFactory: + return REDQAlgorithmFactory( self._params, - self._sampling_config, + self._training_config, self._get_actor_factory(), self._get_critic_ensemble_factory(), self._get_optim_factory(), @@ -1315,7 +1342,7 @@ def _create_agent_factory(self) -> AgentFactory: class SACExperimentBuilder( - ExperimentBuilder, + OffPolicyExperimentBuilder, _BuilderMixinActorFactory_ContinuousGaussian, _BuilderMixinDualCriticFactory, ): @@ -1323,9 +1350,9 @@ def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, - sampling_config: SamplingConfig | None = None, + training_config: OffPolicyTrainingConfig | None = None, ): - super().__init__(env_factory, experiment_config, sampling_config) + super().__init__(env_factory, experiment_config, training_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinDualCriticFactory.__init__(self, self) self._params: SACParams = SACParams() @@ -1334,10 +1361,10 @@ def with_sac_params(self, params: SACParams) -> Self: self._params = params return self - def _create_agent_factory(self) -> AgentFactory: - return SACAgentFactory( + def _create_algorithm_factory(self) -> AlgorithmFactory: + return SACAlgorithmFactory( self._params, - self._sampling_config, + self._training_config, self._get_actor_factory(), self._get_critic_factory(0), self._get_critic_factory(1), @@ -1346,7 +1373,7 @@ def _create_agent_factory(self) -> AgentFactory: class DiscreteSACExperimentBuilder( - ExperimentBuilder, + OffPolicyExperimentBuilder, _BuilderMixinActorFactory_DiscreteOnly, _BuilderMixinDualCriticFactory, ): @@ -1354,9 +1381,9 @@ def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, - sampling_config: SamplingConfig | None = None, + training_config: OffPolicyTrainingConfig | None = None, ): - super().__init__(env_factory, experiment_config, sampling_config) + super().__init__(env_factory, experiment_config, training_config) _BuilderMixinActorFactory_DiscreteOnly.__init__(self) _BuilderMixinDualCriticFactory.__init__(self, self) self._params: DiscreteSACParams = DiscreteSACParams() @@ -1365,10 +1392,10 @@ def with_sac_params(self, params: DiscreteSACParams) -> Self: self._params = params return self - def _create_agent_factory(self) -> AgentFactory: - return DiscreteSACAgentFactory( + def _create_algorithm_factory(self) -> AlgorithmFactory: + return DiscreteSACAlgorithmFactory( self._params, - self._sampling_config, + self._training_config, self._get_actor_factory(), self._get_critic_factory(0), self._get_critic_factory(1), @@ -1377,7 +1404,7 @@ def _create_agent_factory(self) -> AgentFactory: class TD3ExperimentBuilder( - ExperimentBuilder, + OffPolicyExperimentBuilder, _BuilderMixinActorFactory_ContinuousDeterministic, _BuilderMixinDualCriticFactory, ): @@ -1385,9 +1412,9 @@ def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, - sampling_config: SamplingConfig | None = None, + training_config: OffPolicyTrainingConfig | None = None, ): - super().__init__(env_factory, experiment_config, sampling_config) + super().__init__(env_factory, experiment_config, training_config) _BuilderMixinActorFactory_ContinuousDeterministic.__init__(self) _BuilderMixinDualCriticFactory.__init__(self, self) self._params: TD3Params = TD3Params() @@ -1396,10 +1423,10 @@ def with_td3_params(self, params: TD3Params) -> Self: self._params = params return self - def _create_agent_factory(self) -> AgentFactory: - return TD3AgentFactory( + def _create_algorithm_factory(self) -> AlgorithmFactory: + return TD3AlgorithmFactory( self._params, - self._sampling_config, + self._training_config, self._get_actor_factory(), self._get_critic_factory(0), self._get_critic_factory(1), diff --git a/tianshou/highlevel/module/actor.py b/tianshou/highlevel/module/actor.py index ceb1262f7..83ae08878 100644 --- a/tianshou/highlevel/module/actor.py +++ b/tianshou/highlevel/module/actor.py @@ -8,6 +8,7 @@ from sensai.util.string import ToStringMixin from torch import nn +from tianshou.algorithm.modelfree.reinforce import TDistFnDiscrOrCont from tianshou.highlevel.env import Environments, EnvType from tianshou.highlevel.module.core import ( ModuleFactory, @@ -18,15 +19,17 @@ IntermediateModule, IntermediateModuleFactory, ) -from tianshou.highlevel.module.module_opt import ModuleOpt -from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.params.dist_fn import ( DistributionFunctionFactoryCategorical, DistributionFunctionFactoryIndependentGaussians, ) -from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont from tianshou.utils.net import continuous, discrete -from tianshou.utils.net.common import BaseActor, ModuleType, Net +from tianshou.utils.net.common import ( + Actor, + ModuleType, + ModuleWithVectorOutput, + Net, +) class ContinuousActorType(Enum): @@ -39,7 +42,7 @@ class ContinuousActorType(Enum): class ActorFuture: """Container, which, in the future, will hold an actor instance.""" - actor: BaseActor | nn.Module | None = None + actor: Actor | nn.Module | None = None class ActorFutureProviderProtocol(Protocol): @@ -49,7 +52,7 @@ def get_actor_future(self) -> ActorFuture: class ActorFactory(ModuleFactory, ToStringMixin, ABC): @abstractmethod - def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module: + def create_module(self, envs: Environments, device: TDevice) -> Actor | nn.Module: pass @abstractmethod @@ -60,25 +63,6 @@ def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None: if the actor does not output distribution parameters """ - def create_module_opt( - self, - envs: Environments, - device: TDevice, - optim_factory: OptimizerFactory, - lr: float, - ) -> ModuleOpt: - """Creates the actor module along with its optimizer for the given learning rate. - - :param envs: the environments - :param device: the torch device - :param optim_factory: the optimizer factory - :param lr: the learning rate - :return: a container with the actor module and its optimizer - """ - module = self.create_module(envs, device) - optim = optim_factory.create_optimizer(module, lr) - return ModuleOpt(module, optim) - @staticmethod def _init_linear(actor: torch.nn.Module) -> None: """Initializes linear layers of an actor module using default mechanisms. @@ -148,7 +132,7 @@ def _create_factory(self, envs: Environments) -> ActorFactory: raise ValueError(f"{env_type} not supported") return factory - def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module: + def create_module(self, envs: Environments, device: TDevice) -> Actor | nn.Module: factory = self._create_factory(envs) return factory.create_module(envs, device) @@ -166,18 +150,16 @@ def __init__(self, hidden_sizes: Sequence[int], activation: ModuleType = nn.ReLU self.hidden_sizes = hidden_sizes self.activation = activation - def create_module(self, envs: Environments, device: TDevice) -> BaseActor: + def create_module(self, envs: Environments, device: TDevice) -> Actor: net_a = Net( state_shape=envs.get_observation_shape(), hidden_sizes=self.hidden_sizes, activation=self.activation, - device=device, ) - return continuous.Actor( + return continuous.ContinuousActorDeterministic( preprocess_net=net_a, action_shape=envs.get_action_shape(), hidden_sizes=(), - device=device, ).to(device) def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None: @@ -204,18 +186,16 @@ def __init__( self.conditioned_sigma = conditioned_sigma self.activation = activation - def create_module(self, envs: Environments, device: TDevice) -> BaseActor: + def create_module(self, envs: Environments, device: TDevice) -> Actor: net_a = Net( state_shape=envs.get_observation_shape(), hidden_sizes=self.hidden_sizes, activation=self.activation, - device=device, ) - actor = continuous.ActorProb( + actor = continuous.ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=envs.get_action_shape(), unbounded=self.unbounded, - device=device, conditioned_sigma=self.conditioned_sigma, ).to(device) @@ -241,18 +221,16 @@ def __init__( self.softmax_output = softmax_output self.activation = activation - def create_module(self, envs: Environments, device: TDevice) -> BaseActor: + def create_module(self, envs: Environments, device: TDevice) -> Actor: net_a = Net( state_shape=envs.get_observation_shape(), hidden_sizes=self.hidden_sizes, activation=self.activation, - device=device, ) - return discrete.Actor( - net_a, - envs.get_action_shape(), + return discrete.DiscreteActor( + preprocess_net=net_a, + action_shape=envs.get_action_shape(), hidden_sizes=(), - device=device, softmax_output=self.softmax_output, ).to(device) @@ -281,7 +259,7 @@ def __setstate__(self, state: dict) -> None: def _tostring_excludes(self) -> list[str]: return [*super()._tostring_excludes(), "_actor_future"] - def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module: + def create_module(self, envs: Environments, device: TDevice) -> Actor | nn.Module: module = self.actor_factory.create_module(envs, device) self._actor_future.actor = module return module @@ -296,5 +274,7 @@ def __init__(self, actor_factory: ActorFactory): def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule: actor = self.actor_factory.create_module(envs, device) - assert isinstance(actor, BaseActor) + assert isinstance( + actor, ModuleWithVectorOutput + ), "Actor factory must produce an actor with known vector output dimension" return IntermediateModule(actor, actor.get_output_dim()) diff --git a/tianshou/highlevel/module/critic.py b/tianshou/highlevel/module/critic.py index 0352fd132..54596be12 100644 --- a/tianshou/highlevel/module/critic.py +++ b/tianshou/highlevel/module/critic.py @@ -8,10 +8,10 @@ from tianshou.highlevel.env import Environments, EnvType from tianshou.highlevel.module.actor import ActorFuture from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal -from tianshou.highlevel.module.module_opt import ModuleOpt -from tianshou.highlevel.optim import OptimizerFactory -from tianshou.utils.net import continuous, discrete -from tianshou.utils.net.common import BaseActor, EnsembleLinear, ModuleType, Net +from tianshou.utils.net import continuous +from tianshou.utils.net.common import Actor, EnsembleLinear, ModuleType, Net +from tianshou.utils.net.continuous import ContinuousCritic +from tianshou.utils.net.discrete import DiscreteCritic class CriticFactory(ToStringMixin, ABC): @@ -34,34 +34,6 @@ def create_module( :return: the module """ - def create_module_opt( - self, - envs: Environments, - device: TDevice, - use_action: bool, - optim_factory: OptimizerFactory, - lr: float, - discrete_last_size_use_action_shape: bool = False, - ) -> ModuleOpt: - """Creates the critic module along with its optimizer for the given learning rate. - - :param envs: the environments - :param device: the torch device - :param use_action: whether to expect the action as an additional input (in addition to the observations) - :param optim_factory: the optimizer factory - :param lr: the learning rate - :param discrete_last_size_use_action_shape: whether, for the discrete case, the output dimension shall use the action shape - :return: - """ - module = self.create_module( - envs, - device, - use_action, - discrete_last_size_use_action_shape=discrete_last_size_use_action_shape, - ) - opt = optim_factory.create_optimizer(module, lr) - return ModuleOpt(module, opt) - class CriticFactoryDefault(CriticFactory): """A critic factory which, depending on the type of environment, creates a suitable MLP-based critic.""" @@ -125,9 +97,8 @@ def create_module( hidden_sizes=self.hidden_sizes, concat=use_action, activation=self.activation, - device=device, ) - critic = continuous.Critic(net_c, device=device).to(device) + critic = continuous.ContinuousCritic(preprocess_net=net_c).to(device) init_linear_orthogonal(critic) return critic @@ -151,12 +122,11 @@ def create_module( hidden_sizes=self.hidden_sizes, concat=use_action, activation=self.activation, - device=device, ) last_size = ( int(np.prod(envs.get_action_shape())) if discrete_last_size_use_action_shape else 1 ) - critic = discrete.Critic(net_c, device=device, last_size=last_size).to(device) + critic = DiscreteCritic(preprocess_net=net_c, last_size=last_size).to(device) init_linear_orthogonal(critic) return critic @@ -188,24 +158,22 @@ def create_module( discrete_last_size_use_action_shape: bool = False, ) -> nn.Module: actor = self.actor_future.actor - if not isinstance(actor, BaseActor): + if not isinstance(actor, Actor): raise ValueError( - f"Option critic_use_action can only be used if actor is of type {BaseActor.__class__.__name__}", + f"Option critic_use_action can only be used if actor is of type {Actor.__class__.__name__}", ) if envs.get_type().is_discrete(): # TODO get rid of this prod pattern here and elsewhere last_size = ( int(np.prod(envs.get_action_shape())) if discrete_last_size_use_action_shape else 1 ) - return discrete.Critic( - actor.get_preprocess_net(), - device=device, + return DiscreteCritic( + preprocess_net=actor.get_preprocess_net(), last_size=last_size, ).to(device) elif envs.get_type().is_continuous(): - return continuous.Critic( - actor.get_preprocess_net(), - device=device, + return ContinuousCritic( + preprocess_net=actor.get_preprocess_net(), apply_preprocess_net_to_obs_only=True, ).to(device) else: @@ -223,19 +191,6 @@ def create_module( ) -> nn.Module: pass - def create_module_opt( - self, - envs: Environments, - device: TDevice, - ensemble_size: int, - use_action: bool, - optim_factory: OptimizerFactory, - lr: float, - ) -> ModuleOpt: - module = self.create_module(envs, device, ensemble_size, use_action) - opt = optim_factory.create_optimizer(module, lr) - return ModuleOpt(module, opt) - class CriticEnsembleFactoryDefault(CriticEnsembleFactory): """A critic ensemble factory which, depending on the type of environment, creates a suitable MLP-based critic.""" @@ -290,12 +245,10 @@ def linear_layer(x: int, y: int) -> EnsembleLinear: hidden_sizes=self.hidden_sizes, concat=use_action, activation=nn.Tanh, - device=device, linear_layer=linear_layer, ) - critic = continuous.Critic( - net_c, - device=device, + critic = continuous.ContinuousCritic( + preprocess_net=net_c, linear_layer=linear_layer, flatten_input=False, ).to(device) diff --git a/tianshou/highlevel/module/intermediate.py b/tianshou/highlevel/module/intermediate.py index 62bf3843f..08b32641a 100644 --- a/tianshou/highlevel/module/intermediate.py +++ b/tianshou/highlevel/module/intermediate.py @@ -6,6 +6,7 @@ from tianshou.highlevel.env import Environments from tianshou.highlevel.module.core import ModuleFactory, TDevice +from tianshou.utils.net.common import ModuleWithVectorOutput @dataclass @@ -15,6 +16,12 @@ class IntermediateModule: module: torch.nn.Module output_dim: int + def get_module_with_vector_output(self) -> ModuleWithVectorOutput: + if isinstance(self.module, ModuleWithVectorOutput): + return self.module + else: + return ModuleWithVectorOutput.from_module(self.module, self.output_dim) + class IntermediateModuleFactory(ToStringMixin, ModuleFactory, ABC): """Factory for the generation of a module which computes an intermediate representation.""" diff --git a/tianshou/highlevel/module/module_opt.py b/tianshou/highlevel/module/module_opt.py deleted file mode 100644 index 558686aa9..000000000 --- a/tianshou/highlevel/module/module_opt.py +++ /dev/null @@ -1,29 +0,0 @@ -from dataclasses import dataclass - -import torch - -from tianshou.utils.net.common import ActorCritic - - -@dataclass -class ModuleOpt: - """Container for a torch module along with its optimizer.""" - - module: torch.nn.Module - optim: torch.optim.Optimizer - - -@dataclass -class ActorCriticOpt: - """Container for an :class:`ActorCritic` instance along with its optimizer.""" - - actor_critic_module: ActorCritic - optim: torch.optim.Optimizer - - @property - def actor(self) -> torch.nn.Module: - return self.actor_critic_module.actor - - @property - def critic(self) -> torch.nn.Module: - return self.actor_critic_module.critic diff --git a/tianshou/highlevel/module/special.py b/tianshou/highlevel/module/special.py index de572d7a1..b36b26d9e 100644 --- a/tianshou/highlevel/module/special.py +++ b/tianshou/highlevel/module/special.py @@ -22,10 +22,8 @@ def __init__( def create_module(self, envs: Environments, device: TDevice) -> ImplicitQuantileNetwork: preprocess_net = self.preprocess_net_factory.create_intermediate_module(envs, device) return ImplicitQuantileNetwork( - preprocess_net=preprocess_net.module, + preprocess_net=preprocess_net.get_module_with_vector_output(), action_shape=envs.get_action_shape(), hidden_sizes=self.hidden_sizes, num_cosines=self.num_cosines, - preprocess_net_output_dim=preprocess_net.output_dim, - device=device, ).to(device) diff --git a/tianshou/highlevel/params/algorithm_params.py b/tianshou/highlevel/params/algorithm_params.py new file mode 100644 index 000000000..4c1c81dad --- /dev/null +++ b/tianshou/highlevel/params/algorithm_params.py @@ -0,0 +1,845 @@ +from abc import ABC, abstractmethod +from collections.abc import Callable, Sequence +from dataclasses import asdict, dataclass +from typing import Any, Literal, Protocol + +from sensai.util.string import ToStringMixin + +from tianshou.exploration import BaseNoise +from tianshou.highlevel.env import Environments +from tianshou.highlevel.module.core import TDevice +from tianshou.highlevel.params.alpha import AutoAlphaFactory +from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactory +from tianshou.highlevel.params.noise import NoiseFactory +from tianshou.highlevel.params.optim import OptimizerFactoryFactory + + +@dataclass(kw_only=True) +class ParamTransformerData: + """Holds data that can be used by `ParamTransformer` instances to perform their transformation. + + The representation contains the superset of all data items that are required by different types of agent factories. + An agent factory is expected to set only the attributes that are relevant to its parameters. + """ + + envs: Environments + device: TDevice + optim_factory_default: OptimizerFactoryFactory + + +class ParamTransformer(ABC): + """Base class for parameter transformations from high to low-level API. + + Transforms one or more parameters from the representation used by the high-level API + to the representation required by the (low-level) policy implementation. + It operates directly on a dictionary of keyword arguments, which is initially + generated from the parameter dataclass (subclass of `Params`). + """ + + @abstractmethod + def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: + pass + + @staticmethod + def get( + d: dict[str, Any], + key: str, + drop: bool = False, + default_factory: Callable[[], Any] | None = None, + ) -> Any: + try: + value = d[key] + except KeyError as e: + raise Exception(f"Key not found: '{key}'; available keys: {list(d.keys())}") from e + if value is None and default_factory is not None: + value = default_factory() + if drop: + del d[key] + return value + + +class ParamTransformerDrop(ParamTransformer): + def __init__(self, *keys: str): + self.keys = keys + + def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None: + for k in self.keys: + del kwargs[k] + + +class ParamTransformerRename(ParamTransformer): + def __init__(self, renamed_params: dict[str, str]): + self.renamed_params = renamed_params + + def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None: + for old_name, new_name in self.renamed_params.items(): + v = kwargs[old_name] + del kwargs[old_name] + kwargs[new_name] = v + + +class ParamTransformerChangeValue(ParamTransformer): + def __init__(self, key: str): + self.key = key + + def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: + params[self.key] = self.change_value(params[self.key], data) + + @abstractmethod + def change_value(self, value: Any, data: ParamTransformerData) -> Any: + pass + + +class ParamTransformerOptimFactory(ParamTransformer): + """Transformer for learning rate scheduler params. + + Transforms a key containing a learning rate scheduler factory (removed) into a key containing + a learning rate scheduler (added) for the data member `optim`. + """ + + def __init__( + self, + key_optim_factory_factory: str, + key_lr: str, + key_lr_scheduler_factory_factory: str, + key_optim_output: str, + ): + self.key_optim_factory_factory = key_optim_factory_factory + self.key_lr = key_lr + self.key_scheduler_factory = key_lr_scheduler_factory_factory + self.key_optim_output = key_optim_output + + def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: + optim_factory_factory: OptimizerFactoryFactory = self.get( + params, + self.key_optim_factory_factory, + drop=True, + default_factory=lambda: data.optim_factory_default, + ) + lr_scheduler_factory_factory: LRSchedulerFactoryFactory | None = self.get( + params, self.key_scheduler_factory, drop=True + ) + lr: float = self.get(params, self.key_lr, drop=True) + optim_factory = optim_factory_factory.create_optimizer_factory(lr) + if lr_scheduler_factory_factory is not None: + optim_factory.with_lr_scheduler_factory( + lr_scheduler_factory_factory.create_lr_scheduler_factory() + ) + params[self.key_optim_output] = optim_factory + + +class ParamTransformerAutoAlpha(ParamTransformer): + def __init__(self, key: str): + self.key = key + + def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None: + alpha = self.get(kwargs, self.key) + if isinstance(alpha, AutoAlphaFactory): + kwargs[self.key] = alpha.create_auto_alpha(data.envs, data.device) + + +class ParamTransformerNoiseFactory(ParamTransformerChangeValue): + def change_value(self, value: Any, data: ParamTransformerData) -> Any: + if isinstance(value, NoiseFactory): + value = value.create_noise(data.envs) + return value + + +class ParamTransformerFloatEnvParamFactory(ParamTransformerChangeValue): + def change_value(self, value: Any, data: ParamTransformerData) -> Any: + if isinstance(value, EnvValueFactory): + value = value.create_value(data.envs) + return value + + +class ParamTransformerActionScaling(ParamTransformerChangeValue): + def change_value(self, value: Any, data: ParamTransformerData) -> Any: + if value == "default": + return data.envs.get_type().is_continuous() + else: + return value + + +class GetParamTransformersProtocol(Protocol): + def _get_param_transformers(self) -> list[ParamTransformer]: + pass + + +@dataclass(kw_only=True) +class Params(GetParamTransformersProtocol, ToStringMixin): + def create_kwargs(self, data: ParamTransformerData) -> dict[str, Any]: + params = asdict(self) + for transformer in self._get_param_transformers(): + transformer.transform(params, data) + return params + + def _get_param_transformers(self) -> list[ParamTransformer]: + return [] + + +@dataclass(kw_only=True) +class ParamsMixinSingleModel(GetParamTransformersProtocol): + optim: OptimizerFactoryFactory | None = None + """the factory for the creation of the model's optimizer; if None, use default""" + lr: float = 1e-3 + """the learning rate to use in the gradient-based optimizer""" + lr_scheduler: LRSchedulerFactoryFactory | None = None + """factory for the creation of a learning rate scheduler""" + + def _get_param_transformers(self) -> list[ParamTransformer]: + return [ + ParamTransformerOptimFactory("optim", "lr", "lr_scheduler", "optim"), + ] + + +@dataclass(kw_only=True) +class ParamsMixinActorAndCritic(GetParamTransformersProtocol): + actor_optim: OptimizerFactoryFactory | None = None + """the factory for the creation of the actor's optimizer; if None, use default""" + critic_optim: OptimizerFactoryFactory | None = None + """the factory for the creation of the critic's optimizer; if None, use default""" + actor_lr: float = 1e-3 + """the learning rate to use for the actor network""" + critic_lr: float = 1e-3 + """the learning rate to use for the critic network""" + actor_lr_scheduler: LRSchedulerFactoryFactory | None = None + """factory for the creation of a learning rate scheduler to use for the actor network (if any)""" + critic_lr_scheduler: LRSchedulerFactoryFactory | None = None + """factory for the creation of a learning rate scheduler to use for the critic network (if any)""" + + def _get_param_transformers(self) -> list[ParamTransformer]: + return [ + ParamTransformerOptimFactory( + "actor_optim", "actor_lr", "actor_lr_scheduler", "policy_optim" + ), + ParamTransformerOptimFactory( + "critic_optim", "critic_lr", "critic_lr_scheduler", "critic_optim" + ), + ] + + +@dataclass(kw_only=True) +class ParamsMixinActionScaling(GetParamTransformersProtocol): + action_scaling: bool | Literal["default"] = "default" + """ + flag indicating whether, for continuous action spaces, actions + should be scaled from the standard neural network output range [-1, 1] to the + environment's action space range [action_space.low, action_space.high]. + This applies to continuous action spaces only (gym.spaces.Box) and has no effect + for discrete spaces. + When enabled, policy outputs are expected to be in the normalized range [-1, 1] + (after bounding), and are then linearly transformed to the actual required range. + This improves neural network training stability, allows the same algorithm to work + across environments with different action ranges, and standardizes exploration + strategies. + Should be disabled if the actor model already produces outputs in the correct range. + """ + + def _get_param_transformers(self) -> list[ParamTransformer]: + return [ParamTransformerActionScaling("action_scaling")] + + +@dataclass(kw_only=True) +class ParamsMixinActionScalingAndBounding(ParamsMixinActionScaling): + action_bound_method: Literal["clip", "tanh"] | None = "clip" + """ + the method used for bounding actions in continuous action spaces + to the range [-1, 1] before scaling them to the environment's action space (provided + that `action_scaling` is enabled). + This applies to continuous action spaces only (`gym.spaces.Box`) and should be set to None + for discrete spaces. + When set to "clip", actions exceeding the [-1, 1] range are simply clipped to this + range. When set to "tanh", a hyperbolic tangent function is applied, which smoothly + constrains outputs to [-1, 1] while preserving gradients. + The choice of bounding method affects both training dynamics and exploration behavior. + Clipping provides hard boundaries but may create plateau regions in the gradient + landscape, while tanh provides smoother transitions but can compress sensitivity + near the boundaries. + Should be set to None if the actor model inherently produces bounded outputs. + Typically used together with `action_scaling=True`. + """ + + +@dataclass(kw_only=True) +class ParamsMixinExplorationNoise(GetParamTransformersProtocol): + exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = None + """ + If not None, add noise to actions for exploration. + This is useful when solving "hard exploration" problems. + It can either be a distribution, a factory for the creation of a distribution or "default". + When set to "default", use Gaussian noise with standard deviation 0.1. + """ + + def _get_param_transformers(self) -> list[ParamTransformer]: + return [ParamTransformerNoiseFactory("exploration_noise")] + + +@dataclass(kw_only=True) +class ParamsMixinNStepReturnHorizon: + n_step_return_horizon: int = 1 + """ + the number of future steps (> 0) to consider when computing temporal difference (TD) targets. + Controls the balance between TD learning and Monte Carlo methods: + Higher values reduce bias (by relying less on potentially inaccurate value estimates) + but increase variance (by incorporating more environmental stochasticity and reducing + the averaging effect). + A value of 1 corresponds to standard TD learning with immediate bootstrapping, while very + large values approach Monte Carlo-like estimation that uses complete episode returns. + """ + + +@dataclass(kw_only=True) +class ParamsMixinGamma: + gamma: float = 0.99 + """ + the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + """ + + +@dataclass(kw_only=True) +class ParamsMixinTau: + tau: float = 0.005 + """ + the soft update coefficient for target networks, controlling the rate at which + target networks track the learned networks. + When the parameters of the target network are updated with the current (source) network's + parameters, a weighted average is used: target = tau * source + (1 - tau) * target. + Smaller values (closer to 0) create more stable but slower learning as target networks + change more gradually. Higher values (closer to 1) allow faster learning but may reduce + stability. + Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. + """ + + +@dataclass(kw_only=True) +class ParamsMixinDeterministicEval: + deterministic_eval: bool = False + """ + flag indicating whether the policy should use deterministic + actions (using the mode of the action distribution) instead of stochastic ones + (using random sampling) during evaluation. + When enabled, the policy will always select the most probable action according to + the learned distribution during evaluation phases, while still using stochastic + sampling during training. This creates a clear distinction between exploration + (training) and exploitation (evaluation) behaviors. + Deterministic actions are generally preferred for final deployment and reproducible + evaluation as they provide consistent behavior, reduce variance in performance + metrics, and are more interpretable for human observers. + Note that this parameter only affects behavior when the policy is not within a + training step. When collecting rollouts for training, actions remain stochastic + regardless of this setting to maintain proper exploration behaviour. + """ + + +class OnPolicyAlgorithmParams( + Params, + ParamsMixinGamma, + ParamsMixinActionScalingAndBounding, + ParamsMixinSingleModel, + ParamsMixinDeterministicEval, +): + def _get_param_transformers(self) -> list[ParamTransformer]: + transformers = super()._get_param_transformers() + transformers.extend(ParamsMixinActionScalingAndBounding._get_param_transformers(self)) + transformers.extend(ParamsMixinSingleModel._get_param_transformers(self)) + return transformers + + +@dataclass(kw_only=True) +class ReinforceParams(OnPolicyAlgorithmParams): + return_standardization: bool = False + """ + whether to standardize episode returns by subtracting the running mean and + dividing by the running standard deviation. + Note that this is known to be detrimental to performance in many cases! + """ + + +@dataclass(kw_only=True) +class ParamsMixinGeneralAdvantageEstimation(GetParamTransformersProtocol): + gae_lambda: float = 0.95 + """ + the lambda parameter in [0, 1] for generalized advantage estimation (GAE). + Controls the bias-variance tradeoff in advantage estimates, acting as a + weighting factor for combining different n-step advantage estimators. Higher values + (closer to 1) reduce bias but increase variance by giving more weight to longer + trajectories, while lower values (closer to 0) reduce variance but increase bias + by relying more on the immediate TD error and value function estimates. At λ=0, + GAE becomes equivalent to the one-step TD error (high bias, low variance); at λ=1, + it becomes equivalent to Monte Carlo advantage estimation (low bias, high variance). + Intermediate values create a weighted average of n-step returns, with exponentially + decaying weights for longer-horizon returns. Typically set between 0.9 and 0.99 for + most policy gradient methods. + """ + max_batchsize: int = 256 + """the maximum number of samples to process at once when computing + generalized advantage estimation (GAE) and value function predictions. + Controls memory usage by breaking large batches into smaller chunks processed sequentially. + Higher values may increase speed but require more GPU/CPU memory; lower values + reduce memory requirements but may increase computation time. Should be adjusted + based on available hardware resources and total batch size of your training data.""" + + def _get_param_transformers(self) -> list[ParamTransformer]: + return [] + + +@dataclass(kw_only=True) +class ActorCriticOnPolicyParams(OnPolicyAlgorithmParams): + return_scaling: bool = False + """ + flag indicating whether to enable scaling of estimated returns by + dividing them by their running standard deviation without centering the mean. + This reduces the magnitude variation of advantages across different episodes while + preserving their signs and relative ordering. + The use of running statistics (rather than batch-specific scaling) means that early + training experiences may be scaled differently than later ones as the statistics evolve. + When enabled, this improves training stability in environments with highly variable + reward scales and makes the algorithm less sensitive to learning rate settings. + However, it may reduce the algorithm's ability to distinguish between episodes with + different absolute return magnitudes. + Best used in environments where the relative ordering of actions is more important + than the absolute scale of returns. + """ + + +@dataclass(kw_only=True) +class A2CParams(ActorCriticOnPolicyParams, ParamsMixinGeneralAdvantageEstimation): + vf_coef: float = 0.5 + """ + coefficient that weights the value loss relative to the actor loss in the overall + loss function. + Higher values prioritize accurate value function estimation over policy improvement. + Controls the trade-off between policy optimization and value function fitting. + Typically set between 0.5 and 1.0 for most actor-critic implementations. + """ + ent_coef: float = 0.01 + """ + coefficient that weights the entropy bonus relative to the actor loss. + Controls the exploration-exploitation trade-off by encouraging policy entropy. + Higher values promote more exploration by encouraging a more uniform action distribution. + Lower values focus more on exploitation of the current policy's knowledge. + Typically set between 0.01 and 0.05 for most actor-critic implementations. + """ + max_grad_norm: float | None = None + """ + the maximum L2 norm threshold for gradient clipping. + When not None, gradients will be rescaled using to ensure their L2 norm does not + exceed this value. This prevents exploding gradients and stabilizes training by + limiting the magnitude of parameter updates. + Set to None to disable gradient clipping. + """ + + def _get_param_transformers(self) -> list[ParamTransformer]: + transformers = super()._get_param_transformers() + transformers.extend(ParamsMixinGeneralAdvantageEstimation._get_param_transformers(self)) + return transformers + + +@dataclass(kw_only=True) +class PPOParams(A2CParams): + eps_clip: float = 0.2 + """ + determines the range of allowed change in the policy during a policy update: + The ratio of action probabilities indicated by the new and old policy is + constrained to stay in the interval [1 - eps_clip, 1 + eps_clip]. + Small values thus force the new policy to stay close to the old policy. + Typical values range between 0.1 and 0.3, the value of 0.2 is recommended + in the original PPO paper. + The optimal value depends on the environment; more stochastic environments may + need larger values. + """ + dual_clip: float | None = None + """ + a clipping parameter (denoted as c in the literature) that prevents + excessive pessimism in policy updates for negative-advantage actions. + Excessive pessimism occurs when the policy update too strongly reduces the probability + of selecting actions that led to negative advantages, potentially eliminating useful + actions based on limited negative experiences. + When enabled (c > 1), the objective for negative advantages becomes: + max(min(r(θ)A, clip(r(θ), 1-ε, 1+ε)A), c*A), where min(r(θ)A, clip(r(θ), 1-ε, 1+ε)A) + is the original single-clipping objective determined by `eps_clip`. + This creates a floor on negative policy gradients, maintaining some probability + of exploring actions despite initial negative outcomes. + Larger values (e.g., 2.0 to 5.0) maintain more exploration, while values closer + to 1.0 provide less protection against pessimistic updates. + Set to None to disable dual clipping. + """ + value_clip: bool = False + """ + flag indicating whether to enable clipping for value function updates. + When enabled, restricts how much the value function estimate can change from its + previous prediction, using the same clipping range as the policy updates (eps_clip). + This stabilizes training by preventing large fluctuations in value estimates, + particularly useful in environments with high reward variance. + The clipped value loss uses a pessimistic approach, taking the maximum of the + original and clipped value errors: + max((returns - value)², (returns - v_clipped)²) + Setting to True often improves training stability but may slow convergence. + Implementation follows the approach mentioned in arXiv:1811.02553v3 Sec. 4.1. + """ + advantage_normalization: bool = True + """whether to apply per mini-batch advantage normalization.""" + recompute_advantage: bool = False + """ + whether to recompute advantage every update repeat as described in + https://arxiv.org/pdf/2006.05990.pdf, Sec. 3.5. + The original PPO implementation splits the data in each policy iteration + step into individual transitions and then randomly assigns them to minibatches. + This makes it impossible to compute advantages as the temporal structure is broken. + Therefore, the advantages are computed once at the beginning of each policy iteration step and + then used in minibatch policy and value function optimization. + This results in higher diversity of data in each minibatch at the cost of + using slightly stale advantage estimations. + Enabling this option will, as a remedy to this problem, recompute the advantages at the beginning + of each pass over the data instead of just once per iteration. + """ + + +@dataclass(kw_only=True) +class NPGParams(ActorCriticOnPolicyParams, ParamsMixinGeneralAdvantageEstimation): + optim_critic_iters: int = 5 + """ + the number of optimization steps performed on the critic network for each policy (actor) update. + Controls the learning rate balance between critic and actor. + Higher values prioritize critic accuracy by training the value function more + extensively before each policy update, which can improve stability but slow down + training. Lower values maintain a more even learning pace between policy and value + function but may lead to less reliable advantage estimates. + Typically set between 1 and 10, depending on the complexity of the value function. + """ + trust_region_size: float = 0.5 + """ + the parameter delta - a scalar multiplier for policy updates in the natural gradient direction. + The mathematical meaning is the trust region size, which is the maximum KL divergence + allowed between the old and new policy distributions. + Controls how far the policy parameters move in the calculated direction + during each update. Higher values allow for faster learning but may cause instability + or policy deterioration; lower values provide more stable but slower learning. Unlike + regular policy gradients, natural gradients already account for the local geometry of + the parameter space, making this step size more robust to different parameterizations. + Typically set between 0.1 and 1.0 for most reinforcement learning tasks. + """ + advantage_normalization: bool = True + """whether to do per mini-batch advantage normalization.""" + + def _get_param_transformers(self) -> list[ParamTransformer]: + transformers = super()._get_param_transformers() + transformers.extend(ParamsMixinGeneralAdvantageEstimation._get_param_transformers(self)) + return transformers + + +@dataclass(kw_only=True) +class TRPOParams(NPGParams): + max_kl: float = 0.01 + """ + maximum KL divergence, used to constrain each actor network update. + """ + backtrack_coeff: float = 0.8 + """ + coefficient with which to reduce the step size when constraints are not met. + """ + max_backtracks: int = 10 + """maximum number of times to backtrack in line search when the constraints are not met.""" + + +@dataclass(kw_only=True) +class ParamsMixinActorAndDualCritics(GetParamTransformersProtocol): + actor_optim: OptimizerFactoryFactory | None = None + """the factory for the creation of the actor's optimizer; if None, use default""" + critic1_optim: OptimizerFactoryFactory | None = None + """the factory for the creation of the first critic's optimizer; if None, use default""" + critic2_optim: OptimizerFactoryFactory | None = None + """the factory for the creation of the second critic's optimizer; if None, use default""" + actor_lr: float = 1e-3 + """the learning rate to use for the actor network""" + critic1_lr: float = 1e-3 + """the learning rate to use for the first critic network""" + critic2_lr: float = 1e-3 + """the learning rate to use for the second critic network""" + actor_lr_scheduler: LRSchedulerFactoryFactory | None = None + """factory for the creation of a learning rate scheduler to use for the actor network (if any)""" + critic1_lr_scheduler: LRSchedulerFactoryFactory | None = None + """factory for the creation of a learning rate scheduler to use for the first critic network (if any)""" + critic2_lr_scheduler: LRSchedulerFactoryFactory | None = None + """factory for the creation of a learning rate scheduler to use for the second critic network (if any)""" + + def _get_param_transformers(self) -> list[ParamTransformer]: + return [ + ParamTransformerOptimFactory( + "actor_optim", "actor_lr", "actor_lr_scheduler", "policy_optim" + ), + ParamTransformerOptimFactory( + "critic1_optim", "critic1_lr", "critic1_lr_scheduler", "critic_optim" + ), + ParamTransformerOptimFactory( + "critic2_optim", "critic2_lr", "critic2_lr_scheduler", "critic2_optim" + ), + ] + + +@dataclass(kw_only=True) +class ParamsMixinAlpha(GetParamTransformersProtocol): + alpha: float | AutoAlphaFactory = 0.2 + """ + the entropy regularization coefficient, which balances exploration and exploitation. + This coefficient controls how much the agent values randomness in its policy versus + pursuing higher rewards. + Higher values (e.g., 0.5-1.0) strongly encourage exploration by rewarding the agent + for maintaining diverse action choices, even if this means selecting some lower-value actions. + Lower values (e.g., 0.01-0.1) prioritize exploitation, allowing the policy to become + more focused on the highest-value actions. + A value of 0 would completely remove entropy regularization, potentially leading to + premature convergence to suboptimal deterministic policies. + Can be provided as a fixed float (0.2 is a reasonable default) or via a factory + to support automatic tuning during training. + """ + + def _get_param_transformers(self) -> list[ParamTransformer]: + return [ParamTransformerAutoAlpha("alpha")] + + +@dataclass(kw_only=True) +class _SACParams( + Params, + ParamsMixinGamma, + ParamsMixinActorAndDualCritics, + ParamsMixinNStepReturnHorizon, + ParamsMixinTau, + ParamsMixinDeterministicEval, + ParamsMixinAlpha, +): + def _get_param_transformers(self) -> list[ParamTransformer]: + transformers = super()._get_param_transformers() + transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self)) + transformers.extend(ParamsMixinAlpha._get_param_transformers(self)) + return transformers + + +@dataclass(kw_only=True) +class SACParams(_SACParams, ParamsMixinExplorationNoise, ParamsMixinActionScaling): + def _get_param_transformers(self) -> list[ParamTransformer]: + transformers = super()._get_param_transformers() + transformers.extend(ParamsMixinExplorationNoise._get_param_transformers(self)) + transformers.extend(ParamsMixinActionScalingAndBounding._get_param_transformers(self)) + return transformers + + +@dataclass(kw_only=True) +class DiscreteSACParams(_SACParams): + pass + + +@dataclass(kw_only=True) +class QLearningOffPolicyParams( + Params, ParamsMixinGamma, ParamsMixinSingleModel, ParamsMixinNStepReturnHorizon +): + target_update_freq: int = 0 + """ + the number of training iterations between each complete update of the target network. + Controls how frequently the target Q-network parameters are updated with the current + Q-network values. + A value of 0 disables the target network entirely, using only a single network for both + action selection and bootstrap targets. + Higher values provide more stable learning targets but slow down the propagation of new + value estimates. Lower positive values allow faster learning but may lead to instability + due to rapidly changing targets. + Typically set between 100-10000 for DQN variants, with exact values depending on + environment complexity. + """ + eps_training: float = 0.0 + """ + the epsilon value for epsilon-greedy exploration during training. + When collecting data for training, this is the probability of choosing a random action + instead of the action chosen by the policy. + A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full + exploration (fully random). + """ + eps_inference: float = 0.0 + """ + the epsilon value for epsilon-greedy exploration during inference, + i.e. non-training cases (such as evaluation during test steps). + The epsilon value is the probability of choosing a random action instead of the action + chosen by the policy. + A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full + exploration (fully random). + """ + + def _get_param_transformers(self) -> list[ParamTransformer]: + transformers = super()._get_param_transformers() + transformers.extend(ParamsMixinSingleModel._get_param_transformers(self)) + return transformers + + +@dataclass(kw_only=True) +class DQNParams(QLearningOffPolicyParams): + is_double: bool = True + """ + flag indicating whether to use the Double DQN algorithm for target value computation. + If True, the algorithm uses the online network to select actions and the target network to + evaluate their Q-values. This approach helps reduce the overestimation bias in Q-learning + by decoupling action selection from action evaluation. + If False, the algorithm follows the vanilla DQN method that directly takes the maximum Q-value + from the target network. + Note: Double Q-learning will only be effective when a target network is used (target_update_freq > 0). + """ + huber_loss_delta: float | None = None + """ + controls whether to use the Huber loss instead of the MSE loss for the TD error and the threshold for + the Huber loss. + If None, the MSE loss is used. + If not None, uses the Huber loss as described in the Nature DQN paper (nature14236) with the given delta, + which limits the influence of outliers. + Unlike the MSE loss where the gradients grow linearly with the error magnitude, the Huber + loss causes the gradients to plateau at a constant value for large errors, providing more stable training. + NOTE: The magnitude of delta should depend on the scale of the returns obtained in the environment. + """ + + def _get_param_transformers(self) -> list[ParamTransformer]: + return super()._get_param_transformers() + + +@dataclass(kw_only=True) +class IQNParams(QLearningOffPolicyParams): + sample_size: int = 32 + """the number of samples for policy evaluation""" + online_sample_size: int = 8 + """the number of samples for online model in training""" + target_sample_size: int = 8 + """the number of samples for target model in training.""" + num_quantiles: int = 200 + """the number of quantile midpoints in the inverse cumulative distribution function of the value""" + hidden_sizes: Sequence[int] = () + """hidden dimensions to use in the IQN network""" + num_cosines: int = 64 + """number of cosines to use in the IQN network""" + + def _get_param_transformers(self) -> list[ParamTransformer]: + transformers = super()._get_param_transformers() + transformers.append(ParamTransformerDrop("hidden_sizes", "num_cosines")) + return transformers + + +@dataclass(kw_only=True) +class DDPGParams( + Params, + ParamsMixinGamma, + ParamsMixinActorAndCritic, + ParamsMixinExplorationNoise, + ParamsMixinActionScalingAndBounding, + ParamsMixinNStepReturnHorizon, + ParamsMixinTau, +): + def _get_param_transformers(self) -> list[ParamTransformer]: + transformers = super()._get_param_transformers() + transformers.extend(ParamsMixinActorAndCritic._get_param_transformers(self)) + transformers.extend(ParamsMixinExplorationNoise._get_param_transformers(self)) + transformers.extend(ParamsMixinActionScalingAndBounding._get_param_transformers(self)) + return transformers + + +@dataclass(kw_only=True) +class REDQParams(DDPGParams, ParamsMixinDeterministicEval, ParamsMixinAlpha): + ensemble_size: int = 10 + """ + the total number of critic networks in the ensemble. + This parameter implements the randomized ensemble approach described in REDQ. + The algorithm maintains `ensemble_size` different critic networks that all share the same architecture. + During target value computation, a random subset of these networks (determined by `subset_size`) is used. + Larger values increase the diversity of the ensemble but require more memory and computation. + The original paper recommends a value of 10 for most tasks, balancing performance and computational efficiency. + """ + subset_size: int = 2 + """ + the number of critic networks randomly selected from the ensemble for computing target Q-values. + During each update, the algorithm samples `subset_size` networks from the ensemble of + `ensemble_size` networks without replacement. + The target Q-value is then calculated as either the minimum or mean (based on target_mode) + of the predictions from this subset. + Smaller values increase randomization and sample efficiency but may introduce more variance. + Larger values provide more stable estimates but reduce the benefits of randomization. + The REDQ paper recommends a value of 2 for optimal sample efficiency. + Must satisfy 0 < subset_size <= ensemble_size. + """ + actor_delay: int = 20 + """ + the number of critic updates performed before each actor update. + The actor network is only updated once for every actor_delay critic updates, implementing + a delayed policy update strategy similar to TD3. + Larger values stabilize training by allowing critics to become more accurate before policy updates. + Smaller values allow the policy to adapt more quickly but may lead to less stable learning. + The REDQ paper recommends a value of 20 for most tasks. + """ + target_mode: Literal["mean", "min"] = "min" + """ + the method used to aggregate Q-values from the subset of critic networks. + Can be either "min" or "mean". + If "min", uses the minimum Q-value across the selected subset of critics for each state-action pair. + If "mean", uses the average Q-value across the selected subset of critics. + Using "min" helps prevent overestimation bias but may lead to more conservative value estimates. + Using "mean" provides more optimistic value estimates but may suffer from overestimation bias. + Default is "min" following the conservative value estimation approach common in recent Q-learning + algorithms. + """ + + def _get_param_transformers(self) -> list[ParamTransformer]: + transformers = super()._get_param_transformers() + transformers.extend(ParamsMixinAlpha._get_param_transformers(self)) + return transformers + + +@dataclass(kw_only=True) +class TD3Params( + Params, + ParamsMixinGamma, + ParamsMixinActorAndDualCritics, + ParamsMixinExplorationNoise, + ParamsMixinActionScalingAndBounding, + ParamsMixinNStepReturnHorizon, + ParamsMixinTau, +): + policy_noise: float | FloatEnvValueFactory = 0.2 + """ + scaling factor for the Gaussian noise added to target policy actions. + This parameter implements target policy smoothing, a regularization technique described in the TD3 paper. + The noise is sampled from a normal distribution and multiplied by this value before being added to actions. + Higher values increase exploration in the target policy, helping to address function approximation error. + The added noise is optionally clipped to a range determined by the noise_clip parameter. + Typically set between 0.1 and 0.5 relative to the action scale of the environment. + """ + noise_clip: float | FloatEnvValueFactory = 0.5 + """ + defines the maximum absolute value of the noise added to target policy actions, i.e. noise values + are clipped to the range [-noise_clip, noise_clip] (after generating and scaling the noise + via `policy_noise`). + This parameter implements bounded target policy smoothing as described in the TD3 paper. + It prevents extreme noise values from causing unrealistic target values during training. + Setting it 0.0 (or a negative value) disables clipping entirely. + It is typically set to about twice the `policy_noise` value (e.g. 0.5 when `policy_noise` is 0.2). + """ + update_actor_freq: int = 2 + """ + the frequency of actor network updates relative to critic network updates + (the actor network is only updated once for every `update_actor_freq` critic updates). + This implements the "delayed" policy updates from the TD3 algorithm, where the actor is + updated less frequently than the critics. + Higher values (e.g., 2-5) help stabilize training by allowing the critic to become more + accurate before updating the policy. + The default value of 2 follows the original TD3 paper's recommendation of updating the + policy at half the rate of the Q-functions. + """ + + def _get_param_transformers(self) -> list[ParamTransformer]: + transformers = super()._get_param_transformers() + transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self)) + transformers.extend(ParamsMixinExplorationNoise._get_param_transformers(self)) + transformers.extend(ParamsMixinActionScalingAndBounding._get_param_transformers(self)) + transformers.append(ParamTransformerFloatEnvParamFactory("policy_noise")) + transformers.append(ParamTransformerFloatEnvParamFactory("noise_clip")) + return transformers diff --git a/tianshou/highlevel/params/algorithm_wrapper.py b/tianshou/highlevel/params/algorithm_wrapper.py new file mode 100644 index 000000000..a5c287fd4 --- /dev/null +++ b/tianshou/highlevel/params/algorithm_wrapper.py @@ -0,0 +1,92 @@ +from abc import ABC, abstractmethod +from collections.abc import Sequence +from typing import Generic, TypeVar + +from sensai.util.string import ToStringMixin + +from tianshou.algorithm import Algorithm, ICMOffPolicyWrapper +from tianshou.algorithm.algorithm_base import OffPolicyAlgorithm, OnPolicyAlgorithm +from tianshou.algorithm.modelbased.icm import ICMOnPolicyWrapper +from tianshou.highlevel.env import Environments +from tianshou.highlevel.module.core import TDevice +from tianshou.highlevel.module.intermediate import IntermediateModuleFactory +from tianshou.highlevel.params.optim import OptimizerFactoryFactory +from tianshou.utils.net.discrete import IntrinsicCuriosityModule + +TAlgorithmOut = TypeVar("TAlgorithmOut", bound=Algorithm) + + +class AlgorithmWrapperFactory(Generic[TAlgorithmOut], ToStringMixin, ABC): + @abstractmethod + def create_wrapped_algorithm( + self, + policy: Algorithm, + envs: Environments, + optim_factory: OptimizerFactoryFactory, + device: TDevice, + ) -> TAlgorithmOut: + pass + + +class AlgorithmWrapperFactoryIntrinsicCuriosity( + AlgorithmWrapperFactory[ICMOffPolicyWrapper | ICMOnPolicyWrapper], +): + def __init__( + self, + *, + feature_net_factory: IntermediateModuleFactory, + hidden_sizes: Sequence[int], + lr: float, + lr_scale: float, + reward_scale: float, + forward_loss_weight: float, + optim: OptimizerFactoryFactory | None = None, + ): + self.feature_net_factory = feature_net_factory + self.hidden_sizes = hidden_sizes + self.lr = lr + self.lr_scale = lr_scale + self.reward_scale = reward_scale + self.forward_loss_weight = forward_loss_weight + self.optim_factory = optim + + def create_wrapped_algorithm( + self, + algorithm: Algorithm, + envs: Environments, + optim_factory_default: OptimizerFactoryFactory, + device: TDevice, + ) -> ICMOffPolicyWrapper | ICMOnPolicyWrapper: + feature_net = self.feature_net_factory.create_intermediate_module(envs, device) + action_dim = envs.get_action_shape() + if not isinstance(action_dim, int): + raise ValueError(f"Environment action shape must be an integer, got {action_dim}") + feature_dim = feature_net.output_dim + icm_net = IntrinsicCuriosityModule( + feature_net=feature_net.module, + feature_dim=feature_dim, + action_dim=action_dim, + hidden_sizes=self.hidden_sizes, + ) + optim_factory = self.optim_factory or optim_factory_default + icm_optim = optim_factory.create_optimizer_factory(lr=self.lr) + if isinstance(algorithm, OffPolicyAlgorithm): + return ICMOffPolicyWrapper( + wrapped_algorithm=algorithm, + model=icm_net, + optim=icm_optim, + lr_scale=self.lr_scale, + reward_scale=self.reward_scale, + forward_loss_weight=self.forward_loss_weight, + ).to(device) + elif isinstance(algorithm, OnPolicyAlgorithm): + return ICMOnPolicyWrapper( + wrapped_algorithm=algorithm, + model=icm_net, + optim=icm_optim, + lr_scale=self.lr_scale, + reward_scale=self.reward_scale, + forward_loss_weight=self.forward_loss_weight, + ).to(device) + else: + raise ValueError(f"{algorithm} is not supported by ICM") diff --git a/tianshou/highlevel/params/alpha.py b/tianshou/highlevel/params/alpha.py index 1c5d60438..61c86cf24 100644 --- a/tianshou/highlevel/params/alpha.py +++ b/tianshou/highlevel/params/alpha.py @@ -1,12 +1,12 @@ from abc import ABC, abstractmethod import numpy as np -import torch from sensai.util.string import ToStringMixin +from tianshou.algorithm.modelfree.sac import Alpha, AutoAlpha from tianshou.highlevel.env import Environments from tianshou.highlevel.module.core import TDevice -from tianshou.highlevel.optim import OptimizerFactory +from tianshou.highlevel.params.optim import OptimizerFactoryFactory class AutoAlphaFactory(ToStringMixin, ABC): @@ -14,14 +14,19 @@ class AutoAlphaFactory(ToStringMixin, ABC): def create_auto_alpha( self, envs: Environments, - optim_factory: OptimizerFactory, device: TDevice, - ) -> tuple[float, torch.Tensor, torch.optim.Optimizer]: + ) -> Alpha: pass class AutoAlphaFactoryDefault(AutoAlphaFactory): - def __init__(self, lr: float = 3e-4, target_entropy_coefficient: float = -1.0): + def __init__( + self, + lr: float = 3e-4, + target_entropy_coefficient: float = -1.0, + log_alpha: float = 0.0, + optim: OptimizerFactoryFactory | None = None, + ) -> None: """ :param lr: the learning rate for the optimizer of the alpha parameter :param target_entropy_coefficient: the coefficient with which to multiply the target entropy; @@ -30,21 +35,23 @@ def __init__(self, lr: float = 3e-4, target_entropy_coefficient: float = -1.0): spaces respectively, which gives a reasonable trade-off between exploration and exploitation. For decidedly stochastic exploration, you can use a positive value closer to 1 (e.g. 0.98); 1.0 would give full entropy exploration. + :param log_alpha: the (initial) value of the log of the entropy regularization coefficient alpha. + :param optim: the optimizer factory to use; if None, use default """ self.lr = lr self.target_entropy_coefficient = target_entropy_coefficient + self.log_alpha = log_alpha + self.optimizer_factory_factory = optim or OptimizerFactoryFactory.default() def create_auto_alpha( self, envs: Environments, - optim_factory: OptimizerFactory, device: TDevice, - ) -> tuple[float, torch.Tensor, torch.optim.Optimizer]: + ) -> AutoAlpha: action_dim = np.prod(envs.get_action_shape()) if envs.get_type().is_continuous(): target_entropy = self.target_entropy_coefficient * float(action_dim) else: target_entropy = self.target_entropy_coefficient * np.log(action_dim) - log_alpha = torch.zeros(1, requires_grad=True, device=device) - alpha_optim = optim_factory.create_optimizer_for_params([log_alpha], self.lr) - return target_entropy, log_alpha, alpha_optim + optim_factory = self.optimizer_factory_factory.create_optimizer_factory(lr=self.lr) + return AutoAlpha(target_entropy, self.log_alpha, optim_factory) diff --git a/tianshou/highlevel/params/dist_fn.py b/tianshou/highlevel/params/dist_fn.py index 6cb436185..d75096bc3 100644 --- a/tianshou/highlevel/params/dist_fn.py +++ b/tianshou/highlevel/params/dist_fn.py @@ -5,8 +5,8 @@ import torch from sensai.util.string import ToStringMixin +from tianshou.algorithm.modelfree.reinforce import TDistFnDiscrete, TDistFnDiscrOrCont from tianshou.highlevel.env import Environments -from tianshou.policy.modelfree.pg import TDistFnDiscrete, TDistFnDiscrOrCont class DistributionFunctionFactory(ToStringMixin, ABC): diff --git a/tianshou/highlevel/params/lr_scheduler.py b/tianshou/highlevel/params/lr_scheduler.py index 09c4c4261..98f90c26c 100644 --- a/tianshou/highlevel/params/lr_scheduler.py +++ b/tianshou/highlevel/params/lr_scheduler.py @@ -1,35 +1,34 @@ from abc import ABC, abstractmethod -import numpy as np -import torch from sensai.util.string import ToStringMixin -from torch.optim.lr_scheduler import LambdaLR, LRScheduler -from tianshou.highlevel.config import SamplingConfig +from tianshou.algorithm.optim import LRSchedulerFactory, LRSchedulerFactoryLinear +from tianshou.highlevel.config import TrainingConfig -class LRSchedulerFactory(ToStringMixin, ABC): - """Factory for the creation of a learning rate scheduler.""" +class LRSchedulerFactoryFactory(ToStringMixin, ABC): + """Factory for the creation of a learning rate scheduler factory.""" @abstractmethod - def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: + def create_lr_scheduler_factory(self) -> LRSchedulerFactory: pass -class LRSchedulerFactoryLinear(LRSchedulerFactory): - def __init__(self, sampling_config: SamplingConfig): - self.sampling_config = sampling_config +class LRSchedulerFactoryFactoryLinear(LRSchedulerFactoryFactory): + def __init__(self, training_config: TrainingConfig): + self.training_config = training_config - def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: - return LambdaLR(optim, lr_lambda=self._LRLambda(self.sampling_config).compute) - - class _LRLambda: - def __init__(self, sampling_config: SamplingConfig): - assert sampling_config.step_per_collect is not None - self.max_update_num = ( - np.ceil(sampling_config.step_per_epoch / sampling_config.step_per_collect) - * sampling_config.num_epochs + def create_lr_scheduler_factory(self) -> LRSchedulerFactory: + if ( + self.training_config.epoch_num_steps is None + or self.training_config.collection_step_num_env_steps is None + ): + raise ValueError( + f"{self.__class__.__name__} requires epoch_num_steps and collection_step_num_env_steps to be set " + f"in order for the scheduling to be well-defined." ) - - def compute(self, epoch: int) -> float: - return 1.0 - epoch / self.max_update_num + return LRSchedulerFactoryLinear( + max_epochs=self.training_config.max_epochs, + epoch_num_steps=self.training_config.epoch_num_steps, + collection_step_num_env_steps=self.training_config.collection_step_num_env_steps, + ) diff --git a/tianshou/highlevel/optim.py b/tianshou/highlevel/params/optim.py similarity index 63% rename from tianshou/highlevel/optim.py rename to tianshou/highlevel/params/optim.py index d480978fb..c80a93331 100644 --- a/tianshou/highlevel/optim.py +++ b/tianshou/highlevel/params/optim.py @@ -4,7 +4,13 @@ import torch from sensai.util.string import ToStringMixin -from torch.optim import Adam, RMSprop + +from tianshou.algorithm.optim import ( + AdamOptimizerFactory, + OptimizerFactory, + RMSpropOptimizerFactory, + TorchOptimizerFactory, +) TParams: TypeAlias = Iterable[torch.Tensor] | Iterable[dict[str, Any]] @@ -14,20 +20,17 @@ def __call__(self, parameters: Any, lr: float, **kwargs: Any) -> torch.optim.Opt pass -class OptimizerFactory(ABC, ToStringMixin): - def create_optimizer( - self, - module: torch.nn.Module, - lr: float, - ) -> torch.optim.Optimizer: - return self.create_optimizer_for_params(module.parameters(), lr) +class OptimizerFactoryFactory(ABC, ToStringMixin): + @staticmethod + def default() -> "OptimizerFactoryFactory": + return OptimizerFactoryFactoryAdam() @abstractmethod - def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer: + def create_optimizer_factory(self, lr: float) -> OptimizerFactory: pass -class OptimizerFactoryTorch(OptimizerFactory): +class OptimizerFactoryFactoryTorch(OptimizerFactoryFactory): def __init__(self, optim_class: OptimizerWithLearningRateProtocol, **kwargs: Any): """Factory for torch optimizers. @@ -39,13 +42,11 @@ def __init__(self, optim_class: OptimizerWithLearningRateProtocol, **kwargs: Any self.optim_class = optim_class self.kwargs = kwargs - def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer: - return self.optim_class(params, lr=lr, **self.kwargs) + def create_optimizer_factory(self, lr: float) -> OptimizerFactory: + return TorchOptimizerFactory(optim_class=self.optim_class, lr=lr) -class OptimizerFactoryAdam(OptimizerFactory): - # Note: currently used as default optimizer - # values should be kept in sync with `ExperimentBuilder.with_optim_factory_default` +class OptimizerFactoryFactoryAdam(OptimizerFactoryFactory): def __init__( self, betas: tuple[float, float] = (0.9, 0.999), @@ -56,9 +57,8 @@ def __init__( self.eps = eps self.betas = betas - def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer: - return Adam( - params, + def create_optimizer_factory(self, lr: float) -> AdamOptimizerFactory: + return AdamOptimizerFactory( lr=lr, betas=self.betas, eps=self.eps, @@ -66,7 +66,7 @@ def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim ) -class OptimizerFactoryRMSprop(OptimizerFactory): +class OptimizerFactoryFactoryRMSprop(OptimizerFactoryFactory): def __init__( self, alpha: float = 0.99, @@ -81,9 +81,8 @@ def __init__( self.weight_decay = weight_decay self.eps = eps - def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer: - return RMSprop( - params, + def create_optimizer_factory(self, lr: float) -> RMSpropOptimizerFactory: + return RMSpropOptimizerFactory( lr=lr, alpha=self.alpha, eps=self.eps, diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py deleted file mode 100644 index d20bbe44b..000000000 --- a/tianshou/highlevel/params/policy_params.py +++ /dev/null @@ -1,636 +0,0 @@ -from abc import ABC, abstractmethod -from collections.abc import Sequence -from dataclasses import asdict, dataclass -from typing import Any, Literal, Protocol - -import torch -from sensai.util.pickle import setstate -from sensai.util.string import ToStringMixin -from torch.optim.lr_scheduler import LRScheduler - -from tianshou.exploration import BaseNoise -from tianshou.highlevel.env import Environments -from tianshou.highlevel.module.core import TDevice -from tianshou.highlevel.module.module_opt import ModuleOpt -from tianshou.highlevel.optim import OptimizerFactory -from tianshou.highlevel.params.alpha import AutoAlphaFactory -from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory -from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory -from tianshou.highlevel.params.noise import NoiseFactory -from tianshou.utils import MultipleLRSchedulers - - -@dataclass(kw_only=True) -class ParamTransformerData: - """Holds data that can be used by `ParamTransformer` instances to perform their transformation. - - The representation contains the superset of all data items that are required by different types of agent factories. - An agent factory is expected to set only the attributes that are relevant to its parameters. - """ - - envs: Environments - device: TDevice - optim_factory: OptimizerFactory - optim: torch.optim.Optimizer | None = None - """the single optimizer for the case where there is just one""" - actor: ModuleOpt | None = None - critic1: ModuleOpt | None = None - critic2: ModuleOpt | None = None - - -class ParamTransformer(ABC): - """Base class for parameter transformations from high to low-level API. - - Transforms one or more parameters from the representation used by the high-level API - to the representation required by the (low-level) policy implementation. - It operates directly on a dictionary of keyword arguments, which is initially - generated from the parameter dataclass (subclass of `Params`). - """ - - @abstractmethod - def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: - pass - - @staticmethod - def get(d: dict[str, Any], key: str, drop: bool = False) -> Any: - value = d[key] - if drop: - del d[key] - return value - - -class ParamTransformerDrop(ParamTransformer): - def __init__(self, *keys: str): - self.keys = keys - - def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None: - for k in self.keys: - del kwargs[k] - - -class ParamTransformerChangeValue(ParamTransformer): - def __init__(self, key: str): - self.key = key - - def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: - params[self.key] = self.change_value(params[self.key], data) - - @abstractmethod - def change_value(self, value: Any, data: ParamTransformerData) -> Any: - pass - - -class ParamTransformerLRScheduler(ParamTransformer): - """Transformer for learning rate scheduler params. - - Transforms a key containing a learning rate scheduler factory (removed) into a key containing - a learning rate scheduler (added) for the data member `optim`. - """ - - def __init__(self, key_scheduler_factory: str, key_scheduler: str): - self.key_scheduler_factory = key_scheduler_factory - self.key_scheduler = key_scheduler - - def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: - assert data.optim is not None - factory: LRSchedulerFactory | None = self.get(params, self.key_scheduler_factory, drop=True) - params[self.key_scheduler] = ( - factory.create_scheduler(data.optim) if factory is not None else None - ) - - -class ParamTransformerMultiLRScheduler(ParamTransformer): - def __init__(self, optim_key_list: list[tuple[torch.optim.Optimizer, str]], key_scheduler: str): - """Transforms several scheduler factories into a single scheduler. - - The result may be a `MultipleLRSchedulers` instance if more than one factory is indeed given. - - :param optim_key_list: a list of tuples (optimizer, key of learning rate factory) - :param key_scheduler: the key under which to store the resulting learning rate scheduler - """ - self.optim_key_list = optim_key_list - self.key_scheduler = key_scheduler - - def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: - lr_schedulers = [] - for optim, lr_scheduler_factory_key in self.optim_key_list: - lr_scheduler_factory: LRSchedulerFactory | None = self.get( - params, - lr_scheduler_factory_key, - drop=True, - ) - if lr_scheduler_factory is not None: - lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim)) - lr_scheduler: LRScheduler | MultipleLRSchedulers | None - match len(lr_schedulers): - case 0: - lr_scheduler = None - case 1: - lr_scheduler = lr_schedulers[0] - case _: - lr_scheduler = MultipleLRSchedulers(*lr_schedulers) - params[self.key_scheduler] = lr_scheduler - - -class ParamTransformerActorAndCriticLRScheduler(ParamTransformer): - def __init__( - self, - key_scheduler_factory_actor: str, - key_scheduler_factory_critic: str, - key_scheduler: str, - ): - self.key_factory_actor = key_scheduler_factory_actor - self.key_factory_critic = key_scheduler_factory_critic - self.key_scheduler = key_scheduler - - def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: - assert data.actor is not None and data.critic1 is not None - transformer = ParamTransformerMultiLRScheduler( - [ - (data.actor.optim, self.key_factory_actor), - (data.critic1.optim, self.key_factory_critic), - ], - self.key_scheduler, - ) - transformer.transform(params, data) - - -class ParamTransformerActorDualCriticsLRScheduler(ParamTransformer): - def __init__( - self, - key_scheduler_factory_actor: str, - key_scheduler_factory_critic1: str, - key_scheduler_factory_critic2: str, - key_scheduler: str, - ): - self.key_factory_actor = key_scheduler_factory_actor - self.key_factory_critic1 = key_scheduler_factory_critic1 - self.key_factory_critic2 = key_scheduler_factory_critic2 - self.key_scheduler = key_scheduler - - def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: - assert data.actor is not None and data.critic1 is not None and data.critic2 is not None - transformer = ParamTransformerMultiLRScheduler( - [ - (data.actor.optim, self.key_factory_actor), - (data.critic1.optim, self.key_factory_critic1), - (data.critic2.optim, self.key_factory_critic2), - ], - self.key_scheduler, - ) - transformer.transform(params, data) - - -class ParamTransformerAutoAlpha(ParamTransformer): - def __init__(self, key: str): - self.key = key - - def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None: - alpha = self.get(kwargs, self.key) - if isinstance(alpha, AutoAlphaFactory): - kwargs[self.key] = alpha.create_auto_alpha(data.envs, data.optim_factory, data.device) - - -class ParamTransformerNoiseFactory(ParamTransformerChangeValue): - def change_value(self, value: Any, data: ParamTransformerData) -> Any: - if isinstance(value, NoiseFactory): - value = value.create_noise(data.envs) - return value - - -class ParamTransformerFloatEnvParamFactory(ParamTransformerChangeValue): - def change_value(self, value: Any, data: ParamTransformerData) -> Any: - if isinstance(value, EnvValueFactory): - value = value.create_value(data.envs) - return value - - -class ParamTransformerActionScaling(ParamTransformerChangeValue): - def change_value(self, value: Any, data: ParamTransformerData) -> Any: - if value == "default": - return data.envs.get_type().is_continuous() - else: - return value - - -class GetParamTransformersProtocol(Protocol): - def _get_param_transformers(self) -> list[ParamTransformer]: - pass - - -@dataclass -class Params(GetParamTransformersProtocol, ToStringMixin): - def create_kwargs(self, data: ParamTransformerData) -> dict[str, Any]: - params = asdict(self) - for transformer in self._get_param_transformers(): - transformer.transform(params, data) - return params - - def _get_param_transformers(self) -> list[ParamTransformer]: - return [] - - -@dataclass -class ParamsMixinLearningRateWithScheduler(GetParamTransformersProtocol): - lr: float = 1e-3 - """the learning rate to use in the gradient-based optimizer""" - lr_scheduler_factory: LRSchedulerFactory | None = None - """factory for the creation of a learning rate scheduler""" - - def _get_param_transformers(self) -> list[ParamTransformer]: - return [ - ParamTransformerDrop("lr"), - ParamTransformerLRScheduler("lr_scheduler_factory", "lr_scheduler"), - ] - - -@dataclass -class ParamsMixinActorAndCritic(GetParamTransformersProtocol): - actor_lr: float = 1e-3 - """the learning rate to use for the actor network""" - critic_lr: float = 1e-3 - """the learning rate to use for the critic network""" - actor_lr_scheduler_factory: LRSchedulerFactory | None = None - """factory for the creation of a learning rate scheduler to use for the actor network (if any)""" - critic_lr_scheduler_factory: LRSchedulerFactory | None = None - """factory for the creation of a learning rate scheduler to use for the critic network (if any)""" - - def _get_param_transformers(self) -> list[ParamTransformer]: - return [ - ParamTransformerDrop("actor_lr", "critic_lr"), - ParamTransformerActorAndCriticLRScheduler( - "actor_lr_scheduler_factory", - "critic_lr_scheduler_factory", - "lr_scheduler", - ), - ] - - -@dataclass -class ParamsMixinActionScaling(GetParamTransformersProtocol): - action_scaling: bool | Literal["default"] = "default" - """whether to apply action scaling; when set to "default", it will be enabled for continuous action spaces""" - action_bound_method: Literal["clip", "tanh"] | None = "clip" - """ - method to bound action to range [-1, 1]. Only used if the action_space is continuous. - """ - - def _get_param_transformers(self) -> list[ParamTransformer]: - return [ParamTransformerActionScaling("action_scaling")] - - -@dataclass -class ParamsMixinExplorationNoise(GetParamTransformersProtocol): - exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = None - """ - If not None, add noise to actions for exploration. - This is useful when solving "hard exploration" problems. - It can either be a distribution, a factory for the creation of a distribution or "default". - When set to "default", use Gaussian noise with standard deviation 0.1. - """ - - def _get_param_transformers(self) -> list[ParamTransformer]: - return [ParamTransformerNoiseFactory("exploration_noise")] - - -@dataclass -class PGParams(Params, ParamsMixinActionScaling, ParamsMixinLearningRateWithScheduler): - discount_factor: float = 0.99 - """ - discount factor (gamma) for future rewards; must be in [0, 1] - """ - reward_normalization: bool = False - """ - if True, will normalize the returns by subtracting the running mean and dividing by the running - standard deviation. - """ - deterministic_eval: bool = False - """ - whether to use deterministic action (the dist's mode) instead of stochastic one during evaluation. - Does not affect training. - """ - - def __setstate__(self, state: dict[str, Any]) -> None: - setstate(PGParams, self, state, removed_properties=["dist_fn"]) - - def _get_param_transformers(self) -> list[ParamTransformer]: - transformers = super()._get_param_transformers() - transformers.extend(ParamsMixinActionScaling._get_param_transformers(self)) - transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self)) - return transformers - - -@dataclass -class ParamsMixinGeneralAdvantageEstimation(GetParamTransformersProtocol): - gae_lambda: float = 0.95 - """ - determines the blend between Monte Carlo and one-step temporal difference (TD) estimates of the advantage - function in general advantage estimation (GAE). - A value of 0 gives a fully TD-based estimate; lambda=1 gives a fully Monte Carlo estimate. - """ - max_batchsize: int = 256 - """the maximum size of the batch when computing general advantage estimation (GAE)""" - - def _get_param_transformers(self) -> list[ParamTransformer]: - return [] - - -@dataclass -class A2CParams(PGParams, ParamsMixinGeneralAdvantageEstimation): - vf_coef: float = 0.5 - """weight (coefficient) of the value loss in the loss function""" - ent_coef: float = 0.01 - """weight (coefficient) of the entropy loss in the loss function""" - max_grad_norm: float | None = None - """maximum norm for clipping gradients in backpropagation""" - - def _get_param_transformers(self) -> list[ParamTransformer]: - transformers = super()._get_param_transformers() - transformers.extend(ParamsMixinGeneralAdvantageEstimation._get_param_transformers(self)) - return transformers - - -@dataclass -class PPOParams(A2CParams): - eps_clip: float = 0.2 - """ - determines the range of allowed change in the policy during a policy update: - The ratio between the probabilities indicated by the new and old policy is - constrained to stay in the interval [1 - eps_clip, 1 + eps_clip]. - Small values thus force the new policy to stay close to the old policy. - Typical values range between 0.1 and 0.3. - The optimal epsilon depends on the environment; more stochastic environments may need larger epsilons. - """ - dual_clip: float | None = None - """ - determines the lower bound clipping for the probability ratio - (corresponds to parameter c in arXiv:1912.09729, Equation 5). - If set to None, dual clipping is not used and the bounds described in parameter eps_clip apply. - If set to a float value c, the lower bound is changed from 1 - eps_clip to c, - where c < 1 - eps_clip. - Setting c > 0 reduces policy oscillation and further stabilizes training. - Typical values are between 0 and 0.5. Smaller values provide more stability. - Setting c = 0 yields PPO with only the upper bound. - """ - value_clip: bool = False - """ - whether to apply clipping of the predicted value function during policy learning. - Value clipping discourages large changes in value predictions between updates. - Inaccurate value predictions can lead to bad policy updates, which can cause training instability. - Clipping values prevents sporadic large errors from skewing policy updates too much. - """ - advantage_normalization: bool = True - """whether to apply per mini-batch advantage normalization.""" - recompute_advantage: bool = False - """ - whether to recompute advantage every update repeat as described in - https://arxiv.org/pdf/2006.05990.pdf, Sec. 3.5. - The original PPO implementation splits the data in each policy iteration - step into individual transitions and then randomly assigns them to minibatches. - This makes it impossible to compute advantages as the temporal structure is broken. - Therefore, the advantages are computed once at the beginning of each policy iteration step and - then used in minibatch policy and value function optimization. - This results in higher diversity of data in each minibatch at the cost of - using slightly stale advantage estimations. - Enabling this option will, as a remedy to this problem, recompute the advantages at the beginning - of each pass over the data instead of just once per iteration. - """ - - -@dataclass -class NPGParams(PGParams, ParamsMixinGeneralAdvantageEstimation): - optim_critic_iters: int = 5 - """number of times to optimize critic network per update.""" - actor_step_size: float = 0.5 - """step size for actor update in natural gradient direction""" - advantage_normalization: bool = True - """whether to do per mini-batch advantage normalization.""" - - def _get_param_transformers(self) -> list[ParamTransformer]: - transformers = super()._get_param_transformers() - transformers.extend(ParamsMixinGeneralAdvantageEstimation._get_param_transformers(self)) - return transformers - - -@dataclass -class TRPOParams(NPGParams): - max_kl: float = 0.01 - """ - maximum KL divergence, used to constrain each actor network update. - """ - backtrack_coeff: float = 0.8 - """ - coefficient with which to reduce the step size when constraints are not met. - """ - max_backtracks: int = 10 - """maximum number of times to backtrack in line search when the constraints are not met.""" - - -@dataclass -class ParamsMixinActorAndDualCritics(GetParamTransformersProtocol): - actor_lr: float = 1e-3 - """the learning rate to use for the actor network""" - critic1_lr: float = 1e-3 - """the learning rate to use for the first critic network""" - critic2_lr: float = 1e-3 - """the learning rate to use for the second critic network""" - actor_lr_scheduler_factory: LRSchedulerFactory | None = None - """factory for the creation of a learning rate scheduler to use for the actor network (if any)""" - critic1_lr_scheduler_factory: LRSchedulerFactory | None = None - """factory for the creation of a learning rate scheduler to use for the first critic network (if any)""" - critic2_lr_scheduler_factory: LRSchedulerFactory | None = None - """factory for the creation of a learning rate scheduler to use for the second critic network (if any)""" - - def _get_param_transformers(self) -> list[ParamTransformer]: - return [ - ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"), - ParamTransformerActorDualCriticsLRScheduler( - "actor_lr_scheduler_factory", - "critic1_lr_scheduler_factory", - "critic2_lr_scheduler_factory", - "lr_scheduler", - ), - ] - - -@dataclass -class _SACParams(Params, ParamsMixinActorAndDualCritics): - tau: float = 0.005 - """controls the contribution of the entropy term in the overall optimization objective, - i.e. the desired amount of randomness in the optimal policy. - Higher values mean greater target entropy and therefore more randomness in the policy. - Lower values mean lower target entropy and therefore a more deterministic policy. - """ - gamma: float = 0.99 - """discount factor (gamma) for future rewards; must be in [0, 1]""" - alpha: float | AutoAlphaFactory = 0.2 - """ - controls the relative importance (coefficient) of the entropy term in the loss function. - This can be a constant or a factory for the creation of a representation that allows the - parameter to be automatically tuned; - use :class:`tianshou.highlevel.params.alpha.AutoAlphaFactoryDefault` for the standard - auto-adjusted alpha. - """ - estimation_step: int = 1 - """the number of steps to look ahead""" - - def _get_param_transformers(self) -> list[ParamTransformer]: - transformers = super()._get_param_transformers() - transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self)) - transformers.append(ParamTransformerAutoAlpha("alpha")) - return transformers - - -@dataclass -class SACParams(_SACParams, ParamsMixinExplorationNoise, ParamsMixinActionScaling): - deterministic_eval: bool = True - """ - whether to use deterministic action (mean of Gaussian policy) in evaluation mode instead of stochastic - action sampled by the policy. Does not affect training.""" - - def _get_param_transformers(self) -> list[ParamTransformer]: - transformers = super()._get_param_transformers() - transformers.extend(ParamsMixinExplorationNoise._get_param_transformers(self)) - transformers.extend(ParamsMixinActionScaling._get_param_transformers(self)) - return transformers - - -@dataclass -class DiscreteSACParams(_SACParams): - pass - - -@dataclass -class DQNParams(Params, ParamsMixinLearningRateWithScheduler): - discount_factor: float = 0.99 - """ - discount factor (gamma) for future rewards; must be in [0, 1] - """ - estimation_step: int = 1 - """the number of steps to look ahead""" - target_update_freq: int = 0 - """the target network update frequency (0 if no target network is to be used)""" - reward_normalization: bool = False - """whether to normalize the returns to Normal(0, 1)""" - is_double: bool = True - """whether to use double Q learning""" - clip_loss_grad: bool = False - """whether to clip the gradient of the loss in accordance with nature14236; this amounts to using the Huber - loss instead of the MSE loss.""" - - def _get_param_transformers(self) -> list[ParamTransformer]: - transformers = super()._get_param_transformers() - transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self)) - return transformers - - -@dataclass -class IQNParams(DQNParams): - sample_size: int = 32 - """the number of samples for policy evaluation""" - online_sample_size: int = 8 - """the number of samples for online model in training""" - target_sample_size: int = 8 - """the number of samples for target model in training.""" - num_quantiles: int = 200 - """the number of quantile midpoints in the inverse cumulative distribution function of the value""" - hidden_sizes: Sequence[int] = () - """hidden dimensions to use in the IQN network""" - num_cosines: int = 64 - """number of cosines to use in the IQN network""" - - def _get_param_transformers(self) -> list[ParamTransformer]: - transformers = super()._get_param_transformers() - transformers.append(ParamTransformerDrop("hidden_sizes", "num_cosines")) - return transformers - - -@dataclass -class DDPGParams( - Params, - ParamsMixinActorAndCritic, - ParamsMixinExplorationNoise, - ParamsMixinActionScaling, -): - tau: float = 0.005 - """ - controls the soft update of the target network. - It determines how slowly the target networks track the main networks. - Smaller tau means slower tracking and more stable learning. - """ - gamma: float = 0.99 - """discount factor (gamma) for future rewards; must be in [0, 1]""" - estimation_step: int = 1 - """the number of steps to look ahead.""" - - def _get_param_transformers(self) -> list[ParamTransformer]: - transformers = super()._get_param_transformers() - transformers.extend(ParamsMixinActorAndCritic._get_param_transformers(self)) - transformers.extend(ParamsMixinExplorationNoise._get_param_transformers(self)) - transformers.extend(ParamsMixinActionScaling._get_param_transformers(self)) - return transformers - - -@dataclass -class REDQParams(DDPGParams): - ensemble_size: int = 10 - """the number of sub-networks in the critic ensemble""" - subset_size: int = 2 - """the number of networks in the subset""" - alpha: float | AutoAlphaFactory = 0.2 - """ - controls the relative importance (coefficient) of the entropy term in the loss function. - This can be a constant or a factory for the creation of a representation that allows the - parameter to be automatically tuned; - use :class:`tianshou.highlevel.params.alpha.AutoAlphaFactoryDefault` for the standard - auto-adjusted alpha. - """ - estimation_step: int = 1 - """the number of steps to look ahead""" - actor_delay: int = 20 - """the number of critic updates before an actor update""" - deterministic_eval: bool = True - """ - whether to use deterministic action (the dist's mode) instead of stochastic one during evaluation. - Does not affect training. - """ - target_mode: Literal["mean", "min"] = "min" - - def _get_param_transformers(self) -> list[ParamTransformer]: - transformers = super()._get_param_transformers() - transformers.append(ParamTransformerAutoAlpha("alpha")) - return transformers - - -@dataclass -class TD3Params( - Params, - ParamsMixinActorAndDualCritics, - ParamsMixinExplorationNoise, - ParamsMixinActionScaling, -): - tau: float = 0.005 - """ - controls the soft update of the target network. - It determines how slowly the target networks track the main networks. - Smaller tau means slower tracking and more stable learning. - """ - gamma: float = 0.99 - """discount factor (gamma) for future rewards; must be in [0, 1]""" - policy_noise: float | FloatEnvValueFactory = 0.2 - """the scale of the the noise used in updating policy network""" - noise_clip: float | FloatEnvValueFactory = 0.5 - """determines the clipping range of the noise used in updating the policy network as [-noise_clip, noise_clip]""" - update_actor_freq: int = 2 - """the update frequency of actor network""" - estimation_step: int = 1 - """the number of steps to look ahead.""" - - def _get_param_transformers(self) -> list[ParamTransformer]: - transformers = super()._get_param_transformers() - transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self)) - transformers.extend(ParamsMixinExplorationNoise._get_param_transformers(self)) - transformers.extend(ParamsMixinActionScaling._get_param_transformers(self)) - transformers.append(ParamTransformerFloatEnvParamFactory("policy_noise")) - transformers.append(ParamTransformerFloatEnvParamFactory("noise_clip")) - return transformers diff --git a/tianshou/highlevel/params/policy_wrapper.py b/tianshou/highlevel/params/policy_wrapper.py deleted file mode 100644 index 43cbfed1e..000000000 --- a/tianshou/highlevel/params/policy_wrapper.py +++ /dev/null @@ -1,77 +0,0 @@ -from abc import ABC, abstractmethod -from collections.abc import Sequence -from typing import Generic, TypeVar - -from sensai.util.string import ToStringMixin - -from tianshou.highlevel.env import Environments -from tianshou.highlevel.module.core import TDevice -from tianshou.highlevel.module.intermediate import IntermediateModuleFactory -from tianshou.highlevel.optim import OptimizerFactory -from tianshou.policy import BasePolicy, ICMPolicy -from tianshou.utils.net.discrete import IntrinsicCuriosityModule - -TPolicyOut = TypeVar("TPolicyOut", bound=BasePolicy) - - -class PolicyWrapperFactory(Generic[TPolicyOut], ToStringMixin, ABC): - @abstractmethod - def create_wrapped_policy( - self, - policy: BasePolicy, - envs: Environments, - optim_factory: OptimizerFactory, - device: TDevice, - ) -> TPolicyOut: - pass - - -class PolicyWrapperFactoryIntrinsicCuriosity( - PolicyWrapperFactory[ICMPolicy], -): - def __init__( - self, - *, - feature_net_factory: IntermediateModuleFactory, - hidden_sizes: Sequence[int], - lr: float, - lr_scale: float, - reward_scale: float, - forward_loss_weight: float, - ): - self.feature_net_factory = feature_net_factory - self.hidden_sizes = hidden_sizes - self.lr = lr - self.lr_scale = lr_scale - self.reward_scale = reward_scale - self.forward_loss_weight = forward_loss_weight - - def create_wrapped_policy( - self, - policy: BasePolicy, - envs: Environments, - optim_factory: OptimizerFactory, - device: TDevice, - ) -> ICMPolicy: - feature_net = self.feature_net_factory.create_intermediate_module(envs, device) - action_dim = envs.get_action_shape() - if not isinstance(action_dim, int): - raise ValueError(f"Environment action shape must be an integer, got {action_dim}") - feature_dim = feature_net.output_dim - icm_net = IntrinsicCuriosityModule( - feature_net.module, - feature_dim, - action_dim, - hidden_sizes=self.hidden_sizes, - device=device, - ) - icm_optim = optim_factory.create_optimizer(icm_net, lr=self.lr) - return ICMPolicy( - policy=policy, - model=icm_net, - optim=icm_optim, - action_space=envs.get_action_space(), - lr_scale=self.lr_scale, - reward_scale=self.reward_scale, - forward_loss_weight=self.forward_loss_weight, - ).to(device) diff --git a/tianshou/highlevel/persistence.py b/tianshou/highlevel/persistence.py index 2758f5066..1c38d3602 100644 --- a/tianshou/highlevel/persistence.py +++ b/tianshou/highlevel/persistence.py @@ -141,10 +141,10 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: match self.mode: case self.Mode.POLICY_STATE_DICT: log.info(f"Saving policy state dictionary in {path}") - torch.save(world.policy.state_dict(), path) + torch.save(world.algorithm.state_dict(), path) case self.Mode.POLICY: log.info(f"Saving policy object in {path}") - torch.save(world.policy, path) + torch.save(world.algorithm, path) case _: raise NotImplementedError if self.additional_persistence is not None: diff --git a/tianshou/highlevel/trainer.py b/tianshou/highlevel/trainer.py index 498cc3173..452232d08 100644 --- a/tianshou/highlevel/trainer.py +++ b/tianshou/highlevel/trainer.py @@ -6,17 +6,18 @@ from sensai.util.string import ToStringMixin +from tianshou.algorithm import DQN, Algorithm +from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.highlevel.env import Environments from tianshou.highlevel.logger import TLogger -from tianshou.policy import BasePolicy, DQNPolicy -TPolicy = TypeVar("TPolicy", bound=BasePolicy) +TAlgorithm = TypeVar("TAlgorithm", bound=Algorithm) log = logging.getLogger(__name__) class TrainingContext: - def __init__(self, policy: TPolicy, envs: Environments, logger: TLogger): - self.policy = policy + def __init__(self, algorithm: TAlgorithm, envs: Environments, logger: TLogger): + self.algorithm = algorithm self.envs = envs self.logger = logger @@ -86,12 +87,13 @@ class EpochTrainCallbackDQNSetEps(EpochTrainCallback): stage in each epoch. """ - def __init__(self, eps_test: float): - self.eps_test = eps_test + def __init__(self, eps: float): + self.eps = eps def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: - policy = cast(DQNPolicy, context.policy) - policy.set_eps(self.eps_test) + algorithm = cast(DQN, context.algorithm) + policy: DiscreteQLearningPolicy = algorithm.policy + policy.set_eps_training(self.eps) class EpochTrainCallbackDQNEpsLinearDecay(EpochTrainCallback): @@ -105,7 +107,8 @@ def __init__(self, eps_train: float, eps_train_final: float, decay_steps: int = self.decay_steps = decay_steps def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: - policy = cast(DQNPolicy, context.policy) + algorithm = cast(DQN, context.algorithm) + policy: DiscreteQLearningPolicy = algorithm.policy logger = context.logger if env_step <= self.decay_steps: eps = self.eps_train - env_step / self.decay_steps * ( @@ -113,7 +116,7 @@ def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: ) else: eps = self.eps_train_final - policy.set_eps(eps) + policy.set_eps_training(eps) logger.write("train/env_step", env_step, {"train/eps": eps}) @@ -122,12 +125,13 @@ class EpochTestCallbackDQNSetEps(EpochTestCallback): stage in each epoch. """ - def __init__(self, eps_test: float): - self.eps_test = eps_test + def __init__(self, eps: float): + self.eps = eps def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None: - policy = cast(DQNPolicy, context.policy) - policy.set_eps(self.eps_test) + algorithm = cast(DQN, context.algorithm) + policy: DiscreteQLearningPolicy = algorithm.policy + policy.set_eps_inference(self.eps) class EpochStopCallbackRewardThreshold(EpochStopCallback): diff --git a/tianshou/highlevel/world.py b/tianshou/highlevel/world.py index 6db216b15..98ac46dea 100644 --- a/tianshou/highlevel/world.py +++ b/tianshou/highlevel/world.py @@ -3,11 +3,11 @@ from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: + from tianshou.algorithm import Algorithm from tianshou.data import BaseCollector from tianshou.highlevel.env import Environments from tianshou.highlevel.logger import TLogger - from tianshou.policy import BasePolicy - from tianshou.trainer import BaseTrainer + from tianshou.trainer import Trainer @dataclass(kw_only=True) @@ -15,13 +15,13 @@ class World: """Container for instances and configuration items that are relevant to an experiment.""" envs: "Environments" - policy: "BasePolicy" + algorithm: "Algorithm" train_collector: Optional["BaseCollector"] = None test_collector: Optional["BaseCollector"] = None logger: "TLogger" persist_directory: str restore_directory: str | None - trainer: Optional["BaseTrainer"] = None + trainer: Optional["Trainer"] = None def persist_path(self, filename: str) -> str: return os.path.abspath(os.path.join(self.persist_directory, filename)) diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index a9b944da8..e69de29bb 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -1,67 +0,0 @@ -"""Policy package.""" -# isort:skip_file - -from tianshou.policy.base import BasePolicy, TrainingStats -from tianshou.policy.random import MARLRandomPolicy -from tianshou.policy.modelfree.dqn import DQNPolicy -from tianshou.policy.modelfree.bdq import BranchingDQNPolicy -from tianshou.policy.modelfree.c51 import C51Policy -from tianshou.policy.modelfree.rainbow import RainbowPolicy -from tianshou.policy.modelfree.qrdqn import QRDQNPolicy -from tianshou.policy.modelfree.iqn import IQNPolicy -from tianshou.policy.modelfree.fqf import FQFPolicy -from tianshou.policy.modelfree.pg import PGPolicy -from tianshou.policy.modelfree.a2c import A2CPolicy -from tianshou.policy.modelfree.npg import NPGPolicy -from tianshou.policy.modelfree.ddpg import DDPGPolicy -from tianshou.policy.modelfree.ppo import PPOPolicy -from tianshou.policy.modelfree.trpo import TRPOPolicy -from tianshou.policy.modelfree.td3 import TD3Policy -from tianshou.policy.modelfree.sac import SACPolicy -from tianshou.policy.modelfree.redq import REDQPolicy -from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy -from tianshou.policy.imitation.base import ImitationPolicy -from tianshou.policy.imitation.bcq import BCQPolicy -from tianshou.policy.imitation.cql import CQLPolicy -from tianshou.policy.imitation.td3_bc import TD3BCPolicy -from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy -from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy -from tianshou.policy.imitation.discrete_crr import DiscreteCRRPolicy -from tianshou.policy.imitation.gail import GAILPolicy -from tianshou.policy.modelbased.psrl import PSRLPolicy -from tianshou.policy.modelbased.icm import ICMPolicy -from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager - -__all__ = [ - "BasePolicy", - "MARLRandomPolicy", - "DQNPolicy", - "BranchingDQNPolicy", - "C51Policy", - "RainbowPolicy", - "QRDQNPolicy", - "IQNPolicy", - "FQFPolicy", - "PGPolicy", - "A2CPolicy", - "NPGPolicy", - "DDPGPolicy", - "PPOPolicy", - "TRPOPolicy", - "TD3Policy", - "SACPolicy", - "REDQPolicy", - "DiscreteSACPolicy", - "ImitationPolicy", - "BCQPolicy", - "CQLPolicy", - "TD3BCPolicy", - "DiscreteBCQPolicy", - "DiscreteCQLPolicy", - "DiscreteCRRPolicy", - "GAILPolicy", - "PSRLPolicy", - "ICMPolicy", - "MultiAgentPolicyManager", - "TrainingStats", -] diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py deleted file mode 100644 index 6e21016d9..000000000 --- a/tianshou/policy/imitation/base.py +++ /dev/null @@ -1,115 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Generic, Literal, TypeVar, cast - -import gymnasium as gym -import numpy as np -import torch -import torch.nn.functional as F - -from tianshou.data import Batch, to_torch -from tianshou.data.batch import BatchProtocol -from tianshou.data.types import ( - ModelOutputBatchProtocol, - ObsBatchProtocol, - RolloutBatchProtocol, -) -from tianshou.policy import BasePolicy -from tianshou.policy.base import TLearningRateScheduler, TrainingStats - -# Dimension Naming Convention -# B - Batch Size -# A - Action -# D - Dist input (usually 2, loc and scale) -# H - Dimension of hidden, can be None - - -@dataclass(kw_only=True) -class ImitationTrainingStats(TrainingStats): - loss: float = 0.0 - - -TImitationTrainingStats = TypeVar("TImitationTrainingStats", bound=ImitationTrainingStats) - - -class ImitationPolicy(BasePolicy[TImitationTrainingStats], Generic[TImitationTrainingStats]): - """Implementation of vanilla imitation learning. - - :param actor: a model following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> a) - :param optim: for optimizing the model. - :param action_space: Env's action_space. - :param observation_space: Env's observation space. - :param action_scaling: if True, scale the action from [-1, 1] to the range - of action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - Only used if the action_space is continuous. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. - """ - - def __init__( - self, - *, - actor: torch.nn.Module, - optim: torch.optim.Optimizer, - action_space: gym.Space, - observation_space: gym.Space | None = None, - action_scaling: bool = False, - action_bound_method: Literal["clip", "tanh"] | None = "clip", - lr_scheduler: TLearningRateScheduler | None = None, - ) -> None: - super().__init__( - action_space=action_space, - observation_space=observation_space, - action_scaling=action_scaling, - action_bound_method=action_bound_method, - lr_scheduler=lr_scheduler, - ) - self.actor = actor - self.optim = optim - - def forward( - self, - batch: ObsBatchProtocol, - state: dict | BatchProtocol | np.ndarray | None = None, - **kwargs: Any, - ) -> ModelOutputBatchProtocol: - # TODO - ALGO-REFACTORING: marked for refactoring when Algorithm abstraction is introduced - if self.action_type == "discrete": - # If it's discrete, the "actor" is usually a critic that maps obs to action_values - # which then could be turned into logits or a Categorigal - action_values_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) - act_B = action_values_BA.argmax(dim=1) - result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH) - elif self.action_type == "continuous": - # If it's continuous, the actor would usually deliver something like loc, scale determining a - # Gaussian dist - dist_input_BD, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) - result = Batch(logits=dist_input_BD, act=dist_input_BD, state=hidden_BH) - else: - raise RuntimeError(f"Unknown {self.action_type=}, this shouldn't have happened!") - return cast(ModelOutputBatchProtocol, result) - - def learn( - self, - batch: RolloutBatchProtocol, - *ags: Any, - **kwargs: Any, - ) -> TImitationTrainingStats: - self.optim.zero_grad() - if self.action_type == "continuous": # regression - act = self(batch).act - act_target = to_torch(batch.act, dtype=torch.float32, device=act.device) - loss = F.mse_loss(act, act_target) - elif self.action_type == "discrete": # classification - act = F.log_softmax(self(batch).logits, dim=-1) - act_target = to_torch(batch.act, dtype=torch.long, device=act.device) - loss = F.nll_loss(act, act_target) - loss.backward() - self.optim.step() - - return ImitationTrainingStats(loss=loss.item()) # type: ignore diff --git a/tianshou/policy/imitation/bcq.py b/tianshou/policy/imitation/bcq.py deleted file mode 100644 index 991c4aace..000000000 --- a/tianshou/policy/imitation/bcq.py +++ /dev/null @@ -1,235 +0,0 @@ -import copy -from dataclasses import dataclass -from typing import Any, Generic, Literal, Self, TypeVar, cast - -import gymnasium as gym -import numpy as np -import torch -import torch.nn.functional as F - -from tianshou.data import Batch, to_torch -from tianshou.data.batch import BatchProtocol -from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol -from tianshou.policy import BasePolicy -from tianshou.policy.base import TLearningRateScheduler, TrainingStats -from tianshou.utils.net.continuous import VAE -from tianshou.utils.optim import clone_optimizer - - -@dataclass(kw_only=True) -class BCQTrainingStats(TrainingStats): - actor_loss: float - critic1_loss: float - critic2_loss: float - vae_loss: float - - -TBCQTrainingStats = TypeVar("TBCQTrainingStats", bound=BCQTrainingStats) - - -class BCQPolicy(BasePolicy[TBCQTrainingStats], Generic[TBCQTrainingStats]): - """Implementation of BCQ algorithm. arXiv:1812.02900. - - :param actor_perturbation: the actor perturbation. `(s, a -> perturbed a)` - :param actor_perturbation_optim: the optimizer for actor network. - :param critic: the first critic network. - :param critic_optim: the optimizer for the first critic network. - :param critic2: the second critic network. - :param critic2_optim: the optimizer for the second critic network. - :param vae: the VAE network, generating actions similar to those in batch. - :param vae_optim: the optimizer for the VAE network. - :param device: which device to create this model on. - :param gamma: discount factor, in [0, 1]. - :param tau: param for soft update of the target network. - :param lmbda: param for Clipped Double Q-learning. - :param forward_sampled_times: the number of sampled actions in forward function. - The policy samples many actions and takes the action with the max value. - :param num_sampled_action: the number of sampled actions in calculating target Q. - The algorithm samples several actions using VAE, and perturbs each action to get the target Q. - :param observation_space: Env's observation space. - :param action_scaling: if True, scale the action from [-1, 1] to the range - of action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - Only used if the action_space is continuous. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. - """ - - def __init__( - self, - *, - actor_perturbation: torch.nn.Module, - actor_perturbation_optim: torch.optim.Optimizer, - critic: torch.nn.Module, - critic_optim: torch.optim.Optimizer, - action_space: gym.Space, - vae: VAE, - vae_optim: torch.optim.Optimizer, - critic2: torch.nn.Module | None = None, - critic2_optim: torch.optim.Optimizer | None = None, - # TODO: remove? Many policies don't use this - device: str | torch.device = "cpu", - gamma: float = 0.99, - tau: float = 0.005, - lmbda: float = 0.75, - forward_sampled_times: int = 100, - num_sampled_action: int = 10, - observation_space: gym.Space | None = None, - action_scaling: bool = False, - action_bound_method: Literal["clip", "tanh"] | None = "clip", - lr_scheduler: TLearningRateScheduler | None = None, - ) -> None: - # actor is Perturbation! - super().__init__( - action_space=action_space, - observation_space=observation_space, - action_scaling=action_scaling, - action_bound_method=action_bound_method, - lr_scheduler=lr_scheduler, - ) - self.actor_perturbation = actor_perturbation - self.actor_perturbation_target = copy.deepcopy(self.actor_perturbation) - self.actor_perturbation_optim = actor_perturbation_optim - - self.critic = critic - self.critic_target = copy.deepcopy(self.critic) - self.critic_optim = critic_optim - - critic2 = critic2 or copy.deepcopy(critic) - critic2_optim = critic2_optim or clone_optimizer(critic_optim, critic2.parameters()) - self.critic2 = critic2 - self.critic2_target = copy.deepcopy(self.critic2) - self.critic2_optim = critic2_optim - - self.vae = vae - self.vae_optim = vae_optim - - self.gamma = gamma - self.tau = tau - self.lmbda = lmbda - self.device = device - self.forward_sampled_times = forward_sampled_times - self.num_sampled_action = num_sampled_action - - def train(self, mode: bool = True) -> Self: - """Set the module in training mode, except for the target network.""" - self.training = mode - self.actor_perturbation.train(mode) - self.critic.train(mode) - self.critic2.train(mode) - return self - - def forward( - self, - batch: ObsBatchProtocol, - state: dict | BatchProtocol | np.ndarray | None = None, - **kwargs: Any, - ) -> ActBatchProtocol: - """Compute action over the given batch data.""" - # There is "obs" in the Batch - # obs_group: several groups. Each group has a state. - obs_group: torch.Tensor = to_torch(batch.obs, device=self.device) - act_group = [] - for obs_orig in obs_group: - # now obs is (state_dim) - obs = (obs_orig.reshape(1, -1)).repeat(self.forward_sampled_times, 1) - # now obs is (forward_sampled_times, state_dim) - - # decode(obs) generates action and actor perturbs it - act = self.actor_perturbation(obs, self.vae.decode(obs)) - # now action is (forward_sampled_times, action_dim) - q1 = self.critic(obs, act) - # q1 is (forward_sampled_times, 1) - max_indice = q1.argmax(0) - act_group.append(act[max_indice].cpu().data.numpy().flatten()) - act_group = np.array(act_group) - return cast(ActBatchProtocol, Batch(act=act_group)) - - def sync_weight(self) -> None: - """Soft-update the weight for the target network.""" - self.soft_update(self.critic_target, self.critic, self.tau) - self.soft_update(self.critic2_target, self.critic2, self.tau) - self.soft_update(self.actor_perturbation_target, self.actor_perturbation, self.tau) - - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBCQTrainingStats: - # batch: obs, act, rew, done, obs_next. (numpy array) - # (batch_size, state_dim) - batch: Batch = to_torch(batch, dtype=torch.float, device=self.device) - obs, act = batch.obs, batch.act - batch_size = obs.shape[0] - - # mean, std: (state.shape[0], latent_dim) - recon, mean, std = self.vae(obs, act) - recon_loss = F.mse_loss(act, recon) - # (....) is D_KL( N(mu, sigma) || N(0,1) ) - KL_loss = (-torch.log(std) + (std.pow(2) + mean.pow(2) - 1) / 2).mean() - vae_loss = recon_loss + KL_loss / 2 - - self.vae_optim.zero_grad() - vae_loss.backward() - self.vae_optim.step() - - # critic training: - with torch.no_grad(): - # repeat num_sampled_action times - obs_next = batch.obs_next.repeat_interleave(self.num_sampled_action, dim=0) - # now obs_next: (num_sampled_action * batch_size, state_dim) - - # perturbed action generated by VAE - act_next = self.vae.decode(obs_next) - # now obs_next: (num_sampled_action * batch_size, action_dim) - target_Q1 = self.critic_target(obs_next, act_next) - target_Q2 = self.critic2_target(obs_next, act_next) - - # Clipped Double Q-learning - target_Q = self.lmbda * torch.min(target_Q1, target_Q2) + (1 - self.lmbda) * torch.max( - target_Q1, - target_Q2, - ) - # now target_Q: (num_sampled_action * batch_size, 1) - - # the max value of Q - target_Q = target_Q.reshape(batch_size, -1).max(dim=1)[0].reshape(-1, 1) - # now target_Q: (batch_size, 1) - - target_Q = ( - batch.rew.reshape(-1, 1) - + torch.logical_not(batch.done).reshape(-1, 1) * self.gamma * target_Q - ) - target_Q = target_Q.float() - - current_Q1 = self.critic(obs, act) - current_Q2 = self.critic2(obs, act) - - critic1_loss = F.mse_loss(current_Q1, target_Q) - critic2_loss = F.mse_loss(current_Q2, target_Q) - - self.critic_optim.zero_grad() - self.critic2_optim.zero_grad() - critic1_loss.backward() - critic2_loss.backward() - self.critic_optim.step() - self.critic2_optim.step() - - sampled_act = self.vae.decode(obs) - perturbed_act = self.actor_perturbation(obs, sampled_act) - - # max - actor_loss = -self.critic(obs, perturbed_act).mean() - - self.actor_perturbation_optim.zero_grad() - actor_loss.backward() - self.actor_perturbation_optim.step() - - # update target network - self.sync_weight() - - return BCQTrainingStats( # type: ignore - actor_loss=actor_loss.item(), - critic1_loss=critic1_loss.item(), - critic2_loss=critic2_loss.item(), - vae_loss=vae_loss.item(), - ) diff --git a/tianshou/policy/imitation/cql.py b/tianshou/policy/imitation/cql.py deleted file mode 100644 index 66438c758..000000000 --- a/tianshou/policy/imitation/cql.py +++ /dev/null @@ -1,402 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Literal, Self, TypeVar, cast - -import gymnasium as gym -import numpy as np -import torch -import torch.nn.functional as F -from overrides import override -from torch.nn.utils import clip_grad_norm_ - -from tianshou.data import Batch, ReplayBuffer, to_torch -from tianshou.data.buffer.base import TBuffer -from tianshou.data.types import RolloutBatchProtocol -from tianshou.exploration import BaseNoise -from tianshou.policy import SACPolicy -from tianshou.policy.base import TLearningRateScheduler -from tianshou.policy.modelfree.sac import SACTrainingStats -from tianshou.utils.conversion import to_optional_float -from tianshou.utils.net.continuous import ActorProb - - -@dataclass(kw_only=True) -class CQLTrainingStats(SACTrainingStats): - """A data structure for storing loss statistics of the CQL learn step.""" - - cql_alpha: float | None = None - cql_alpha_loss: float | None = None - - -TCQLTrainingStats = TypeVar("TCQLTrainingStats", bound=CQLTrainingStats) - - -class CQLPolicy(SACPolicy[TCQLTrainingStats]): - """Implementation of CQL algorithm. arXiv:2006.04779. - - :param actor: the actor network following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> a) - :param actor_optim: The optimizer for actor network. - :param critic: The first critic network. - :param critic_optim: The optimizer for the first critic network. - :param action_space: Env's action space. - :param critic2: the second critic network. (s, a -> Q(s, a)). - If None, use the same network as critic (via deepcopy). - :param critic2_optim: the optimizer for the second critic network. - If None, clone critic_optim to use for critic2.parameters(). - :param cql_alpha_lr: The learning rate of cql_log_alpha. - :param cql_weight: - :param tau: Parameter for soft update of the target network. - :param gamma: Discount factor, in [0, 1]. - :param alpha: Entropy regularization coefficient or a tuple - (target_entropy, log_alpha, alpha_optim) for automatic tuning. - :param temperature: - :param with_lagrange: Whether to use Lagrange. - TODO: extend documentation - what does this mean? - :param lagrange_threshold: The value of tau in CQL(Lagrange). - :param min_action: The minimum value of each dimension of action. - :param max_action: The maximum value of each dimension of action. - :param num_repeat_actions: The number of times the action is repeated when calculating log-sum-exp. - :param alpha_min: Lower bound for clipping cql_alpha. - :param alpha_max: Upper bound for clipping cql_alpha. - :param clip_grad: Clip_grad for updating critic network. - :param calibrated: calibrate Q-values as in CalQL paper `arXiv:2303.05479`. - Useful for offline pre-training followed by online training, - and also was observed to achieve better results than vanilla cql. - :param device: Which device to create this model on. - :param estimation_step: Estimation steps. - :param exploration_noise: Type of exploration noise. - :param deterministic_eval: Flag for deterministic evaluation. - :param action_scaling: Flag for action scaling. - :param action_bound_method: Method for action bounding. Only used if the - action_space is continuous. - :param observation_space: Env's Observation space. - :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in - optimizer in each policy.update(). - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. - """ - - def __init__( - self, - *, - actor: ActorProb, - actor_optim: torch.optim.Optimizer, - critic: torch.nn.Module, - critic_optim: torch.optim.Optimizer, - action_space: gym.spaces.Box, - critic2: torch.nn.Module | None = None, - critic2_optim: torch.optim.Optimizer | None = None, - cql_alpha_lr: float = 1e-4, - cql_weight: float = 1.0, - tau: float = 0.005, - gamma: float = 0.99, - alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] = 0.2, - temperature: float = 1.0, - with_lagrange: bool = True, - lagrange_threshold: float = 10.0, - min_action: float = -1.0, - max_action: float = 1.0, - num_repeat_actions: int = 10, - alpha_min: float = 0.0, - alpha_max: float = 1e6, - clip_grad: float = 1.0, - calibrated: bool = True, - # TODO: why does this one have device? Almost no other policies have it - device: str | torch.device = "cpu", - estimation_step: int = 1, - exploration_noise: BaseNoise | Literal["default"] | None = None, - deterministic_eval: bool = True, - action_scaling: bool = True, - action_bound_method: Literal["clip"] | None = "clip", - observation_space: gym.Space | None = None, - lr_scheduler: TLearningRateScheduler | None = None, - ) -> None: - super().__init__( - actor=actor, - actor_optim=actor_optim, - critic=critic, - critic_optim=critic_optim, - action_space=action_space, - critic2=critic2, - critic2_optim=critic2_optim, - tau=tau, - gamma=gamma, - deterministic_eval=deterministic_eval, - alpha=alpha, - exploration_noise=exploration_noise, - estimation_step=estimation_step, - action_scaling=action_scaling, - action_bound_method=action_bound_method, - observation_space=observation_space, - lr_scheduler=lr_scheduler, - ) - # There are _target_entropy, _log_alpha, _alpha_optim in SACPolicy. - self.device = device - self.temperature = temperature - self.with_lagrange = with_lagrange - self.lagrange_threshold = lagrange_threshold - - self.cql_weight = cql_weight - - self.cql_log_alpha = torch.tensor([0.0], requires_grad=True) - self.cql_alpha_optim = torch.optim.Adam([self.cql_log_alpha], lr=cql_alpha_lr) - self.cql_log_alpha = self.cql_log_alpha.to(device) - - self.min_action = min_action - self.max_action = max_action - - self.num_repeat_actions = num_repeat_actions - - self.alpha_min = alpha_min - self.alpha_max = alpha_max - self.clip_grad = clip_grad - - self.calibrated = calibrated - - def train(self, mode: bool = True) -> Self: - """Set the module in training mode, except for the target network.""" - self.training = mode - self.actor.train(mode) - self.critic.train(mode) - self.critic2.train(mode) - return self - - def sync_weight(self) -> None: - """Soft-update the weight for the target network.""" - self.soft_update(self.critic_old, self.critic, self.tau) - self.soft_update(self.critic2_old, self.critic2, self.tau) - - def actor_pred(self, obs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - batch = Batch(obs=obs, info=[None] * len(obs)) - obs_result = self(batch) - return obs_result.act, obs_result.log_prob - - def calc_actor_loss(self, obs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - act_pred, log_pi = self.actor_pred(obs) - q1 = self.critic(obs, act_pred) - q2 = self.critic2(obs, act_pred) - min_Q = torch.min(q1, q2) - # self.alpha: float | torch.Tensor - actor_loss = (self.alpha * log_pi - min_Q).mean() - # actor_loss.shape: (), log_pi.shape: (batch_size, 1) - return actor_loss, log_pi - - def calc_pi_values( - self, - obs_pi: torch.Tensor, - obs_to_pred: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - act_pred, log_pi = self.actor_pred(obs_pi) - - q1 = self.critic(obs_to_pred, act_pred) - q2 = self.critic2(obs_to_pred, act_pred) - - return q1 - log_pi.detach(), q2 - log_pi.detach() - - def calc_random_values( - self, - obs: torch.Tensor, - act: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - random_value1 = self.critic(obs, act) - random_log_prob1 = np.log(0.5 ** act.shape[-1]) - - random_value2 = self.critic2(obs, act) - random_log_prob2 = np.log(0.5 ** act.shape[-1]) - - return random_value1 - random_log_prob1, random_value2 - random_log_prob2 - - @override - def process_buffer(self, buffer: TBuffer) -> TBuffer: - """If `self.calibrated = True`, adds `calibration_returns` to buffer._meta. - - :param buffer: - :return: - """ - if self.calibrated: - # otherwise _meta hack cannot work - assert isinstance(buffer, ReplayBuffer) - batch, indices = buffer.sample(0) - returns, _ = self.compute_episodic_return( - batch=batch, - buffer=buffer, - indices=indices, - gamma=self.gamma, - gae_lambda=1.0, - ) - # TODO: don't access _meta directly - buffer._meta = cast( - RolloutBatchProtocol, - Batch(**buffer._meta.__dict__, calibration_returns=returns), - ) - return buffer - - def process_fn( - self, - batch: RolloutBatchProtocol, - buffer: ReplayBuffer, - indices: np.ndarray, - ) -> RolloutBatchProtocol: - # TODO: mypy rightly complains here b/c the design violates - # Liskov Substitution Principle - # DDPGPolicy.process_fn() results in a batch with returns but - # CQLPolicy.process_fn() doesn't add the returns. - # Should probably be fixed! - return batch - - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TCQLTrainingStats: # type: ignore - batch: Batch = to_torch(batch, dtype=torch.float, device=self.device) - obs, act, rew, obs_next = batch.obs, batch.act, batch.rew, batch.obs_next - batch_size = obs.shape[0] - - # compute actor loss and update actor - actor_loss, log_pi = self.calc_actor_loss(obs) - self.actor_optim.zero_grad() - actor_loss.backward() - self.actor_optim.step() - - alpha_loss = None - # compute alpha loss - if self.is_auto_alpha: - log_pi = log_pi + self.target_entropy - alpha_loss = -(self.log_alpha * log_pi.detach()).mean() - self.alpha_optim.zero_grad() - # update log_alpha - alpha_loss.backward() - self.alpha_optim.step() - # update alpha - # TODO: it's probably a bad idea to track both alpha and log_alpha in different fields - self.alpha = self.log_alpha.detach().exp() - - # compute target_Q - with torch.no_grad(): - act_next, new_log_pi = self.actor_pred(obs_next) - - target_Q1 = self.critic_old(obs_next, act_next) - target_Q2 = self.critic2_old(obs_next, act_next) - - target_Q = torch.min(target_Q1, target_Q2) - self.alpha * new_log_pi - - target_Q = rew + torch.logical_not(batch.done) * self.gamma * target_Q.flatten() - target_Q = target_Q.float() - # shape: (batch_size) - - # compute critic loss - current_Q1 = self.critic(obs, act).flatten() - current_Q2 = self.critic2(obs, act).flatten() - # shape: (batch_size) - - critic1_loss = F.mse_loss(current_Q1, target_Q) - critic2_loss = F.mse_loss(current_Q2, target_Q) - - # CQL - random_actions = ( - torch.FloatTensor(batch_size * self.num_repeat_actions, act.shape[-1]) - .uniform_(-self.min_action, self.max_action) - .to(self.device) - ) - - obs_len = len(obs.shape) - repeat_size = [1, self.num_repeat_actions] + [1] * (obs_len - 1) - view_size = [batch_size * self.num_repeat_actions, *list(obs.shape[1:])] - tmp_obs = obs.unsqueeze(1).repeat(*repeat_size).view(*view_size) - tmp_obs_next = obs_next.unsqueeze(1).repeat(*repeat_size).view(*view_size) - # tmp_obs & tmp_obs_next: (batch_size * num_repeat, state_dim) - - current_pi_value1, current_pi_value2 = self.calc_pi_values(tmp_obs, tmp_obs) - next_pi_value1, next_pi_value2 = self.calc_pi_values(tmp_obs_next, tmp_obs) - - random_value1, random_value2 = self.calc_random_values(tmp_obs, random_actions) - - for value in [ - current_pi_value1, - current_pi_value2, - next_pi_value1, - next_pi_value2, - random_value1, - random_value2, - ]: - value.reshape(batch_size, self.num_repeat_actions, 1) - - if self.calibrated: - returns = ( - batch.calibration_returns.unsqueeze(1) - .repeat( - (1, self.num_repeat_actions), - ) - .view(-1, 1) - ) - random_value1 = torch.max(random_value1, returns) - random_value2 = torch.max(random_value2, returns) - - current_pi_value1 = torch.max(current_pi_value1, returns) - current_pi_value2 = torch.max(current_pi_value2, returns) - - next_pi_value1 = torch.max(next_pi_value1, returns) - next_pi_value2 = torch.max(next_pi_value2, returns) - - # cat q values - cat_q1 = torch.cat([random_value1, current_pi_value1, next_pi_value1], 1) - cat_q2 = torch.cat([random_value2, current_pi_value2, next_pi_value2], 1) - # shape: (batch_size, 3 * num_repeat, 1) - - cql1_scaled_loss = ( - torch.logsumexp(cat_q1 / self.temperature, dim=1).mean() - * self.cql_weight - * self.temperature - - current_Q1.mean() * self.cql_weight - ) - cql2_scaled_loss = ( - torch.logsumexp(cat_q2 / self.temperature, dim=1).mean() - * self.cql_weight - * self.temperature - - current_Q2.mean() * self.cql_weight - ) - # shape: (1) - - cql_alpha_loss = None - cql_alpha = None - if self.with_lagrange: - cql_alpha = torch.clamp( - self.cql_log_alpha.exp(), - self.alpha_min, - self.alpha_max, - ) - cql1_scaled_loss = cql_alpha * (cql1_scaled_loss - self.lagrange_threshold) - cql2_scaled_loss = cql_alpha * (cql2_scaled_loss - self.lagrange_threshold) - - self.cql_alpha_optim.zero_grad() - cql_alpha_loss = -(cql1_scaled_loss + cql2_scaled_loss) * 0.5 - cql_alpha_loss.backward(retain_graph=True) - self.cql_alpha_optim.step() - - critic1_loss = critic1_loss + cql1_scaled_loss - critic2_loss = critic2_loss + cql2_scaled_loss - - # update critic - self.critic_optim.zero_grad() - critic1_loss.backward(retain_graph=True) - # clip grad, prevent the vanishing gradient problem - # It doesn't seem necessary - clip_grad_norm_(self.critic.parameters(), self.clip_grad) - self.critic_optim.step() - - self.critic2_optim.zero_grad() - critic2_loss.backward() - clip_grad_norm_(self.critic2.parameters(), self.clip_grad) - self.critic2_optim.step() - - self.sync_weight() - - return CQLTrainingStats( # type: ignore[return-value] - actor_loss=to_optional_float(actor_loss), - critic1_loss=to_optional_float(critic1_loss), - critic2_loss=to_optional_float(critic2_loss), - alpha=to_optional_float(self.alpha), - alpha_loss=to_optional_float(alpha_loss), - cql_alpha_loss=to_optional_float(cql_alpha_loss), - cql_alpha=to_optional_float(cql_alpha), - ) diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py deleted file mode 100644 index b5258c141..000000000 --- a/tianshou/policy/imitation/discrete_bcq.py +++ /dev/null @@ -1,179 +0,0 @@ -import math -from dataclasses import dataclass -from typing import Any, Self, TypeVar, cast - -import gymnasium as gym -import numpy as np -import torch -import torch.nn.functional as F - -from tianshou.data import Batch, ReplayBuffer, to_torch -from tianshou.data.types import ( - ImitationBatchProtocol, - ObsBatchProtocol, - RolloutBatchProtocol, -) -from tianshou.policy import DQNPolicy -from tianshou.policy.base import TLearningRateScheduler -from tianshou.policy.modelfree.dqn import DQNTrainingStats - -float_info = torch.finfo(torch.float32) -INF = float_info.max - - -@dataclass(kw_only=True) -class DiscreteBCQTrainingStats(DQNTrainingStats): - q_loss: float - i_loss: float - reg_loss: float - - -TDiscreteBCQTrainingStats = TypeVar("TDiscreteBCQTrainingStats", bound=DiscreteBCQTrainingStats) - - -class DiscreteBCQPolicy(DQNPolicy[TDiscreteBCQTrainingStats]): - """Implementation of discrete BCQ algorithm. arXiv:1910.01708. - - :param model: a model following the rules (s_B -> action_values_BA) - :param imitator: a model following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> imitation_logits) - :param optim: a torch.optim for optimizing the model. - :param discount_factor: in [0, 1]. - :param estimation_step: the number of steps to look ahead - :param target_update_freq: the target network update frequency. - :param eval_eps: the epsilon-greedy noise added in evaluation. - :param unlikely_action_threshold: the threshold (tau) for unlikely - actions, as shown in Equ. (17) in the paper. - :param imitation_logits_penalty: regularization weight for imitation - logits. - :param estimation_step: the number of steps to look ahead. - :param target_update_freq: the target network update frequency (0 if - you do not use the target network). - :param reward_normalization: normalize the **returns** to Normal(0, 1). - TODO: rename to return_normalization? - :param is_double: use double dqn. - :param clip_loss_grad: clip the gradient of the loss in accordance - with nature14236; this amounts to using the Huber loss instead of - the MSE loss. - :param observation_space: Env's observation space. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. - """ - - def __init__( - self, - *, - model: torch.nn.Module, - imitator: torch.nn.Module, - optim: torch.optim.Optimizer, - action_space: gym.spaces.Discrete, - discount_factor: float = 0.99, - estimation_step: int = 1, - target_update_freq: int = 8000, - eval_eps: float = 1e-3, - unlikely_action_threshold: float = 0.3, - imitation_logits_penalty: float = 1e-2, - reward_normalization: bool = False, - is_double: bool = True, - clip_loss_grad: bool = False, - observation_space: gym.Space | None = None, - lr_scheduler: TLearningRateScheduler | None = None, - ) -> None: - super().__init__( - model=model, - optim=optim, - action_space=action_space, - discount_factor=discount_factor, - estimation_step=estimation_step, - target_update_freq=target_update_freq, - reward_normalization=reward_normalization, - is_double=is_double, - clip_loss_grad=clip_loss_grad, - observation_space=observation_space, - lr_scheduler=lr_scheduler, - ) - assert ( - target_update_freq > 0 - ), f"BCQ needs target_update_freq>0 but got: {target_update_freq}." - self.imitator = imitator - assert ( - 0.0 <= unlikely_action_threshold < 1.0 - ), f"unlikely_action_threshold should be in [0, 1) but got: {unlikely_action_threshold}" - if unlikely_action_threshold > 0: - self._log_tau = math.log(unlikely_action_threshold) - else: - self._log_tau = -np.inf - assert 0.0 <= eval_eps < 1.0 - self.eps = eval_eps - self._weight_reg = imitation_logits_penalty - - def train(self, mode: bool = True) -> Self: - self.training = mode - self.model.train(mode) - self.imitator.train(mode) - return self - - def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: - batch = buffer[indices] # batch.obs_next: s_{t+n} - next_obs_batch = Batch(obs=batch.obs_next, info=[None] * len(batch)) - # target_Q = Q_old(s_, argmax(Q_new(s_, *))) - act = self(next_obs_batch).act - target_q, _ = self.model_old(batch.obs_next) - return target_q[np.arange(len(act)), act] - - def forward( # type: ignore - self, - batch: ObsBatchProtocol, - state: dict | Batch | np.ndarray | None = None, - **kwargs: Any, - ) -> ImitationBatchProtocol: - # TODO: Liskov substitution principle is violated here, the superclass - # produces a batch with the field logits, but this one doesn't. - # Should be fixed in the future! - q_value, state = self.model(batch.obs, state=state, info=batch.info) - if self.max_action_num is None: - self.max_action_num = q_value.shape[1] - imitation_logits, _ = self.imitator(batch.obs, state=state, info=batch.info) - - # mask actions for argmax - ratio = imitation_logits - imitation_logits.max(dim=-1, keepdim=True).values - mask = (ratio < self._log_tau).float() - act = (q_value - INF * mask).argmax(dim=-1) - - result = Batch(act=act, state=state, q_value=q_value, imitation_logits=imitation_logits) - return cast(ImitationBatchProtocol, result) - - def learn( - self, - batch: RolloutBatchProtocol, - *args: Any, - **kwargs: Any, - ) -> TDiscreteBCQTrainingStats: - if self._iter % self.freq == 0: - self.sync_weight() - self._iter += 1 - - target_q = batch.returns.flatten() - result = self(batch) - imitation_logits = result.imitation_logits - current_q = result.q_value[np.arange(len(target_q)), batch.act] - act = to_torch(batch.act, dtype=torch.long, device=target_q.device) - q_loss = F.smooth_l1_loss(current_q, target_q) - i_loss = F.nll_loss(F.log_softmax(imitation_logits, dim=-1), act) - reg_loss = imitation_logits.pow(2).mean() - loss = q_loss + i_loss + self._weight_reg * reg_loss - - self.optim.zero_grad() - loss.backward() - self.optim.step() - - return DiscreteBCQTrainingStats( # type: ignore[return-value] - loss=loss.item(), - q_loss=q_loss.item(), - i_loss=i_loss.item(), - reg_loss=reg_loss.item(), - ) diff --git a/tianshou/policy/imitation/discrete_cql.py b/tianshou/policy/imitation/discrete_cql.py deleted file mode 100644 index b63f83e11..000000000 --- a/tianshou/policy/imitation/discrete_cql.py +++ /dev/null @@ -1,124 +0,0 @@ -from dataclasses import dataclass -from typing import Any, TypeVar - -import gymnasium as gym -import numpy as np -import torch -import torch.nn.functional as F - -from tianshou.data import to_torch -from tianshou.data.types import RolloutBatchProtocol -from tianshou.policy import QRDQNPolicy -from tianshou.policy.base import TLearningRateScheduler -from tianshou.policy.modelfree.qrdqn import QRDQNTrainingStats - - -@dataclass(kw_only=True) -class DiscreteCQLTrainingStats(QRDQNTrainingStats): - cql_loss: float - qr_loss: float - - -TDiscreteCQLTrainingStats = TypeVar("TDiscreteCQLTrainingStats", bound=DiscreteCQLTrainingStats) - - -class DiscreteCQLPolicy(QRDQNPolicy[TDiscreteCQLTrainingStats]): - """Implementation of discrete Conservative Q-Learning algorithm. arXiv:2006.04779. - - :param model: a model following the rules (s_B -> action_values_BA) - :param optim: a torch.optim for optimizing the model. - :param action_space: Env's action space. - :param min_q_weight: the weight for the cql loss. - :param discount_factor: in [0, 1]. - :param num_quantiles: the number of quantile midpoints in the inverse - cumulative distribution function of the value. - :param estimation_step: the number of steps to look ahead. - :param target_update_freq: the target network update frequency (0 if - you do not use the target network). - :param reward_normalization: normalize the **returns** to Normal(0, 1). - TODO: rename to return_normalization? - :param is_double: use double dqn. - :param clip_loss_grad: clip the gradient of the loss in accordance - with nature14236; this amounts to using the Huber loss instead of - the MSE loss. - :param observation_space: Env's observation space. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - .. seealso:: - Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed - explanation. - """ - - def __init__( - self, - *, - model: torch.nn.Module, - optim: torch.optim.Optimizer, - action_space: gym.spaces.Discrete, - min_q_weight: float = 10.0, - discount_factor: float = 0.99, - num_quantiles: int = 200, - estimation_step: int = 1, - target_update_freq: int = 0, - reward_normalization: bool = False, - is_double: bool = True, - clip_loss_grad: bool = False, - observation_space: gym.Space | None = None, - lr_scheduler: TLearningRateScheduler | None = None, - ) -> None: - super().__init__( - model=model, - optim=optim, - action_space=action_space, - discount_factor=discount_factor, - num_quantiles=num_quantiles, - estimation_step=estimation_step, - target_update_freq=target_update_freq, - reward_normalization=reward_normalization, - is_double=is_double, - clip_loss_grad=clip_loss_grad, - observation_space=observation_space, - lr_scheduler=lr_scheduler, - ) - self.min_q_weight = min_q_weight - - def learn( - self, - batch: RolloutBatchProtocol, - *args: Any, - **kwargs: Any, - ) -> TDiscreteCQLTrainingStats: - if self._target and self._iter % self.freq == 0: - self.sync_weight() - self.optim.zero_grad() - weight = batch.pop("weight", 1.0) - all_dist = self(batch).logits - act = to_torch(batch.act, dtype=torch.long, device=all_dist.device) - curr_dist = all_dist[np.arange(len(act)), act, :].unsqueeze(2) - target_dist = batch.returns.unsqueeze(1) - # calculate each element's difference between curr_dist and target_dist - dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") - huber_loss = ( - (dist_diff * (self.tau_hat - (target_dist - curr_dist).detach().le(0.0).float()).abs()) - .sum(-1) - .mean(1) - ) - qr_loss = (huber_loss * weight).mean() - # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ - # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 - batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer - # add CQL loss - q = self.compute_q_value(all_dist, None) - dataset_expec = q.gather(1, act.unsqueeze(1)).mean() - negative_sampling = q.logsumexp(1).mean() - min_q_loss = negative_sampling - dataset_expec - loss = qr_loss + min_q_loss * self.min_q_weight - loss.backward() - self.optim.step() - self._iter += 1 - - return DiscreteCQLTrainingStats( # type: ignore[return-value] - loss=loss.item(), - qr_loss=qr_loss.item(), - cql_loss=min_q_loss.item(), - ) diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py deleted file mode 100644 index 9c54129da..000000000 --- a/tianshou/policy/imitation/discrete_crr.py +++ /dev/null @@ -1,153 +0,0 @@ -from copy import deepcopy -from dataclasses import dataclass -from typing import Any, Literal, TypeVar - -import gymnasium as gym -import torch -import torch.nn.functional as F -from torch.distributions import Categorical - -from tianshou.data import to_torch, to_torch_as -from tianshou.data.types import RolloutBatchProtocol -from tianshou.policy.base import TLearningRateScheduler -from tianshou.policy.modelfree.pg import PGPolicy, PGTrainingStats -from tianshou.utils.net.discrete import Actor, Critic - - -@dataclass -class DiscreteCRRTrainingStats(PGTrainingStats): - actor_loss: float - critic_loss: float - cql_loss: float - - -TDiscreteCRRTrainingStats = TypeVar("TDiscreteCRRTrainingStats", bound=DiscreteCRRTrainingStats) - - -class DiscreteCRRPolicy(PGPolicy[TDiscreteCRRTrainingStats]): - r"""Implementation of discrete Critic Regularized Regression. arXiv:2006.15134. - - :param actor: the actor network following the rules: - If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`). - If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`). - :param critic: the action-value critic (i.e., Q function) - network. (s -> Q(s, \*)) - :param optim: a torch.optim for optimizing the model. - :param discount_factor: in [0, 1]. - :param str policy_improvement_mode: type of the weight function f. Possible - values: "binary"/"exp"/"all". - :param ratio_upper_bound: when policy_improvement_mode is "exp", the value - of the exp function is upper-bounded by this parameter. - :param beta: when policy_improvement_mode is "exp", this is the denominator - of the exp function. - :param min_q_weight: weight for CQL loss/regularizer. Default to 10. - :param target_update_freq: the target network update frequency (0 if - you do not use the target network). - :param reward_normalization: if True, will normalize the *returns* - by subtracting the running mean and dividing by the running standard deviation. - Can be detrimental to performance! See TODO in process_fn. - :param observation_space: Env's observation space. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - .. seealso:: - Please refer to :class:`~tianshou.policy.PGPolicy` for more detailed - explanation. - """ - - def __init__( - self, - *, - actor: torch.nn.Module | Actor, - critic: torch.nn.Module | Critic, - optim: torch.optim.Optimizer, - action_space: gym.spaces.Discrete, - discount_factor: float = 0.99, - policy_improvement_mode: Literal["exp", "binary", "all"] = "exp", - ratio_upper_bound: float = 20.0, - beta: float = 1.0, - min_q_weight: float = 10.0, - target_update_freq: int = 0, - reward_normalization: bool = False, - observation_space: gym.Space | None = None, - lr_scheduler: TLearningRateScheduler | None = None, - ) -> None: - super().__init__( - actor=actor, - optim=optim, - action_space=action_space, - dist_fn=lambda x: Categorical(logits=x), - discount_factor=discount_factor, - reward_normalization=reward_normalization, - observation_space=observation_space, - action_scaling=False, - action_bound_method=None, - lr_scheduler=lr_scheduler, - ) - self.critic = critic - self._target = target_update_freq > 0 - self._freq = target_update_freq - self._iter = 0 - if self._target: - self.actor_old = deepcopy(self.actor) - self.actor_old.eval() - self.critic_old = deepcopy(self.critic) - self.critic_old.eval() - else: - self.actor_old = self.actor - self.critic_old = self.critic - self._policy_improvement_mode = policy_improvement_mode - self._ratio_upper_bound = ratio_upper_bound - self._beta = beta - self._min_q_weight = min_q_weight - - def sync_weight(self) -> None: - self.actor_old.load_state_dict(self.actor.state_dict()) - self.critic_old.load_state_dict(self.critic.state_dict()) - - def learn( # type: ignore - self, - batch: RolloutBatchProtocol, - *args: Any, - **kwargs: Any, - ) -> TDiscreteCRRTrainingStats: - if self._target and self._iter % self._freq == 0: - self.sync_weight() - self.optim.zero_grad() - q_t = self.critic(batch.obs) - act = to_torch(batch.act, dtype=torch.long, device=q_t.device) - qa_t = q_t.gather(1, act.unsqueeze(1)) - # Critic loss - with torch.no_grad(): - target_a_t, _ = self.actor_old(batch.obs_next) - target_m = Categorical(logits=target_a_t) - q_t_target = self.critic_old(batch.obs_next) - rew = to_torch_as(batch.rew, q_t_target) - expected_target_q = (q_t_target * target_m.probs).sum(-1, keepdim=True) - expected_target_q[batch.done > 0] = 0.0 - target = rew.unsqueeze(1) + self.gamma * expected_target_q - critic_loss = 0.5 * F.mse_loss(qa_t, target) - # Actor loss - act_target, _ = self.actor(batch.obs) - dist = Categorical(logits=act_target) - expected_policy_q = (q_t * dist.probs).sum(-1, keepdim=True) - advantage = qa_t - expected_policy_q - if self._policy_improvement_mode == "binary": - actor_loss_coef = (advantage > 0).float() - elif self._policy_improvement_mode == "exp": - actor_loss_coef = (advantage / self._beta).exp().clamp(0, self._ratio_upper_bound) - else: - actor_loss_coef = 1.0 # effectively behavior cloning - actor_loss = (-dist.log_prob(act) * actor_loss_coef).mean() - # CQL loss/regularizer - min_q_loss = (q_t.logsumexp(1) - qa_t).mean() - loss = actor_loss + critic_loss + self._min_q_weight * min_q_loss - loss.backward() - self.optim.step() - self._iter += 1 - - return DiscreteCRRTrainingStats( # type: ignore[return-value] - loss=loss.item(), - actor_loss=actor_loss.item(), - critic_loss=critic_loss.item(), - cql_loss=min_q_loss.item(), - ) diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py deleted file mode 100644 index 524f04001..000000000 --- a/tianshou/policy/imitation/gail.py +++ /dev/null @@ -1,198 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Literal, TypeVar - -import gymnasium as gym -import numpy as np -import torch -import torch.nn.functional as F - -from tianshou.data import ( - ReplayBuffer, - SequenceSummaryStats, - to_numpy, - to_torch, -) -from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol -from tianshou.policy import PPOPolicy -from tianshou.policy.base import TLearningRateScheduler -from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont -from tianshou.policy.modelfree.ppo import PPOTrainingStats -from tianshou.utils.net.continuous import ActorProb, Critic -from tianshou.utils.net.discrete import Actor as DiscreteActor -from tianshou.utils.net.discrete import Critic as DiscreteCritic - - -@dataclass(kw_only=True) -class GailTrainingStats(PPOTrainingStats): - disc_loss: SequenceSummaryStats - acc_pi: SequenceSummaryStats - acc_exp: SequenceSummaryStats - - -TGailTrainingStats = TypeVar("TGailTrainingStats", bound=GailTrainingStats) - - -class GAILPolicy(PPOPolicy[TGailTrainingStats]): - r"""Implementation of Generative Adversarial Imitation Learning. arXiv:1606.03476. - - :param actor: the actor network following the rules: - If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`). - If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`). - :param critic: the critic network. (s -> V(s)) - :param optim: the optimizer for actor and critic network. - :param dist_fn: distribution class for computing the action. - :param action_space: env's action space - :param expert_buffer: the replay buffer containing expert experience. - :param disc_net: the discriminator network with input dim equals - state dim plus action dim and output dim equals 1. - :param disc_optim: the optimizer for the discriminator network. - :param disc_update_num: the number of discriminator grad steps per model grad step. - :param eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original - paper. - :param dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5, - where c > 1 is a constant indicating the lower bound. Set to None - to disable dual-clip PPO. - :param value_clip: a parameter mentioned in arXiv:1811.02553v3 Sec. 4.1. - :param advantage_normalization: whether to do per mini-batch advantage - normalization. - :param recompute_advantage: whether to recompute advantage every update - repeat according to https://arxiv.org/pdf/2006.05990.pdf Sec. 3.5. - :param vf_coef: weight for value loss. - :param ent_coef: weight for entropy loss. - :param max_grad_norm: clipping gradients in back propagation. - :param gae_lambda: in [0, 1], param for Generalized Advantage Estimation. - :param max_batchsize: the maximum size of the batch when computing GAE. - :param discount_factor: in [0, 1]. - :param reward_normalization: normalize estimated values to have std close to 1. - :param deterministic_eval: if True, use deterministic evaluation. - :param observation_space: the space of the observation. - :param action_scaling: if True, scale the action from [-1, 1] to the range of - action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.PPOPolicy` for more detailed - explanation. - """ - - def __init__( - self, - *, - actor: torch.nn.Module | ActorProb | DiscreteActor, - critic: torch.nn.Module | Critic | DiscreteCritic, - optim: torch.optim.Optimizer, - dist_fn: TDistFnDiscrOrCont, - action_space: gym.Space, - expert_buffer: ReplayBuffer, - disc_net: torch.nn.Module, - disc_optim: torch.optim.Optimizer, - disc_update_num: int = 4, - eps_clip: float = 0.2, - dual_clip: float | None = None, - value_clip: bool = False, - advantage_normalization: bool = True, - recompute_advantage: bool = False, - vf_coef: float = 0.5, - ent_coef: float = 0.01, - max_grad_norm: float | None = None, - gae_lambda: float = 0.95, - max_batchsize: int = 256, - discount_factor: float = 0.99, - # TODO: rename to return_normalization? - reward_normalization: bool = False, - deterministic_eval: bool = False, - observation_space: gym.Space | None = None, - action_scaling: bool = True, - action_bound_method: Literal["clip", "tanh"] | None = "clip", - lr_scheduler: TLearningRateScheduler | None = None, - ) -> None: - super().__init__( - actor=actor, - critic=critic, - optim=optim, - dist_fn=dist_fn, - action_space=action_space, - eps_clip=eps_clip, - dual_clip=dual_clip, - value_clip=value_clip, - advantage_normalization=advantage_normalization, - recompute_advantage=recompute_advantage, - vf_coef=vf_coef, - ent_coef=ent_coef, - max_grad_norm=max_grad_norm, - gae_lambda=gae_lambda, - max_batchsize=max_batchsize, - discount_factor=discount_factor, - reward_normalization=reward_normalization, - deterministic_eval=deterministic_eval, - observation_space=observation_space, - action_scaling=action_scaling, - action_bound_method=action_bound_method, - lr_scheduler=lr_scheduler, - ) - self.disc_net = disc_net - self.disc_optim = disc_optim - self.disc_update_num = disc_update_num - self.expert_buffer = expert_buffer - self.action_dim = actor.output_dim - - def process_fn( - self, - batch: RolloutBatchProtocol, - buffer: ReplayBuffer, - indices: np.ndarray, - ) -> LogpOldProtocol: - """Pre-process the data from the provided replay buffer. - - Used in :meth:`update`. Check out :ref:`process_fn` for more information. - """ - # update reward - with torch.no_grad(): - batch.rew = to_numpy(-F.logsigmoid(-self.disc(batch)).flatten()) - return super().process_fn(batch, buffer, indices) - - def disc(self, batch: RolloutBatchProtocol) -> torch.Tensor: - obs = to_torch(batch.obs, device=self.disc_net.device) - act = to_torch(batch.act, device=self.disc_net.device) - return self.disc_net(torch.cat([obs, act], dim=1)) - - def learn( # type: ignore - self, - batch: RolloutBatchProtocol, - batch_size: int | None, - repeat: int, - **kwargs: Any, - ) -> TGailTrainingStats: - # update discriminator - losses = [] - acc_pis = [] - acc_exps = [] - bsz = len(batch) // self.disc_update_num - for b in batch.split(bsz, merge_last=True): - logits_pi = self.disc(b) - exp_b = self.expert_buffer.sample(bsz)[0] - logits_exp = self.disc(exp_b) - loss_pi = -F.logsigmoid(-logits_pi).mean() - loss_exp = -F.logsigmoid(logits_exp).mean() - loss_disc = loss_pi + loss_exp - self.disc_optim.zero_grad() - loss_disc.backward() - self.disc_optim.step() - losses.append(loss_disc.item()) - acc_pis.append((logits_pi < 0).float().mean().item()) - acc_exps.append((logits_exp > 0).float().mean().item()) - # update policy - ppo_loss_stat = super().learn(batch, batch_size, repeat, **kwargs) - - disc_losses_summary = SequenceSummaryStats.from_sequence(losses) - acc_pi_summary = SequenceSummaryStats.from_sequence(acc_pis) - acc_exps_summary = SequenceSummaryStats.from_sequence(acc_exps) - - return GailTrainingStats( # type: ignore[return-value] - **ppo_loss_stat.__dict__, - disc_loss=disc_losses_summary, - acc_pi=acc_pi_summary, - acc_exp=acc_exps_summary, - ) diff --git a/tianshou/policy/imitation/td3_bc.py b/tianshou/policy/imitation/td3_bc.py deleted file mode 100644 index f4b2bfe91..000000000 --- a/tianshou/policy/imitation/td3_bc.py +++ /dev/null @@ -1,130 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Literal, TypeVar - -import gymnasium as gym -import torch -import torch.nn.functional as F - -from tianshou.data import to_torch_as -from tianshou.data.types import RolloutBatchProtocol -from tianshou.exploration import BaseNoise, GaussianNoise -from tianshou.policy import TD3Policy -from tianshou.policy.base import TLearningRateScheduler -from tianshou.policy.modelfree.td3 import TD3TrainingStats - - -@dataclass(kw_only=True) -class TD3BCTrainingStats(TD3TrainingStats): - pass - - -TTD3BCTrainingStats = TypeVar("TTD3BCTrainingStats", bound=TD3BCTrainingStats) - - -class TD3BCPolicy(TD3Policy[TTD3BCTrainingStats]): - """Implementation of TD3+BC. arXiv:2106.06860. - - :param actor: the actor network following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> actions) - :param actor_optim: the optimizer for actor network. - :param critic: the first critic network. (s, a -> Q(s, a)) - :param critic_optim: the optimizer for the first critic network. - :param action_space: Env's action space. Should be gym.spaces.Box. - :param critic2: the second critic network. (s, a -> Q(s, a)). - If None, use the same network as critic (via deepcopy). - :param critic2_optim: the optimizer for the second critic network. - If None, clone critic_optim to use for critic2.parameters(). - :param tau: param for soft update of the target network. - :param gamma: discount factor, in [0, 1]. - :param exploration_noise: add noise to action for exploration. - This is useful when solving "hard exploration" problems. - "default" is equivalent to GaussianNoise(sigma=0.1). - :param policy_noise: the noise used in updating policy network. - :param update_actor_freq: the update frequency of actor network. - :param noise_clip: the clipping range used in updating policy network. - :param alpha: the value of alpha, which controls the weight for TD3 learning - relative to behavior cloning. - :param observation_space: Env's observation space. - :param action_scaling: if True, scale the action from [-1, 1] to the range - of action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - Only used if the action_space is continuous. - :param lr_scheduler: a learning rate scheduler that adjusts the learning rate - in optimizer in each policy.update() - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. - """ - - def __init__( - self, - *, - actor: torch.nn.Module, - actor_optim: torch.optim.Optimizer, - critic: torch.nn.Module, - critic_optim: torch.optim.Optimizer, - action_space: gym.Space, - critic2: torch.nn.Module | None = None, - critic2_optim: torch.optim.Optimizer | None = None, - tau: float = 0.005, - gamma: float = 0.99, - exploration_noise: BaseNoise | None = GaussianNoise(sigma=0.1), - policy_noise: float = 0.2, - update_actor_freq: int = 2, - noise_clip: float = 0.5, - # TODO: same name as alpha in SAC and REDQ, which also inherit from DDPGPolicy. Rename? - alpha: float = 2.5, - estimation_step: int = 1, - observation_space: gym.Space | None = None, - action_scaling: bool = True, - action_bound_method: Literal["clip"] | None = "clip", - lr_scheduler: TLearningRateScheduler | None = None, - ) -> None: - super().__init__( - actor=actor, - actor_optim=actor_optim, - critic=critic, - critic_optim=critic_optim, - action_space=action_space, - critic2=critic2, - critic2_optim=critic2_optim, - tau=tau, - gamma=gamma, - exploration_noise=exploration_noise, - policy_noise=policy_noise, - noise_clip=noise_clip, - update_actor_freq=update_actor_freq, - estimation_step=estimation_step, - action_scaling=action_scaling, - action_bound_method=action_bound_method, - observation_space=observation_space, - lr_scheduler=lr_scheduler, - ) - self.alpha = alpha - - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TTD3BCTrainingStats: # type: ignore - # critic 1&2 - td1, critic1_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) - td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim) - batch.weight = (td1 + td2) / 2.0 # prio-buffer - - # actor - if self._cnt % self.update_actor_freq == 0: - act = self(batch, eps=0.0).act - q_value = self.critic(batch.obs, act) - lmbda = self.alpha / q_value.abs().mean().detach() - actor_loss = -lmbda * q_value.mean() + F.mse_loss(act, to_torch_as(batch.act, act)) - self.actor_optim.zero_grad() - actor_loss.backward() - self._last = actor_loss.item() - self.actor_optim.step() - self.sync_weight() - self._cnt += 1 - - return TD3BCTrainingStats( # type: ignore[return-value] - actor_loss=self._last, - critic1_loss=critic1_loss.item(), - critic2_loss=critic2_loss.item(), - ) diff --git a/tianshou/policy/modelbased/icm.py b/tianshou/policy/modelbased/icm.py deleted file mode 100644 index 9a603b7de..000000000 --- a/tianshou/policy/modelbased/icm.py +++ /dev/null @@ -1,176 +0,0 @@ -from typing import Any, Literal, Self, TypeVar - -import gymnasium as gym -import numpy as np -import torch -import torch.nn.functional as F - -from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch -from tianshou.data.batch import BatchProtocol -from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol -from tianshou.policy import BasePolicy -from tianshou.policy.base import ( - TLearningRateScheduler, - TrainingStats, - TrainingStatsWrapper, - TTrainingStats, -) -from tianshou.utils.net.discrete import IntrinsicCuriosityModule - - -class ICMTrainingStats(TrainingStatsWrapper): - def __init__( - self, - wrapped_stats: TrainingStats, - *, - icm_loss: float, - icm_forward_loss: float, - icm_inverse_loss: float, - ) -> None: - self.icm_loss = icm_loss - self.icm_forward_loss = icm_forward_loss - self.icm_inverse_loss = icm_inverse_loss - super().__init__(wrapped_stats) - - -class ICMPolicy(BasePolicy[ICMTrainingStats]): - """Implementation of Intrinsic Curiosity Module. arXiv:1705.05363. - - :param policy: a base policy to add ICM to. - :param model: the ICM model. - :param optim: a torch.optim for optimizing the model. - :param lr_scale: the scaling factor for ICM learning. - :param forward_loss_weight: the weight for forward model loss. - :param observation_space: Env's observation space. - :param action_scaling: if True, scale the action from [-1, 1] to the range - of action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - Only used if the action_space is continuous. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. - """ - - def __init__( - self, - *, - policy: BasePolicy[TTrainingStats], - model: IntrinsicCuriosityModule, - optim: torch.optim.Optimizer, - lr_scale: float, - reward_scale: float, - forward_loss_weight: float, - action_space: gym.Space, - observation_space: gym.Space | None = None, - action_scaling: bool = False, - action_bound_method: Literal["clip", "tanh"] | None = "clip", - lr_scheduler: TLearningRateScheduler | None = None, - ) -> None: - super().__init__( - action_space=action_space, - observation_space=observation_space, - action_scaling=action_scaling, - action_bound_method=action_bound_method, - lr_scheduler=lr_scheduler, - ) - self.policy = policy - self.model = model - self.optim = optim - self.lr_scale = lr_scale - self.reward_scale = reward_scale - self.forward_loss_weight = forward_loss_weight - - def train(self, mode: bool = True) -> Self: - """Set the module in training mode.""" - self.policy.train(mode) - self.training = mode - self.model.train(mode) - return self - - def forward( - self, - batch: ObsBatchProtocol, - state: dict | BatchProtocol | np.ndarray | None = None, - **kwargs: Any, - ) -> ActBatchProtocol: - """Compute action over the given batch data by inner policy. - - .. seealso:: - - Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for - more detailed explanation. - """ - return self.policy.forward(batch, state, **kwargs) - - _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") - - def exploration_noise( - self, - act: _TArrOrActBatch, - batch: ObsBatchProtocol, - ) -> _TArrOrActBatch: - return self.policy.exploration_noise(act, batch) - - def set_eps(self, eps: float) -> None: - """Set the eps for epsilon-greedy exploration.""" - if hasattr(self.policy, "set_eps"): - self.policy.set_eps(eps) - else: - raise NotImplementedError - - def process_fn( - self, - batch: RolloutBatchProtocol, - buffer: ReplayBuffer, - indices: np.ndarray, - ) -> RolloutBatchProtocol: - """Pre-process the data from the provided replay buffer. - - Used in :meth:`update`. Check out :ref:`process_fn` for more information. - """ - mse_loss, act_hat = self.model(batch.obs, batch.act, batch.obs_next) - batch.policy = Batch(orig_rew=batch.rew, act_hat=act_hat, mse_loss=mse_loss) - batch.rew += to_numpy(mse_loss * self.reward_scale) - return self.policy.process_fn(batch, buffer, indices) - - def post_process_fn( - self, - batch: BatchProtocol, - buffer: ReplayBuffer, - indices: np.ndarray, - ) -> None: - """Post-process the data from the provided replay buffer. - - Typical usage is to update the sampling weight in prioritized - experience replay. Used in :meth:`update`. - """ - self.policy.post_process_fn(batch, buffer, indices) - batch.rew = batch.policy.orig_rew # restore original reward - - def learn( - self, - batch: RolloutBatchProtocol, - *args: Any, - **kwargs: Any, - ) -> ICMTrainingStats: - training_stat = self.policy.learn(batch, **kwargs) - self.optim.zero_grad() - act_hat = batch.policy.act_hat - act = to_torch(batch.act, dtype=torch.long, device=act_hat.device) - inverse_loss = F.cross_entropy(act_hat, act).mean() - forward_loss = batch.policy.mse_loss.mean() - loss = ( - (1 - self.forward_loss_weight) * inverse_loss + self.forward_loss_weight * forward_loss - ) * self.lr_scale - loss.backward() - self.optim.step() - - return ICMTrainingStats( - training_stat, - icm_loss=loss.item(), - icm_forward_loss=forward_loss.item(), - icm_inverse_loss=inverse_loss.item(), - ) diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py deleted file mode 100644 index d41ccb463..000000000 --- a/tianshou/policy/modelfree/a2c.py +++ /dev/null @@ -1,206 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Generic, Literal, TypeVar, cast - -import gymnasium as gym -import numpy as np -import torch -import torch.nn.functional as F -from torch import nn - -from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as -from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol -from tianshou.policy import PGPolicy -from tianshou.policy.base import TLearningRateScheduler, TrainingStats -from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont -from tianshou.utils.net.common import ActorCritic -from tianshou.utils.net.continuous import ActorProb, Critic -from tianshou.utils.net.discrete import Actor as DiscreteActor -from tianshou.utils.net.discrete import Critic as DiscreteCritic - - -@dataclass(kw_only=True) -class A2CTrainingStats(TrainingStats): - loss: SequenceSummaryStats - actor_loss: SequenceSummaryStats - vf_loss: SequenceSummaryStats - ent_loss: SequenceSummaryStats - - -TA2CTrainingStats = TypeVar("TA2CTrainingStats", bound=A2CTrainingStats) - - -# TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. -class A2CPolicy(PGPolicy[TA2CTrainingStats], Generic[TA2CTrainingStats]): # type: ignore[type-var] - """Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783. - - :param actor: the actor network following the rules: - If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`). - If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`). - :param critic: the critic network. (s -> V(s)) - :param optim: the optimizer for actor and critic network. - :param dist_fn: distribution class for computing the action. - :param action_space: env's action space - :param vf_coef: weight for value loss. - :param ent_coef: weight for entropy loss. - :param max_grad_norm: clipping gradients in back propagation. - :param gae_lambda: in [0, 1], param for Generalized Advantage Estimation. - :param max_batchsize: the maximum size of the batch when computing GAE. - :param discount_factor: in [0, 1]. - :param reward_normalization: normalize estimated values to have std close to 1. - :param deterministic_eval: if True, use deterministic evaluation. - :param observation_space: the space of the observation. - :param action_scaling: if True, scale the action from [-1, 1] to the range of - action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - Only used if the action_space is continuous. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. - """ - - def __init__( - self, - *, - actor: torch.nn.Module | ActorProb | DiscreteActor, - critic: torch.nn.Module | Critic | DiscreteCritic, - optim: torch.optim.Optimizer, - dist_fn: TDistFnDiscrOrCont, - action_space: gym.Space, - vf_coef: float = 0.5, - ent_coef: float = 0.01, - max_grad_norm: float | None = None, - gae_lambda: float = 0.95, - max_batchsize: int = 256, - discount_factor: float = 0.99, - # TODO: rename to return_normalization? - reward_normalization: bool = False, - deterministic_eval: bool = False, - observation_space: gym.Space | None = None, - action_scaling: bool = True, - action_bound_method: Literal["clip", "tanh"] | None = "clip", - lr_scheduler: TLearningRateScheduler | None = None, - ) -> None: - super().__init__( - actor=actor, - optim=optim, - dist_fn=dist_fn, - action_space=action_space, - discount_factor=discount_factor, - reward_normalization=reward_normalization, - deterministic_eval=deterministic_eval, - observation_space=observation_space, - action_scaling=action_scaling, - action_bound_method=action_bound_method, - lr_scheduler=lr_scheduler, - ) - self.critic = critic - assert 0.0 <= gae_lambda <= 1.0, f"GAE lambda should be in [0, 1] but got: {gae_lambda}" - self.gae_lambda = gae_lambda - self.vf_coef = vf_coef - self.ent_coef = ent_coef - self.max_grad_norm = max_grad_norm - self.max_batchsize = max_batchsize - self._actor_critic = ActorCritic(self.actor, self.critic) - - def process_fn( - self, - batch: RolloutBatchProtocol, - buffer: ReplayBuffer, - indices: np.ndarray, - ) -> BatchWithAdvantagesProtocol: - batch = self._compute_returns(batch, buffer, indices) - batch.act = to_torch_as(batch.act, batch.v_s) - return batch - - def _compute_returns( - self, - batch: RolloutBatchProtocol, - buffer: ReplayBuffer, - indices: np.ndarray, - ) -> BatchWithAdvantagesProtocol: - v_s, v_s_ = [], [] - with torch.no_grad(): - for minibatch in batch.split(self.max_batchsize, shuffle=False, merge_last=True): - v_s.append(self.critic(minibatch.obs)) - v_s_.append(self.critic(minibatch.obs_next)) - batch.v_s = torch.cat(v_s, dim=0).flatten() # old value - v_s = batch.v_s.cpu().numpy() - v_s_ = torch.cat(v_s_, dim=0).flatten().cpu().numpy() - # when normalizing values, we do not minus self.ret_rms.mean to be numerically - # consistent with OPENAI baselines' value normalization pipeline. Empirical - # study also shows that "minus mean" will harm performances a tiny little bit - # due to unknown reasons (on Mujoco envs, not confident, though). - # TODO: see todo in PGPolicy.process_fn - if self.rew_norm: # unnormalize v_s & v_s_ - v_s = v_s * np.sqrt(self.ret_rms.var + self._eps) - v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) - unnormalized_returns, advantages = self.compute_episodic_return( - batch, - buffer, - indices, - v_s_, - v_s, - gamma=self.gamma, - gae_lambda=self.gae_lambda, - ) - if self.rew_norm: - batch.returns = unnormalized_returns / np.sqrt(self.ret_rms.var + self._eps) - self.ret_rms.update(unnormalized_returns) - else: - batch.returns = unnormalized_returns - batch.returns = to_torch_as(batch.returns, batch.v_s) - batch.adv = to_torch_as(advantages, batch.v_s) - return cast(BatchWithAdvantagesProtocol, batch) - - # TODO: mypy complains b/c signature is different from superclass, although - # it's compatible. Can this be fixed? - def learn( # type: ignore - self, - batch: RolloutBatchProtocol, - batch_size: int | None, - repeat: int, - *args: Any, - **kwargs: Any, - ) -> TA2CTrainingStats: - losses, actor_losses, vf_losses, ent_losses = [], [], [], [] - split_batch_size = batch_size or -1 - for _ in range(repeat): - for minibatch in batch.split(split_batch_size, merge_last=True): - # calculate loss for actor - dist = self(minibatch).dist - log_prob = dist.log_prob(minibatch.act) - log_prob = log_prob.reshape(len(minibatch.adv), -1).transpose(0, 1) - actor_loss = -(log_prob * minibatch.adv).mean() - # calculate loss for critic - value = self.critic(minibatch.obs).flatten() - vf_loss = F.mse_loss(minibatch.returns, value) - # calculate regularization and overall loss - ent_loss = dist.entropy().mean() - loss = actor_loss + self.vf_coef * vf_loss - self.ent_coef * ent_loss - self.optim.zero_grad() - loss.backward() - if self.max_grad_norm: # clip large gradient - nn.utils.clip_grad_norm_( - self._actor_critic.parameters(), - max_norm=self.max_grad_norm, - ) - self.optim.step() - actor_losses.append(actor_loss.item()) - vf_losses.append(vf_loss.item()) - ent_losses.append(ent_loss.item()) - losses.append(loss.item()) - - loss_summary_stat = SequenceSummaryStats.from_sequence(losses) - actor_loss_summary_stat = SequenceSummaryStats.from_sequence(actor_losses) - vf_loss_summary_stat = SequenceSummaryStats.from_sequence(vf_losses) - ent_loss_summary_stat = SequenceSummaryStats.from_sequence(ent_losses) - - return A2CTrainingStats( # type: ignore[return-value] - loss=loss_summary_stat, - actor_loss=actor_loss_summary_stat, - vf_loss=vf_loss_summary_stat, - ent_loss=ent_loss_summary_stat, - ) diff --git a/tianshou/policy/modelfree/bdq.py b/tianshou/policy/modelfree/bdq.py deleted file mode 100644 index d7196a92b..000000000 --- a/tianshou/policy/modelfree/bdq.py +++ /dev/null @@ -1,204 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Literal, TypeVar, cast - -import gymnasium as gym -import numpy as np -import torch - -from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch, to_torch_as -from tianshou.data.batch import BatchProtocol -from tianshou.data.types import ( - ActBatchProtocol, - BatchWithReturnsProtocol, - ModelOutputBatchProtocol, - ObsBatchProtocol, - RolloutBatchProtocol, -) -from tianshou.policy import DQNPolicy -from tianshou.policy.base import TLearningRateScheduler -from tianshou.policy.modelfree.dqn import DQNTrainingStats -from tianshou.utils.net.common import BranchingNet - - -@dataclass(kw_only=True) -class BDQNTrainingStats(DQNTrainingStats): - pass - - -TBDQNTrainingStats = TypeVar("TBDQNTrainingStats", bound=BDQNTrainingStats) - - -class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]): - """Implementation of the Branching dual Q network arXiv:1711.08946. - - :param model: BranchingNet mapping (obs, state, info) -> action_values_BA. - :param optim: a torch.optim for optimizing the model. - :param discount_factor: in [0, 1]. - :param estimation_step: the number of steps to look ahead. - :param target_update_freq: the target network update frequency (0 if - you do not use the target network). - :param reward_normalization: normalize the **returns** to Normal(0, 1). - TODO: rename to return_normalization? - :param is_double: use double dqn. - :param clip_loss_grad: clip the gradient of the loss in accordance - with nature14236; this amounts to using the Huber loss instead of - the MSE loss. - :param observation_space: Env's observation space. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. - """ - - def __init__( - self, - *, - model: BranchingNet, - optim: torch.optim.Optimizer, - action_space: gym.spaces.Discrete, - discount_factor: float = 0.99, - estimation_step: int = 1, - target_update_freq: int = 0, - reward_normalization: bool = False, - is_double: bool = True, - clip_loss_grad: bool = False, - observation_space: gym.Space | None = None, - lr_scheduler: TLearningRateScheduler | None = None, - ) -> None: - assert ( - estimation_step == 1 - ), f"N-step bigger than one is not supported by BDQ but got: {estimation_step}" - super().__init__( - model=model, - optim=optim, - action_space=action_space, - discount_factor=discount_factor, - estimation_step=estimation_step, - target_update_freq=target_update_freq, - reward_normalization=reward_normalization, - is_double=is_double, - clip_loss_grad=clip_loss_grad, - observation_space=observation_space, - lr_scheduler=lr_scheduler, - ) - self.model = cast(BranchingNet, self.model) - - # TODO: this used to be a public property called max_action_num, - # but it collides with an attr of the same name in base class - @property - def _action_per_branch(self) -> int: - return self.model.action_per_branch - - @property - def num_branches(self) -> int: - return self.model.num_branches - - def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: - obs_next_batch = Batch( - obs=buffer[indices].obs_next, - info=[None] * len(indices), - ) # obs_next: s_{t+n} - result = self(obs_next_batch) - if self._target: - # target_Q = Q_old(s_, argmax(Q_new(s_, *))) - target_q = self(obs_next_batch, model="model_old").logits - else: - target_q = result.logits - if self.is_double: - act = np.expand_dims(self(obs_next_batch).act, -1) - act = to_torch(act, dtype=torch.long, device=target_q.device) - else: - act = target_q.max(-1).indices.unsqueeze(-1) - return torch.gather(target_q, -1, act).squeeze() - - def _compute_return( - self, - batch: RolloutBatchProtocol, - buffer: ReplayBuffer, - indice: np.ndarray, - gamma: float = 0.99, - ) -> BatchWithReturnsProtocol: - rew = batch.rew - with torch.no_grad(): - target_q_torch = self._target_q(buffer, indice) # (bsz, ?) - target_q = to_numpy(target_q_torch) - end_flag = buffer.done.copy() - end_flag[buffer.unfinished_index()] = True - end_flag = end_flag[indice] - mean_target_q = np.mean(target_q, -1) if len(target_q.shape) > 1 else target_q - _target_q = rew + gamma * mean_target_q * (1 - end_flag) - target_q = np.repeat(_target_q[..., None], self.num_branches, axis=-1) - target_q = np.repeat(target_q[..., None], self._action_per_branch, axis=-1) - - batch.returns = to_torch_as(target_q, target_q_torch) - if hasattr(batch, "weight"): # prio buffer update - batch.weight = to_torch_as(batch.weight, target_q_torch) - return cast(BatchWithReturnsProtocol, batch) - - def process_fn( - self, - batch: RolloutBatchProtocol, - buffer: ReplayBuffer, - indices: np.ndarray, - ) -> BatchWithReturnsProtocol: - """Compute the 1-step return for BDQ targets.""" - return self._compute_return(batch, buffer, indices) - - def forward( - self, - batch: ObsBatchProtocol, - state: dict | BatchProtocol | np.ndarray | None = None, - model: Literal["model", "model_old"] = "model", - **kwargs: Any, - ) -> ModelOutputBatchProtocol: - model = getattr(self, model) - obs = batch.obs - # TODO: this is very contrived, see also iqn.py - obs_next_BO = obs.obs if hasattr(obs, "obs") else obs - action_values_BA, hidden_BH = model(obs_next_BO, state=state, info=batch.info) - act_B = to_numpy(action_values_BA.argmax(dim=-1)) - result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH) - return cast(ModelOutputBatchProtocol, result) - - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBDQNTrainingStats: - if self._target and self._iter % self.freq == 0: - self.sync_weight() - self.optim.zero_grad() - weight = batch.pop("weight", 1.0) - act = to_torch(batch.act, dtype=torch.long, device=batch.returns.device) - q = self(batch).logits - act_mask = torch.zeros_like(q) - act_mask = act_mask.scatter_(-1, act.unsqueeze(-1), 1) - act_q = q * act_mask - returns = batch.returns - returns = returns * act_mask - td_error = returns - act_q - loss = (td_error.pow(2).sum(-1).mean(-1) * weight).mean() - batch.weight = td_error.sum(-1).sum(-1) # prio-buffer - loss.backward() - self.optim.step() - self._iter += 1 - - return BDQNTrainingStats(loss=loss.item()) # type: ignore[return-value] - - _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") - - def exploration_noise( - self, - act: _TArrOrActBatch, - batch: ObsBatchProtocol, - ) -> _TArrOrActBatch: - if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0): - bsz = len(act) - rand_mask = np.random.rand(bsz) < self.eps - rand_act = np.random.randint( - low=0, - high=self._action_per_branch, - size=(bsz, act.shape[-1]), - ) - if hasattr(batch.obs, "mask"): - rand_act += batch.obs.mask - act[rand_mask] = rand_act[rand_mask] - return act diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py deleted file mode 100644 index 5bfdba0c1..000000000 --- a/tianshou/policy/modelfree/c51.py +++ /dev/null @@ -1,137 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Generic, TypeVar - -import gymnasium as gym -import numpy as np -import torch - -from tianshou.data import Batch, ReplayBuffer -from tianshou.data.types import RolloutBatchProtocol -from tianshou.policy import DQNPolicy -from tianshou.policy.base import TLearningRateScheduler -from tianshou.policy.modelfree.dqn import DQNTrainingStats - - -@dataclass(kw_only=True) -class C51TrainingStats(DQNTrainingStats): - pass - - -TC51TrainingStats = TypeVar("TC51TrainingStats", bound=C51TrainingStats) - - -class C51Policy(DQNPolicy[TC51TrainingStats], Generic[TC51TrainingStats]): - """Implementation of Categorical Deep Q-Network. arXiv:1707.06887. - - :param model: a model following the rules (s_B -> action_values_BA) - :param optim: a torch.optim for optimizing the model. - :param discount_factor: in [0, 1]. - :param num_atoms: the number of atoms in the support set of the - value distribution. Default to 51. - :param v_min: the value of the smallest atom in the support set. - Default to -10.0. - :param v_max: the value of the largest atom in the support set. - Default to 10.0. - :param estimation_step: the number of steps to look ahead. - :param target_update_freq: the target network update frequency (0 if - you do not use the target network). - :param reward_normalization: normalize the **returns** to Normal(0, 1). - TODO: rename to return_normalization? - :param is_double: use double dqn. - :param clip_loss_grad: clip the gradient of the loss in accordance - with nature14236; this amounts to using the Huber loss instead of - the MSE loss. - :param observation_space: Env's observation space. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.DQNPolicy` for more detailed - explanation. - """ - - def __init__( - self, - *, - model: torch.nn.Module, - optim: torch.optim.Optimizer, - action_space: gym.spaces.Discrete, - discount_factor: float = 0.99, - num_atoms: int = 51, - v_min: float = -10.0, - v_max: float = 10.0, - estimation_step: int = 1, - target_update_freq: int = 0, - reward_normalization: bool = False, - is_double: bool = True, - clip_loss_grad: bool = False, - observation_space: gym.Space | None = None, - lr_scheduler: TLearningRateScheduler | None = None, - ) -> None: - assert num_atoms > 1, f"num_atoms should be greater than 1 but got: {num_atoms}" - assert v_min < v_max, f"v_max should be larger than v_min, but got {v_min=} and {v_max=}" - super().__init__( - model=model, - optim=optim, - action_space=action_space, - discount_factor=discount_factor, - estimation_step=estimation_step, - target_update_freq=target_update_freq, - reward_normalization=reward_normalization, - is_double=is_double, - clip_loss_grad=clip_loss_grad, - observation_space=observation_space, - lr_scheduler=lr_scheduler, - ) - self._num_atoms = num_atoms - self._v_min = v_min - self._v_max = v_max - self.support = torch.nn.Parameter( - torch.linspace(self._v_min, self._v_max, self._num_atoms), - requires_grad=False, - ) - self.delta_z = (v_max - v_min) / (num_atoms - 1) - - def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: - return self.support.repeat(len(indices), 1) # shape: [bsz, num_atoms] - - def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torch.Tensor: - return super().compute_q_value((logits * self.support).sum(2), mask) - - def _target_dist(self, batch: RolloutBatchProtocol) -> torch.Tensor: - obs_next_batch = Batch(obs=batch.obs_next, info=[None] * len(batch)) - if self._target: - act = self(obs_next_batch).act - next_dist = self(obs_next_batch, model="model_old").logits - else: - next_batch = self(obs_next_batch) - act = next_batch.act - next_dist = next_batch.logits - next_dist = next_dist[np.arange(len(act)), act, :] - target_support = batch.returns.clamp(self._v_min, self._v_max) - # An amazing trick for calculating the projection gracefully. - # ref: https://github.com/ShangtongZhang/DeepRL - target_dist = ( - 1 - (target_support.unsqueeze(1) - self.support.view(1, -1, 1)).abs() / self.delta_z - ).clamp(0, 1) * next_dist.unsqueeze(1) - return target_dist.sum(-1) - - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TC51TrainingStats: - if self._target and self._iter % self.freq == 0: - self.sync_weight() - self.optim.zero_grad() - with torch.no_grad(): - target_dist = self._target_dist(batch) - weight = batch.pop("weight", 1.0) - curr_dist = self(batch).logits - act = batch.act - curr_dist = curr_dist[np.arange(len(act)), act, :] - cross_entropy = -(target_dist * torch.log(curr_dist + 1e-8)).sum(1) - loss = (cross_entropy * weight).mean() - # ref: https://github.com/Kaixhin/Rainbow/blob/master/agent.py L94-100 - batch.weight = cross_entropy.detach() # prio-buffer - loss.backward() - self.optim.step() - self._iter += 1 - - return C51TrainingStats(loss=loss.item()) # type: ignore[return-value] diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py deleted file mode 100644 index f21744f72..000000000 --- a/tianshou/policy/modelfree/ddpg.py +++ /dev/null @@ -1,224 +0,0 @@ -import warnings -from copy import deepcopy -from dataclasses import dataclass -from typing import Any, Generic, Literal, Self, TypeVar, cast - -import gymnasium as gym -import numpy as np -import torch - -from tianshou.data import Batch, ReplayBuffer -from tianshou.data.batch import BatchProtocol -from tianshou.data.types import ( - ActBatchProtocol, - ActStateBatchProtocol, - BatchWithReturnsProtocol, - ObsBatchProtocol, - RolloutBatchProtocol, -) -from tianshou.exploration import BaseNoise, GaussianNoise -from tianshou.policy import BasePolicy -from tianshou.policy.base import TLearningRateScheduler, TrainingStats -from tianshou.utils.net.continuous import Actor, Critic - - -@dataclass(kw_only=True) -class DDPGTrainingStats(TrainingStats): - actor_loss: float - critic_loss: float - - -TDDPGTrainingStats = TypeVar("TDDPGTrainingStats", bound=DDPGTrainingStats) - - -class DDPGPolicy(BasePolicy[TDDPGTrainingStats], Generic[TDDPGTrainingStats]): - """Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971. - - :param actor: The actor network following the rules (s -> actions) - :param actor_optim: The optimizer for actor network. - :param critic: The critic network. (s, a -> Q(s, a)) - :param critic_optim: The optimizer for critic network. - :param action_space: Env's action space. - :param tau: Param for soft update of the target network. - :param gamma: Discount factor, in [0, 1]. - :param exploration_noise: The exploration noise, added to the action. Defaults - to ``GaussianNoise(sigma=0.1)``. - :param estimation_step: The number of steps to look ahead. - :param observation_space: Env's observation space. - :param action_scaling: if True, scale the action from [-1, 1] to the range - of action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - Only used if the action_space is continuous. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. - """ - - def __init__( - self, - *, - actor: torch.nn.Module | Actor, - actor_optim: torch.optim.Optimizer, - critic: torch.nn.Module | Critic, - critic_optim: torch.optim.Optimizer, - action_space: gym.Space, - tau: float = 0.005, - gamma: float = 0.99, - exploration_noise: BaseNoise | Literal["default"] | None = "default", - estimation_step: int = 1, - observation_space: gym.Space | None = None, - action_scaling: bool = True, - # tanh not supported, see assert below - action_bound_method: Literal["clip"] | None = "clip", - lr_scheduler: TLearningRateScheduler | None = None, - ) -> None: - assert 0.0 <= tau <= 1.0, f"tau should be in [0, 1] but got: {tau}" - assert 0.0 <= gamma <= 1.0, f"gamma should be in [0, 1] but got: {gamma}" - assert action_bound_method != "tanh", ( # type: ignore[comparison-overlap] - "tanh mapping is not supported" - "in policies where action is used as input of critic , because" - "raw action in range (-inf, inf) will cause instability in training" - ) - super().__init__( - action_space=action_space, - observation_space=observation_space, - action_scaling=action_scaling, - action_bound_method=action_bound_method, - lr_scheduler=lr_scheduler, - ) - if action_scaling and not np.isclose(actor.max_action, 1.0): - warnings.warn( - "action_scaling and action_bound_method are only intended to deal" - "with unbounded model action space, but find actor model bound" - f"action space with max_action={actor.max_action}." - "Consider using unbounded=True option of the actor model," - "or set action_scaling to False and action_bound_method to None.", - ) - self.actor = actor - self.actor_old = deepcopy(actor) - self.actor_old.eval() - self.actor_optim = actor_optim - self.critic = critic - self.critic_old = deepcopy(critic) - self.critic_old.eval() - self.critic_optim = critic_optim - self.tau = tau - self.gamma = gamma - if exploration_noise == "default": - exploration_noise = GaussianNoise(sigma=0.1) - # TODO: IMPORTANT - can't call this "exploration_noise" because confusingly, - # there is already a method called exploration_noise() in the base class - # Now this method doesn't apply any noise and is also not overridden. See TODO there - self._exploration_noise = exploration_noise - # it is only a little difference to use GaussianNoise - # self.noise = OUNoise() - self.estimation_step = estimation_step - - def set_exp_noise(self, noise: BaseNoise | None) -> None: - """Set the exploration noise.""" - self._exploration_noise = noise - - def train(self, mode: bool = True) -> Self: - """Set the module in training mode, except for the target network.""" - self.training = mode - self.actor.train(mode) - self.critic.train(mode) - return self - - def sync_weight(self) -> None: - """Soft-update the weight for the target network.""" - self.soft_update(self.actor_old, self.actor, self.tau) - self.soft_update(self.critic_old, self.critic, self.tau) - - def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: - obs_next_batch = Batch( - obs=buffer[indices].obs_next, - info=[None] * len(indices), - ) # obs_next: s_{t+n} - return self.critic_old(obs_next_batch.obs, self(obs_next_batch, model="actor_old").act) - - def process_fn( - self, - batch: RolloutBatchProtocol, - buffer: ReplayBuffer, - indices: np.ndarray, - ) -> RolloutBatchProtocol | BatchWithReturnsProtocol: - return self.compute_nstep_return( - batch=batch, - buffer=buffer, - indices=indices, - target_q_fn=self._target_q, - gamma=self.gamma, - n_step=self.estimation_step, - ) - - def forward( - self, - batch: ObsBatchProtocol, - state: dict | BatchProtocol | np.ndarray | None = None, - model: Literal["actor", "actor_old"] = "actor", - **kwargs: Any, - ) -> ActStateBatchProtocol: - """Compute action over the given batch data. - - :return: A :class:`~tianshou.data.Batch` which has 2 keys: - - * ``act`` the action. - * ``state`` the hidden state. - - .. seealso:: - - Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for - more detailed explanation. - """ - model = getattr(self, model) - actions, hidden = model(batch.obs, state=state, info=batch.info) - return cast(ActStateBatchProtocol, Batch(act=actions, state=hidden)) - - @staticmethod - def _mse_optimizer( - batch: RolloutBatchProtocol, - critic: torch.nn.Module, - optimizer: torch.optim.Optimizer, - ) -> tuple[torch.Tensor, torch.Tensor]: - """A simple wrapper script for updating critic network.""" - weight = getattr(batch, "weight", 1.0) - current_q = critic(batch.obs, batch.act).flatten() - target_q = batch.returns.flatten() - td = current_q - target_q - # critic_loss = F.mse_loss(current_q1, target_q) - critic_loss = (td.pow(2) * weight).mean() - optimizer.zero_grad() - critic_loss.backward() - optimizer.step() - return td, critic_loss - - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDDPGTrainingStats: # type: ignore - # critic - td, critic_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) - batch.weight = td # prio-buffer - # actor - actor_loss = -self.critic(batch.obs, self(batch).act).mean() - self.actor_optim.zero_grad() - actor_loss.backward() - self.actor_optim.step() - self.sync_weight() - - return DDPGTrainingStats(actor_loss=actor_loss.item(), critic_loss=critic_loss.item()) # type: ignore[return-value] - - _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") - - def exploration_noise( - self, - act: _TArrOrActBatch, - batch: ObsBatchProtocol, - ) -> _TArrOrActBatch: - if self._exploration_noise is None: - return act - if isinstance(act, np.ndarray): - return act + self._exploration_noise(act.shape) - warnings.warn("Cannot add exploration noise to non-numpy_array action.") - return act diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py deleted file mode 100644 index d1ce28da9..000000000 --- a/tianshou/policy/modelfree/discrete_sac.py +++ /dev/null @@ -1,194 +0,0 @@ -from dataclasses import dataclass -from typing import Any, TypeVar, cast - -import gymnasium as gym -import numpy as np -import torch -from overrides import override -from torch.distributions import Categorical - -from tianshou.data import Batch, ReplayBuffer, to_torch -from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol -from tianshou.policy import SACPolicy -from tianshou.policy.base import TLearningRateScheduler -from tianshou.policy.modelfree.sac import SACTrainingStats -from tianshou.utils.net.discrete import Actor, Critic - - -@dataclass -class DiscreteSACTrainingStats(SACTrainingStats): - pass - - -TDiscreteSACTrainingStats = TypeVar("TDiscreteSACTrainingStats", bound=DiscreteSACTrainingStats) - - -class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]): - """Implementation of SAC for Discrete Action Settings. arXiv:1910.07207. - - :param actor: the actor network following the rules (s_B -> dist_input_BD) - :param actor_optim: the optimizer for actor network. - :param critic: the first critic network. (s, a -> Q(s, a)) - :param critic_optim: the optimizer for the first critic network. - :param action_space: Env's action space. Should be gym.spaces.Box. - :param critic2: the second critic network. (s, a -> Q(s, a)). - If None, use the same network as critic (via deepcopy). - :param critic2_optim: the optimizer for the second critic network. - If None, clone critic_optim to use for critic2.parameters(). - :param tau: param for soft update of the target network. - :param gamma: discount factor, in [0, 1]. - :param alpha: entropy regularization coefficient. - If a tuple (target_entropy, log_alpha, alpha_optim) is provided, - then alpha is automatically tuned. - :param estimation_step: the number of steps to look ahead for calculating - :param observation_space: Env's observation space. - :param lr_scheduler: a learning rate scheduler that adjusts the learning rate - in optimizer in each policy.update() - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. - """ - - def __init__( - self, - *, - actor: torch.nn.Module | Actor, - actor_optim: torch.optim.Optimizer, - critic: torch.nn.Module | Critic, - critic_optim: torch.optim.Optimizer, - action_space: gym.spaces.Discrete, - critic2: torch.nn.Module | Critic | None = None, - critic2_optim: torch.optim.Optimizer | None = None, - tau: float = 0.005, - gamma: float = 0.99, - alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] = 0.2, - estimation_step: int = 1, - observation_space: gym.Space | None = None, - lr_scheduler: TLearningRateScheduler | None = None, - ) -> None: - super().__init__( - actor=actor, - actor_optim=actor_optim, - critic=critic, - critic_optim=critic_optim, - action_space=action_space, - critic2=critic2, - critic2_optim=critic2_optim, - tau=tau, - gamma=gamma, - alpha=alpha, - estimation_step=estimation_step, - # Note: inheriting from continuous sac reduces code duplication, - # but continuous stuff has to be disabled - exploration_noise=None, - action_scaling=False, - action_bound_method=None, - observation_space=observation_space, - lr_scheduler=lr_scheduler, - ) - - # TODO: violates Liskov substitution principle, incompatible action space with SAC - # Not too urgent, but still.. - @override - def _check_field_validity(self) -> None: - if not isinstance(self.action_space, gym.spaces.Discrete): - raise ValueError( - f"DiscreteSACPolicy only supports gym.spaces.Discrete, but got {self.action_space=}." - f"Please use SACPolicy for continuous action spaces.", - ) - - def forward( # type: ignore - self, - batch: ObsBatchProtocol, - state: dict | Batch | np.ndarray | None = None, - **kwargs: Any, - ) -> Batch: - logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) - dist = Categorical(logits=logits_BA) - act_B = ( - dist.mode - if self.deterministic_eval and not self.is_within_training_step - else dist.sample() - ) - return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist) - - def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: - obs_next_batch = Batch( - obs=buffer[indices].obs_next, - info=[None] * len(indices), - ) # obs_next: s_{t+n} - obs_next_result = self(obs_next_batch) - dist = obs_next_result.dist - target_q = dist.probs * torch.min( - self.critic_old(obs_next_batch.obs), - self.critic2_old(obs_next_batch.obs), - ) - return target_q.sum(dim=-1) + self.alpha * dist.entropy() - - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDiscreteSACTrainingStats: # type: ignore - weight = batch.pop("weight", 1.0) - target_q = batch.returns.flatten() - act = to_torch(batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long) - - # critic 1 - current_q1 = self.critic(batch.obs).gather(1, act).flatten() - td1 = current_q1 - target_q - critic1_loss = (td1.pow(2) * weight).mean() - - self.critic_optim.zero_grad() - critic1_loss.backward() - self.critic_optim.step() - - # critic 2 - current_q2 = self.critic2(batch.obs).gather(1, act).flatten() - td2 = current_q2 - target_q - critic2_loss = (td2.pow(2) * weight).mean() - - self.critic2_optim.zero_grad() - critic2_loss.backward() - self.critic2_optim.step() - batch.weight = (td1 + td2) / 2.0 # prio-buffer - - # actor - dist = self(batch).dist - entropy = dist.entropy() - with torch.no_grad(): - current_q1a = self.critic(batch.obs) - current_q2a = self.critic2(batch.obs) - q = torch.min(current_q1a, current_q2a) - actor_loss = -(self.alpha * entropy + (dist.probs * q).sum(dim=-1)).mean() - self.actor_optim.zero_grad() - actor_loss.backward() - self.actor_optim.step() - - if self.is_auto_alpha: - log_prob = -entropy.detach() + self.target_entropy - alpha_loss = -(self.log_alpha * log_prob).mean() - self.alpha_optim.zero_grad() - alpha_loss.backward() - self.alpha_optim.step() - self.alpha = self.log_alpha.detach().exp() - - self.sync_weight() - - if self.is_auto_alpha: - self.alpha = cast(torch.Tensor, self.alpha) - - return DiscreteSACTrainingStats( # type: ignore[return-value] - actor_loss=actor_loss.item(), - critic1_loss=critic1_loss.item(), - critic2_loss=critic2_loss.item(), - alpha=self.alpha.item() if isinstance(self.alpha, torch.Tensor) else self.alpha, - alpha_loss=None if not self.is_auto_alpha else alpha_loss.item(), - ) - - _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") - - def exploration_noise( - self, - act: _TArrOrActBatch, - batch: ObsBatchProtocol, - ) -> _TArrOrActBatch: - return act diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py deleted file mode 100644 index e0ada0733..000000000 --- a/tianshou/policy/modelfree/dqn.py +++ /dev/null @@ -1,254 +0,0 @@ -from copy import deepcopy -from dataclasses import dataclass -from typing import Any, Generic, Literal, Self, TypeVar, cast - -import gymnasium as gym -import numpy as np -import torch - -from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as -from tianshou.data.batch import BatchProtocol -from tianshou.data.types import ( - ActBatchProtocol, - BatchWithReturnsProtocol, - ModelOutputBatchProtocol, - ObsBatchProtocol, - RolloutBatchProtocol, -) -from tianshou.policy import BasePolicy -from tianshou.policy.base import TLearningRateScheduler, TrainingStats -from tianshou.utils.net.common import Net - - -@dataclass(kw_only=True) -class DQNTrainingStats(TrainingStats): - loss: float - - -TDQNTrainingStats = TypeVar("TDQNTrainingStats", bound=DQNTrainingStats) - - -class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]): - """Implementation of Deep Q Network. arXiv:1312.5602. - - Implementation of Double Q-Learning. arXiv:1509.06461. - - Implementation of Dueling DQN. arXiv:1511.06581 (the dueling DQN is - implemented in the network side, not here). - - :param model: a model following the rules (s -> action_values_BA) - :param optim: a torch.optim for optimizing the model. - :param discount_factor: in [0, 1]. - :param estimation_step: the number of steps to look ahead. - :param target_update_freq: the target network update frequency (0 if - you do not use the target network). - :param reward_normalization: normalize the **returns** to Normal(0, 1). - TODO: rename to return_normalization? - :param is_double: use double dqn. - :param clip_loss_grad: clip the gradient of the loss in accordance - with nature14236; this amounts to using the Huber loss instead of - the MSE loss. - :param observation_space: Env's observation space. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. - """ - - def __init__( - self, - *, - model: torch.nn.Module | Net, - optim: torch.optim.Optimizer, - # TODO: type violates Liskov substitution principle - action_space: gym.spaces.Discrete, - discount_factor: float = 0.99, - estimation_step: int = 1, - target_update_freq: int = 0, - reward_normalization: bool = False, - is_double: bool = True, - clip_loss_grad: bool = False, - observation_space: gym.Space | None = None, - lr_scheduler: TLearningRateScheduler | None = None, - ) -> None: - super().__init__( - action_space=action_space, - observation_space=observation_space, - action_scaling=False, - action_bound_method=None, - lr_scheduler=lr_scheduler, - ) - self.model = model - self.optim = optim - self.eps = 0.0 - assert ( - 0.0 <= discount_factor <= 1.0 - ), f"discount factor should be in [0, 1] but got: {discount_factor}" - self.gamma = discount_factor - assert ( - estimation_step > 0 - ), f"estimation_step should be greater than 0 but got: {estimation_step}" - self.n_step = estimation_step - self._target = target_update_freq > 0 - self.freq = target_update_freq - self._iter = 0 - if self._target: - self.model_old = deepcopy(self.model) - self.model_old.eval() - self.rew_norm = reward_normalization - self.is_double = is_double - self.clip_loss_grad = clip_loss_grad - - # TODO: set in forward, fix this! - self.max_action_num: int | None = None - - def set_eps(self, eps: float) -> None: - """Set the eps for epsilon-greedy exploration.""" - self.eps = eps - - def train(self, mode: bool = True) -> Self: - """Set the module in training mode, except for the target network.""" - self.training = mode - self.model.train(mode) - return self - - def sync_weight(self) -> None: - """Synchronize the weight for the target network.""" - self.model_old.load_state_dict(self.model.state_dict()) - - def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: - obs_next_batch = Batch( - obs=buffer[indices].obs_next, - info=[None] * len(indices), - ) # obs_next: s_{t+n} - result = self(obs_next_batch) - if self._target: - # target_Q = Q_old(s_, argmax(Q_new(s_, *))) - target_q = self(obs_next_batch, model="model_old").logits - else: - target_q = result.logits - if self.is_double: - return target_q[np.arange(len(result.act)), result.act] - # Nature DQN, over estimate - return target_q.max(dim=1)[0] - - def process_fn( - self, - batch: RolloutBatchProtocol, - buffer: ReplayBuffer, - indices: np.ndarray, - ) -> BatchWithReturnsProtocol: - """Compute the n-step return for Q-learning targets. - - More details can be found at - :meth:`~tianshou.policy.BasePolicy.compute_nstep_return`. - """ - return self.compute_nstep_return( - batch=batch, - buffer=buffer, - indices=indices, - target_q_fn=self._target_q, - gamma=self.gamma, - n_step=self.n_step, - rew_norm=self.rew_norm, - ) - - def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torch.Tensor: - """Compute the q value based on the network's raw output and action mask.""" - if mask is not None: - # the masked q value should be smaller than logits.min() - min_value = logits.min() - logits.max() - 1.0 - logits = logits + to_torch_as(1 - mask, logits) * min_value - return logits - - def forward( - self, - batch: ObsBatchProtocol, - state: dict | BatchProtocol | np.ndarray | None = None, - model: Literal["model", "model_old"] = "model", - **kwargs: Any, - ) -> ModelOutputBatchProtocol: - """Compute action over the given batch data. - - If you need to mask the action, please add a "mask" into batch.obs, for - example, if we have an environment that has "0/1/2" three actions: - :: - - batch == Batch( - obs=Batch( - obs="original obs, with batch_size=1 for demonstration", - mask=np.array([[False, True, False]]), - # action 1 is available - # action 0 and 2 are unavailable - ), - ... - ) - - :return: A :class:`~tianshou.data.Batch` which has 3 keys: - - * ``act`` the action. - * ``logits`` the network's raw output. - * ``state`` the hidden state. - - .. seealso:: - - Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for - more detailed explanation. - """ - model = getattr(self, model) - obs = batch.obs - # TODO: this is convoluted! See also other places where this is done. - obs_next = obs.obs if hasattr(obs, "obs") else obs - action_values_BA, hidden_BH = model(obs_next, state=state, info=batch.info) - q = self.compute_q_value(action_values_BA, getattr(obs, "mask", None)) - if self.max_action_num is None: - self.max_action_num = q.shape[1] - act_B = to_numpy(q.argmax(dim=1)) - result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH) - return cast(ModelOutputBatchProtocol, result) - - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDQNTrainingStats: - if self._target and self._iter % self.freq == 0: - self.sync_weight() - self.optim.zero_grad() - weight = batch.pop("weight", 1.0) - q = self(batch).logits - q = q[np.arange(len(q)), batch.act] - returns = to_torch_as(batch.returns.flatten(), q) - td_error = returns - q - - if self.clip_loss_grad: - y = q.reshape(-1, 1) - t = returns.reshape(-1, 1) - loss = torch.nn.functional.huber_loss(y, t, reduction="mean") - else: - loss = (td_error.pow(2) * weight).mean() - - batch.weight = td_error # prio-buffer - loss.backward() - self.optim.step() - self._iter += 1 - - return DQNTrainingStats(loss=loss.item()) # type: ignore[return-value] - - _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") - - def exploration_noise( - self, - act: _TArrOrActBatch, - batch: ObsBatchProtocol, - ) -> _TArrOrActBatch: - if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0): - bsz = len(act) - rand_mask = np.random.rand(bsz) < self.eps - assert ( - self.max_action_num is not None - ), "Can't call this method before max_action_num was set in first forward" - q = np.random.rand(bsz, self.max_action_num) # [0, 1] - if hasattr(batch.obs, "mask"): - q += batch.obs.mask - rand_act = q.argmax(axis=1) - act[rand_mask] = rand_act[rand_mask] - return act diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py deleted file mode 100644 index 9c87f9cac..000000000 --- a/tianshou/policy/modelfree/fqf.py +++ /dev/null @@ -1,219 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Literal, TypeVar, cast - -import gymnasium as gym -import numpy as np -import torch -import torch.nn.functional as F - -from tianshou.data import Batch, ReplayBuffer, to_numpy -from tianshou.data.types import FQFBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol -from tianshou.policy import DQNPolicy, QRDQNPolicy -from tianshou.policy.base import TLearningRateScheduler -from tianshou.policy.modelfree.qrdqn import QRDQNTrainingStats -from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction - - -@dataclass(kw_only=True) -class FQFTrainingStats(QRDQNTrainingStats): - quantile_loss: float - fraction_loss: float - entropy_loss: float - - -TFQFTrainingStats = TypeVar("TFQFTrainingStats", bound=FQFTrainingStats) - - -class FQFPolicy(QRDQNPolicy[TFQFTrainingStats]): - """Implementation of Fully-parameterized Quantile Function. arXiv:1911.02140. - - :param model: a model following the rules (s_B -> action_values_BA) - :param optim: a torch.optim for optimizing the model. - :param fraction_model: a FractionProposalNetwork for - proposing fractions/quantiles given state. - :param fraction_optim: a torch.optim for optimizing - the fraction model above. - :param action_space: Env's action space. - :param discount_factor: in [0, 1]. - :param num_fractions: the number of fractions to use. - :param ent_coef: the coefficient for entropy loss. - :param estimation_step: the number of steps to look ahead. - :param target_update_freq: the target network update frequency (0 if - you do not use the target network). - :param reward_normalization: normalize the **returns** to Normal(0, 1). - TODO: rename to return_normalization? - :param is_double: use double dqn. - :param clip_loss_grad: clip the gradient of the loss in accordance - with nature14236; this amounts to using the Huber loss instead of - the MSE loss. - :param observation_space: Env's observation space. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed - explanation. - """ - - def __init__( - self, - *, - model: FullQuantileFunction, - optim: torch.optim.Optimizer, - fraction_model: FractionProposalNetwork, - fraction_optim: torch.optim.Optimizer, - action_space: gym.spaces.Discrete, - discount_factor: float = 0.99, - # TODO: used as num_quantiles in QRDQNPolicy, but num_fractions in FQFPolicy. - # Rename? Or at least explain what happens here. - num_fractions: int = 32, - ent_coef: float = 0.0, - estimation_step: int = 1, - target_update_freq: int = 0, - reward_normalization: bool = False, - is_double: bool = True, - clip_loss_grad: bool = False, - observation_space: gym.Space | None = None, - lr_scheduler: TLearningRateScheduler | None = None, - ) -> None: - super().__init__( - model=model, - optim=optim, - action_space=action_space, - discount_factor=discount_factor, - num_quantiles=num_fractions, - estimation_step=estimation_step, - target_update_freq=target_update_freq, - reward_normalization=reward_normalization, - is_double=is_double, - clip_loss_grad=clip_loss_grad, - observation_space=observation_space, - lr_scheduler=lr_scheduler, - ) - self.fraction_model = fraction_model - self.ent_coef = ent_coef - self.fraction_optim = fraction_optim - - def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: - obs_next_batch = Batch( - obs=buffer[indices].obs_next, - info=[None] * len(indices), - ) # obs_next: s_{t+n} - if self._target: - result = self(obs_next_batch) - act, fractions = result.act, result.fractions - next_dist = self(obs_next_batch, model="model_old", fractions=fractions).logits - else: - next_batch = self(obs_next_batch) - act = next_batch.act - next_dist = next_batch.logits - return next_dist[np.arange(len(act)), act, :] - - # TODO: fix Liskov substitution principle violation - def forward( # type: ignore - self, - batch: ObsBatchProtocol, - state: dict | Batch | np.ndarray | None = None, - model: Literal["model", "model_old"] = "model", - fractions: Batch | None = None, - **kwargs: Any, - ) -> FQFBatchProtocol: - model = getattr(self, model) - obs = batch.obs - # TODO: this is convoluted! See also other places where this is done - obs_next = obs.obs if hasattr(obs, "obs") else obs - if fractions is None: - (logits, fractions, quantiles_tau), hidden = model( - obs_next, - propose_model=self.fraction_model, - state=state, - info=batch.info, - ) - else: - (logits, _, quantiles_tau), hidden = model( - obs_next, - propose_model=self.fraction_model, - fractions=fractions, - state=state, - info=batch.info, - ) - weighted_logits = (fractions.taus[:, 1:] - fractions.taus[:, :-1]).unsqueeze(1) * logits - q = DQNPolicy.compute_q_value(self, weighted_logits.sum(2), getattr(obs, "mask", None)) - if self.max_action_num is None: # type: ignore - # TODO: see same thing in DQNPolicy! Also reduce code duplication. - self.max_action_num = q.shape[1] - act = to_numpy(q.max(dim=1)[1]) - result = Batch( - logits=logits, - act=act, - state=hidden, - fractions=fractions, - quantiles_tau=quantiles_tau, - ) - return cast(FQFBatchProtocol, result) - - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TFQFTrainingStats: - if self._target and self._iter % self.freq == 0: - self.sync_weight() - weight = batch.pop("weight", 1.0) - out = self(batch) - curr_dist_orig = out.logits - taus, tau_hats = out.fractions.taus, out.fractions.tau_hats - act = batch.act - curr_dist = curr_dist_orig[np.arange(len(act)), act, :].unsqueeze(2) - target_dist = batch.returns.unsqueeze(1) - # calculate each element's difference between curr_dist and target_dist - dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") - huber_loss = ( - ( - dist_diff - * (tau_hats.unsqueeze(2) - (target_dist - curr_dist).detach().le(0.0).float()).abs() - ) - .sum(-1) - .mean(1) - ) - quantile_loss = (huber_loss * weight).mean() - # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ - # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 - batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer - # calculate fraction loss - with torch.no_grad(): - sa_quantile_hats = curr_dist_orig[np.arange(len(act)), act, :] - sa_quantiles = out.quantiles_tau[np.arange(len(act)), act, :] - # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ - # blob/master/fqf_iqn_qrdqn/agent/fqf_agent.py L169 - values_1 = sa_quantiles - sa_quantile_hats[:, :-1] - signs_1 = sa_quantiles > torch.cat( - [sa_quantile_hats[:, :1], sa_quantiles[:, :-1]], - dim=1, - ) - - values_2 = sa_quantiles - sa_quantile_hats[:, 1:] - signs_2 = sa_quantiles < torch.cat( - [sa_quantiles[:, 1:], sa_quantile_hats[:, -1:]], - dim=1, - ) - - gradient_of_taus = torch.where(signs_1, values_1, -values_1) + torch.where( - signs_2, - values_2, - -values_2, - ) - fraction_loss = (gradient_of_taus * taus[:, 1:-1]).sum(1).mean() - # calculate entropy loss - entropy_loss = out.fractions.entropies.mean() - fraction_entropy_loss = fraction_loss - self.ent_coef * entropy_loss - self.fraction_optim.zero_grad() - fraction_entropy_loss.backward(retain_graph=True) - self.fraction_optim.step() - self.optim.zero_grad() - quantile_loss.backward() - self.optim.step() - self._iter += 1 - - return FQFTrainingStats( # type: ignore[return-value] - loss=quantile_loss.item() + fraction_entropy_loss.item(), - quantile_loss=quantile_loss.item(), - fraction_loss=fraction_loss.item(), - entropy_loss=entropy_loss.item(), - ) diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py deleted file mode 100644 index 75d76a2dd..000000000 --- a/tianshou/policy/modelfree/iqn.py +++ /dev/null @@ -1,161 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Literal, TypeVar, cast - -import gymnasium as gym -import numpy as np -import torch -import torch.nn.functional as F - -from tianshou.data import Batch, to_numpy -from tianshou.data.batch import BatchProtocol -from tianshou.data.types import ( - ObsBatchProtocol, - QuantileRegressionBatchProtocol, - RolloutBatchProtocol, -) -from tianshou.policy import QRDQNPolicy -from tianshou.policy.base import TLearningRateScheduler -from tianshou.policy.modelfree.qrdqn import QRDQNTrainingStats - - -@dataclass(kw_only=True) -class IQNTrainingStats(QRDQNTrainingStats): - pass - - -TIQNTrainingStats = TypeVar("TIQNTrainingStats", bound=IQNTrainingStats) - - -class IQNPolicy(QRDQNPolicy[TIQNTrainingStats]): - """Implementation of Implicit Quantile Network. arXiv:1806.06923. - - :param model: a model following the rules (s_B -> action_values_BA) - :param optim: a torch.optim for optimizing the model. - :param discount_factor: in [0, 1]. - :param sample_size: the number of samples for policy evaluation. - :param online_sample_size: the number of samples for online model - in training. - :param target_sample_size: the number of samples for target model - in training. - :param num_quantiles: the number of quantile midpoints in the inverse - cumulative distribution function of the value. - :param estimation_step: the number of steps to look ahead. - :param target_update_freq: the target network update frequency (0 if - you do not use the target network). - :param reward_normalization: normalize the **returns** to Normal(0, 1). - TODO: rename to return_normalization? - :param is_double: use double dqn. - :param clip_loss_grad: clip the gradient of the loss in accordance - with nature14236; this amounts to using the Huber loss instead of - the MSE loss. - :param observation_space: Env's observation space. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed - explanation. - """ - - def __init__( - self, - *, - model: torch.nn.Module, - optim: torch.optim.Optimizer, - action_space: gym.spaces.Discrete, - discount_factor: float = 0.99, - sample_size: int = 32, - online_sample_size: int = 8, - target_sample_size: int = 8, - num_quantiles: int = 200, - estimation_step: int = 1, - target_update_freq: int = 0, - reward_normalization: bool = False, - is_double: bool = True, - clip_loss_grad: bool = False, - observation_space: gym.Space | None = None, - lr_scheduler: TLearningRateScheduler | None = None, - ) -> None: - assert sample_size > 1, f"sample_size should be greater than 1 but got: {sample_size}" - assert ( - online_sample_size > 1 - ), f"online_sample_size should be greater than 1 but got: {online_sample_size}" - assert ( - target_sample_size > 1 - ), f"target_sample_size should be greater than 1 but got: {target_sample_size}" - super().__init__( - model=model, - optim=optim, - action_space=action_space, - discount_factor=discount_factor, - num_quantiles=num_quantiles, - estimation_step=estimation_step, - target_update_freq=target_update_freq, - reward_normalization=reward_normalization, - is_double=is_double, - clip_loss_grad=clip_loss_grad, - observation_space=observation_space, - lr_scheduler=lr_scheduler, - ) - self.sample_size = sample_size # for policy eval - self.online_sample_size = online_sample_size - self.target_sample_size = target_sample_size - - def forward( - self, - batch: ObsBatchProtocol, - state: dict | BatchProtocol | np.ndarray | None = None, - model: Literal["model", "model_old"] = "model", - **kwargs: Any, - ) -> QuantileRegressionBatchProtocol: - if model == "model_old": - sample_size = self.target_sample_size - elif self.training: - sample_size = self.online_sample_size - else: - sample_size = self.sample_size - model = getattr(self, model) - obs = batch.obs - # TODO: this seems very contrived! - obs_next = obs.obs if hasattr(obs, "obs") else obs - (logits, taus), hidden = model( - obs_next, - sample_size=sample_size, - state=state, - info=batch.info, - ) - q = self.compute_q_value(logits, getattr(obs, "mask", None)) - if self.max_action_num is None: # type: ignore - # TODO: see same thing in DQNPolicy! - self.max_action_num = q.shape[1] - act = to_numpy(q.max(dim=1)[1]) - result = Batch(logits=logits, act=act, state=hidden, taus=taus) - return cast(QuantileRegressionBatchProtocol, result) - - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TIQNTrainingStats: - if self._target and self._iter % self.freq == 0: - self.sync_weight() - self.optim.zero_grad() - weight = batch.pop("weight", 1.0) - action_batch = self(batch) - curr_dist, taus = action_batch.logits, action_batch.taus - act = batch.act - curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2) - target_dist = batch.returns.unsqueeze(1) - # calculate each element's difference between curr_dist and target_dist - dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") - huber_loss = ( - ( - dist_diff - * (taus.unsqueeze(2) - (target_dist - curr_dist).detach().le(0.0).float()).abs() - ) - .sum(-1) - .mean(1) - ) - loss = (huber_loss * weight).mean() - # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ - # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 - batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer - loss.backward() - self.optim.step() - self._iter += 1 - - return IQNTrainingStats(loss=loss.item()) # type: ignore[return-value] diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py deleted file mode 100644 index 005454396..000000000 --- a/tianshou/policy/modelfree/npg.py +++ /dev/null @@ -1,228 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Generic, Literal, TypeVar - -import gymnasium as gym -import numpy as np -import torch -import torch.nn.functional as F -from torch import nn -from torch.distributions import kl_divergence - -from tianshou.data import Batch, ReplayBuffer, SequenceSummaryStats -from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol -from tianshou.policy import A2CPolicy -from tianshou.policy.base import TLearningRateScheduler, TrainingStats -from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont -from tianshou.utils.net.continuous import ActorProb, Critic -from tianshou.utils.net.discrete import Actor as DiscreteActor -from tianshou.utils.net.discrete import Critic as DiscreteCritic - - -@dataclass(kw_only=True) -class NPGTrainingStats(TrainingStats): - actor_loss: SequenceSummaryStats - vf_loss: SequenceSummaryStats - kl: SequenceSummaryStats - - -TNPGTrainingStats = TypeVar("TNPGTrainingStats", bound=NPGTrainingStats) - - -# TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. -class NPGPolicy(A2CPolicy[TNPGTrainingStats], Generic[TNPGTrainingStats]): # type: ignore[type-var] - """Implementation of Natural Policy Gradient. - - https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf - - :param actor: the actor network following the rules: - If `self.action_type == "discrete"`: (`s` ->`action_values_BA`). - If `self.action_type == "continuous"`: (`s` -> `dist_input_BD`). - :param critic: the critic network. (s -> V(s)) - :param optim: the optimizer for the critic network only. The actor network - is optimized via natural gradients internally. - :param dist_fn: distribution class for computing the action. - :param action_space: env's action space - :param optim_critic_iters: Number of times to optimize critic network per update. - :param actor_step_size: step size for actor update in natural gradient direction. - :param advantage_normalization: whether to do per mini-batch advantage - normalization. - :param gae_lambda: in [0, 1], param for Generalized Advantage Estimation. - :param max_batchsize: the maximum size of the batch when computing GAE. - :param discount_factor: in [0, 1]. - :param reward_normalization: normalize estimated values to have std close to 1. - :param deterministic_eval: if True, use deterministic evaluation. - :param observation_space: the space of the observation. - :param action_scaling: if True, scale the action from [-1, 1] to the range of - action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - :param lr_scheduler: if not None, will be called in `policy.update()`. - """ - - def __init__( - self, - *, - actor: torch.nn.Module | ActorProb | DiscreteActor, - critic: torch.nn.Module | Critic | DiscreteCritic, - optim: torch.optim.Optimizer, - dist_fn: TDistFnDiscrOrCont, - action_space: gym.Space, - optim_critic_iters: int = 5, - actor_step_size: float = 0.5, - advantage_normalization: bool = True, - gae_lambda: float = 0.95, - max_batchsize: int = 256, - discount_factor: float = 0.99, - # TODO: rename to return_normalization? - reward_normalization: bool = False, - deterministic_eval: bool = False, - observation_space: gym.Space | None = None, - action_scaling: bool = True, - action_bound_method: Literal["clip", "tanh"] | None = "clip", - lr_scheduler: TLearningRateScheduler | None = None, - ) -> None: - super().__init__( - actor=actor, - critic=critic, - optim=optim, - dist_fn=dist_fn, - action_space=action_space, - # TODO: violates Liskov substitution principle, see the del statement below - vf_coef=None, # type: ignore - ent_coef=None, # type: ignore - max_grad_norm=None, - gae_lambda=gae_lambda, - max_batchsize=max_batchsize, - discount_factor=discount_factor, - reward_normalization=reward_normalization, - deterministic_eval=deterministic_eval, - observation_space=observation_space, - action_scaling=action_scaling, - action_bound_method=action_bound_method, - lr_scheduler=lr_scheduler, - ) - # TODO: see above, it ain't pretty... - del self.vf_coef, self.ent_coef, self.max_grad_norm - self.norm_adv = advantage_normalization - self.optim_critic_iters = optim_critic_iters - self.actor_step_size = actor_step_size - # adjusts Hessian-vector product calculation for numerical stability - self._damping = 0.1 - - def process_fn( - self, - batch: RolloutBatchProtocol, - buffer: ReplayBuffer, - indices: np.ndarray, - ) -> BatchWithAdvantagesProtocol: - batch = super().process_fn(batch, buffer, indices) - old_log_prob = [] - with torch.no_grad(): - for minibatch in batch.split(self.max_batchsize, shuffle=False, merge_last=True): - old_log_prob.append(self(minibatch).dist.log_prob(minibatch.act)) - batch.logp_old = torch.cat(old_log_prob, dim=0) - if self.norm_adv: - batch.adv = (batch.adv - batch.adv.mean()) / batch.adv.std() - return batch - - def learn( # type: ignore - self, - batch: Batch, - batch_size: int | None, - repeat: int, - **kwargs: Any, - ) -> TNPGTrainingStats: - actor_losses, vf_losses, kls = [], [], [] - split_batch_size = batch_size or -1 - for _ in range(repeat): - for minibatch in batch.split(split_batch_size, merge_last=True): - # optimize actor - # direction: calculate villia gradient - dist = self(minibatch).dist - log_prob = dist.log_prob(minibatch.act) - log_prob = log_prob.reshape(log_prob.size(0), -1).transpose(0, 1) - actor_loss = -(log_prob * minibatch.adv).mean() - flat_grads = self._get_flat_grad(actor_loss, self.actor, retain_graph=True).detach() - - # direction: calculate natural gradient - with torch.no_grad(): - old_dist = self(minibatch).dist - - kl = kl_divergence(old_dist, dist).mean() - # calculate first order gradient of kl with respect to theta - flat_kl_grad = self._get_flat_grad(kl, self.actor, create_graph=True) - search_direction = -self._conjugate_gradients(flat_grads, flat_kl_grad, nsteps=10) - - # step - with torch.no_grad(): - flat_params = torch.cat( - [param.data.view(-1) for param in self.actor.parameters()], - ) - new_flat_params = flat_params + self.actor_step_size * search_direction - self._set_from_flat_params(self.actor, new_flat_params) - new_dist = self(minibatch).dist - kl = kl_divergence(old_dist, new_dist).mean() - - # optimize critic - for _ in range(self.optim_critic_iters): - value = self.critic(minibatch.obs).flatten() - vf_loss = F.mse_loss(minibatch.returns, value) - self.optim.zero_grad() - vf_loss.backward() - self.optim.step() - - actor_losses.append(actor_loss.item()) - vf_losses.append(vf_loss.item()) - kls.append(kl.item()) - - actor_loss_summary_stat = SequenceSummaryStats.from_sequence(actor_losses) - vf_loss_summary_stat = SequenceSummaryStats.from_sequence(vf_losses) - kl_summary_stat = SequenceSummaryStats.from_sequence(kls) - - return NPGTrainingStats( # type: ignore[return-value] - actor_loss=actor_loss_summary_stat, - vf_loss=vf_loss_summary_stat, - kl=kl_summary_stat, - ) - - def _MVP(self, v: torch.Tensor, flat_kl_grad: torch.Tensor) -> torch.Tensor: - """Matrix vector product.""" - # caculate second order gradient of kl with respect to theta - kl_v = (flat_kl_grad * v).sum() - flat_kl_grad_grad = self._get_flat_grad(kl_v, self.actor, retain_graph=True).detach() - return flat_kl_grad_grad + v * self._damping - - def _conjugate_gradients( - self, - minibatch: torch.Tensor, - flat_kl_grad: torch.Tensor, - nsteps: int = 10, - residual_tol: float = 1e-10, - ) -> torch.Tensor: - x = torch.zeros_like(minibatch) - r, p = minibatch.clone(), minibatch.clone() - # Note: should be 'r, p = minibatch - MVP(x)', but for x=0, MVP(x)=0. - # Change if doing warm start. - rdotr = r.dot(r) - for _ in range(nsteps): - z = self._MVP(p, flat_kl_grad) - alpha = rdotr / p.dot(z) - x += alpha * p - r -= alpha * z - new_rdotr = r.dot(r) - if new_rdotr < residual_tol: - break - p = r + new_rdotr / rdotr * p - rdotr = new_rdotr - return x - - def _get_flat_grad(self, y: torch.Tensor, model: nn.Module, **kwargs: Any) -> torch.Tensor: - grads = torch.autograd.grad(y, model.parameters(), **kwargs) # type: ignore - return torch.cat([grad.reshape(-1) for grad in grads]) - - def _set_from_flat_params(self, model: nn.Module, flat_params: torch.Tensor) -> nn.Module: - prev_ind = 0 - for param in model.parameters(): - flat_size = int(np.prod(list(param.size()))) - param.data.copy_(flat_params[prev_ind : prev_ind + flat_size].view(param.size())) - prev_ind += flat_size - return model diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py deleted file mode 100644 index 01f059df8..000000000 --- a/tianshou/policy/modelfree/pg.py +++ /dev/null @@ -1,238 +0,0 @@ -import logging -import warnings -from collections.abc import Callable -from dataclasses import dataclass -from typing import Any, Generic, Literal, TypeVar, cast - -import gymnasium as gym -import numpy as np -import torch - -from tianshou.data import ( - Batch, - ReplayBuffer, - SequenceSummaryStats, - to_torch, - to_torch_as, -) -from tianshou.data.batch import BatchProtocol -from tianshou.data.types import ( - BatchWithReturnsProtocol, - DistBatchProtocol, - ObsBatchProtocol, - RolloutBatchProtocol, -) -from tianshou.policy import BasePolicy -from tianshou.policy.base import TLearningRateScheduler, TrainingStats -from tianshou.utils import RunningMeanStd -from tianshou.utils.net.continuous import ActorProb -from tianshou.utils.net.discrete import Actor - -log = logging.getLogger(__name__) - - -# Dimension Naming Convention -# B - Batch Size -# A - Action -# D - Dist input (usually 2, loc and scale) -# H - Dimension of hidden, can be None - -TDistFnContinuous = Callable[ - [tuple[torch.Tensor, torch.Tensor]], - torch.distributions.Distribution, -] -TDistFnDiscrete = Callable[[torch.Tensor], torch.distributions.Categorical] - -TDistFnDiscrOrCont = TDistFnContinuous | TDistFnDiscrete - - -@dataclass(kw_only=True) -class PGTrainingStats(TrainingStats): - loss: SequenceSummaryStats - - -TPGTrainingStats = TypeVar("TPGTrainingStats", bound=PGTrainingStats) - - -class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]): - """Implementation of REINFORCE algorithm. - - :param actor: the actor network following the rules: - If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`). - If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`). - :param optim: optimizer for actor network. - :param dist_fn: distribution class for computing the action. - Maps model_output -> distribution. Typically a Gaussian distribution - taking `model_output=mean,std` as input for continuous action spaces, - or a categorical distribution taking `model_output=logits` - for discrete action spaces. Note that as user, you are responsible - for ensuring that the distribution is compatible with the action space. - :param action_space: env's action space. - :param discount_factor: in [0, 1]. - :param reward_normalization: if True, will normalize the *returns* - by subtracting the running mean and dividing by the running standard deviation. - Can be detrimental to performance! See TODO in process_fn. - :param deterministic_eval: if True, will use deterministic action (the dist's mode) - instead of stochastic one during evaluation. Does not affect training. - :param observation_space: Env's observation space. - :param action_scaling: if True, scale the action from [-1, 1] to the range - of action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - Only used if the action_space is continuous. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. - """ - - def __init__( - self, - *, - actor: torch.nn.Module | ActorProb | Actor, - optim: torch.optim.Optimizer, - dist_fn: TDistFnDiscrOrCont, - action_space: gym.Space, - discount_factor: float = 0.99, - # TODO: rename to return_normalization? - reward_normalization: bool = False, - deterministic_eval: bool = False, - observation_space: gym.Space | None = None, - # TODO: why change the default from the base? - action_scaling: bool = True, - action_bound_method: Literal["clip", "tanh"] | None = "clip", - lr_scheduler: TLearningRateScheduler | None = None, - ) -> None: - super().__init__( - action_space=action_space, - observation_space=observation_space, - action_scaling=action_scaling, - action_bound_method=action_bound_method, - lr_scheduler=lr_scheduler, - ) - if action_scaling and not np.isclose(actor.max_action, 1.0): - warnings.warn( - "action_scaling and action_bound_method are only intended" - "to deal with unbounded model action space, but find actor model" - f"bound action space with max_action={actor.max_action}." - "Consider using unbounded=True option of the actor model," - "or set action_scaling to False and action_bound_method to None.", - ) - self.actor = actor - self.optim = optim - self.dist_fn = dist_fn - assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]" - self.gamma = discount_factor - self.rew_norm = reward_normalization - self.ret_rms = RunningMeanStd() - self._eps = 1e-8 - self.deterministic_eval = deterministic_eval - - def process_fn( - self, - batch: RolloutBatchProtocol, - buffer: ReplayBuffer, - indices: np.ndarray, - ) -> BatchWithReturnsProtocol: - r"""Compute the discounted returns (Monte Carlo estimates) for each transition. - - They are added to the batch under the field `returns`. - Note: this function will modify the input batch! - - .. math:: - G_t = \sum_{i=t}^T \gamma^{i-t}r_i - - where :math:`T` is the terminal time step, :math:`\gamma` is the - discount factor, :math:`\gamma \in [0, 1]`. - - :param batch: a data batch which contains several episodes of data in - sequential order. Mind that the end of each finished episode of batch - should be marked by done flag, unfinished (or collecting) episodes will be - recognized by buffer.unfinished_index(). - :param buffer: the corresponding replay buffer. - :param numpy.ndarray indices: tell batch's location in buffer, batch is equal - to buffer[indices]. - """ - v_s_ = np.full(indices.shape, self.ret_rms.mean) - # gae_lambda = 1.0 means we use Monte Carlo estimate - unnormalized_returns, _ = self.compute_episodic_return( - batch, - buffer, - indices, - v_s_=v_s_, - gamma=self.gamma, - gae_lambda=1.0, - ) - # TODO: overridden in A2C, where mean is not subtracted. Subtracting mean - # can be very detrimental! It also has no theoretical grounding. - # This should be addressed soon! - if self.rew_norm: - batch.returns = (unnormalized_returns - self.ret_rms.mean) / np.sqrt( - self.ret_rms.var + self._eps, - ) - self.ret_rms.update(unnormalized_returns) - else: - batch.returns = unnormalized_returns - batch: BatchWithReturnsProtocol - return batch - - def forward( - self, - batch: ObsBatchProtocol, - state: dict | BatchProtocol | np.ndarray | None = None, - **kwargs: Any, - ) -> DistBatchProtocol: - """Compute action over the given batch data by applying the actor. - - Will sample from the dist_fn, if appropriate. - Returns a new object representing the processed batch data - (contrary to other methods that modify the input batch inplace). - - .. seealso:: - - Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for - more detailed explanation. - """ - # TODO - ALGO: marked for algorithm refactoring - action_dist_input_BD, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) - # in the case that self.action_type == "discrete", the dist should always be Categorical, and D=A - # therefore action_dist_input_BD is equivalent to logits_BA - # If discrete, dist_fn will typically map loc, scale to a distribution (usually a Gaussian) - # the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked - dist = self.dist_fn(action_dist_input_BD) - - act_B = ( - dist.mode - if self.deterministic_eval and not self.is_within_training_step - else dist.sample() - ) - # act is of dimension BA in continuous case and of dimension B in discrete - result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist) - return cast(DistBatchProtocol, result) - - # TODO: why does mypy complain? - def learn( # type: ignore - self, - batch: BatchWithReturnsProtocol, - batch_size: int | None, - repeat: int, - *args: Any, - **kwargs: Any, - ) -> TPGTrainingStats: - losses = [] - split_batch_size = batch_size or -1 - for _ in range(repeat): - for minibatch in batch.split(split_batch_size, merge_last=True): - self.optim.zero_grad() - result = self(minibatch) - dist = result.dist - act = to_torch_as(minibatch.act, result.act) - ret = to_torch(minibatch.returns, torch.float, result.act.device) - log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1) - loss = -(log_prob * ret).mean() - loss.backward() - self.optim.step() - losses.append(loss.item()) - - loss_summary_stat = SequenceSummaryStats.from_sequence(losses) - return PGTrainingStats(loss=loss_summary_stat) # type: ignore[return-value] diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py deleted file mode 100644 index a4694b57b..000000000 --- a/tianshou/policy/modelfree/ppo.py +++ /dev/null @@ -1,235 +0,0 @@ -from collections.abc import Sequence -from dataclasses import dataclass -from typing import Any, Generic, Literal, Self, TypeVar - -import gymnasium as gym -import numpy as np -import torch -from torch import nn - -from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as -from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol -from tianshou.policy import A2CPolicy -from tianshou.policy.base import TLearningRateScheduler, TrainingStats -from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont -from tianshou.utils.net.common import ActorCritic -from tianshou.utils.net.continuous import ActorProb, Critic -from tianshou.utils.net.discrete import Actor as DiscreteActor -from tianshou.utils.net.discrete import Critic as DiscreteCritic - - -@dataclass(kw_only=True) -class PPOTrainingStats(TrainingStats): - loss: SequenceSummaryStats - clip_loss: SequenceSummaryStats - vf_loss: SequenceSummaryStats - ent_loss: SequenceSummaryStats - gradient_steps: int = 0 - - @classmethod - def from_sequences( - cls, - *, - losses: Sequence[float], - clip_losses: Sequence[float], - vf_losses: Sequence[float], - ent_losses: Sequence[float], - gradient_steps: int = 0, - ) -> Self: - return cls( - loss=SequenceSummaryStats.from_sequence(losses), - clip_loss=SequenceSummaryStats.from_sequence(clip_losses), - vf_loss=SequenceSummaryStats.from_sequence(vf_losses), - ent_loss=SequenceSummaryStats.from_sequence(ent_losses), - gradient_steps=gradient_steps, - ) - - -TPPOTrainingStats = TypeVar("TPPOTrainingStats", bound=PPOTrainingStats) - - -# TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. -class PPOPolicy(A2CPolicy[TPPOTrainingStats], Generic[TPPOTrainingStats]): # type: ignore[type-var] - r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347. - - :param actor: the actor network following the rules: - If `self.action_type == "discrete"`: (`s` ->`action_values_BA`). - If `self.action_type == "continuous"`: (`s` -> `dist_input_BD`). - :param critic: the critic network. (s -> V(s)) - :param optim: the optimizer for actor and critic network. - :param dist_fn: distribution class for computing the action. - :param action_space: env's action space - :param eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original - paper. - :param dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5, - where c > 1 is a constant indicating the lower bound. Set to None - to disable dual-clip PPO. - :param value_clip: a parameter mentioned in arXiv:1811.02553v3 Sec. 4.1. - :param advantage_normalization: whether to do per mini-batch advantage - normalization. - :param recompute_advantage: whether to recompute advantage every update - repeat according to https://arxiv.org/pdf/2006.05990.pdf Sec. 3.5. - :param vf_coef: weight for value loss. - :param ent_coef: weight for entropy loss. - :param max_grad_norm: clipping gradients in back propagation. - :param gae_lambda: in [0, 1], param for Generalized Advantage Estimation. - :param max_batchsize: the maximum size of the batch when computing GAE. - :param discount_factor: in [0, 1]. - :param reward_normalization: normalize estimated values to have std close to 1. - :param deterministic_eval: if True, use deterministic evaluation. - :param observation_space: the space of the observation. - :param action_scaling: if True, scale the action from [-1, 1] to the range of - action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. - """ - - def __init__( - self, - *, - actor: torch.nn.Module | ActorProb | DiscreteActor, - critic: torch.nn.Module | Critic | DiscreteCritic, - optim: torch.optim.Optimizer, - dist_fn: TDistFnDiscrOrCont, - action_space: gym.Space, - eps_clip: float = 0.2, - dual_clip: float | None = None, - value_clip: bool = False, - advantage_normalization: bool = True, - recompute_advantage: bool = False, - vf_coef: float = 0.5, - ent_coef: float = 0.01, - max_grad_norm: float | None = None, - gae_lambda: float = 0.95, - max_batchsize: int = 256, - discount_factor: float = 0.99, - # TODO: rename to return_normalization? - reward_normalization: bool = False, - deterministic_eval: bool = False, - observation_space: gym.Space | None = None, - action_scaling: bool = True, - action_bound_method: Literal["clip", "tanh"] | None = "clip", - lr_scheduler: TLearningRateScheduler | None = None, - ) -> None: - assert ( - dual_clip is None or dual_clip > 1.0 - ), f"Dual-clip PPO parameter should greater than 1.0 but got {dual_clip}" - - super().__init__( - actor=actor, - critic=critic, - optim=optim, - dist_fn=dist_fn, - action_space=action_space, - vf_coef=vf_coef, - ent_coef=ent_coef, - max_grad_norm=max_grad_norm, - gae_lambda=gae_lambda, - max_batchsize=max_batchsize, - discount_factor=discount_factor, - reward_normalization=reward_normalization, - deterministic_eval=deterministic_eval, - observation_space=observation_space, - action_scaling=action_scaling, - action_bound_method=action_bound_method, - lr_scheduler=lr_scheduler, - ) - self.eps_clip = eps_clip - self.dual_clip = dual_clip - self.value_clip = value_clip - self.norm_adv = advantage_normalization - self.recompute_adv = recompute_advantage - self._actor_critic: ActorCritic - - def process_fn( - self, - batch: RolloutBatchProtocol, - buffer: ReplayBuffer, - indices: np.ndarray, - ) -> LogpOldProtocol: - if self.recompute_adv: - # buffer input `buffer` and `indices` to be used in `learn()`. - self._buffer, self._indices = buffer, indices - batch = self._compute_returns(batch, buffer, indices) - batch.act = to_torch_as(batch.act, batch.v_s) - logp_old = [] - with torch.no_grad(): - for minibatch in batch.split(self.max_batchsize, shuffle=False, merge_last=True): - logp_old.append(self(minibatch).dist.log_prob(minibatch.act)) - batch.logp_old = torch.cat(logp_old, dim=0).flatten() - batch: LogpOldProtocol - return batch - - # TODO: why does mypy complain? - def learn( # type: ignore - self, - batch: RolloutBatchProtocol, - batch_size: int | None, - repeat: int, - *args: Any, - **kwargs: Any, - ) -> TPPOTrainingStats: - losses, clip_losses, vf_losses, ent_losses = [], [], [], [] - gradient_steps = 0 - split_batch_size = batch_size or -1 - for step in range(repeat): - if self.recompute_adv and step > 0: - batch = self._compute_returns(batch, self._buffer, self._indices) - for minibatch in batch.split(split_batch_size, merge_last=True): - gradient_steps += 1 - # calculate loss for actor - advantages = minibatch.adv - dist = self(minibatch).dist - if self.norm_adv: - mean, std = advantages.mean(), advantages.std() - advantages = (advantages - mean) / (std + self._eps) # per-batch norm - ratios = (dist.log_prob(minibatch.act) - minibatch.logp_old).exp().float() - ratios = ratios.reshape(ratios.size(0), -1).transpose(0, 1) - surr1 = ratios * advantages - surr2 = ratios.clamp(1.0 - self.eps_clip, 1.0 + self.eps_clip) * advantages - if self.dual_clip: - clip1 = torch.min(surr1, surr2) - clip2 = torch.max(clip1, self.dual_clip * advantages) - clip_loss = -torch.where(advantages < 0, clip2, clip1).mean() - else: - clip_loss = -torch.min(surr1, surr2).mean() - # calculate loss for critic - value = self.critic(minibatch.obs).flatten() - if self.value_clip: - v_clip = minibatch.v_s + (value - minibatch.v_s).clamp( - -self.eps_clip, - self.eps_clip, - ) - vf1 = (minibatch.returns - value).pow(2) - vf2 = (minibatch.returns - v_clip).pow(2) - vf_loss = torch.max(vf1, vf2).mean() - else: - vf_loss = (minibatch.returns - value).pow(2).mean() - # calculate regularization and overall loss - ent_loss = dist.entropy().mean() - loss = clip_loss + self.vf_coef * vf_loss - self.ent_coef * ent_loss - self.optim.zero_grad() - loss.backward() - if self.max_grad_norm: # clip large gradient - nn.utils.clip_grad_norm_( - self._actor_critic.parameters(), - max_norm=self.max_grad_norm, - ) - self.optim.step() - clip_losses.append(clip_loss.item()) - vf_losses.append(vf_loss.item()) - ent_losses.append(ent_loss.item()) - losses.append(loss.item()) - - return PPOTrainingStats.from_sequences( # type: ignore[return-value] - losses=losses, - clip_losses=clip_losses, - vf_losses=vf_losses, - ent_losses=ent_losses, - gradient_steps=gradient_steps, - ) diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py deleted file mode 100644 index 71c36de0c..000000000 --- a/tianshou/policy/modelfree/qrdqn.py +++ /dev/null @@ -1,131 +0,0 @@ -import warnings -from dataclasses import dataclass -from typing import Any, Generic, TypeVar - -import gymnasium as gym -import numpy as np -import torch -import torch.nn.functional as F - -from tianshou.data import Batch, ReplayBuffer -from tianshou.data.types import RolloutBatchProtocol -from tianshou.policy import DQNPolicy -from tianshou.policy.base import TLearningRateScheduler -from tianshou.policy.modelfree.dqn import DQNTrainingStats - - -@dataclass(kw_only=True) -class QRDQNTrainingStats(DQNTrainingStats): - pass - - -TQRDQNTrainingStats = TypeVar("TQRDQNTrainingStats", bound=QRDQNTrainingStats) - - -class QRDQNPolicy(DQNPolicy[TQRDQNTrainingStats], Generic[TQRDQNTrainingStats]): - """Implementation of Quantile Regression Deep Q-Network. arXiv:1710.10044. - - :param model: a model following the rules (s -> action_values_BA) - :param optim: a torch.optim for optimizing the model. - :param action_space: Env's action space. - :param discount_factor: in [0, 1]. - :param num_quantiles: the number of quantile midpoints in the inverse - cumulative distribution function of the value. - :param estimation_step: the number of steps to look ahead. - :param target_update_freq: the target network update frequency (0 if - you do not use the target network). - :param reward_normalization: normalize the **returns** to Normal(0, 1). - TODO: rename to return_normalization? - :param is_double: use double dqn. - :param clip_loss_grad: clip the gradient of the loss in accordance - with nature14236; this amounts to using the Huber loss instead of - the MSE loss. - :param observation_space: Env's observation space. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.DQNPolicy` for more detailed - explanation. - """ - - def __init__( - self, - *, - model: torch.nn.Module, - optim: torch.optim.Optimizer, - action_space: gym.spaces.Discrete, - discount_factor: float = 0.99, - num_quantiles: int = 200, - estimation_step: int = 1, - target_update_freq: int = 0, - reward_normalization: bool = False, - is_double: bool = True, - clip_loss_grad: bool = False, - observation_space: gym.Space | None = None, - lr_scheduler: TLearningRateScheduler | None = None, - ) -> None: - assert num_quantiles > 1, f"num_quantiles should be greater than 1 but got: {num_quantiles}" - super().__init__( - model=model, - optim=optim, - action_space=action_space, - discount_factor=discount_factor, - estimation_step=estimation_step, - target_update_freq=target_update_freq, - reward_normalization=reward_normalization, - is_double=is_double, - clip_loss_grad=clip_loss_grad, - observation_space=observation_space, - lr_scheduler=lr_scheduler, - ) - self.num_quantiles = num_quantiles - tau = torch.linspace(0, 1, self.num_quantiles + 1) - self.tau_hat = torch.nn.Parameter( - ((tau[:-1] + tau[1:]) / 2).view(1, -1, 1), - requires_grad=False, - ) - warnings.filterwarnings("ignore", message="Using a target size") - - def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: - obs_next_batch = Batch( - obs=buffer[indices].obs_next, - info=[None] * len(indices), - ) # obs_next: s_{t+n} - if self._target: - act = self(obs_next_batch).act - next_dist = self(obs_next_batch, model="model_old").logits - else: - next_batch = self(obs_next_batch) - act = next_batch.act - next_dist = next_batch.logits - return next_dist[np.arange(len(act)), act, :] - - def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torch.Tensor: - return super().compute_q_value(logits.mean(2), mask) - - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TQRDQNTrainingStats: - if self._target and self._iter % self.freq == 0: - self.sync_weight() - self.optim.zero_grad() - weight = batch.pop("weight", 1.0) - curr_dist = self(batch).logits - act = batch.act - curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2) - target_dist = batch.returns.unsqueeze(1) - # calculate each element's difference between curr_dist and target_dist - dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") - huber_loss = ( - (dist_diff * (self.tau_hat - (target_dist - curr_dist).detach().le(0.0).float()).abs()) - .sum(-1) - .mean(1) - ) - loss = (huber_loss * weight).mean() - # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ - # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 - batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer - loss.backward() - self.optim.step() - self._iter += 1 - - return QRDQNTrainingStats(loss=loss.item()) # type: ignore[return-value] diff --git a/tianshou/policy/modelfree/rainbow.py b/tianshou/policy/modelfree/rainbow.py deleted file mode 100644 index fad567cd2..000000000 --- a/tianshou/policy/modelfree/rainbow.py +++ /dev/null @@ -1,59 +0,0 @@ -from dataclasses import dataclass -from typing import Any, TypeVar - -from torch import nn - -from tianshou.data.types import RolloutBatchProtocol -from tianshou.policy import C51Policy -from tianshou.policy.modelfree.c51 import C51TrainingStats -from tianshou.utils.net.discrete import NoisyLinear - - -# TODO: this is a hacky thing interviewing side-effects and a return. Should improve. -def _sample_noise(model: nn.Module) -> bool: - """Sample the random noises of NoisyLinear modules in the model. - - Returns True if at least one NoisyLinear submodule was found. - - :param model: a PyTorch module which may have NoisyLinear submodules. - :returns: True if model has at least one NoisyLinear submodule; - otherwise, False. - """ - sampled_any_noise = False - for m in model.modules(): - if isinstance(m, NoisyLinear): - m.sample() - sampled_any_noise = True - return sampled_any_noise - - -@dataclass(kw_only=True) -class RainbowTrainingStats(C51TrainingStats): - loss: float - - -TRainbowTrainingStats = TypeVar("TRainbowTrainingStats", bound=RainbowTrainingStats) - - -# TODO: is this class worth keeping? It barely does anything -class RainbowPolicy(C51Policy[TRainbowTrainingStats]): - """Implementation of Rainbow DQN. arXiv:1710.02298. - - Same parameters as :class:`~tianshou.policy.C51Policy`. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.C51Policy` for more detailed - explanation. - """ - - def learn( - self, - batch: RolloutBatchProtocol, - *args: Any, - **kwargs: Any, - ) -> TRainbowTrainingStats: - _sample_noise(self.model) - if self._target and _sample_noise(self.model_old): - self.model_old.train() # so that NoisyLinear takes effect - return super().learn(batch, **kwargs) diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py deleted file mode 100644 index 25f299733..000000000 --- a/tianshou/policy/modelfree/redq.py +++ /dev/null @@ -1,239 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Literal, TypeVar, cast - -import gymnasium as gym -import numpy as np -import torch -from torch.distributions import Independent, Normal - -from tianshou.data import Batch, ReplayBuffer -from tianshou.data.types import ObsBatchProtocol, RolloutBatchProtocol -from tianshou.exploration import BaseNoise -from tianshou.policy import DDPGPolicy -from tianshou.policy.base import TLearningRateScheduler -from tianshou.policy.modelfree.ddpg import DDPGTrainingStats -from tianshou.utils.net.continuous import ActorProb - - -@dataclass -class REDQTrainingStats(DDPGTrainingStats): - """A data structure for storing loss statistics of the REDQ learn step.""" - - alpha: float | None = None - alpha_loss: float | None = None - - -TREDQTrainingStats = TypeVar("TREDQTrainingStats", bound=REDQTrainingStats) - - -class REDQPolicy(DDPGPolicy[TREDQTrainingStats]): - """Implementation of REDQ. arXiv:2101.05982. - - :param actor: The actor network following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> model_output) - :param actor_optim: The optimizer for actor network. - :param critic: The critic network. (s, a -> Q(s, a)) - :param critic_optim: The optimizer for critic network. - :param action_space: Env's action space. - :param ensemble_size: Number of sub-networks in the critic ensemble. - :param subset_size: Number of networks in the subset. - :param tau: Param for soft update of the target network. - :param gamma: Discount factor, in [0, 1]. - :param alpha: entropy regularization coefficient. - If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then - alpha is automatically tuned. - :param exploration_noise: The exploration noise, added to the action. Defaults - to ``GaussianNoise(sigma=0.1)``. - :param estimation_step: The number of steps to look ahead. - :param actor_delay: Number of critic updates before an actor update. - :param observation_space: Env's observation space. - :param action_scaling: if True, scale the action from [-1, 1] to the range - of action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - Only used if the action_space is continuous. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. - """ - - def __init__( - self, - *, - actor: torch.nn.Module | ActorProb, - actor_optim: torch.optim.Optimizer, - critic: torch.nn.Module, - critic_optim: torch.optim.Optimizer, - action_space: gym.spaces.Box, - ensemble_size: int = 10, - subset_size: int = 2, - tau: float = 0.005, - gamma: float = 0.99, - alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] = 0.2, - estimation_step: int = 1, - actor_delay: int = 20, - exploration_noise: BaseNoise | Literal["default"] | None = None, - deterministic_eval: bool = True, - target_mode: Literal["mean", "min"] = "min", - action_scaling: bool = True, - action_bound_method: Literal["clip"] | None = "clip", - observation_space: gym.Space | None = None, - lr_scheduler: TLearningRateScheduler | None = None, - ) -> None: - if target_mode not in ("min", "mean"): - raise ValueError(f"Unsupported target_mode: {target_mode}") - if not 0 < subset_size <= ensemble_size: - raise ValueError( - f"Invalid choice of ensemble size or subset size. " - f"Should be 0 < {subset_size=} <= {ensemble_size=}", - ) - super().__init__( - actor=actor, - actor_optim=actor_optim, - critic=critic, - critic_optim=critic_optim, - action_space=action_space, - tau=tau, - gamma=gamma, - exploration_noise=exploration_noise, - estimation_step=estimation_step, - action_scaling=action_scaling, - action_bound_method=action_bound_method, - observation_space=observation_space, - lr_scheduler=lr_scheduler, - ) - self.ensemble_size = ensemble_size - self.subset_size = subset_size - - self.target_mode = target_mode - self.critic_gradient_step = 0 - self.actor_delay = actor_delay - self.deterministic_eval = deterministic_eval - self.__eps = np.finfo(np.float32).eps.item() - - self._last_actor_loss = 0.0 # only for logging purposes - - # TODO: reduce duplication with SACPolicy - self.alpha: float | torch.Tensor - self._is_auto_alpha = not isinstance(alpha, float) - if self._is_auto_alpha: - # TODO: why doesn't mypy understand that this must be a tuple? - alpha = cast(tuple[float, torch.Tensor, torch.optim.Optimizer], alpha) - if alpha[1].shape != torch.Size([1]): - raise ValueError( - f"Expected log_alpha to have shape torch.Size([1]), " - f"but got {alpha[1].shape} instead.", - ) - if not alpha[1].requires_grad: - raise ValueError("Expected log_alpha to require gradient, but it doesn't.") - - self.target_entropy, self.log_alpha, self.alpha_optim = alpha - self.alpha = self.log_alpha.detach().exp() - else: - # TODO: make mypy undestand this, or switch to something like pyright... - alpha = cast(float, alpha) - self.alpha = alpha - - @property - def is_auto_alpha(self) -> bool: - return self._is_auto_alpha - - # TODO: why override from the base class? - def sync_weight(self) -> None: - for o, n in zip(self.critic_old.parameters(), self.critic.parameters(), strict=True): - o.data.copy_(o.data * (1.0 - self.tau) + n.data * self.tau) - - def forward( # type: ignore - self, - batch: ObsBatchProtocol, - state: dict | Batch | np.ndarray | None = None, - **kwargs: Any, - ) -> Batch: - (loc_B, scale_B), h_BH = self.actor(batch.obs, state=state, info=batch.info) - dist = Independent(Normal(loc_B, scale_B), 1) - if self.deterministic_eval and not self.is_within_training_step: - act_B = dist.mode - else: - act_B = dist.rsample() - log_prob = dist.log_prob(act_B).unsqueeze(-1) - # apply correction for Tanh squashing when computing logprob from Gaussian - # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. - # in appendix C to get some understanding of this equation. - squashed_action = torch.tanh(act_B) - log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self.__eps).sum( - -1, - keepdim=True, - ) - return Batch( - logits=(loc_B, scale_B), - act=squashed_action, - state=h_BH, - dist=dist, - log_prob=log_prob, - ) - - def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: - obs_next_batch = Batch( - obs=buffer[indices].obs_next, - info=[None] * len(indices), - ) # obs_next: s_{t+n} - obs_next_result = self(obs_next_batch) - a_ = obs_next_result.act - sample_ensemble_idx = np.random.choice(self.ensemble_size, self.subset_size, replace=False) - qs = self.critic_old(obs_next_batch.obs, a_)[sample_ensemble_idx, ...] - if self.target_mode == "min": - target_q, _ = torch.min(qs, dim=0) - elif self.target_mode == "mean": - target_q = torch.mean(qs, dim=0) - - target_q -= self.alpha * obs_next_result.log_prob - - return target_q - - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TREDQTrainingStats: # type: ignore - # critic ensemble - weight = getattr(batch, "weight", 1.0) - current_qs = self.critic(batch.obs, batch.act).flatten(1) - target_q = batch.returns.flatten() - td = current_qs - target_q - critic_loss = (td.pow(2) * weight).mean() - self.critic_optim.zero_grad() - critic_loss.backward() - self.critic_optim.step() - batch.weight = torch.mean(td, dim=0) # prio-buffer - self.critic_gradient_step += 1 - - alpha_loss = None - # actor - if self.critic_gradient_step % self.actor_delay == 0: - obs_result = self(batch) - a = obs_result.act - current_qa = self.critic(batch.obs, a).mean(dim=0).flatten() - actor_loss = (self.alpha * obs_result.log_prob.flatten() - current_qa).mean() - self.actor_optim.zero_grad() - actor_loss.backward() - self.actor_optim.step() - - if self.is_auto_alpha: - log_prob = obs_result.log_prob.detach() + self._target_entropy - alpha_loss = -(self._log_alpha * log_prob).mean() - self.alpha_optim.zero_grad() - alpha_loss.backward() - self.alpha_optim.step() - self.alpha = self.log_alpha.detach().exp() - - self.sync_weight() - - if self.critic_gradient_step % self.actor_delay == 0: - self._last_actor_loss = actor_loss.item() - if self.is_auto_alpha: - self.alpha = cast(torch.Tensor, self.alpha) - - return REDQTrainingStats( # type: ignore[return-value] - actor_loss=self._last_actor_loss, - critic_loss=critic_loss.item(), - alpha=self.alpha.item() if isinstance(self.alpha, torch.Tensor) else self.alpha, - alpha_loss=alpha_loss, - ) diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py deleted file mode 100644 index c1f19eff7..000000000 --- a/tianshou/policy/modelfree/sac.py +++ /dev/null @@ -1,262 +0,0 @@ -from copy import deepcopy -from dataclasses import dataclass -from typing import Any, Generic, Literal, Self, TypeVar, cast - -import gymnasium as gym -import numpy as np -import torch -from torch.distributions import Independent, Normal - -from tianshou.data import Batch, ReplayBuffer -from tianshou.data.types import ( - DistLogProbBatchProtocol, - ObsBatchProtocol, - RolloutBatchProtocol, -) -from tianshou.exploration import BaseNoise -from tianshou.policy import DDPGPolicy -from tianshou.policy.base import TLearningRateScheduler, TrainingStats -from tianshou.utils.conversion import to_optional_float -from tianshou.utils.net.continuous import ActorProb -from tianshou.utils.optim import clone_optimizer - - -def correct_log_prob_gaussian_tanh( - log_prob: torch.Tensor, - tanh_squashed_action: torch.Tensor, - eps: float = np.finfo(np.float32).eps.item(), -) -> torch.Tensor: - """Apply correction for Tanh squashing when computing `log_prob` from Gaussian. - - See equation 21 in the original `SAC paper `_. - - :param log_prob: log probability of the action - :param tanh_squashed_action: action squashed to values in (-1, 1) range by tanh - :param eps: epsilon for numerical stability - """ - log_prob_correction = torch.log(1 - tanh_squashed_action.pow(2) + eps).sum(-1, keepdim=True) - return log_prob - log_prob_correction - - -@dataclass(kw_only=True) -class SACTrainingStats(TrainingStats): - actor_loss: float - critic1_loss: float - critic2_loss: float - alpha: float | None = None - alpha_loss: float | None = None - - -TSACTrainingStats = TypeVar("TSACTrainingStats", bound=SACTrainingStats) - - -# TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. -class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # type: ignore[type-var] - """Implementation of Soft Actor-Critic. arXiv:1812.05905. - - :param actor: the actor network following the rules (s -> dist_input_BD) - :param actor_optim: the optimizer for actor network. - :param critic: the first critic network. (s, a -> Q(s, a)) - :param critic_optim: the optimizer for the first critic network. - :param action_space: Env's action space. Should be gym.spaces.Box. - :param critic2: the second critic network. (s, a -> Q(s, a)). - If None, use the same network as critic (via deepcopy). - :param critic2_optim: the optimizer for the second critic network. - If None, clone critic_optim to use for critic2.parameters(). - :param tau: param for soft update of the target network. - :param gamma: discount factor, in [0, 1]. - :param alpha: entropy regularization coefficient. - If a tuple (target_entropy, log_alpha, alpha_optim) is provided, - then alpha is automatically tuned. - :param estimation_step: The number of steps to look ahead. - :param exploration_noise: add noise to action for exploration. - This is useful when solving "hard exploration" problems. - "default" is equivalent to GaussianNoise(sigma=0.1). - :param deterministic_eval: whether to use deterministic action - (mode of Gaussian policy) in evaluation mode instead of stochastic - action sampled by the policy. Does not affect training. - :param action_scaling: whether to map actions from range [-1, 1] - to range[action_spaces.low, action_spaces.high]. - :param action_bound_method: method to bound action to range [-1, 1], - can be either "clip" (for simply clipping the action) - or empty string for no bounding. Only used if the action_space is continuous. - This parameter is ignored in SAC, which used tanh squashing after sampling - unbounded from the gaussian policy (as in (arXiv 1801.01290): Equation 21.). - :param observation_space: Env's observation space. - :param lr_scheduler: a learning rate scheduler that adjusts the learning rate - in optimizer in each policy.update() - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. - """ - - def __init__( - self, - *, - actor: torch.nn.Module | ActorProb, - actor_optim: torch.optim.Optimizer, - critic: torch.nn.Module, - critic_optim: torch.optim.Optimizer, - action_space: gym.Space, - critic2: torch.nn.Module | None = None, - critic2_optim: torch.optim.Optimizer | None = None, - tau: float = 0.005, - gamma: float = 0.99, - alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] = 0.2, - estimation_step: int = 1, - exploration_noise: BaseNoise | Literal["default"] | None = None, - deterministic_eval: bool = True, - action_scaling: bool = True, - action_bound_method: Literal["clip"] | None = "clip", - observation_space: gym.Space | None = None, - lr_scheduler: TLearningRateScheduler | None = None, - ) -> None: - super().__init__( - actor=actor, - actor_optim=actor_optim, - critic=critic, - critic_optim=critic_optim, - action_space=action_space, - tau=tau, - gamma=gamma, - exploration_noise=exploration_noise, - estimation_step=estimation_step, - action_scaling=action_scaling, - action_bound_method=action_bound_method, - observation_space=observation_space, - lr_scheduler=lr_scheduler, - ) - critic2 = critic2 or deepcopy(critic) - critic2_optim = critic2_optim or clone_optimizer(critic_optim, critic2.parameters()) - self.critic2, self.critic2_old = critic2, deepcopy(critic2) - self.critic2_old.eval() - self.critic2_optim = critic2_optim - self.deterministic_eval = deterministic_eval - - self.alpha: float | torch.Tensor - self._is_auto_alpha = not isinstance(alpha, float) - if self._is_auto_alpha: - # TODO: why doesn't mypy understand that this must be a tuple? - alpha = cast(tuple[float, torch.Tensor, torch.optim.Optimizer], alpha) - if alpha[1].shape != torch.Size([1]): - raise ValueError( - f"Expected log_alpha to have shape torch.Size([1]), " - f"but got {alpha[1].shape} instead.", - ) - if not alpha[1].requires_grad: - raise ValueError("Expected log_alpha to require gradient, but it doesn't.") - - self.target_entropy, self.log_alpha, self.alpha_optim = alpha - self.alpha = self.log_alpha.detach().exp() - else: - alpha = cast( - float, - alpha, - ) # can we convert alpha to a constant tensor here? then mypy wouldn't complain - self.alpha = alpha - - # TODO or not TODO: add to BasePolicy? - self._check_field_validity() - - def _check_field_validity(self) -> None: - if not isinstance(self.action_space, gym.spaces.Box): - raise ValueError( - f"SACPolicy only supports gym.spaces.Box, but got {self.action_space=}." - f"Please use DiscreteSACPolicy for discrete action spaces.", - ) - - @property - def is_auto_alpha(self) -> bool: - return self._is_auto_alpha - - def train(self, mode: bool = True) -> Self: - self.training = mode - self.actor.train(mode) - self.critic.train(mode) - self.critic2.train(mode) - return self - - def sync_weight(self) -> None: - self.soft_update(self.critic_old, self.critic, self.tau) - self.soft_update(self.critic2_old, self.critic2, self.tau) - - # TODO: violates Liskov substitution principle - def forward( # type: ignore - self, - batch: ObsBatchProtocol, - state: dict | Batch | np.ndarray | None = None, - **kwargs: Any, - ) -> DistLogProbBatchProtocol: - (loc_B, scale_B), hidden_BH = self.actor(batch.obs, state=state, info=batch.info) - dist = Independent(Normal(loc=loc_B, scale=scale_B), 1) - if self.deterministic_eval and not self.is_within_training_step: - act_B = dist.mode - else: - act_B = dist.rsample() - log_prob = dist.log_prob(act_B).unsqueeze(-1) - - squashed_action = torch.tanh(act_B) - log_prob = correct_log_prob_gaussian_tanh(log_prob, squashed_action) - result = Batch( - logits=(loc_B, scale_B), - act=squashed_action, - state=hidden_BH, - dist=dist, - log_prob=log_prob, - ) - return cast(DistLogProbBatchProtocol, result) - - def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: - obs_next_batch = Batch( - obs=buffer[indices].obs_next, - info=[None] * len(indices), - ) # obs_next: s_{t+n} - obs_next_result = self(obs_next_batch) - act_ = obs_next_result.act - return ( - torch.min( - self.critic_old(obs_next_batch.obs, act_), - self.critic2_old(obs_next_batch.obs, act_), - ) - - self.alpha * obs_next_result.log_prob - ) - - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TSACTrainingStats: # type: ignore - # critic 1&2 - td1, critic1_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) - td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim) - batch.weight = (td1 + td2) / 2.0 # prio-buffer - - # actor - obs_result = self(batch) - act = obs_result.act - current_q1a = self.critic(batch.obs, act).flatten() - current_q2a = self.critic2(batch.obs, act).flatten() - actor_loss = ( - self.alpha * obs_result.log_prob.flatten() - torch.min(current_q1a, current_q2a) - ).mean() - self.actor_optim.zero_grad() - actor_loss.backward() - self.actor_optim.step() - alpha_loss = None - - if self.is_auto_alpha: - log_prob = obs_result.log_prob.detach() + self.target_entropy - # please take a look at issue #258 if you'd like to change this line - alpha_loss = -(self.log_alpha * log_prob).mean() - self.alpha_optim.zero_grad() - alpha_loss.backward() - self.alpha_optim.step() - self.alpha = self.log_alpha.detach().exp() - - self.sync_weight() - - return SACTrainingStats( # type: ignore[return-value] - actor_loss=actor_loss.item(), - critic1_loss=critic1_loss.item(), - critic2_loss=critic2_loss.item(), - alpha=to_optional_float(self.alpha), - alpha_loss=to_optional_float(alpha_loss), - ) diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py deleted file mode 100644 index 8c2ae8c98..000000000 --- a/tianshou/policy/modelfree/td3.py +++ /dev/null @@ -1,163 +0,0 @@ -from copy import deepcopy -from dataclasses import dataclass -from typing import Any, Generic, Literal, Self, TypeVar - -import gymnasium as gym -import numpy as np -import torch - -from tianshou.data import Batch, ReplayBuffer -from tianshou.data.types import RolloutBatchProtocol -from tianshou.exploration import BaseNoise -from tianshou.policy import DDPGPolicy -from tianshou.policy.base import TLearningRateScheduler, TrainingStats -from tianshou.utils.optim import clone_optimizer - - -@dataclass(kw_only=True) -class TD3TrainingStats(TrainingStats): - actor_loss: float - critic1_loss: float - critic2_loss: float - - -TTD3TrainingStats = TypeVar("TTD3TrainingStats", bound=TD3TrainingStats) - - -# TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. -class TD3Policy(DDPGPolicy[TTD3TrainingStats], Generic[TTD3TrainingStats]): # type: ignore[type-var] - """Implementation of TD3, arXiv:1802.09477. - - :param actor: the actor network following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> actions) - :param actor_optim: the optimizer for actor network. - :param critic: the first critic network. (s, a -> Q(s, a)) - :param critic_optim: the optimizer for the first critic network. - :param action_space: Env's action space. Should be gym.spaces.Box. - :param critic2: the second critic network. (s, a -> Q(s, a)). - If None, use the same network as critic (via deepcopy). - :param critic2_optim: the optimizer for the second critic network. - If None, clone critic_optim to use for critic2.parameters(). - :param tau: param for soft update of the target network. - :param gamma: discount factor, in [0, 1]. - :param exploration_noise: add noise to action for exploration. - This is useful when solving "hard exploration" problems. - "default" is equivalent to GaussianNoise(sigma=0.1). - :param policy_noise: the noise used in updating policy network. - :param update_actor_freq: the update frequency of actor network. - :param noise_clip: the clipping range used in updating policy network. - :param observation_space: Env's observation space. - :param action_scaling: if True, scale the action from [-1, 1] to the range - of action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - Only used if the action_space is continuous. - :param lr_scheduler: a learning rate scheduler that adjusts the learning rate - in optimizer in each policy.update() - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. - """ - - def __init__( - self, - *, - actor: torch.nn.Module, - actor_optim: torch.optim.Optimizer, - critic: torch.nn.Module, - critic_optim: torch.optim.Optimizer, - action_space: gym.Space, - critic2: torch.nn.Module | None = None, - critic2_optim: torch.optim.Optimizer | None = None, - tau: float = 0.005, - gamma: float = 0.99, - exploration_noise: BaseNoise | Literal["default"] | None = "default", - policy_noise: float = 0.2, - update_actor_freq: int = 2, - noise_clip: float = 0.5, - estimation_step: int = 1, - observation_space: gym.Space | None = None, - action_scaling: bool = True, - action_bound_method: Literal["clip"] | None = "clip", - lr_scheduler: TLearningRateScheduler | None = None, - ) -> None: - # TODO: reduce duplication with SAC. - # Some intermediate class, like TwoCriticPolicy? - super().__init__( - actor=actor, - actor_optim=actor_optim, - critic=critic, - critic_optim=critic_optim, - action_space=action_space, - tau=tau, - gamma=gamma, - exploration_noise=exploration_noise, - estimation_step=estimation_step, - action_scaling=action_scaling, - action_bound_method=action_bound_method, - observation_space=observation_space, - lr_scheduler=lr_scheduler, - ) - if critic2 and not critic2_optim: - raise ValueError("critic2_optim must be provided if critic2 is provided") - critic2 = critic2 or deepcopy(critic) - critic2_optim = critic2_optim or clone_optimizer(critic_optim, critic2.parameters()) - self.critic2, self.critic2_old = critic2, deepcopy(critic2) - self.critic2_old.eval() - self.critic2_optim = critic2_optim - - self.policy_noise = policy_noise - self.update_actor_freq = update_actor_freq - self.noise_clip = noise_clip - self._cnt = 0 - self._last = 0 - - def train(self, mode: bool = True) -> Self: - self.training = mode - self.actor.train(mode) - self.critic.train(mode) - self.critic2.train(mode) - return self - - def sync_weight(self) -> None: - self.soft_update(self.critic_old, self.critic, self.tau) - self.soft_update(self.critic2_old, self.critic2, self.tau) - self.soft_update(self.actor_old, self.actor, self.tau) - - def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: - obs_next_batch = Batch( - obs=buffer[indices].obs_next, - info=[None] * len(indices), - ) # obs_next: s_{t+n} - act_ = self(obs_next_batch, model="actor_old").act - noise = torch.randn(size=act_.shape, device=act_.device) * self.policy_noise - if self.noise_clip > 0.0: - noise = noise.clamp(-self.noise_clip, self.noise_clip) - act_ += noise - return torch.min( - self.critic_old(obs_next_batch.obs, act_), - self.critic2_old(obs_next_batch.obs, act_), - ) - - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TTD3TrainingStats: # type: ignore - # critic 1&2 - td1, critic1_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) - td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim) - batch.weight = (td1 + td2) / 2.0 # prio-buffer - - # actor - if self._cnt % self.update_actor_freq == 0: - actor_loss = -self.critic(batch.obs, self(batch, eps=0.0).act).mean() - self.actor_optim.zero_grad() - actor_loss.backward() - self._last = actor_loss.item() - self.actor_optim.step() - self.sync_weight() - self._cnt += 1 - - return TD3TrainingStats( # type: ignore[return-value] - actor_loss=self._last, - critic1_loss=critic1_loss.item(), - critic2_loss=critic2_loss.item(), - ) diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py deleted file mode 100644 index 51a2d7cf0..000000000 --- a/tianshou/policy/modelfree/trpo.py +++ /dev/null @@ -1,200 +0,0 @@ -import warnings -from dataclasses import dataclass -from typing import Any, Literal, TypeVar - -import gymnasium as gym -import torch -import torch.nn.functional as F -from torch.distributions import kl_divergence - -from tianshou.data import Batch, SequenceSummaryStats -from tianshou.policy import NPGPolicy -from tianshou.policy.base import TLearningRateScheduler -from tianshou.policy.modelfree.npg import NPGTrainingStats -from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont -from tianshou.utils.net.continuous import ActorProb, Critic -from tianshou.utils.net.discrete import Actor as DiscreteActor -from tianshou.utils.net.discrete import Critic as DiscreteCritic - - -@dataclass(kw_only=True) -class TRPOTrainingStats(NPGTrainingStats): - step_size: SequenceSummaryStats - - -TTRPOTrainingStats = TypeVar("TTRPOTrainingStats", bound=TRPOTrainingStats) - - -class TRPOPolicy(NPGPolicy[TTRPOTrainingStats]): - """Implementation of Trust Region Policy Optimization. arXiv:1502.05477. - - :param actor: the actor network following the rules: - If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`). - If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`). - :param critic: the critic network. (s -> V(s)) - :param optim: the optimizer for the critic network only. The actor network - is optimized via natural gradients internally. - :param dist_fn: distribution class for computing the action. - :param action_space: env's action space - :param max_kl: max kl-divergence used to constrain each actor network update. - :param backtrack_coeff: Coefficient to be multiplied by step size when - constraints are not met. - :param max_backtracks: Max number of backtracking times in linesearch. - :param optim_critic_iters: Number of times to optimize critic network per update. - :param actor_step_size: step size for actor update in natural gradient direction. - :param advantage_normalization: whether to do per mini-batch advantage - normalization. - :param gae_lambda: in [0, 1], param for Generalized Advantage Estimation. - :param max_batchsize: the maximum size of the batch when computing GAE. - :param discount_factor: in [0, 1]. - :param reward_normalization: normalize estimated values to have std close to 1. - :param deterministic_eval: if True, use deterministic evaluation. - :param observation_space: the space of the observation. - :param action_scaling: if True, scale the action from [-1, 1] to the range of - action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - :param lr_scheduler: if not None, will be called in `policy.update()`. - """ - - def __init__( - self, - *, - actor: torch.nn.Module | ActorProb | DiscreteActor, - critic: torch.nn.Module | Critic | DiscreteCritic, - optim: torch.optim.Optimizer, - dist_fn: TDistFnDiscrOrCont, - action_space: gym.Space, - max_kl: float = 0.01, - backtrack_coeff: float = 0.8, - max_backtracks: int = 10, - optim_critic_iters: int = 5, - actor_step_size: float = 0.5, - advantage_normalization: bool = True, - gae_lambda: float = 0.95, - max_batchsize: int = 256, - discount_factor: float = 0.99, - # TODO: rename to return_normalization? - reward_normalization: bool = False, - deterministic_eval: bool = False, - observation_space: gym.Space | None = None, - action_scaling: bool = True, - action_bound_method: Literal["clip", "tanh"] | None = "clip", - lr_scheduler: TLearningRateScheduler | None = None, - ) -> None: - super().__init__( - actor=actor, - critic=critic, - optim=optim, - dist_fn=dist_fn, - action_space=action_space, - optim_critic_iters=optim_critic_iters, - actor_step_size=actor_step_size, - advantage_normalization=advantage_normalization, - gae_lambda=gae_lambda, - max_batchsize=max_batchsize, - discount_factor=discount_factor, - reward_normalization=reward_normalization, - deterministic_eval=deterministic_eval, - observation_space=observation_space, - action_scaling=action_scaling, - action_bound_method=action_bound_method, - lr_scheduler=lr_scheduler, - ) - self.max_backtracks = max_backtracks - self.max_kl = max_kl - self.backtrack_coeff = backtrack_coeff - - def learn( # type: ignore - self, - batch: Batch, - batch_size: int | None, - repeat: int, - **kwargs: Any, - ) -> TTRPOTrainingStats: - actor_losses, vf_losses, step_sizes, kls = [], [], [], [] - split_batch_size = batch_size or -1 - for _ in range(repeat): - for minibatch in batch.split(split_batch_size, merge_last=True): - # optimize actor - # direction: calculate villia gradient - dist = self(minibatch).dist # TODO could come from batch - ratio = (dist.log_prob(minibatch.act) - minibatch.logp_old).exp().float() - ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) - actor_loss = -(ratio * minibatch.adv).mean() - flat_grads = self._get_flat_grad(actor_loss, self.actor, retain_graph=True).detach() - - # direction: calculate natural gradient - with torch.no_grad(): - old_dist = self(minibatch).dist - - kl = kl_divergence(old_dist, dist).mean() - # calculate first order gradient of kl with respect to theta - flat_kl_grad = self._get_flat_grad(kl, self.actor, create_graph=True) - search_direction = -self._conjugate_gradients(flat_grads, flat_kl_grad, nsteps=10) - - # stepsize: calculate max stepsize constrained by kl bound - step_size = torch.sqrt( - 2 - * self.max_kl - / (search_direction * self._MVP(search_direction, flat_kl_grad)).sum( - 0, - keepdim=True, - ), - ) - - # stepsize: linesearch stepsize - with torch.no_grad(): - flat_params = torch.cat( - [param.data.view(-1) for param in self.actor.parameters()], - ) - for i in range(self.max_backtracks): - new_flat_params = flat_params + step_size * search_direction - self._set_from_flat_params(self.actor, new_flat_params) - # calculate kl and if in bound, loss actually down - new_dist = self(minibatch).dist - new_dratio = ( - (new_dist.log_prob(minibatch.act) - minibatch.logp_old).exp().float() - ) - new_dratio = new_dratio.reshape(new_dratio.size(0), -1).transpose(0, 1) - new_actor_loss = -(new_dratio * minibatch.adv).mean() - kl = kl_divergence(old_dist, new_dist).mean() - - if kl < self.max_kl and new_actor_loss < actor_loss: - if i > 0: - warnings.warn(f"Backtracking to step {i}.") - break - if i < self.max_backtracks - 1: - step_size = step_size * self.backtrack_coeff - else: - self._set_from_flat_params(self.actor, new_flat_params) - step_size = torch.tensor([0.0]) - warnings.warn( - "Line search failed! It seems hyperparamters" - " are poor and need to be changed.", - ) - - # optimize critic - # TODO: remove type-ignore once the top-level type-ignore is removed - for _ in range(self.optim_critic_iters): # type: ignore - value = self.critic(minibatch.obs).flatten() - vf_loss = F.mse_loss(minibatch.returns, value) - self.optim.zero_grad() - vf_loss.backward() - self.optim.step() - - actor_losses.append(actor_loss.item()) - vf_losses.append(vf_loss.item()) - step_sizes.append(step_size.item()) - kls.append(kl.item()) - - actor_loss_summary_stat = SequenceSummaryStats.from_sequence(actor_losses) - vf_loss_summary_stat = SequenceSummaryStats.from_sequence(vf_losses) - kl_summary_stat = SequenceSummaryStats.from_sequence(kls) - step_size_stat = SequenceSummaryStats.from_sequence(step_sizes) - - return TRPOTrainingStats( # type: ignore[return-value] - actor_loss=actor_loss_summary_stat, - vf_loss=vf_loss_summary_stat, - kl=kl_summary_stat, - step_size=step_size_stat, - ) diff --git a/tianshou/policy/random.py b/tianshou/policy/random.py deleted file mode 100644 index bf665bc2b..000000000 --- a/tianshou/policy/random.py +++ /dev/null @@ -1,54 +0,0 @@ -from typing import Any, TypeVar, cast - -import numpy as np - -from tianshou.data import Batch -from tianshou.data.batch import BatchProtocol -from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol -from tianshou.policy import BasePolicy -from tianshou.policy.base import TrainingStats - - -class MARLRandomTrainingStats(TrainingStats): - pass - - -TMARLRandomTrainingStats = TypeVar("TMARLRandomTrainingStats", bound=MARLRandomTrainingStats) - - -class MARLRandomPolicy(BasePolicy[TMARLRandomTrainingStats]): - """A random agent used in multi-agent learning. - - It randomly chooses an action from the legal action. - """ - - def forward( - self, - batch: ObsBatchProtocol, - state: dict | BatchProtocol | np.ndarray | None = None, - **kwargs: Any, - ) -> ActBatchProtocol: - """Compute the random action over the given batch data. - - The input should contain a mask in batch.obs, with "True" to be - available and "False" to be unavailable. For example, - ``batch.obs.mask == np.array([[False, True, False]])`` means with batch - size 1, action "1" is available but action "0" and "2" are unavailable. - - :return: A :class:`~tianshou.data.Batch` with "act" key, containing - the random action. - - .. seealso:: - - Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for - more detailed explanation. - """ - mask = batch.obs.mask # type: ignore - logits = np.random.rand(*mask.shape) - logits[~mask] = -np.inf - result = Batch(act=logits.argmax(axis=-1)) - return cast(ActBatchProtocol, result) - - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TMARLRandomTrainingStats: # type: ignore - """Since a random agent learns nothing, it returns an empty dict.""" - return MARLRandomTrainingStats() # type: ignore[return-value] diff --git a/tianshou/trainer/__init__.py b/tianshou/trainer/__init__.py index 5946555a2..702ce6bf4 100644 --- a/tianshou/trainer/__init__.py +++ b/tianshou/trainer/__init__.py @@ -1,18 +1,12 @@ """Trainer package.""" -from tianshou.trainer.base import ( - BaseTrainer, +from .trainer import ( OfflineTrainer, - OffpolicyTrainer, - OnpolicyTrainer, + OfflineTrainerParams, + OffPolicyTrainer, + OffPolicyTrainerParams, + OnPolicyTrainer, + OnPolicyTrainerParams, + Trainer, + TrainerParams, ) -from tianshou.trainer.utils import gather_info, test_episode - -__all__ = [ - "BaseTrainer", - "OffpolicyTrainer", - "OnpolicyTrainer", - "OfflineTrainer", - "test_episode", - "gather_info", -] diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py deleted file mode 100644 index ec4645741..000000000 --- a/tianshou/trainer/base.py +++ /dev/null @@ -1,823 +0,0 @@ -import logging -import time -from abc import ABC, abstractmethod -from collections import defaultdict, deque -from collections.abc import Callable -from dataclasses import asdict -from functools import partial - -import numpy as np -import torch -import tqdm - -from tianshou.data import ( - AsyncCollector, - CollectStats, - EpochStats, - InfoStats, - ReplayBuffer, - SequenceSummaryStats, -) -from tianshou.data.buffer.base import MalformedBufferError -from tianshou.data.collector import BaseCollector, CollectStatsBase -from tianshou.policy import BasePolicy -from tianshou.policy.base import TrainingStats -from tianshou.trainer.utils import gather_info, test_episode -from tianshou.utils import ( - BaseLogger, - LazyLogger, - MovAvg, -) -from tianshou.utils.determinism import TraceLogger, torch_param_hash -from tianshou.utils.logging import set_numerical_fields_to_precision -from tianshou.utils.torch_utils import policy_within_training_step - -log = logging.getLogger(__name__) - - -class BaseTrainer(ABC): - """An iterator base class for trainers. - - Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results - on every epoch. - - :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. - :param batch_size: the batch size of sample data, which is going to feed in - the policy network. If None, will use the whole buffer in each gradient step. - :param train_collector: the collector used for training. - :param test_collector: the collector used for testing. If it's None, - then no testing will be performed. - :param buffer: the replay buffer used for off-policy algorithms or for pre-training. - If a policy overrides the ``process_buffer`` method, the replay buffer will - be pre-processed before training. - :param max_epoch: the maximum number of epochs for training. The training - process might be finished before reaching ``max_epoch`` if ``stop_fn`` - is set. - :param step_per_epoch: the number of transitions collected per epoch. - :param repeat_per_collect: the number of repeat time for policy learning, - for example, set it to 2 means the policy needs to learn each given batch - data twice. Only used in on-policy algorithms - :param episode_per_test: the number of episodes for one policy evaluation. - :param update_per_step: only used in off-policy algorithms. - How many gradient steps to perform per step in the environment - (i.e., per sample added to the buffer). - :param step_per_collect: the number of transitions the collector would - collect before the network update, i.e., trainer will collect - "step_per_collect" transitions and do some policy network update repeatedly - in each epoch. - :param episode_per_collect: the number of episodes the collector would - collect before the network update, i.e., trainer will collect - "episode_per_collect" episodes and do some policy network update repeatedly - in each epoch. - :param train_fn: a hook called at the beginning of training in each - epoch. It can be used to perform custom additional operations, with the - signature ``f(num_epoch: int, step_idx: int) -> None``. - :param test_fn: a hook called at the beginning of testing in each - epoch. It can be used to perform custom additional operations, with the - signature ``f(num_epoch: int, step_idx: int) -> None``. - :param compute_score_fn: Calculate the test batch performance score to - determine whether it is the best model, the mean reward will be used as score if not provided. - :param save_best_fn: a hook called when the undiscounted average mean - reward in evaluation phase gets better, with the signature - ``f(policy: BasePolicy) -> None``. - :param save_checkpoint_fn: a function to save training process and - return the saved checkpoint path, with the signature ``f(epoch: int, - env_step: int, gradient_step: int) -> str``; you can save whatever you want. - :param resume_from_log: resume env_step/gradient_step and other metadata - from existing tensorboard log. - :param stop_fn: a function with signature ``f(mean_rewards: float) -> - bool``, receives the average undiscounted returns of the testing result, - returns a boolean which indicates whether reaching the goal. - :param reward_metric: a function with signature - ``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray - with shape (num_episode,)``, used in multi-agent RL. We need to return a - single scalar for each episode's result to monitor training in the - multi-agent RL setting. This function specifies what is the desired metric, - e.g., the reward of agent 1 or the average reward over all agents. - :param logger: A logger that logs statistics during - training/testing/updating. To not log anything, keep the default logger. - :param verbose: whether to print status information to stdout. - If set to False, status information will still be logged (provided that - logging is enabled via the `logging` module). - :param show_progress: whether to display a progress bar when training. - :param test_in_train: whether to test in the training phase. - """ - - __doc__: str - - @staticmethod - def gen_doc(learning_type: str) -> str: - """Document string for subclass trainer.""" - step_means = f'The "step" in {learning_type} trainer means ' - if learning_type != "offline": - step_means += "an environment step (a.k.a. transition)." - else: # offline - step_means += "a gradient step." - - trainer_name = learning_type.capitalize() + "Trainer" - - return f"""An iterator class for {learning_type} trainer procedure. - - Returns an iterator that yields a 3-tuple (epoch, stats, info) of - train results on every epoch. - - {step_means} - - Example usage: - - :: - - trainer = {trainer_name}(...) - for epoch, epoch_stat, info in trainer: - print("Epoch:", epoch) - print(epoch_stat) - print(info) - do_something_with_policy() - query_something_about_policy() - make_a_plot_with(epoch_stat) - display(info) - - - epoch int: the epoch number - - epoch_stat dict: a large collection of metrics of the current epoch - - info dict: result returned from :func:`~tianshou.trainer.gather_info` - - You can even iterate on several trainers at the same time: - - :: - - trainer1 = {trainer_name}(...) - trainer2 = {trainer_name}(...) - for result1, result2, ... in zip(trainer1, trainer2, ...): - compare_results(result1, result2, ...) - """ - - def __init__( - self, - policy: BasePolicy, - max_epoch: int, - batch_size: int | None, - train_collector: BaseCollector | None = None, - test_collector: BaseCollector | None = None, - buffer: ReplayBuffer | None = None, - step_per_epoch: int | None = None, - repeat_per_collect: int | None = None, - episode_per_test: int | None = None, - update_per_step: float = 1.0, - step_per_collect: int | None = None, - episode_per_collect: int | None = None, - train_fn: Callable[[int, int], None] | None = None, - test_fn: Callable[[int, int | None], None] | None = None, - stop_fn: Callable[[float], bool] | None = None, - compute_score_fn: Callable[[CollectStats], float] | None = None, - save_best_fn: Callable[[BasePolicy], None] | None = None, - save_checkpoint_fn: Callable[[int, int, int], str] | None = None, - resume_from_log: bool = False, - reward_metric: Callable[[np.ndarray], np.ndarray] | None = None, - logger: BaseLogger | None = None, - verbose: bool = True, - show_progress: bool = True, - test_in_train: bool = True, - ): - logger = logger or LazyLogger() - self.policy = policy - - if buffer is not None: - buffer = policy.process_buffer(buffer) - self.buffer = buffer - - self.train_collector = train_collector - self.test_collector = test_collector - - self.logger = logger - self.start_time = time.time() - self.stat: defaultdict[str, MovAvg] = defaultdict(MovAvg) - self.best_score = 0.0 - self.best_reward = 0.0 - self.best_reward_std = 0.0 - self.start_epoch = 0 - # This is only used for logging but creeps into the implementations - # of the trainers. I believe it would be better to remove - self._gradient_step = 0 - self.env_step = 0 - self.env_episode = 0 - self.policy_update_time = 0.0 - self.max_epoch = max_epoch - assert ( - step_per_epoch is not None - ), "The trainer requires step_per_epoch to be set, sorry for the wrong type hint" - self.step_per_epoch: int = step_per_epoch - - # either on of these two - self.step_per_collect = step_per_collect - self.episode_per_collect = episode_per_collect - - self.update_per_step = update_per_step - self.repeat_per_collect = repeat_per_collect - - self.episode_per_test = episode_per_test - - self.batch_size = batch_size - - self.train_fn = train_fn - self.test_fn = test_fn - self.stop_fn = stop_fn - self.compute_score_fn: Callable[[CollectStats], float] - if compute_score_fn is None: - - def compute_score_fn(stat: CollectStats) -> float: - assert stat.returns_stat is not None # for mypy - return stat.returns_stat.mean - - self.compute_score_fn = compute_score_fn - self.save_best_fn = save_best_fn - self.save_checkpoint_fn = save_checkpoint_fn - - self.reward_metric = reward_metric - self.verbose = verbose - self.show_progress = show_progress - self.test_in_train = test_in_train - self.resume_from_log = resume_from_log - - self.is_run = False - self.last_rew, self.last_len = 0.0, 0.0 - - self.epoch = self.start_epoch - self.best_epoch = self.start_epoch - self.stop_fn_flag = False - self.iter_num = 0 - - @property - def _pbar(self) -> type[tqdm.tqdm]: - """Use as context manager or iterator, i.e., `with self._pbar(...) as t:` or `for _ in self._pbar(...):`.""" - return partial( - tqdm.tqdm, - dynamic_ncols=True, - ascii=True, - disable=not self.show_progress, - ) # type: ignore[return-value] - - def _reset_collectors(self, reset_buffer: bool = False) -> None: - if self.train_collector is not None: - self.train_collector.reset(reset_buffer=reset_buffer) - if self.test_collector is not None: - self.test_collector.reset(reset_buffer=reset_buffer) - - def reset(self, reset_collectors: bool = True, reset_buffer: bool = False) -> None: - """Initialize or reset the instance to yield a new iterator from zero.""" - TraceLogger.log(log, lambda: "Trainer reset") - self.is_run = False - self.env_step = 0 - if self.resume_from_log: - ( - self.start_epoch, - self.env_step, - self._gradient_step, - ) = self.logger.restore_data() - - self.last_rew, self.last_len = 0.0, 0.0 - self.start_time = time.time() - - if reset_collectors: - self._reset_collectors(reset_buffer=reset_buffer) - - if self.train_collector is not None and ( - self.train_collector.policy != self.policy or self.test_collector is None - ): - self.test_in_train = False - - if self.test_collector is not None: - assert self.episode_per_test is not None - assert not isinstance(self.test_collector, AsyncCollector) # Issue 700 - test_result = test_episode( - self.test_collector, - self.test_fn, - self.start_epoch, - self.episode_per_test, - self.logger, - self.env_step, - self.reward_metric, - ) - assert test_result.returns_stat is not None # for mypy - self.best_epoch = self.start_epoch - self.best_reward, self.best_reward_std = ( - test_result.returns_stat.mean, - test_result.returns_stat.std, - ) - self.best_score = self.compute_score_fn(test_result) - if self.save_best_fn: - self.save_best_fn(self.policy) - - self.epoch = self.start_epoch - self.stop_fn_flag = False - self.iter_num = 0 - - self._log_params(self.policy) - - def _log_params(self, module: torch.nn.Module) -> None: - """Logs the parameters of the module to the trace logger by subcomponent (if the trace logger is enabled).""" - from tianshou.utils.net.common import ActorCritic - - if not TraceLogger.is_enabled: - return - - def module_has_params(m: torch.nn.Module) -> bool: - return any(p.requires_grad for p in m.parameters()) - - relevant_modules = {} - - def gather_modules(m: torch.nn.Module) -> None: - for name, submodule in m.named_children(): - if isinstance(submodule, ActorCritic): - gather_modules(submodule) - else: - if module_has_params(submodule): - relevant_modules[name] = submodule - - gather_modules(module) - - for name, module in sorted(relevant_modules.items()): - TraceLogger.log( - log, - lambda: f"Params[{name}]: {torch_param_hash(module)}", - ) - - def __iter__(self): # type: ignore - return self - - def __next__(self) -> EpochStats: - """Perform one epoch (both train and eval).""" - self.epoch += 1 - self.iter_num += 1 - - if self.iter_num > 1: - # iterator exhaustion check - if self.epoch > self.max_epoch: - raise StopIteration - - # exit flag 1, when stop_fn succeeds in train_step or test_step - if self.stop_fn_flag: - raise StopIteration - - # perform n step_per_epoch - steps_done_in_this_epoch = 0 - with self._pbar(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", position=1) as t: - TraceLogger.log(log, lambda: f"Epoch #{self.epoch} start") - collect_stats: CollectStatsBase - while steps_done_in_this_epoch < self.step_per_epoch and not self.stop_fn_flag: - TraceLogger.log(log, lambda: "Training step") - collect_stats, training_stats, self.stop_fn_flag = self.training_step() - TraceLogger.log( - log, - lambda: f"Training step complete: stats={training_stats.get_loss_stats_dict() if training_stats else None}", - ) - self._log_params(self.policy) - - if isinstance(collect_stats, CollectStats): - pbar_data_dict = { - "env_step": str(self.env_step), - "env_episode": str(self.env_episode), - "rew": f"{self.last_rew:.2f}", - "len": str(int(self.last_len)), - "n/ep": str(collect_stats.n_collected_episodes), - "n/st": str(collect_stats.n_collected_steps), - } - - # t might be disabled, we track the steps manually - t.update(collect_stats.n_collected_steps) - steps_done_in_this_epoch += collect_stats.n_collected_steps - - if self.stop_fn_flag: - t.set_postfix(**pbar_data_dict) - else: - # TODO: there is no iteration happening here, it's the offline case - # Code should be restructured! - pbar_data_dict = {} - assert self.buffer, "No train_collector or buffer specified" - collect_stats = CollectStatsBase( - n_collected_steps=len(self.buffer), - ) - - # t might be disabled, we track the steps manually - t.update() - steps_done_in_this_epoch += 1 - - pbar_data_dict = set_numerical_fields_to_precision(pbar_data_dict) - pbar_data_dict["gradient_step"] = str(self._gradient_step) - t.set_postfix(**pbar_data_dict) - - if self.stop_fn_flag: - break - - if steps_done_in_this_epoch <= self.step_per_epoch and not self.stop_fn_flag: - # t might be disabled, we track the steps manually - t.update() - steps_done_in_this_epoch += 1 - - # for offline RL - if self.train_collector is None: - assert self.buffer is not None - batch_size = self.batch_size or len(self.buffer) - self.env_step = self._gradient_step * batch_size - - test_stat = None - if not self.stop_fn_flag: - self.logger.save_data( - self.epoch, - self.env_step, - self._gradient_step, - self.save_checkpoint_fn, - ) - # test - if self.test_collector is not None: - test_stat, self.stop_fn_flag = self.test_step() - - info_stat = gather_info( - start_time=self.start_time, - policy_update_time=self.policy_update_time, - gradient_step=self._gradient_step, - best_score=self.best_score, - best_reward=self.best_reward, - best_reward_std=self.best_reward_std, - train_collector=self.train_collector, - test_collector=self.test_collector, - ) - - self.logger.log_info_data(asdict(info_stat), self.epoch) - - # in case trainer is used with run(), epoch_stat will not be returned - return EpochStats( - epoch=self.epoch, - train_collect_stat=collect_stats, - test_collect_stat=test_stat, - training_stat=training_stats, - info_stat=info_stat, - ) - - def test_step(self) -> tuple[CollectStats, bool]: - """Perform one testing step.""" - assert self.episode_per_test is not None - assert self.test_collector is not None - stop_fn_flag = False - test_stat = test_episode( - self.test_collector, - self.test_fn, - self.epoch, - self.episode_per_test, - self.logger, - self.env_step, - self.reward_metric, - ) - assert test_stat.returns_stat is not None # for mypy - rew, rew_std = test_stat.returns_stat.mean, test_stat.returns_stat.std - score = self.compute_score_fn(test_stat) - if self.best_epoch < 0 or self.best_score < score: - self.best_score = score - self.best_epoch = self.epoch - self.best_reward = float(rew) - self.best_reward_std = rew_std - if self.save_best_fn: - self.save_best_fn(self.policy) - cur_info, best_info = "", "" - if score != rew: - cur_info, best_info = f", score: {score: .6f}", f", best_score: {self.best_score:.6f}" - log_msg = ( - f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f},{cur_info}" - f" best_reward: {self.best_reward:.6f} ± " - f"{self.best_reward_std:.6f}{best_info} in #{self.best_epoch}" - ) - log.info(log_msg) - if self.verbose: - print(log_msg, flush=True) - - if self.stop_fn and self.stop_fn(self.best_reward): - stop_fn_flag = True - - return test_stat, stop_fn_flag - - def training_step(self) -> tuple[CollectStatsBase, TrainingStats | None, bool]: - """Perform one training iteration. - - A training iteration includes collecting data (for online RL), determining whether to stop training, - and performing a policy update if the training iteration should continue. - - :return: the iteration's collect stats, training stats, and a flag indicating whether to stop training. - If training is to be stopped, no gradient steps will be performed and the training stats will be `None`. - """ - with policy_within_training_step(self.policy): - should_stop_training = False - - collect_stats: CollectStatsBase | CollectStats - if self.train_collector is not None: - collect_stats = self._collect_training_data() - should_stop_training = self._update_best_reward_and_return_should_stop_training( - collect_stats, - ) - else: - assert self.buffer is not None, "Either train_collector or buffer must be provided." - collect_stats = CollectStatsBase( - n_collected_episodes=len(self.buffer), - ) - - if not should_stop_training: - training_stats = self.policy_update_fn(collect_stats) - else: - training_stats = None - - return collect_stats, training_stats, should_stop_training - - def _collect_training_data(self) -> CollectStats: - """Performs training data collection. - - :return: the data collection stats - """ - assert self.episode_per_test is not None - assert self.train_collector is not None - if self.train_fn: - self.train_fn(self.epoch, self.env_step) - collect_stats = self.train_collector.collect( - n_step=self.step_per_collect, - n_episode=self.episode_per_collect, - ) - TraceLogger.log( - log, - lambda: f"Collected {collect_stats.n_collected_steps} steps, {collect_stats.n_collected_episodes} episodes", - ) - - if self.train_collector.raise_on_nan_in_buffer and self.train_collector.buffer.hasnull(): - from tianshou.data.collector import EpisodeRolloutHook - from tianshou.env import DummyVectorEnv - - raise MalformedBufferError( - f"Encountered NaNs in buffer after {self.env_step} steps." - f"Such errors are usually caused by either a bug in the environment or by " - f"problematic implementations {EpisodeRolloutHook.__class__.__name__}. " - f"For debugging such issues it is recommended to run the training in a single process, " - f"e.g., by using {DummyVectorEnv.__class__.__name__}.", - ) - - self.env_step += collect_stats.n_collected_steps - self.env_episode += collect_stats.n_collected_episodes - - if collect_stats.n_collected_episodes > 0: - assert collect_stats.returns_stat is not None # for mypy - assert collect_stats.lens_stat is not None # for mypy - self.last_rew = collect_stats.returns_stat.mean - self.last_len = collect_stats.lens_stat.mean - if self.reward_metric: # TODO: move inside collector - rew = self.reward_metric(collect_stats.returns) - collect_stats.returns = rew - collect_stats.returns_stat = SequenceSummaryStats.from_sequence(rew) - - self.logger.log_train_data(asdict(collect_stats), self.env_step) - return collect_stats - - # TODO (maybe): separate out side effect, simplify name? - def _update_best_reward_and_return_should_stop_training( - self, - collect_stats: CollectStats, - ) -> bool: - """If `test_in_train` and `stop_fn` are set, will compute the `stop_fn` on the mean return of the training data. - Then, if the `stop_fn` is True there, will collect test data also compute the stop_fn of the mean return - on it. - Finally, if the latter is also True, will return True. - - **NOTE:** has a side effect of updating the best reward and corresponding std. - - - :param collect_stats: the data collection stats - :return: flag indicating whether to stop training - """ - should_stop_training = False - - # Because we need to evaluate the policy, we need to temporarily leave the "is_training_step" semantics - with policy_within_training_step(self.policy, enabled=False): - if ( - collect_stats.n_collected_episodes > 0 - and self.test_in_train - and self.stop_fn - and self.stop_fn(collect_stats.returns_stat.mean) # type: ignore - ): - assert self.test_collector is not None - assert self.episode_per_test is not None and self.episode_per_test > 0 - test_result = test_episode( - self.test_collector, - self.test_fn, - self.epoch, - self.episode_per_test, - self.logger, - self.env_step, - ) - assert test_result.returns_stat is not None # for mypy - if self.stop_fn(test_result.returns_stat.mean): - should_stop_training = True - self.best_reward = test_result.returns_stat.mean - self.best_reward_std = test_result.returns_stat.std - self.best_score = self.compute_score_fn(test_result) - - return should_stop_training - - # TODO: move moving average computation and logging into its own logger - # TODO: maybe think about a command line logger instead of always printing data dict - def _update_moving_avg_stats_and_log_update_data(self, update_stat: TrainingStats) -> None: - """Log losses, update moving average stats, and also modify the smoothed_loss in update_stat.""" - cur_losses_dict = update_stat.get_loss_stats_dict() - update_stat.smoothed_loss = self._update_moving_avg_stats_and_get_averaged_data( - cur_losses_dict, - ) - self.logger.log_update_data(asdict(update_stat), self._gradient_step) - - # TODO: seems convoluted, there should be a better way of dealing with the moving average stats - def _update_moving_avg_stats_and_get_averaged_data( - self, - data: dict[str, float], - ) -> dict[str, float]: - """Add entries to the moving average object in the trainer and retrieve the averaged results. - - :param data: any entries to be tracked in the moving average object. - :return: A dictionary containing the averaged values of the tracked entries. - - """ - smoothed_data = {} - for key, loss_item in data.items(): - self.stat[key].add(loss_item) - smoothed_data[key] = self.stat[key].get() - return smoothed_data - - @abstractmethod - def policy_update_fn( - self, - collect_stats: CollectStatsBase, - ) -> TrainingStats: - """Policy update function for different trainer implementation. - - :param collect_stats: provides info about the most recent collection. In the offline case, this will contain - stats of the whole dataset - """ - - def run(self, reset_collectors: bool = True, reset_buffer: bool = False) -> InfoStats: - """Consume iterator. - - See itertools - recipes. Use functions that consume iterators at C speed - (feed the entire iterator into a zero-length deque). - - :param reset_collectors: whether to reset the collectors prior to starting the training process. - Specifically, this will reset the environments in the collectors (starting new episodes), - and the statistics stored in the collector. Whether the contained buffers will be reset/cleared - is determined by the `reset_buffer` parameter. - :param reset_collector_buffers: whether, for the case where the collectors are reset, to reset/clear the - contained buffers as well. - This has no effect if `reset_collectors` is False. - """ - self.reset(reset_collectors=reset_collectors, reset_buffer=reset_buffer) - try: - self.is_run = True - deque(self, maxlen=0) # feed the entire iterator into a zero-length deque - info = gather_info( - start_time=self.start_time, - policy_update_time=self.policy_update_time, - gradient_step=self._gradient_step, - best_score=self.best_score, - best_reward=self.best_reward, - best_reward_std=self.best_reward_std, - train_collector=self.train_collector, - test_collector=self.test_collector, - ) - finally: - self.is_run = False - - return info - - def _sample_and_update(self, buffer: ReplayBuffer) -> TrainingStats: - """Sample a mini-batch, perform one gradient step, and update the _gradient_step counter.""" - self._gradient_step += 1 - # Note: since sample_size=batch_size, this will perform - # exactly one gradient step. This is why we don't need to calculate the - # number of gradient steps, like in the on-policy case. - update_stat = self.policy.update(sample_size=self.batch_size, buffer=buffer) - self._update_moving_avg_stats_and_log_update_data(update_stat) - return update_stat - - -class OfflineTrainer(BaseTrainer): - """Offline trainer, samples mini-batches from buffer and passes them to update. - - Uses a buffer directly and usually does not have a collector. - """ - - # for mypy - assert isinstance(BaseTrainer.__doc__, str) - __doc__ += BaseTrainer.gen_doc("offline") + "\n".join(BaseTrainer.__doc__.split("\n")[1:]) - - def policy_update_fn( - self, - collect_stats: CollectStatsBase | None = None, - ) -> TrainingStats: - """Perform one off-line policy update.""" - assert self.buffer - update_stat = self._sample_and_update(self.buffer) - # logging - self.policy_update_time += update_stat.train_time - return update_stat - - -class OffpolicyTrainer(BaseTrainer): - """Offpolicy trainer, samples mini-batches from buffer and passes them to update. - - Note that with this trainer, it is expected that the policy's `learn` method - does not perform additional mini-batching but just updates params from the received - mini-batch. - """ - - # for mypy - assert isinstance(BaseTrainer.__doc__, str) - __doc__ += BaseTrainer.gen_doc("offpolicy") + "\n".join(BaseTrainer.__doc__.split("\n")[1:]) - - def policy_update_fn( - self, - # TODO: this is the only implementation where collect_stats is actually needed. Maybe change interface? - collect_stats: CollectStatsBase, - ) -> TrainingStats: - """Perform `update_per_step * n_collected_steps` gradient steps by sampling mini-batches from the buffer. - - :param collect_stats: the :class:`~TrainingStats` instance returned by the last gradient step. Some values - in it will be replaced by their moving averages. - """ - assert self.train_collector is not None - n_collected_steps = collect_stats.n_collected_steps - n_gradient_steps = round(self.update_per_step * n_collected_steps) - if n_gradient_steps == 0: - raise ValueError( - f"n_gradient_steps is 0, n_collected_steps={n_collected_steps}, " - f"update_per_step={self.update_per_step}", - ) - - for _ in self._pbar( - range(n_gradient_steps), - desc="Offpolicy gradient update", - position=0, - leave=False, - ): - update_stat = self._sample_and_update(self.train_collector.buffer) - self.policy_update_time += update_stat.train_time - # TODO: only the last update_stat is returned, should be improved - return update_stat - - -class OnpolicyTrainer(BaseTrainer): - """On-policy trainer, passes the entire buffer to .update and resets it after. - - Note that it is expected that the learn method of a policy will perform - batching when using this trainer. - """ - - # for mypy - assert isinstance(BaseTrainer.__doc__, str) - __doc__ = BaseTrainer.gen_doc("onpolicy") + "\n".join(BaseTrainer.__doc__.split("\n")[1:]) - - def policy_update_fn( - self, - result: CollectStatsBase | None = None, - ) -> TrainingStats: - """Perform one on-policy update by passing the entire buffer to the policy's update method.""" - assert self.train_collector is not None - # TODO: add logging like in off-policy. Iteration over minibatches currently happens in the learn implementation of - # on-policy algos like PG or PPO - log.info( - f"Performing on-policy update on buffer of length {len(self.train_collector.buffer)}", - ) - training_stat = self.policy.update( - sample_size=0, - buffer=self.train_collector.buffer, - # Note: sample_size is None, so the whole buffer is used for the update. - # The kwargs are in the end passed to the .learn method, which uses - # batch_size to iterate through the buffer in mini-batches - # Off-policy algos typically don't use the batch_size kwarg at all - batch_size=self.batch_size, - repeat=self.repeat_per_collect, - ) - - # just for logging, no functional role - self.policy_update_time += training_stat.train_time - # TODO: remove the gradient step counting in trainers? Doesn't seem like - # it's important and it adds complexity - self._gradient_step += 1 - if self.batch_size is None: - self._gradient_step += 1 - elif self.batch_size > 0: - self._gradient_step += int((len(self.train_collector.buffer) - 0.1) // self.batch_size) - - # Note 1: this is the main difference to the off-policy trainer! - # The second difference is that batches of data are sampled without replacement - # during training, whereas in off-policy or offline training, the batches are - # sampled with replacement (and potentially custom prioritization). - # Note 2: in the policy-update we modify the buffer, which is not very clean. - # currently the modification will erase previous samples but keep things like - # _ep_rew and _ep_len. This means that such quantities can no longer be computed - # from samples still contained in the buffer, which is also not clean - # TODO: improve this situation - self.train_collector.reset_buffer(keep_statistics=True) - - # The step is the number of mini-batches used for the update, so essentially - self._update_moving_avg_stats_and_log_update_data(training_stat) - - return training_stat diff --git a/tianshou/trainer/trainer.py b/tianshou/trainer/trainer.py new file mode 100644 index 000000000..adec7e066 --- /dev/null +++ b/tianshou/trainer/trainer.py @@ -0,0 +1,1109 @@ +""" +This module contains Tianshou's trainer classes, which orchestrate the training and call upon an RL algorithm's +specific network updating logic to perform the actual gradient updates. + +Training is structured as follows (hierarchical glossary): + +- **epoch**: the outermost iteration level of the training loop. Each epoch consists of a number of training steps + and one test step (see :attr:`TrainerParams.max_epoch` for a detailed explanation). + + - **training step**: a training step performs the steps necessary in order to apply a single update of the neural + network components as defined by the underlying RL algorithm (:class:`Algorithm`). This involves the following sub-steps: + + - for online learning algorithms: + + - **collection step**: collecting environment steps/transitions to be used for training. + + - (Potentially) a test step (see below) if the early stopping criterion is satisfied based on + the data collected (see :attr:`OnlineTrainerParams.test_in_train`). + + - **update step**: applying the actual gradient updates using the RL algorithm. + The update is based on either: + + - data from only the preceding collection step (on-policy learning), + - data from the collection step and previously collected data (off-policy learning), or + - data from the user-provided replay buffer (offline learning). + + For offline learning algorithms, a training step is thus equivalent to an update step. + + - **test step**: collects test episodes from dedicated test environments which are used to evaluate the performance + of the policy. Optionally, the performance result can be used to determine whether training shall stop early + (see :attr:`TrainerParams.stop_fn`). +""" + +import logging +import time +from abc import ABC, abstractmethod +from collections import defaultdict +from collections.abc import Callable +from dataclasses import asdict, dataclass +from functools import partial +from typing import Generic, TypeVar + +import numpy as np +import torch +import tqdm +from sensai.util.helper import count_none +from sensai.util.string import ToStringMixin + +from tianshou.algorithm.algorithm_base import ( + Algorithm, + OfflineAlgorithm, + OffPolicyAlgorithm, + OnPolicyAlgorithm, + TrainingStats, +) +from tianshou.data import ( + AsyncCollector, + CollectStats, + EpochStats, + InfoStats, + ReplayBuffer, + SequenceSummaryStats, + TimingStats, +) +from tianshou.data.buffer.buffer_base import MalformedBufferError +from tianshou.data.collector import BaseCollector, CollectStatsBase +from tianshou.utils import ( + BaseLogger, + LazyLogger, + MovAvg, +) +from tianshou.utils.determinism import TraceLogger, torch_param_hash +from tianshou.utils.logging import set_numerical_fields_to_precision +from tianshou.utils.torch_utils import policy_within_training_step + +log = logging.getLogger(__name__) + + +@dataclass(kw_only=True) +class TrainerParams(ToStringMixin): + max_epochs: int = 100 + """ + the (maximum) number of epochs to run training for. An **epoch** is the outermost iteration level and each + epoch consists of a number of training steps and one test step, where each training step + + * [for the online case] collects environment steps/transitions (**collection step**), + adding them to the (replay) buffer (see :attr:`collection_step_num_env_steps` and :attr:`collection_step_num_episodes`) + * performs an **update step** via the RL algorithm being used, which can involve + one or more actual gradient updates, depending on the algorithm + + and the test step collects :attr:`num_episodes_per_test` test episodes in order to evaluate + agent performance. + + Training may be stopped early if the stop criterion is met (see :attr:`stop_fn`). + + For online training, the number of training steps in each epoch is indirectly determined by + :attr:`epoch_num_steps`: As many training steps will be performed as are required in + order to reach :attr:`epoch_num_steps` total steps in the training environments. + Specifically, if the number of transitions collected per step is `c` (see + :attr:`collection_step_num_env_steps`) and :attr:`epoch_num_steps` is set to `s`, then the number + of training steps per epoch is `ceil(s / c)`. + Therefore, if `max_epochs = e`, the total number of environment steps taken during training + can be computed as `e * ceil(s / c) * c`. + + For offline training, the number of training steps per epoch is equal to :attr:`epoch_num_steps`. + """ + + epoch_num_steps: int = 30000 + """ + For an online algorithm, this is the total number of environment steps to be collected per epoch, and, + for an offline algorithm, it is the total number of training steps to take per epoch. + See :attr:`max_epochs` for an explanation of epoch semantics. + """ + + test_collector: BaseCollector | None = None + """ + the collector to use for test episode collection (test steps); if None, perform no test steps. + """ + + test_step_num_episodes: int = 1 + """the number of episodes to collect in each test step. + """ + + train_fn: Callable[[int, int], None] | None = None + """ + a callback function which is called at the beginning of each training step. + It can be used to perform custom additional operations, with the + signature ``f(num_epoch: int, step_idx: int) -> None``. + """ + + test_fn: Callable[[int, int | None], None] | None = None + """ + a callback function to be called at the beginning of each test step. + It can be used to perform custom additional operations, with the + signature ``f(num_epoch: int, step_idx: int) -> None``. + """ + + stop_fn: Callable[[float], bool] | None = None + """ + a callback function with signature ``f(score: float) -> bool``, which + is used to decide whether training shall be stopped early based on the score + achieved in a test step. + The score it receives is computed by the :attr:`compute_score_fn` callback + (which defaults to the mean reward if the function is not provided). + + Requires test steps to be activated and thus :attr:`test_collector` to be set. + + Note: The function is also used when :attr:`test_in_train` is activated (see docstring). + """ + + compute_score_fn: Callable[[CollectStats], float] | None = None + """ + the callback function to use in order to compute the test batch performance score, which is used to + determine what the best model is (score is maximized); if None, use the mean reward. + """ + + save_best_fn: Callable[["Algorithm"], None] | None = None + """ + the callback function to call in order to save the best model whenever a new best score (see :attr:`compute_score_fn`) + is achieved in a test step. It should have the signature ``f(algorithm: Algorithm) -> None``. + """ + + save_checkpoint_fn: Callable[[int, int, int], str] | None = None + """ + the callback function with which to save checkpoint data after each training step, + which can save whatever data is desired to a file and returns the path of the file. + Signature: ``f(epoch: int, env_step: int, gradient_step: int) -> str``. + """ + + resume_from_log: bool = False + """ + whether to load env_step/gradient_step and other metadata from the existing log, + which is given in :attr:`logger`. + """ + + multi_agent_return_reduction: Callable[[np.ndarray], np.ndarray] | None = None + """ + a function with signature + ``f(returns: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, + which is used in multi-agent RL. We need to return a single scalar for each episode's return + to monitor training in the multi-agent RL setting. This function specifies what is the desired metric, + e.g., the return achieved by agent 1 or the average return over all agents. + """ + + logger: BaseLogger | None = None + """ + the logger with which to log statistics during training/testing/updating. To not log anything, use None. + + Relevant step types for logger update intervals: + * `update_interval`: update step + * `train_interval`: env step + * `test_interval`: env step + """ + + verbose: bool = True + """ + whether to print status information to stdout. + If set to False, status information will still be logged (provided that logging is enabled via the + `logging` Python module). + """ + + show_progress: bool = True + """ + whether to display a progress bars during training. + """ + + def __post_init__(self) -> None: + if self.resume_from_log and self.logger is None: + raise ValueError("Cannot resume from log without a logger being provided") + if self.test_collector is None: + if self.stop_fn is not None: + raise ValueError( + "stop_fn cannot be activated without test steps being enabled (test_collector being set)" + ) + if self.test_fn is not None: + raise ValueError( + "test_fn is set while test steps are disabled (test_collector is None)" + ) + if self.save_best_fn is not None: + raise ValueError( + "save_best_fn is set while test steps are disabled (test_collector is None)" + ) + else: + if self.test_step_num_episodes < 1: + raise ValueError( + "test_step_num_episodes must be positive if test steps are enabled " + "(test_collector not None)" + ) + + +@dataclass(kw_only=True) +class OnlineTrainerParams(TrainerParams): + train_collector: BaseCollector + """ + the collector with which to gather new data for training in each training step + """ + + collection_step_num_env_steps: int | None = 2048 + """ + the number of environment steps/transitions to collect in each collection step before the + network update within each training step. + + This is mutually exclusive with :attr:`collection_step_num_episodes`, and one of the two must be set. + + Note that the exact number can be reached only if this is a multiple of the number of + training environments being used, as each training environment will produce the same + (non-zero) number of transitions. + Specifically, if this is set to `n` and `m` training environments are used, then the total + number of transitions collected per collection step is `ceil(n / m) * m =: c`. + + See :attr:`max_epochs` for information on the total number of environment steps being + collected during training. + """ + + collection_step_num_episodes: int | None = None + """ + the number of episodes to collect in each collection step before the network update within + each training step. If this is set, the number of environment steps collected in each + collection step is the sum of the lengths of the episodes collected. + + This is mutually exclusive with :attr:`collection_step_num_env_steps`, and one of the two must be set. + """ + + test_in_train: bool = False + """ + Whether to apply a test step within a training step depending on the early stopping criterion + (given by :attr:`stop_fn`) being satisfied based on the data collected within the training step. + Specifically, after each collect step, we check whether the early stopping criterion (:attr:`stop_fn`) + would be satisfied by data we collected (provided that at least one episode was indeed completed, such + that we can evaluate returns, etc.). If the criterion is satisfied, we perform a full test step + (collecting :attr:`test_step_num_episodes` episodes in order to evaluate performance), and if the early + stopping criterion is also satisfied based on the test data, we stop training early. + """ + + def __post_init__(self) -> None: + super().__post_init__() + if count_none(self.collection_step_num_env_steps, self.collection_step_num_episodes) != 1: + raise ValueError( + "Exactly one of {collection_step_num_env_steps, collection_step_num_episodes} must be set" + ) + if self.test_in_train and (self.test_collector is None or self.stop_fn is None): + raise ValueError("test_in_train requires test_collector and stop_fn to be set") + + +@dataclass(kw_only=True) +class OnPolicyTrainerParams(OnlineTrainerParams): + batch_size: int | None = 64 + """ + Use mini-batches of this size for gradient updates (causing the gradient to be less accurate, + a form of regularization). + Set ``batch_size=None`` for the full buffer that was collected within the training step to be + used for the gradient update (no mini-batching). + """ + + update_step_num_repetitions: int = 1 + """ + controls, within one update step of an on-policy algorithm, the number of times + the full collected data is applied for gradient updates, i.e. if the parameter is + 5, then the collected data shall be used five times to update the policy within the same + update step. + """ + + +@dataclass(kw_only=True) +class OffPolicyTrainerParams(OnlineTrainerParams): + batch_size: int = 64 + """ + the the number of environment steps/transitions to sample from the buffer for a gradient update. + """ + + update_step_num_gradient_steps_per_sample: float = 1.0 + """ + the number of gradient steps to perform per sample collected (see :attr:`collection_step_num_env_steps`). + Specifically, if this is set to `u` and the number of samples collected in the preceding + collection step is `n`, then `round(u * n)` gradient steps will be performed. + """ + + +@dataclass(kw_only=True) +class OfflineTrainerParams(TrainerParams): + buffer: ReplayBuffer + """ + the replay buffer with environment steps to use as training data for offline learning. + This buffer will be pre-processed using the RL algorithm's pre-processing + function (if any) before training. + """ + + batch_size: int = 64 + """ + the number of environment steps/transitions to sample from the buffer for a gradient update. + """ + + +TTrainerParams = TypeVar("TTrainerParams", bound=TrainerParams) +TOnlineTrainerParams = TypeVar("TOnlineTrainerParams", bound=OnlineTrainerParams) +TAlgorithm = TypeVar("TAlgorithm", bound=Algorithm) + + +class Trainer(Generic[TAlgorithm, TTrainerParams], ABC): + """ + Base class for trainers in Tianshou, which orchestrate the training process and call upon an RL algorithm's + specific network updating logic to perform the actual gradient updates. + + The base class already implements the fundamental epoch logic and fully implements the test step + logic, which is common to all trainers. The training step logic is left to be implemented by subclasses. + """ + + def __init__( + self, + algorithm: TAlgorithm, + params: TTrainerParams, + ): + self.algorithm = algorithm + self.params = params + + self._logger = params.logger or LazyLogger() + + self._start_time = time.time() + self._stat: defaultdict[str, MovAvg] = defaultdict(MovAvg) + self._start_epoch = 0 + + self._epoch = self._start_epoch + + # initialize stats on the best model found during a test step + # NOTE: The values don't matter, as in the first test step (which is taken in reset() + # at the beginning of the training process), these will all be updated + self._best_score = 0.0 + self._best_reward = 0.0 + self._best_reward_std = 0.0 + self._best_epoch = self._start_epoch + + self._current_update_step = 0 + """ + the current (1-based) update step/training step number (to be incremented before the actual step is taken) + """ + + self._env_step = 0 + """ + the step counter which is used to track progress of the training process. + For online learning (i.e. on-policy and off-policy learning), this is the total number of + environment steps collected, and for offline training, it is the total number of environment + steps that have been sampled from the replay buffer to perform gradient updates. + """ + + self._policy_update_time = 0.0 + + self._compute_score_fn: Callable[[CollectStats], float] = ( + params.compute_score_fn or self._compute_score_fn_default + ) + + self._stop_fn_flag = False + + @staticmethod + def _compute_score_fn_default(stat: CollectStats) -> float: + """ + The default score function, which returns the mean return/reward. + + :param stat: the collection stats + :return: the mean return + """ + assert stat.returns_stat is not None # for mypy + return stat.returns_stat.mean + + @property + def _pbar(self) -> Callable[..., tqdm.tqdm]: + """Use as context manager or iterator, i.e., `with self._pbar(...) as t:` or `for _ in self._pbar(...):`.""" + return partial( + tqdm.tqdm, + dynamic_ncols=True, + ascii=True, + disable=not self.params.show_progress, + ) + + def _reset_collectors(self, reset_buffer: bool = False) -> None: + if self.params.test_collector is not None: + self.params.test_collector.reset(reset_buffer=reset_buffer) + + def reset(self, reset_collectors: bool = True, reset_collector_buffers: bool = False) -> None: + """Initializes the training process. + + :param reset_collectors: whether to reset the collectors prior to starting the training process. + Specifically, this will reset the environments in the collectors (starting new episodes), + and the statistics stored in the collector. Whether the contained buffers will be reset/cleared + is determined by the `reset_buffer` parameter. + :param reset_collector_buffers: whether, for the case where the collectors are reset, to reset/clear the + contained buffers as well. + This has no effect if `reset_collectors` is False. + """ + TraceLogger.log(log, lambda: "Trainer reset") + self._env_step = 0 + self._current_update_step = 0 + + if self.params.resume_from_log: + ( + self._start_epoch, + self._env_step, + self._current_update_step, + ) = self._logger.restore_data() + + self._epoch = self._start_epoch + + self._start_time = time.time() + + if reset_collectors: + self._reset_collectors(reset_buffer=reset_collector_buffers) + + # make an initial test step to determine the initial best model + if self.params.test_collector is not None: + assert self.params.test_step_num_episodes is not None + assert not isinstance(self.params.test_collector, AsyncCollector) # Issue 700 + self._test_step(force_update_best=True, log_msg_prefix="Initial test step") + + self._stop_fn_flag = False + + self._log_params(self.algorithm) + + def _log_params(self, module: torch.nn.Module) -> None: + """Logs the parameters of the module to the trace logger by subcomponent (if the trace logger is enabled).""" + if not TraceLogger.is_enabled: + return + + def module_has_params(m: torch.nn.Module) -> bool: + return any(p.requires_grad for p in m.parameters()) + + relevant_modules = {} + + def gather_modules(m: torch.nn.Module) -> None: + for name, submodule in m.named_children(): + if name == "policy": + gather_modules(submodule) + else: + if module_has_params(submodule): + relevant_modules[name] = submodule + + gather_modules(module) + + for name, module in sorted(relevant_modules.items()): + TraceLogger.log( + log, + lambda: f"Params[{name}]: {torch_param_hash(module)}", + ) + + class _TrainingStepResult(ABC): + @abstractmethod + def get_steps_in_epoch_advancement(self) -> int: + """ + :return: the number of steps that were done within the epoch, where the concrete semantics + of what a step is depend on the type of algorithm. See docstring of `TrainerParams.epoch_num_steps`. + """ + + @abstractmethod + def get_collect_stats(self) -> CollectStats | None: + pass + + @abstractmethod + def get_training_stats(self) -> TrainingStats | None: + pass + + @abstractmethod + def is_training_done(self) -> bool: + """:return: whether the early stopping criterion is satisfied and training shall stop.""" + + @abstractmethod + def get_env_step_advancement(self) -> int: + """ + :return: the number of steps by which to advance the env_step counter in the trainer (see docstring + of trainer attribute). The semantics depend on the type of the algorithm. + """ + + @abstractmethod + def _create_epoch_pbar_data_dict( + self, training_step_result: _TrainingStepResult + ) -> dict[str, str]: + pass + + def _create_info_stats( + self, + ) -> InfoStats: + test_collector = self.params.test_collector + if isinstance(self.params, OnlineTrainerParams): + train_collector = self.params.train_collector + else: + train_collector = None + + duration = max(0.0, time.time() - self._start_time) + test_time = 0.0 + update_speed = 0.0 + train_time_collect = 0.0 + if test_collector is not None: + test_time = test_collector.collect_time + + if train_collector is not None: + train_time_collect = train_collector.collect_time + update_speed = train_collector.collect_step / (duration - test_time) + + timing_stat = TimingStats( + total_time=duration, + train_time=duration - test_time, + train_time_collect=train_time_collect, + train_time_update=self._policy_update_time, + test_time=test_time, + update_speed=update_speed, + ) + + return InfoStats( + update_step=self._current_update_step, + best_score=self._best_score, + best_reward=self._best_reward, + best_reward_std=self._best_reward_std, + train_step=train_collector.collect_step if train_collector is not None else 0, + train_episode=train_collector.collect_episode if train_collector is not None else 0, + test_step=test_collector.collect_step if test_collector is not None else 0, + test_episode=test_collector.collect_episode if test_collector is not None else 0, + timing=timing_stat, + ) + + def execute_epoch(self) -> EpochStats: + self._epoch += 1 + TraceLogger.log(log, lambda: f"Epoch #{self._epoch} start") + + # perform the required number of steps for the epoch (`epoch_num_steps`) + steps_done_in_this_epoch = 0 + train_collect_stats, training_stats = None, None + with self._pbar( + total=self.params.epoch_num_steps, desc=f"Epoch #{self._epoch}", position=1 + ) as t: + while steps_done_in_this_epoch < self.params.epoch_num_steps and not self._stop_fn_flag: + # perform a training step and update progress + TraceLogger.log(log, lambda: "Training step") + self._current_update_step += 1 + training_step_result = self._training_step() + steps_done_in_this_epoch += training_step_result.get_steps_in_epoch_advancement() + t.update(training_step_result.get_steps_in_epoch_advancement()) + self._stop_fn_flag = training_step_result.is_training_done() + self._env_step += training_step_result.get_env_step_advancement() + training_stats = training_step_result.get_training_stats() + TraceLogger.log( + log, + lambda: f"Training step complete: stats={training_stats.get_loss_stats_dict() if training_stats is not None else None}", + ) + self._log_params(self.algorithm) + + collect_stats = training_step_result.get_collect_stats() + if collect_stats is not None: + self._logger.log_train_data(asdict(collect_stats), self._env_step) + + pbar_data_dict = self._create_epoch_pbar_data_dict(training_step_result) + pbar_data_dict = set_numerical_fields_to_precision(pbar_data_dict) + pbar_data_dict["update_step"] = str(self._current_update_step) + t.set_postfix(**pbar_data_dict) + + test_collect_stats = None + if not self._stop_fn_flag: + self._logger.save_data( + self._epoch, + self._env_step, + self._current_update_step, + self.params.save_checkpoint_fn, + ) + + # test step + if self.params.test_collector is not None: + test_collect_stats, self._stop_fn_flag = self._test_step() + + info_stats = self._create_info_stats() + + self._logger.log_info_data(asdict(info_stats), self._epoch) + + return EpochStats( + epoch=self._epoch, + train_collect_stat=train_collect_stats, + test_collect_stat=test_collect_stats, + training_stat=training_stats, + info_stat=info_stats, + ) + + def _should_stop_training_early( + self, *, score: float | None = None, collect_stats: CollectStats | None = None + ) -> bool: + """ + Determine whether, given the early stopping criterion stop_fn, training shall be stopped early + based on the score achieved or the collection stats (from which the score could be computed). + """ + # If no stop criterion is defined, we can never stop training early + if self.params.stop_fn is None: + return False + + if score is None: + if collect_stats is None: + raise ValueError("Must provide collect_stats if score is not given") + + # If no episodes were collected, we have no episode returns and thus cannot compute a score + if collect_stats.n_collected_episodes == 0: + return False + + score = self._compute_score_fn(collect_stats) + + return self.params.stop_fn(score) + + def _collect_test_episodes( + self, + ) -> CollectStats: + assert self.params.test_collector is not None + collector = self.params.test_collector + collector.reset(reset_stats=False) + if self.params.test_fn: + self.params.test_fn(self._epoch, self._env_step) + result = collector.collect(n_episode=self.params.test_step_num_episodes) + if self.params.multi_agent_return_reduction: + rew = self.params.multi_agent_return_reduction(result.returns) + result.returns = rew + result.returns_stat = SequenceSummaryStats.from_sequence(rew) + if self._logger and self._env_step is not None: + assert result.n_collected_episodes > 0 + self._logger.log_test_data(asdict(result), self._env_step) + return result + + def _test_step( + self, force_update_best: bool = False, log_msg_prefix: str | None = None + ) -> tuple[CollectStats, bool]: + """Performs one test step. + + :param log_msg_prefix: a prefix to prepend to the log message, which is to establish the context within + which the test step is being carried out + :param force_update_best: whether to force updating of the best model stats (best score, reward, etc.) + and call the `save_best_fn` callback + """ + assert self.params.test_step_num_episodes is not None + assert self.params.test_collector is not None + + # collect test episodes + test_stat = self._collect_test_episodes() + assert test_stat.returns_stat is not None # for mypy + + # check whether we have a new best score and, if so, update stats and save the model + # (or if forced) + rew, rew_std = test_stat.returns_stat.mean, test_stat.returns_stat.std + score = self._compute_score_fn(test_stat) + if score > self._best_score or force_update_best: + self._best_score = score + self._best_epoch = self._epoch + self._best_reward = float(rew) + self._best_reward_std = rew_std + if self.params.save_best_fn: + self.params.save_best_fn(self.algorithm) + + # log results + cur_info, best_info = "", "" + if score != rew: + cur_info, best_info = f", score: {score: .6f}", f", best_score: {self._best_score:.6f}" + if log_msg_prefix is None: + log_msg_prefix = f"Epoch #{self._epoch}" + log_msg = ( + f"{log_msg_prefix}: test_reward: {rew:.6f} ± {rew_std:.6f},{cur_info}" + f" best_reward: {self._best_reward:.6f} ± " + f"{self._best_reward_std:.6f}{best_info} in #{self._best_epoch}" + ) + log.info(log_msg) + if self.params.verbose: + print(log_msg, flush=True) + + # determine whether training shall be stopped early + stop_fn_flag = self._should_stop_training_early(score=self._best_score) + + return test_stat, stop_fn_flag + + @abstractmethod + def _training_step(self) -> _TrainingStepResult: + """Performs one training step.""" + + def _update_moving_avg_stats_and_log_update_data(self, update_stat: TrainingStats) -> None: + """Log losses, update moving average stats, and also modify the smoothed_loss in update_stat.""" + cur_losses_dict = update_stat.get_loss_stats_dict() + update_stat.smoothed_loss = self._update_moving_avg_stats_and_get_averaged_data( + cur_losses_dict, + ) + self._logger.log_update_data(asdict(update_stat), self._current_update_step) + + # TODO: seems convoluted, there should be a better way of dealing with the moving average stats + def _update_moving_avg_stats_and_get_averaged_data( + self, + data: dict[str, float], + ) -> dict[str, float]: + """Add entries to the moving average object in the trainer and retrieve the averaged results. + + :param data: any entries to be tracked in the moving average object. + :return: A dictionary containing the averaged values of the tracked entries. + + """ + smoothed_data = {} + for key, loss_item in data.items(): + self._stat[key].add(loss_item) + smoothed_data[key] = self._stat[key].get() + return smoothed_data + + def run( + self, reset_collectors: bool = True, reset_collector_buffers: bool = False + ) -> InfoStats: + """Runs the training process with the configuration given at construction. + + :param reset_collectors: whether to reset the collectors prior to starting the training process. + Specifically, this will reset the environments in the collectors (starting new episodes), + and the statistics stored in the collector. Whether the contained buffers will be reset/cleared + is determined by the `reset_buffer` parameter. + :param reset_collector_buffers: whether, for the case where the collectors are reset, to reset/clear the + contained buffers as well. + This has no effect if `reset_collectors` is False. + """ + self.reset( + reset_collectors=reset_collectors, reset_collector_buffers=reset_collector_buffers + ) + + while self._epoch < self.params.max_epochs and not self._stop_fn_flag: + self.execute_epoch() + + return self._create_info_stats() + + +class OfflineTrainer(Trainer[OfflineAlgorithm, OfflineTrainerParams]): + """An offline trainer, which samples mini-batches from a given buffer and passes them to + the algorithm's update function. + """ + + def __init__( + self, + algorithm: OfflineAlgorithm, + params: OfflineTrainerParams, + ): + super().__init__(algorithm, params) + self._buffer = algorithm.process_buffer(self.params.buffer) + + class _TrainingStepResult(Trainer._TrainingStepResult): + def __init__(self, training_stats: TrainingStats, env_step_advancement: int): + self._training_stats = training_stats + self._env_step_advancement = env_step_advancement + + def get_steps_in_epoch_advancement(self) -> int: + return 1 + + def get_collect_stats(self) -> None: + return None + + def get_training_stats(self) -> TrainingStats: + return self._training_stats + + def is_training_done(self) -> bool: + return False + + def get_env_step_advancement(self) -> int: + return self._env_step_advancement + + def _training_step(self) -> _TrainingStepResult: + with policy_within_training_step(self.algorithm.policy): + # Note: since sample_size=batch_size, this will perform + # exactly one gradient step. This is why we don't need to calculate the + # number of gradient steps, like in the on-policy case. + training_stats = self.algorithm.update( + sample_size=self.params.batch_size, buffer=self._buffer + ) + self._update_moving_avg_stats_and_log_update_data(training_stats) + self._policy_update_time += training_stats.train_time + return self._TrainingStepResult( + training_stats=training_stats, env_step_advancement=self.params.batch_size + ) + + def _create_epoch_pbar_data_dict( + self, training_step_result: Trainer._TrainingStepResult + ) -> dict[str, str]: + return {} + + +class OnlineTrainer( + Trainer[TAlgorithm, TOnlineTrainerParams], Generic[TAlgorithm, TOnlineTrainerParams], ABC +): + """ + An online trainer, which collects data from the environment in each training step and + uses the collected data to perform an update step, the nature of which is to be defined + in subclasses. + """ + + def __init__( + self, + algorithm: TAlgorithm, + params: TOnlineTrainerParams, + ): + super().__init__(algorithm, params) + self._env_episode = 0 + """ + the total number of episodes collected in the environment + """ + + def _reset_collectors(self, reset_buffer: bool = False) -> None: + super()._reset_collectors(reset_buffer=reset_buffer) + self.params.train_collector.reset(reset_buffer=reset_buffer) + + def reset(self, reset_collectors: bool = True, reset_collector_buffers: bool = False) -> None: + super().reset( + reset_collectors=reset_collectors, reset_collector_buffers=reset_collector_buffers + ) + + if ( + self.params.test_in_train + and self.params.train_collector.policy is not self.algorithm.policy + ): + log.warning( + "The training data collector's policy is not the same as the one being trained, " + "yet test_in_train is enabled. This may lead to unexpected results." + ) + + self._env_episode = 0 + + class _TrainingStepResult(Trainer._TrainingStepResult): + def __init__( + self, + collect_stats: CollectStats, + training_stats: TrainingStats | None, + is_training_done: bool, + ): + self._collect_stats = collect_stats + self._training_stats = training_stats + self._is_training_done = is_training_done + + def get_steps_in_epoch_advancement(self) -> int: + return self.get_env_step_advancement() + + def get_collect_stats(self) -> CollectStats: + return self._collect_stats + + def get_training_stats(self) -> TrainingStats | None: + return self._training_stats + + def is_training_done(self) -> bool: + return self._is_training_done + + def get_env_step_advancement(self) -> int: + return self._collect_stats.n_collected_steps + + def _training_step(self) -> _TrainingStepResult: + """Perform one training step. + + For an online algorithm, a training step involves: + * collecting data + * for the case where `test_in_train` is activated, + determining whether the stop condition has been reached + (and returning without performing any actual training if so) + * performing a gradient update step + """ + with policy_within_training_step(self.algorithm.policy): + # collect data + collect_stats = self._collect_training_data() + + # determine whether we should stop training based on the data collected + should_stop_training = False + if self.params.test_in_train: + should_stop_training = self._test_in_train(collect_stats) + + # perform gradient update step (if not already done) + training_stats: TrainingStats | None = None + if not should_stop_training: + training_stats = self._update_step(collect_stats) + + return self._TrainingStepResult( + collect_stats=collect_stats, + training_stats=training_stats, + is_training_done=should_stop_training, + ) + + def _collect_training_data(self) -> CollectStats: + """Performs training data collection. + + :return: the data collection stats + """ + assert self.params.test_step_num_episodes is not None + assert self.params.train_collector is not None + + if self.params.train_fn: + self.params.train_fn(self._epoch, self._env_step) + + collect_stats = self.params.train_collector.collect( + n_step=self.params.collection_step_num_env_steps, + n_episode=self.params.collection_step_num_episodes, + ) + TraceLogger.log( + log, + lambda: f"Collected {collect_stats.n_collected_steps} steps, {collect_stats.n_collected_episodes} episodes", + ) + + if self.params.train_collector.buffer.hasnull(): + from tianshou.data.collector import EpisodeRolloutHook + from tianshou.env import DummyVectorEnv + + raise MalformedBufferError( + f"Encountered NaNs in buffer after {self._env_step} steps." + f"Such errors are usually caused by either a bug in the environment or by " + f"problematic implementations {EpisodeRolloutHook.__class__.__name__}. " + f"For debugging such issues it is recommended to run the training in a single process, " + f"e.g., by using {DummyVectorEnv.__class__.__name__}.", + ) + + if collect_stats.n_collected_episodes > 0: + assert collect_stats.returns_stat is not None # for mypy + assert collect_stats.lens_stat is not None # for mypy + if self.params.multi_agent_return_reduction: + rew = self.params.multi_agent_return_reduction(collect_stats.returns) + collect_stats.returns = rew + collect_stats.returns_stat = SequenceSummaryStats.from_sequence(rew) + + # update collection stats specific to this specialization + self._env_episode += collect_stats.n_collected_episodes + + return collect_stats + + def _test_in_train( + self, + train_collect_stats: CollectStats, + ) -> bool: + """ + Performs a test step if the data collected in the current training step suggests that performance + is good enough to stop training early. If the test step confirms that performance is indeed good + enough, returns True, and False otherwise. + + Specifically, applies the early stopping criterion to the data collected in the current training step, + and if the criterion is satisfied, performs a test step which returns the relevant result. + + :param train_collect_stats: the data collection stats from the preceding collection step + :return: flag indicating whether to stop training early + """ + should_stop_training = False + + # check whether the stop criterion is satisfied based on the data collected in the training step + # (if any full episodes were indeed collected) + if train_collect_stats.n_collected_episodes > 0 and self._should_stop_training_early( + collect_stats=train_collect_stats + ): + # apply a test step, temporarily switching out of "is_training_step" semantics such that the policy can + # be evaluated, in order to determine whether we should stop training + with policy_within_training_step(self.algorithm.policy, enabled=False): + _, should_stop_training = self._test_step( + log_msg_prefix=f"Test step triggered by train stats (env_step={self._env_step})" + ) + + return should_stop_training + + @abstractmethod + def _update_step( + self, + collect_stats: CollectStatsBase, + ) -> TrainingStats: + """Performs a gradient update step, calling the algorithm's update method accordingly. + + :param collect_stats: provides info about the preceding data collection step. + """ + + def _create_epoch_pbar_data_dict( + self, training_step_result: Trainer._TrainingStepResult + ) -> dict[str, str]: + collect_stats = training_step_result.get_collect_stats() + assert collect_stats is not None + result = { + "env_step": str(self._env_step), + "env_episode": str(self._env_episode), + "n_ep": str(collect_stats.n_collected_episodes), + "n_st": str(collect_stats.n_collected_steps), + } + # return and episode length info is only available if at least one episode was completed + if collect_stats.n_collected_episodes > 0: + assert collect_stats.returns_stat is not None + assert collect_stats.lens_stat is not None + result.update( + { + "rew": f"{collect_stats.returns_stat.mean:.2f}", + "len": str(int(collect_stats.lens_stat.mean)), + } + ) + return result + + +class OffPolicyTrainer(OnlineTrainer[OffPolicyAlgorithm, OffPolicyTrainerParams]): + """An off-policy trainer, which samples mini-batches from the buffer of collected data and passes them to + algorithm's `update` function. + + The algorithm's `update` method is expected to not perform additional mini-batching but just update + model parameters from the received mini-batch. + """ + + def _update_step( + self, + collect_stats: CollectStatsBase, + ) -> TrainingStats: + """Perform `update_step_num_gradient_steps_per_sample * n_collected_steps` gradient steps by sampling + mini-batches from the buffer. + + :param collect_stats: the :class:`~TrainingStats` instance returned by the last gradient step. Some values + in it will be replaced by their moving averages. + """ + assert self.params.train_collector is not None + n_collected_steps = collect_stats.n_collected_steps + n_gradient_steps = round( + self.params.update_step_num_gradient_steps_per_sample * n_collected_steps + ) + if n_gradient_steps == 0: + raise ValueError( + f"n_gradient_steps is 0, n_collected_steps={n_collected_steps}, " + f"update_step_num_gradient_steps_per_sample={self.params.update_step_num_gradient_steps_per_sample}", + ) + + update_stat = None + for _ in self._pbar( + range(n_gradient_steps), + desc="Offpolicy gradient update", + position=0, + leave=False, + ): + update_stat = self._sample_and_update(self.params.train_collector.buffer) + self._policy_update_time += update_stat.train_time + + # TODO: only the last update_stat is returned, should be improved + assert update_stat is not None + return update_stat + + def _sample_and_update(self, buffer: ReplayBuffer) -> TrainingStats: + """Sample a mini-batch, perform one gradient step, and update the _gradient_step counter.""" + # Note: since sample_size=batch_size, this will perform + # exactly one gradient step. This is why we don't need to calculate the + # number of gradient steps, like in the on-policy case. + update_stat = self.algorithm.update(sample_size=self.params.batch_size, buffer=buffer) + self._update_moving_avg_stats_and_log_update_data(update_stat) + return update_stat + + +class OnPolicyTrainer(OnlineTrainer[OnPolicyAlgorithm, OnPolicyTrainerParams]): + """An on-policy trainer, which passes the entire buffer to the algorithm's `update` methods and + resets the buffer thereafter. + + Note that it is expected that the update method of the algorithm will perform + batching when using this trainer. + """ + + def _update_step( + self, + collect_stats: CollectStatsBase | None = None, + ) -> TrainingStats: + """Perform one on-policy update by passing the entire buffer to the algorithm's update method.""" + assert self.params.train_collector is not None + log.info( + f"Performing on-policy update on buffer of length {len(self.params.train_collector.buffer)}", + ) + training_stat = self.algorithm.update( + buffer=self.params.train_collector.buffer, + batch_size=self.params.batch_size, + repeat=self.params.update_step_num_repetitions, + ) + + # just for logging, no functional role + self._policy_update_time += training_stat.train_time + + # Note 2: in the policy-update we modify the buffer, which is not very clean. + # currently the modification will erase previous samples but keep things like + # _ep_rew and _ep_len (b/c keep_statistics=True). This is needed since the collection might have stopped + # in the middle of an episode and in the next collect iteration we need these numbers to compute correct + # return and episode length values. With the current code structure, this means that after an update and buffer reset + # such quantities can no longer be computed + # from samples still contained in the buffer, which is also not clean + self.params.train_collector.reset_buffer(keep_statistics=True) + + # The step is the number of mini-batches used for the update, so essentially + self._update_moving_avg_stats_and_log_update_data(training_stat) + + return training_stat diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py deleted file mode 100644 index 1f4369f72..000000000 --- a/tianshou/trainer/utils.py +++ /dev/null @@ -1,87 +0,0 @@ -import time -from collections.abc import Callable -from dataclasses import asdict - -import numpy as np - -from tianshou.data import ( - CollectStats, - InfoStats, - SequenceSummaryStats, - TimingStats, -) -from tianshou.data.collector import BaseCollector -from tianshou.utils import BaseLogger - - -def test_episode( - collector: BaseCollector, - test_fn: Callable[[int, int | None], None] | None, - epoch: int, - n_episode: int, - logger: BaseLogger | None = None, - global_step: int | None = None, - reward_metric: Callable[[np.ndarray], np.ndarray] | None = None, -) -> CollectStats: - """A simple wrapper of testing policy in collector.""" - collector.reset(reset_stats=False) - if test_fn: - test_fn(epoch, global_step) - result = collector.collect(n_episode=n_episode) - if reward_metric: # TODO: move into collector - rew = reward_metric(result.returns) - result.returns = rew - result.returns_stat = SequenceSummaryStats.from_sequence(rew) - if logger and global_step is not None: - assert result.n_collected_episodes > 0 - logger.log_test_data(asdict(result), global_step) - return result - - -def gather_info( - start_time: float, - policy_update_time: float, - gradient_step: int, - best_score: float, - best_reward: float, - best_reward_std: float, - train_collector: BaseCollector | None = None, - test_collector: BaseCollector | None = None, -) -> InfoStats: - """A simple wrapper of gathering information from collectors. - - :return: InfoStats object with times computed based on the `start_time` and - episode/step counts read off the collectors. No computation of - expensive statistics is done here. - """ - duration = max(0.0, time.time() - start_time) - test_time = 0.0 - update_speed = 0.0 - train_time_collect = 0.0 - if test_collector is not None: - test_time = test_collector.collect_time - - if train_collector is not None: - train_time_collect = train_collector.collect_time - update_speed = train_collector.collect_step / (duration - test_time) - - timing_stat = TimingStats( - total_time=duration, - train_time=duration - test_time, - train_time_collect=train_time_collect, - train_time_update=policy_update_time, - test_time=test_time, - update_speed=update_speed, - ) - - return InfoStats( - gradient_step=gradient_step, - best_score=best_score, - best_reward=best_reward, - best_reward_std=best_reward_std, - train_step=train_collector.collect_step if train_collector is not None else 0, - train_episode=train_collector.collect_episode if train_collector is not None else 0, - test_step=test_collector.collect_step if test_collector is not None else 0, - test_episode=test_collector.collect_episode if test_collector is not None else 0, - timing=timing_stat, - ) diff --git a/tianshou/utils/__init__.py b/tianshou/utils/__init__.py index 47a3c4497..a23841b36 100644 --- a/tianshou/utils/__init__.py +++ b/tianshou/utils/__init__.py @@ -1,9 +1,8 @@ """Utils package.""" -from tianshou.utils.logger.base import BaseLogger, LazyLogger +from tianshou.utils.logger.logger_base import BaseLogger, LazyLogger from tianshou.utils.logger.tensorboard import TensorboardLogger from tianshou.utils.logger.wandb import WandbLogger -from tianshou.utils.lr_scheduler import MultipleLRSchedulers from tianshou.utils.progress_bar import DummyTqdm, tqdm_config from tianshou.utils.statistics import MovAvg, RunningMeanStd from tianshou.utils.warning import deprecation @@ -18,5 +17,4 @@ "TensorboardLogger", "LazyLogger", "WandbLogger", - "MultipleLRSchedulers", ] diff --git a/tianshou/utils/lagged_network.py b/tianshou/utils/lagged_network.py new file mode 100644 index 000000000..37a2aa71b --- /dev/null +++ b/tianshou/utils/lagged_network.py @@ -0,0 +1,87 @@ +from copy import deepcopy +from dataclasses import dataclass +from typing import Self + +import torch + + +def polyak_parameter_update(tgt: torch.nn.Module, src: torch.nn.Module, tau: float) -> None: + """Softly updates the parameters of a target network `tgt` with the parameters of a source network `src` + using Polyak averaging: `tau * src + (1 - tau) * tgt`. + + :param tgt: the target network that receives the parameter update + :param src: the source network whose parameters are used for the update + :param tau: the fraction with which to use the source network's parameters, the inverse `1-tau` being + the fraction with which to retain the target network's parameters. + """ + for tgt_param, src_param in zip(tgt.parameters(), src.parameters(), strict=True): + tgt_param.data.copy_(tau * src_param.data + (1 - tau) * tgt_param.data) + + +class EvalModeModuleWrapper(torch.nn.Module): + """ + A wrapper around a torch.nn.Module that forces the module to eval mode. + + The wrapped module supports only the forward method, attribute access is not supported. + **NOTE**: It is *not* recommended to support attribute/method access beyond this via `__getattr__`, + because torch.nn.Module already heavily relies on `__getattr__` to provides its own attribute access. + Overriding it naively will cause problems! + But it's also not necessary for our use cases; forward is enough. + """ + + def __init__(self, m: torch.nn.Module): + super().__init__() + m.eval() + self.module = m + + def forward(self, *args, **kwargs): # type: ignore + self.module.eval() + return self.module(*args, **kwargs) + + def train(self, mode: bool = True) -> Self: + super().train(mode=mode) + self.module.eval() # force eval mode + return self + + +@dataclass +class LaggedNetworkPair: + target: torch.nn.Module + source: torch.nn.Module + + +class LaggedNetworkCollection: + def __init__(self) -> None: + self._lagged_network_pairs: list[LaggedNetworkPair] = [] + + def add_lagged_network(self, source: torch.nn.Module) -> EvalModeModuleWrapper: + """ + Adds a lagged network to the collection, returning the target network, which + is forced to eval mode. The target network is a copy of the source network, + which, however, supports only the forward method (hence the type torch.nn.Module); + attribute access is not supported. + + :param source: the source network whose parameters are to be copied to the target network + :return: the target network, which supports only the forward method and is forced to eval mode + """ + target = deepcopy(source) + self._lagged_network_pairs.append(LaggedNetworkPair(target, source)) + return EvalModeModuleWrapper(target) + + def polyak_parameter_update(self, tau: float) -> None: + """Softly updates the parameters of each target network `tgt` with the parameters of a source network `src` + using Polyak averaging: `tau * src + (1 - tau) * tgt`. + + :param tau: the fraction with which to use the source network's parameters, the inverse `1-tau` being + the fraction with which to retain the target network's parameters. + """ + for pair in self._lagged_network_pairs: + polyak_parameter_update(pair.target, pair.source, tau) + + def full_parameter_update(self) -> None: + """Fully updates the target networks with the source networks' parameters (exact copy).""" + for pair in self._lagged_network_pairs: + for tgt_param, src_param in zip( + pair.target.parameters(), pair.source.parameters(), strict=True + ): + tgt_param.data.copy_(src_param.data) diff --git a/tianshou/utils/logger/base.py b/tianshou/utils/logger/logger_base.py similarity index 96% rename from tianshou/utils/logger/base.py rename to tianshou/utils/logger/logger_base.py index 2ff6c6760..305606ef0 100644 --- a/tianshou/utils/logger/base.py +++ b/tianshou/utils/logger/logger_base.py @@ -104,7 +104,7 @@ def log_update_data(self, log_data: dict, step: int) -> None: # TODO: move interval check to calling method if step - self.last_log_update_step >= self.update_interval: log_data = self.prepare_dict_for_logging(log_data) - self.write(f"{DataScope.UPDATE}/gradient_step", step, log_data) + self.write(f"{DataScope.UPDATE}/update_step", step, log_data) self.last_log_update_step = step def log_info_data(self, log_data: dict, step: int) -> None: @@ -125,14 +125,14 @@ def save_data( self, epoch: int, env_step: int, - gradient_step: int, + update_step: int, save_checkpoint_fn: Callable[[int, int, int], str] | None = None, ) -> None: """Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer. :param epoch: the epoch in trainer. :param env_step: the env_step in trainer. - :param gradient_step: the gradient_step in trainer. + :param update_step: the update step count in the trainer. :param function save_checkpoint_fn: a hook defined by user, see trainer documentation for detail. """ @@ -144,7 +144,7 @@ def restore_data(self) -> tuple[int, int, int]: If it finds nothing or an error occurs during the recover process, it will return the default parameters. - :return: epoch, env_step, gradient_step. + :return: epoch, env_step, update_step. """ @staticmethod @@ -180,7 +180,7 @@ def save_data( self, epoch: int, env_step: int, - gradient_step: int, + update_step: int, save_checkpoint_fn: Callable[[int, int, int], str] | None = None, ) -> None: pass diff --git a/tianshou/utils/logger/tensorboard.py b/tianshou/utils/logger/tensorboard.py index ef504cb58..dba11b555 100644 --- a/tianshou/utils/logger/tensorboard.py +++ b/tianshou/utils/logger/tensorboard.py @@ -6,7 +6,7 @@ from tensorboard.backend.event_processing import event_accumulator from torch.utils.tensorboard import SummaryWriter -from tianshou.utils.logger.base import ( +from tianshou.utils.logger.logger_base import ( VALID_LOG_VALS, VALID_LOG_VALS_TYPE, BaseLogger, @@ -106,18 +106,18 @@ def save_data( self, epoch: int, env_step: int, - gradient_step: int, + update_step: int, save_checkpoint_fn: Callable[[int, int, int], str] | None = None, ) -> None: if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval: self.last_save_step = epoch - save_checkpoint_fn(epoch, env_step, gradient_step) + save_checkpoint_fn(epoch, env_step, update_step) self.write("save/epoch", epoch, {"save/epoch": epoch}) self.write("save/env_step", env_step, {"save/env_step": env_step}) self.write( "save/gradient_step", - gradient_step, - {"save/gradient_step": gradient_step}, + update_step, + {"save/gradient_step": update_step}, ) def restore_data(self) -> tuple[int, int, int]: diff --git a/tianshou/utils/logger/wandb.py b/tianshou/utils/logger/wandb.py index 9172bf54b..f92c3fd2c 100644 --- a/tianshou/utils/logger/wandb.py +++ b/tianshou/utils/logger/wandb.py @@ -7,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.utils import BaseLogger, TensorboardLogger -from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE, TRestoredData +from tianshou.utils.logger.logger_base import VALID_LOG_VALS_TYPE, TRestoredData with contextlib.suppress(ImportError): import wandb @@ -26,8 +26,6 @@ class WandbLogger(BaseLogger): logger = WandbLogger() logger.load(SummaryWriter(log_path)) - result = OnpolicyTrainer(policy, train_collector, test_collector, - logger=logger).run() :param train_interval: the log interval in log_train_data(). Default to 1000. :param test_interval: the log interval in log_test_data(). Default to 1. @@ -132,20 +130,20 @@ def save_data( self, epoch: int, env_step: int, - gradient_step: int, + update_step: int, save_checkpoint_fn: Callable[[int, int, int], str] | None = None, ) -> None: """Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer. :param epoch: the epoch in trainer. :param env_step: the env_step in trainer. - :param gradient_step: the gradient_step in trainer. + :param update_step: the gradient_step in trainer. :param function save_checkpoint_fn: a hook defined by user, see trainer documentation for detail. """ if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval: self.last_save_step = epoch - checkpoint_path = save_checkpoint_fn(epoch, env_step, gradient_step) + checkpoint_path = save_checkpoint_fn(epoch, env_step, update_step) checkpoint_artifact = wandb.Artifact( "run_" + self.wandb_run.id + "_checkpoint", # type: ignore @@ -153,7 +151,7 @@ def save_data( metadata={ "save/epoch": epoch, "save/env_step": env_step, - "save/gradient_step": gradient_step, + "save/gradient_step": update_step, "checkpoint_path": str(checkpoint_path), }, ) diff --git a/tianshou/utils/lr_scheduler.py b/tianshou/utils/lr_scheduler.py deleted file mode 100644 index 59890c75c..000000000 --- a/tianshou/utils/lr_scheduler.py +++ /dev/null @@ -1,40 +0,0 @@ -import torch - - -class MultipleLRSchedulers: - """A wrapper for multiple learning rate schedulers. - - Every time :meth:`~tianshou.utils.MultipleLRSchedulers.step` is called, - it calls the step() method of each of the schedulers that it contains. - Example usage: - :: - - scheduler1 = ConstantLR(opt1, factor=0.1, total_iters=2) - scheduler2 = ExponentialLR(opt2, gamma=0.9) - scheduler = MultipleLRSchedulers(scheduler1, scheduler2) - policy = PPOPolicy(..., lr_scheduler=scheduler) - """ - - def __init__(self, *args: torch.optim.lr_scheduler.LRScheduler): - self.schedulers = args - - def step(self) -> None: - """Take a step in each of the learning rate schedulers.""" - for scheduler in self.schedulers: - scheduler.step() - - def state_dict(self) -> list[dict]: - """Get state_dict for each of the learning rate schedulers. - - :return: A list of state_dict of learning rate schedulers. - """ - return [s.state_dict() for s in self.schedulers] - - def load_state_dict(self, state_dict: list[dict]) -> None: - """Load states from state_dict. - - :param state_dict: A list of learning rate scheduler - state_dict, in the same order as the schedulers. - """ - for s, sd in zip(self.schedulers, state_dict, strict=True): - s.__dict__.update(sd) diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 243a04093..b6da01b5f 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -7,9 +7,10 @@ from gymnasium import spaces from torch import nn -from tianshou.data.batch import Batch, BatchProtocol -from tianshou.data.types import RecurrentStateBatch +from tianshou.data.batch import Batch +from tianshou.data.types import RecurrentStateBatch, TObs from tianshou.utils.space_info import ActionSpaceInfo +from tianshou.utils.torch_utils import torch_device ModuleType = type[nn.Module] ArgsType = tuple[Any, ...] | dict[Any, Any] | Sequence[tuple[Any, ...]] | Sequence[dict[Any, Any]] @@ -46,33 +47,52 @@ def miniblock( return layers -class MLP(nn.Module): - """Simple MLP backbone. - - Create a MLP of size input_dim * hidden_sizes[0] * hidden_sizes[1] * ... - * hidden_sizes[-1] * output_dim +class ModuleWithVectorOutput(nn.Module): + """ + A module that outputs a vector of a known size. - :param input_dim: dimension of the input vector. - :param output_dim: dimension of the output vector. If set to 0, there - is no final linear layer. - :param hidden_sizes: shape of MLP passed in as a list, not including - input_dim and output_dim. - :param norm_layer: use which normalization before activation, e.g., - ``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization. - You can also pass a list of normalization modules with the same length - of hidden_sizes, to use different normalization module in different - layers. Default to no normalization. - :param activation: which activation to use after each layer, can be both - the same activation for all layers if passed in nn.Module, or different - activation for different Modules if passed in a list. Default to - nn.ReLU. - :param device: which device to create this model on. Default to None. - :param linear_layer: use this module as linear layer. Default to nn.Linear. - :param flatten_input: whether to flatten input data. Default to True. + Use `from_module` to adapt a module to this interface. """ + def __init__(self, output_dim: int) -> None: + """:param output_dim: the dimension of the output vector.""" + super().__init__() + self.output_dim = output_dim + + @staticmethod + def from_module(module: nn.Module, output_dim: int) -> "ModuleWithVectorOutput": + """ + :param module: the module to adapt. + :param output_dim: dimension of the output vector produced by the module. + """ + return ModuleWithVectorOutputAdapter(module, output_dim) + + def get_output_dim(self) -> int: + """:return: the dimension of the output vector.""" + return self.output_dim + + +class ModuleWithVectorOutputAdapter(ModuleWithVectorOutput): + """Adapts a module with vector output to provide the :class:`ModuleWithVectorOutput` interface.""" + + def __init__(self, module: nn.Module, output_dim: int) -> None: + """ + :param module: the module to adapt. + :param output_dim: the dimension of the output vector produced by the module. + """ + super().__init__(output_dim) + self.module = module + + def forward(self, *args: Any, **kwargs: Any) -> Any: + return self.module(*args, **kwargs) + + +class MLP(ModuleWithVectorOutput): + """Simple MLP backbone.""" + def __init__( self, + *, input_dim: int, output_dim: int = 0, hidden_sizes: Sequence[int] = (), @@ -80,12 +100,27 @@ def __init__( norm_args: ArgsType | None = None, activation: ModuleType | Sequence[ModuleType] | None = nn.ReLU, act_args: ArgsType | None = None, - device: str | int | torch.device | None = None, linear_layer: TLinearLayer = nn.Linear, flatten_input: bool = True, ) -> None: - super().__init__() - self.device = device + """ + :param input_dim: dimension of the input vector. + :param output_dim: dimension of the output vector. If set to 0, there + is no explicit final linear layer and the output dimension is the last hidden layer's dimension. + :param hidden_sizes: shape of MLP passed in as a list, not including + input_dim and output_dim. + :param norm_layer: use which normalization before activation, e.g., + ``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization. + You can also pass a list of normalization modules with the same length + of hidden_sizes, to use different normalization module in different + layers. Default to no normalization. + :param activation: which activation to use after each layer, can be both + the same activation for all layers if passed in nn.Module, or different + activation for different Modules if passed in a list. Default to + nn.ReLU. + :param linear_layer: use this module as linear layer. Default to nn.Linear. + :param flatten_input: whether to flatten input data. Default to True. + """ if norm_layer: if isinstance(norm_layer, list): assert len(norm_layer) == len(hidden_sizes) @@ -130,13 +165,14 @@ def __init__( model += miniblock(in_dim, out_dim, norm, norm_args, activ, act_args, linear_layer) if output_dim > 0: model += [linear_layer(hidden_sizes[-1], output_dim)] - self.output_dim = output_dim or hidden_sizes[-1] + super().__init__(output_dim or hidden_sizes[-1]) self.model = nn.Sequential(*model) self.flatten_input = flatten_input @no_type_check def forward(self, obs: np.ndarray | torch.Tensor) -> torch.Tensor: - obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32) + device = torch_device(self) + obs = torch.as_tensor(obs, device=device, dtype=torch.float32) if self.flatten_input: obs = obs.flatten(1) return self.model(obs) @@ -145,24 +181,70 @@ def forward(self, obs: np.ndarray | torch.Tensor) -> torch.Tensor: TRecurrentState = TypeVar("TRecurrentState", bound=Any) -class NetBase(nn.Module, Generic[TRecurrentState], ABC): - """Interface for NNs used in policies.""" +class ActionReprNet(Generic[TRecurrentState], nn.Module, ABC): + """Abstract base class for neural networks used to compute action-related + representations from environment observations, which defines the + signature of the forward method. + + An action-related representation can be a number of things, including: + * a distribution over actions in a discrete action space in the form of a vector of + unnormalized log probabilities (called "logits" in PyTorch jargon) + * the Q-values of all actions in a discrete action space + * the parameters of a distribution (e.g., mean and std. dev. for a Gaussian distribution) + over actions in a continuous action space + """ @abstractmethod def forward( self, - obs: np.ndarray | torch.Tensor, + obs: TObs, state: TRecurrentState | None = None, info: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, TRecurrentState | None]: - pass + ) -> tuple[torch.Tensor | Sequence[torch.Tensor], TRecurrentState | None]: + """ + The main method for tianshou to compute action representations (such as actions, inputs of distributions, Q-values, etc) + from env observations. + Implementations will always make use of the preprocess_net as the first processing step. + + :param obs: the observations from the environment as retrieved from `ObsBatchProtocol.obs`. + If the environment is a dict env, this will be an instance of Batch, otherwise it will be an array (or tensor if your + env returns tensors). + :param state: the hidden state of the RNN, if applicable + :param info: the info object from the environment step + :return: a tuple (action_repr, hidden_state), where action_repr is either an actual action for the environment or + a representation from which it can be retrieved/sampled (e.g., mean and std for a Gaussian distribution), + and hidden_state is the new hidden state of the RNN, if applicable. + """ -class Net(NetBase[Any]): - """Wrapper of MLP to support more specific DRL usage. +class ActionReprNetWithVectorOutput(Generic[T], ActionReprNet[T], ModuleWithVectorOutput): + """A neural network for the computation of action-related representations which outputs + a vector of a known size. + """ - For advanced usage (how to customize the network), please refer to - :ref:`build_the_network`. + def __init__(self, output_dim: int) -> None: + super().__init__(output_dim) + + +class Actor(Generic[T], ActionReprNetWithVectorOutput[T], ABC): + @abstractmethod + def get_preprocess_net(self) -> ModuleWithVectorOutput: + """Returns the network component that is used for pre-processing, i.e. + the component which produces a latent representation, which then is transformed + into the final output. + This is, therefore, the first part of the network which processes the input. + For example, a CNN is often used in Atari examples. + + We need this method to be able to share latent representation computations with + other networks (e.g. critics) within an algorithm. + + Actors that do not have a pre-processing stage can return nn.Identity() + (see :class:`RandomActor` for an example). + """ + + +class Net(ActionReprNetWithVectorOutput[Any]): + """A multi-layer perceptron which outputs an action-related representation. :param state_shape: int or a sequence of int of the shape of state. :param action_shape: int or a sequence of int of the shape of action. @@ -176,8 +258,6 @@ class Net(NetBase[Any]): the same activation for all layers if passed in nn.Module, or different activation for different Modules if passed in a list. Default to nn.ReLU. - :param device: specify the device when the network actually runs. Default - to "cpu". :param softmax: whether to apply a softmax layer over the last layer's output. :param concat: whether the input shape is concatenated by state_shape @@ -205,6 +285,7 @@ class Net(NetBase[Any]): def __init__( self, + *, state_shape: int | Sequence[int], action_shape: TActionShape = 0, hidden_sizes: Sequence[int] = (), @@ -212,42 +293,33 @@ def __init__( norm_args: ArgsType | None = None, activation: ModuleType | Sequence[ModuleType] | None = nn.ReLU, act_args: ArgsType | None = None, - device: str | int | torch.device = "cpu", softmax: bool = False, concat: bool = False, num_atoms: int = 1, dueling_param: tuple[dict[str, Any], dict[str, Any]] | None = None, linear_layer: TLinearLayer = nn.Linear, ) -> None: - super().__init__() - self.device = device - self.softmax = softmax - self.num_atoms = num_atoms - self.Q: MLP | None = None - self.V: MLP | None = None - input_dim = int(np.prod(state_shape)) action_dim = int(np.prod(action_shape)) * num_atoms if concat: input_dim += action_dim - self.use_dueling = dueling_param is not None - output_dim = action_dim if not self.use_dueling and not concat else 0 - self.model = MLP( - input_dim, - output_dim, - hidden_sizes, - norm_layer, - norm_args, - activation, - act_args, - device, - linear_layer, + use_dueling = dueling_param is not None + model = MLP( + input_dim=input_dim, + output_dim=action_dim if not use_dueling and not concat else 0, + hidden_sizes=hidden_sizes, + norm_layer=norm_layer, + norm_args=norm_args, + activation=activation, + act_args=act_args, + linear_layer=linear_layer, ) - if self.use_dueling: # dueling DQN + Q: MLP | None = None + V: MLP | None = None + if use_dueling: # dueling DQN assert dueling_param is not None kwargs_update = { - "input_dim": self.model.output_dim, - "device": self.device, + "input_dim": model.output_dim, } # Important: don't change the original dict (e.g., don't use .update()) q_kwargs = {**dueling_param[0], **kwargs_update} @@ -255,17 +327,25 @@ def __init__( q_kwargs["output_dim"] = 0 if concat else action_dim v_kwargs["output_dim"] = 0 if concat else num_atoms - self.Q, self.V = MLP(**q_kwargs), MLP(**v_kwargs) - self.output_dim = self.Q.output_dim + Q, V = MLP(**q_kwargs), MLP(**v_kwargs) + output_dim = Q.output_dim else: - self.output_dim = self.model.output_dim + output_dim = model.output_dim + + super().__init__(output_dim) + self.use_dueling = use_dueling + self.softmax = softmax + self.num_atoms = num_atoms + self.model = model + self.Q = Q + self.V = V def forward( self, - obs: np.ndarray | torch.Tensor, - state: Any = None, + obs: TObs, + state: T | None = None, info: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, Any]: + ) -> tuple[torch.Tensor, T | Any]: """Mapping: obs -> flatten (inside MLP)-> logits. :param obs: @@ -289,7 +369,7 @@ def forward( return logits, state -class Recurrent(NetBase[RecurrentStateBatch]): +class Recurrent(ActionReprNetWithVectorOutput[RecurrentStateBatch]): """Simple Recurrent network based on LSTM. For advanced usage (how to customize the network), please refer to @@ -298,14 +378,14 @@ class Recurrent(NetBase[RecurrentStateBatch]): def __init__( self, + *, layer_num: int, state_shape: int | Sequence[int], action_shape: TActionShape, - device: str | int | torch.device = "cpu", hidden_layer_size: int = 128, ) -> None: - super().__init__() - self.device = device + output_dim = int(np.prod(action_shape)) + super().__init__(output_dim) self.nn = nn.LSTM( input_size=hidden_layer_size, hidden_size=hidden_layer_size, @@ -313,11 +393,14 @@ def __init__( batch_first=True, ) self.fc1 = nn.Linear(int(np.prod(state_shape)), hidden_layer_size) - self.fc2 = nn.Linear(hidden_layer_size, int(np.prod(action_shape))) + self.fc2 = nn.Linear(hidden_layer_size, output_dim) + + def get_preprocess_net(self) -> ModuleWithVectorOutput: + return ModuleWithVectorOutput.from_module(nn.Identity(), self.output_dim) def forward( self, - obs: np.ndarray | torch.Tensor, + obs: TObs, state: RecurrentStateBatch | None = None, info: dict[str, Any] | None = None, ) -> tuple[torch.Tensor, RecurrentStateBatch]: @@ -340,7 +423,8 @@ def forward( f"Expected to find keys 'hidden' and 'cell' but instead found {state.keys()}", ) - obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32) + device = torch_device(self) + obs = torch.as_tensor(obs, device=device, dtype=torch.float32) # obs [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. @@ -397,7 +481,7 @@ class DataParallelNet(nn.Module): Tensor. If the input is a nested dictionary, the user should create a similar class to do the same thing. - :param nn.Module net: the network to be distributed in different GPUs. + :param net: the network to be distributed in different GPUs. """ def __init__(self, net: nn.Module) -> None: @@ -406,13 +490,33 @@ def __init__(self, net: nn.Module) -> None: def forward( self, - obs: np.ndarray | torch.Tensor, + obs: TObs, *args: Any, **kwargs: Any, ) -> tuple[Any, Any]: if not isinstance(obs, torch.Tensor): obs = torch.as_tensor(obs, dtype=torch.float32) - return self.net(obs=obs.cuda(), *args, **kwargs) # noqa: B026 + obs = obs.cuda() + return self.net(obs, *args, **kwargs) + + +# The same functionality as DataParallelNet +# The duplication is worth it because the ActionReprNet abstraction is so important +class ActionReprNetDataParallelWrapper(ActionReprNet): + def __init__(self, net: ActionReprNet) -> None: + super().__init__() + self.net = nn.DataParallel(net) + + def forward( + self, + obs: TObs, + state: TRecurrentState | None = None, + info: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, TRecurrentState | None]: + if not isinstance(obs, torch.Tensor): + obs = torch.as_tensor(obs, dtype=torch.float32) + obs = obs.cuda() + return self.net(obs, state=state, info=info) class EnsembleLinear(nn.Module): @@ -450,38 +554,30 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -# TODO: fix docstring -class BranchingNet(NetBase[Any]): +class BranchingNet(ActionReprNet): """Branching dual Q network. Network for the BranchingDQNPolicy, it uses a common network module, a value module - and action "branches" one for each dimension.It allows for a linear scaling + and action "branches" one for each dimension. It allows for a linear scaling of Q-value the output w.r.t. the number of dimensions in the action space. - For more info please refer to: arXiv:1711.08946. - :param state_shape: int or a sequence of int of the shape of state. - :param action_shape: int or a sequence of int of the shape of action. - :param action_peer_branch: int or a sequence of int of the number of actions in - each dimension. - :param common_hidden_sizes: shape of the common MLP network passed in as a list. - :param value_hidden_sizes: shape of the value MLP network passed in as a list. - :param action_hidden_sizes: shape of the action MLP network passed in as a list. - :param norm_layer: use which normalization before activation, e.g., - ``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization. - You can also pass a list of normalization modules with the same length - of hidden_sizes, to use different normalization module in different - layers. Default to no normalization. - :param activation: which activation to use after each layer, can be both - the same activation for all layers if passed in nn.Module, or different - activation for different Modules if passed in a list. Default to - nn.ReLU. - :param device: specify the device when the network actually runs. Default - to "cpu". - :param softmax: whether to apply a softmax layer over the last layer's - output. + + This network architecture efficiently handles environments with multiple independent + action dimensions by using a branching structure. Instead of representing all action + combinations (which grows exponentially), it represents each action dimension separately + (linear scaling). + For example, if there are 3 actions with 3 possible values each, then we would normally + need to consider 3^4 = 81 unique actions, whereas with this architecture, we can instead + use 3 branches with 4 actions per dimension, resulting in 3 * 4 = 12 values to be considered. + + Common use cases include multi-joint robotic control tasks, where each joint can be controlled + independently. + + For more information, please refer to: arXiv:1711.08946. """ def __init__( self, + *, state_shape: int | Sequence[int], num_branches: int = 0, action_per_branch: int = 2, @@ -492,41 +588,58 @@ def __init__( norm_args: ArgsType | None = None, activation: ModuleType | None = nn.ReLU, act_args: ArgsType | None = None, - device: str | int | torch.device = "cpu", ) -> None: + """ + :param state_shape: int or a sequence of int of the shape of state. + :param num_branches: number of action dimensions in the environment. + Each branch represents one independent action dimension. + For example, in a robot with 7 joints, you would set this to 7. + :param action_per_branch: Number of possible discrete values for each action dimension. + For example, if each joint can have 3 positions (left, center, right), + you would set this to 3. + :param common_hidden_sizes: shape of the common MLP network passed in as a list. + :param value_hidden_sizes: shape of the value MLP network passed in as a list. + :param action_hidden_sizes: shape of the action MLP network passed in as a list. + :param norm_layer: use which normalization before activation, e.g., + ``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization. + You can also pass a list of normalization modules with the same length + of hidden_sizes, to use different normalization module in different + layers. Default to no normalization. + :param activation: which activation to use after each layer, can be both + the same activation for all layers if passed in nn.Module, or different + activation for different Modules if passed in a list. Default to + nn.ReLU. + """ super().__init__() common_hidden_sizes = common_hidden_sizes or [] value_hidden_sizes = value_hidden_sizes or [] action_hidden_sizes = action_hidden_sizes or [] - self.device = device self.num_branches = num_branches self.action_per_branch = action_per_branch # common network common_input_dim = int(np.prod(state_shape)) common_output_dim = 0 self.common = MLP( - common_input_dim, - common_output_dim, - common_hidden_sizes, - norm_layer, - norm_args, - activation, - act_args, - device, + input_dim=common_input_dim, + output_dim=common_output_dim, + hidden_sizes=common_hidden_sizes, + norm_layer=norm_layer, + norm_args=norm_args, + activation=activation, + act_args=act_args, ) # value network value_input_dim = common_hidden_sizes[-1] value_output_dim = 1 self.value = MLP( - value_input_dim, - value_output_dim, - value_hidden_sizes, - norm_layer, - norm_args, - activation, - act_args, - device, + input_dim=value_input_dim, + output_dim=value_output_dim, + hidden_sizes=value_hidden_sizes, + norm_layer=norm_layer, + norm_args=norm_args, + activation=activation, + act_args=act_args, ) # action branching network action_input_dim = common_hidden_sizes[-1] @@ -534,14 +647,13 @@ def __init__( self.branches = nn.ModuleList( [ MLP( - action_input_dim, - action_output_dim, - action_hidden_sizes, - norm_layer, - norm_args, - activation, - act_args, - device, + input_dim=action_input_dim, + output_dim=action_output_dim, + hidden_sizes=action_hidden_sizes, + norm_layer=norm_layer, + norm_args=norm_args, + activation=activation, + act_args=act_args, ) for _ in range(self.num_branches) ], @@ -549,10 +661,10 @@ def __init__( def forward( self, - obs: np.ndarray | torch.Tensor, - state: Any = None, + obs: TObs, + state: T | None = None, info: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, Any]: + ) -> tuple[torch.Tensor, T | None]: """Mapping: obs -> model -> logits.""" common_out = self.common(obs) value_out = self.value(common_out) @@ -605,7 +717,7 @@ def preprocess_obs(obs: Batch | dict | torch.Tensor | np.ndarray) -> torch.Tenso @no_type_check def decorator_fn(net_class): class new_net_class(net_class): - def forward(self, obs: np.ndarray | torch.Tensor, *args, **kwargs) -> Any: + def forward(self, obs: TObs, *args, **kwargs) -> Any: return super().forward(preprocess_obs(obs), *args, **kwargs) return new_net_class @@ -613,28 +725,29 @@ def forward(self, obs: np.ndarray | torch.Tensor, *args, **kwargs) -> Any: return decorator_fn, new_state_shape -class BaseActor(nn.Module, ABC): - @abstractmethod - def get_preprocess_net(self) -> nn.Module: - pass +class AbstractContinuousActorProbabilistic(Actor, ABC): + """Type bound for probabilistic actors which output distribution parameters for continuous action spaces.""" - @abstractmethod - def get_output_dim(self) -> int: - pass - @abstractmethod - def forward( - self, - obs: np.ndarray | torch.Tensor, - state: Any = None, - info: dict[str, Any] | None = None, - ) -> tuple[Any, Any]: - # TODO: ALGO-REFACTORING. Marked to be addressed as part of Algorithm abstraction. - # Return type needs to be more specific - pass +class AbstractDiscreteActor(Actor, ABC): + """ + Type bound for discrete actors. + + For on-policy algos like Reinforce, this typically directly outputs unnormalized log + probabilities, which can be interpreted as "logits" in conjunction with a + `torch.distributions.Categorical` instance. + + In Tianshou, discrete actors are also used for computing action distributions within + Q-learning type algorithms (e.g., DQN). In this case, the observations are mapped + to a vector of Q-values (one for each action). In other words, the component is actually + a critic, not an actor in the traditional sense. + Note that when sampling actions, the Q-values can be interpreted as inputs for + a `torch.distributions.Categorical` instance, similar to the on-policy case mentioned + above. + """ -class RandomActor(BaseActor): +class RandomActor(AbstractContinuousActorProbabilistic, AbstractDiscreteActor): """An actor that returns random actions. For continuous action spaces, forward returns a batch of random actions sampled from the action space. @@ -643,7 +756,11 @@ class RandomActor(BaseActor): """ def __init__(self, action_space: spaces.Box | spaces.Discrete) -> None: - super().__init__() + if isinstance(action_space, spaces.Discrete): + output_dim = action_space.n + else: + output_dim = np.prod(action_space.shape) + super().__init__(int(output_dim)) self._action_space = action_space self._space_info = ActionSpaceInfo.from_space(action_space) @@ -655,8 +772,8 @@ def action_space(self) -> spaces.Box | spaces.Discrete: def space_info(self) -> ActionSpaceInfo: return self._space_info - def get_preprocess_net(self) -> nn.Module: - return nn.Identity() + def get_preprocess_net(self) -> ModuleWithVectorOutput: + return ModuleWithVectorOutput.from_module(nn.Identity(), self.output_dim) def get_output_dim(self) -> int: return self.space_info.action_dim @@ -667,58 +784,22 @@ def is_discrete(self) -> bool: def forward( self, - obs: np.ndarray | torch.Tensor | BatchProtocol, - state: Any | None = None, + obs: TObs, + state: T | None = None, info: dict[str, Any] | None = None, - ) -> tuple[np.ndarray, Any | None]: + ) -> tuple[torch.Tensor, T | None]: batch_size = len(obs) if isinstance(self.action_space, spaces.Box): action = np.stack([self.action_space.sample() for _ in range(batch_size)]) else: # Discrete Actors currently return an n-dimensional array of probabilities for each action action = 1 / self.action_space.n * np.ones((batch_size, self.action_space.n)) - return action, state + return torch.Tensor(action), state - def compute_action_batch(self, obs: np.ndarray | torch.Tensor | BatchProtocol) -> np.ndarray: + def compute_action_batch(self, obs: TObs) -> torch.Tensor: if self.is_discrete: # Different from forward which returns discrete probabilities, see comment there assert isinstance(self.action_space, spaces.Discrete) # for mypy - return np.random.randint(low=0, high=self.action_space.n, size=len(obs)) + return torch.Tensor(np.random.randint(low=0, high=self.action_space.n, size=len(obs))) else: return self.forward(obs)[0] - - -def getattr_with_matching_alt_value(obj: Any, attr_name: str, alt_value: T | None) -> T: - """Gets the given attribute from the given object or takes the alternative value if it is not present. - If both are present, they are required to match. - - :param obj: the object from which to obtain the attribute value - :param attr_name: the attribute name - :param alt_value: the alternative value for the case where the attribute is not present, which cannot be None - if the attribute is not present - :return: the value - """ - v = getattr(obj, attr_name) - if v is not None: - if alt_value is not None and v != alt_value: - raise ValueError( - f"Attribute '{attr_name}' of {obj} is defined ({v}) but does not match alt. value ({alt_value})", - ) - return v - else: - if alt_value is None: - raise ValueError( - f"Attribute '{attr_name}' of {obj} is not defined and no fallback given", - ) - return alt_value - - -def get_output_dim(module: nn.Module, alt_value: int | None) -> int: - """Retrieves value the `output_dim` attribute of the given module or uses the given alternative value if the attribute is not present. - If both are present, they must match. - - :param module: the module - :param alt_value: the alternative value - :return: the value - """ - return getattr_with_matching_alt_value(module, "output_dim", alt_value) diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index a0b85ede2..83cddd049 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -1,40 +1,45 @@ import warnings from abc import ABC, abstractmethod from collections.abc import Sequence -from typing import Any +from typing import Any, TypeVar import numpy as np import torch from sensai.util.pickle import setstate from torch import nn +from tianshou.data.types import TObs from tianshou.utils.net.common import ( MLP, - BaseActor, - Net, + AbstractContinuousActorProbabilistic, + Actor, + ModuleWithVectorOutput, TActionShape, TLinearLayer, - get_output_dim, ) +from tianshou.utils.torch_utils import torch_device SIGMA_MIN = -20 SIGMA_MAX = 2 +T = TypeVar("T") -class Actor(BaseActor): - """Simple actor network that directly outputs actions for continuous action space. - Used primarily in DDPG and its variants. For probabilistic policies, see :class:`~ActorProb`. + +class AbstractContinuousActorDeterministic(Actor, ABC): + """Marker interface for continuous deterministic actors (DDPG like).""" + + +class ContinuousActorDeterministic(AbstractContinuousActorDeterministic): + """Actor network that directly outputs actions for continuous action space. + Used primarily in DDPG and its variants. It will create an actor operated in continuous action space with structure of preprocess_net ---> action_shape. - :param preprocess_net: a self-defined preprocess_net, see usage. - Typically, an instance of :class:`~tianshou.utils.net.common.Net`. + :param preprocess_net: first part of input processing. :param action_shape: a sequence of int for the shape of action. :param hidden_sizes: a sequence of int for constructing the MLP after preprocess_net. :param max_action: the scale for the final action. - :param preprocess_net_output_dim: the output dimension of - `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`. For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. @@ -42,27 +47,24 @@ class Actor(BaseActor): def __init__( self, - preprocess_net: nn.Module | Net, + *, + preprocess_net: ModuleWithVectorOutput, action_shape: TActionShape, hidden_sizes: Sequence[int] = (), max_action: float = 1.0, - device: str | int | torch.device = "cpu", - preprocess_net_output_dim: int | None = None, ) -> None: - super().__init__() - self.device = device + output_dim = int(np.prod(action_shape)) + super().__init__(output_dim) self.preprocess = preprocess_net - self.output_dim = int(np.prod(action_shape)) - input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) + input_dim = preprocess_net.get_output_dim() self.last = MLP( - input_dim, - self.output_dim, - hidden_sizes, - device=self.device, + input_dim=input_dim, + output_dim=self.output_dim, + hidden_sizes=hidden_sizes, ) self.max_action = max_action - def get_preprocess_net(self) -> nn.Module: + def get_preprocess_net(self) -> ModuleWithVectorOutput: return self.preprocess def get_output_dim(self) -> int: @@ -70,10 +72,10 @@ def get_output_dim(self) -> int: def forward( self, - obs: np.ndarray | torch.Tensor, - state: Any = None, + obs: TObs, + state: T | None = None, info: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, Any]: + ) -> tuple[torch.Tensor, T | None]: """Mapping: s_B -> action_values_BA, hidden_state_BH | None. Returns a tensor representing the actions directly, i.e, of shape @@ -86,7 +88,7 @@ def forward( return action_BA, hidden_BH -class CriticBase(nn.Module, ABC): +class AbstractContinuousCritic(ModuleWithVectorOutput, ABC): @abstractmethod def forward( self, @@ -97,17 +99,15 @@ def forward( """Mapping: (s_B, a_B) -> Q(s, a)_B.""" -class Critic(CriticBase): +class ContinuousCritic(AbstractContinuousCritic): """Simple critic network. It will create an actor operated in continuous action space with structure of preprocess_net ---> 1(q value). - :param preprocess_net: a self-defined preprocess_net, see usage. + :param preprocess_net: the pre-processing network, which returns a vector of a known dimension. Typically, an instance of :class:`~tianshou.utils.net.common.Net`. :param hidden_sizes: a sequence of int for constructing the MLP after preprocess_net. - :param preprocess_net_output_dim: the output dimension of - `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`. :param linear_layer: use this module as linear layer. :param flatten_input: whether to flatten input data for the last layer. :param apply_preprocess_net_to_obs_only: whether to apply `preprocess_net` to the observations only (before @@ -120,32 +120,28 @@ class Critic(CriticBase): def __init__( self, - preprocess_net: nn.Module | Net, + *, + preprocess_net: ModuleWithVectorOutput, hidden_sizes: Sequence[int] = (), - device: str | int | torch.device = "cpu", - preprocess_net_output_dim: int | None = None, linear_layer: TLinearLayer = nn.Linear, flatten_input: bool = True, apply_preprocess_net_to_obs_only: bool = False, ) -> None: - super().__init__() - self.device = device + super().__init__(output_dim=1) self.preprocess = preprocess_net - self.output_dim = 1 self.apply_preprocess_net_to_obs_only = apply_preprocess_net_to_obs_only - input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) + input_dim = preprocess_net.get_output_dim() self.last = MLP( - input_dim, - 1, - hidden_sizes, - device=self.device, + input_dim=input_dim, + output_dim=1, + hidden_sizes=hidden_sizes, linear_layer=linear_layer, flatten_input=flatten_input, ) def __setstate__(self, state: dict) -> None: setstate( - Critic, + ContinuousCritic, self, state, new_default_properties={"apply_preprocess_net_to_obs_only": False}, @@ -158,9 +154,10 @@ def forward( info: dict[str, Any] | None = None, ) -> torch.Tensor: """Mapping: (s_B, a_B) -> Q(s, a)_B.""" + device = torch_device(self) obs = torch.as_tensor( obs, - device=self.device, + device=device, dtype=torch.float32, ) if self.apply_preprocess_net_to_obs_only: @@ -169,7 +166,7 @@ def forward( if act is not None: act = torch.as_tensor( act, - device=self.device, + device=device, dtype=torch.float32, ).flatten(1) obs = torch.cat([obs, act], dim=1) @@ -178,13 +175,12 @@ def forward( return self.last(obs) -class ActorProb(BaseActor): +class ContinuousActorProbabilistic(AbstractContinuousActorProbabilistic): """Simple actor network that outputs `mu` and `sigma` to be used as input for a `dist_fn` (typically, a Gaussian). Used primarily in SAC, PPO and variants thereof. For deterministic policies, see :class:`~Actor`. - :param preprocess_net: a self-defined preprocess_net, see usage. - Typically, an instance of :class:`~tianshou.utils.net.common.Net`. + :param preprocess_net: the pre-processing network, which returns a vector of a known dimension. :param action_shape: a sequence of int for the shape of action. :param hidden_sizes: a sequence of int for constructing the MLP after preprocess_net. @@ -192,60 +188,50 @@ class ActorProb(BaseActor): :param unbounded: whether to apply tanh activation on final logits. :param conditioned_sigma: True when sigma is calculated from the input, False when sigma is an independent parameter. - :param preprocess_net_output_dim: the output dimension of - `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`. For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. """ - # TODO: force kwargs, adjust downstream code def __init__( self, - preprocess_net: nn.Module | Net, + *, + preprocess_net: ModuleWithVectorOutput, action_shape: TActionShape, hidden_sizes: Sequence[int] = (), max_action: float = 1.0, - device: str | int | torch.device = "cpu", unbounded: bool = False, conditioned_sigma: bool = False, - preprocess_net_output_dim: int | None = None, ) -> None: - super().__init__() + output_dim = int(np.prod(action_shape)) + super().__init__(output_dim) if unbounded and not np.isclose(max_action, 1.0): warnings.warn("Note that max_action input will be discarded when unbounded is True.") max_action = 1.0 self.preprocess = preprocess_net - self.device = device - self.output_dim = int(np.prod(action_shape)) - input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) - self.mu = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device) + input_dim = preprocess_net.get_output_dim() + self.mu = MLP(input_dim=input_dim, output_dim=output_dim, hidden_sizes=hidden_sizes) self._c_sigma = conditioned_sigma if conditioned_sigma: self.sigma = MLP( - input_dim, - self.output_dim, - hidden_sizes, - device=self.device, + input_dim=input_dim, + output_dim=output_dim, + hidden_sizes=hidden_sizes, ) else: - self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1)) + self.sigma_param = nn.Parameter(torch.zeros(output_dim, 1)) self.max_action = max_action self._unbounded = unbounded - def get_preprocess_net(self) -> nn.Module: + def get_preprocess_net(self) -> ModuleWithVectorOutput: return self.preprocess - def get_output_dim(self) -> int: - return self.output_dim - def forward( self, - obs: np.ndarray | torch.Tensor, - state: Any = None, + obs: TObs, + state: T | None = None, info: dict[str, Any] | None = None, - ) -> tuple[tuple[torch.Tensor, torch.Tensor], Any]: - """Mapping: obs -> logits -> (mu, sigma).""" + ) -> tuple[tuple[torch.Tensor, torch.Tensor], T | None]: if info is None: info = {} logits, hidden = self.preprocess(obs, state) @@ -270,12 +256,12 @@ class RecurrentActorProb(nn.Module): def __init__( self, + *, layer_num: int, state_shape: Sequence[int], action_shape: Sequence[int], hidden_layer_size: int = 128, max_action: float = 1.0, - device: str | int | torch.device = "cpu", unbounded: bool = False, conditioned_sigma: bool = False, ) -> None: @@ -283,7 +269,6 @@ def __init__( if unbounded and not np.isclose(max_action, 1.0): warnings.warn("Note that max_action input will be discarded when unbounded is True.") max_action = 1.0 - self.device = device self.nn = nn.LSTM( input_size=int(np.prod(state_shape)), hidden_size=hidden_layer_size, @@ -309,9 +294,10 @@ def forward( """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`.""" if info is None: info = {} + device = torch_device(self) obs = torch.as_tensor( obs, - device=self.device, + device=device, dtype=torch.float32, ) # obs [bsz, len, dim] (training) or [bsz, dim] (evaluation) @@ -360,14 +346,12 @@ def __init__( self, layer_num: int, state_shape: Sequence[int], - action_shape: Sequence[int] = [0], - device: str | int | torch.device = "cpu", + action_shape: Sequence[int] = (0,), hidden_layer_size: int = 128, ) -> None: super().__init__() self.state_shape = state_shape self.action_shape = action_shape - self.device = device self.nn = nn.LSTM( input_size=int(np.prod(state_shape)), hidden_size=hidden_layer_size, @@ -385,9 +369,10 @@ def forward( """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`.""" if info is None: info = {} + device = torch_device(self) obs = torch.as_tensor( obs, - device=self.device, + device=device, dtype=torch.float32, ) # obs [bsz, len, dim] (training) or [bsz, dim] (evaluation) @@ -400,7 +385,7 @@ def forward( if act is not None: act = torch.as_tensor( act, - device=self.device, + device=device, dtype=torch.float32, ) obs = torch.cat([obs, act], dim=1) @@ -428,15 +413,14 @@ class Perturbation(nn.Module): def __init__( self, + *, preprocess_net: nn.Module, max_action: float, - device: str | int | torch.device = "cpu", phi: float = 0.05, ): # preprocess_net: input_dim=state_dim+action_dim, output_dim=action_dim super().__init__() self.preprocess_net = preprocess_net - self.device = device self.max_action = max_action self.phi = phi @@ -473,12 +457,12 @@ class VAE(nn.Module): def __init__( self, + *, encoder: nn.Module, decoder: nn.Module, hidden_dim: int, latent_dim: int, max_action: float, - device: str | torch.device = "cpu", ): super().__init__() self.encoder = encoder @@ -490,7 +474,6 @@ def __init__( self.max_action = max_action self.latent_dim = latent_dim - self.device = device def forward( self, @@ -521,8 +504,9 @@ def decode( if latent_z is None: # state.shape[0] may be batch_size # latent vector clipped to [-0.5, 0.5] + device = torch_device(self) latent_z = ( - torch.randn(state.shape[:-1] + (self.latent_dim,)).to(self.device).clamp(-0.5, 0.5) + torch.randn(state.shape[:-1] + (self.latent_dim,)).to(device).clamp(-0.5, 0.5) ) # decode z with state! diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index ab9069801..8f2e11f1e 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Any +from typing import Any, TypeVar import numpy as np import torch @@ -7,67 +7,74 @@ from torch import nn from tianshou.data import Batch, to_torch -from tianshou.utils.net.common import MLP, BaseActor, Net, TActionShape, get_output_dim +from tianshou.data.types import TObs +from tianshou.utils.net.common import ( + MLP, + AbstractDiscreteActor, + ModuleWithVectorOutput, + TActionShape, +) +from tianshou.utils.torch_utils import torch_device +T = TypeVar("T") -class Actor(BaseActor): - """Simple actor network for discrete action spaces. - :param preprocess_net: a self-defined preprocess_net. Typically, an instance of - :class:`~tianshou.utils.net.common.Net`. - :param action_shape: a sequence of int for the shape of action. - :param hidden_sizes: a sequence of int for constructing the MLP after - preprocess_net. Default to empty sequence (where the MLP now contains - only a single linear layer). - :param softmax_output: whether to apply a softmax layer over the last - layer's output. - :param preprocess_net_output_dim: the output dimension of - `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`. +def dist_fn_categorical_from_logits(logits: torch.Tensor) -> torch.distributions.Categorical: + """Default distribution function for categorical actors.""" + return torch.distributions.Categorical(logits=logits) - For advanced usage (how to customize the network), please refer to - :ref:`build_the_network`. + +class DiscreteActor(AbstractDiscreteActor): + """ + Generic discrete actor which uses a preprocessing network to generate a latent representation + which is subsequently passed to an MLP to compute the output. + + For common output semantics, see :class:`DiscreteActorInterface`. """ def __init__( self, - preprocess_net: nn.Module | Net, + *, + preprocess_net: ModuleWithVectorOutput, action_shape: TActionShape, hidden_sizes: Sequence[int] = (), softmax_output: bool = True, - preprocess_net_output_dim: int | None = None, - device: str | int | torch.device = "cpu", ) -> None: - super().__init__() - # TODO: reduce duplication with continuous.py. Probably introducing - # base classes is a good idea. - self.device = device + """ + :param preprocess_net: the preprocessing network, which outputs a vector of a known dimension; + typically an instance of :class:`~tianshou.utils.net.common.Net`. + :param action_shape: a sequence of int for the shape of action. + :param hidden_sizes: a sequence of int for constructing the MLP after + preprocess_net. Default to empty sequence (where the MLP now contains + only a single linear layer). + :param softmax_output: whether to apply a softmax layer over the last + layer's output. + """ + output_dim = int(np.prod(action_shape)) + super().__init__(output_dim) self.preprocess = preprocess_net - self.output_dim = int(np.prod(action_shape)) - input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) + input_dim = preprocess_net.get_output_dim() self.last = MLP( - input_dim, - self.output_dim, - hidden_sizes, - device=self.device, + input_dim=input_dim, + output_dim=self.output_dim, + hidden_sizes=hidden_sizes, ) self.softmax_output = softmax_output - def get_preprocess_net(self) -> nn.Module: + def get_preprocess_net(self) -> ModuleWithVectorOutput: return self.preprocess - def get_output_dim(self) -> int: - return self.output_dim - def forward( self, - obs: np.ndarray | torch.Tensor, - state: Any = None, + obs: TObs, + state: T | None = None, info: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - r"""Mapping: s_B -> action_values_BA, hidden_state_BH | None. + ) -> tuple[torch.Tensor, T | None]: + r"""Mapping: (s_B, ...) -> action_values_BA, hidden_state_BH | None. + Returns a tensor representing the values of each action, i.e, of shape - `(n_actions, )`, and + `(n_actions, )` (see class docstring for more info on the meaning of that), and a hidden state (which may be None). If `self.softmax_output` is True, they are the probabilities for taking each action. Otherwise, they will be action values. The hidden state is only @@ -82,17 +89,15 @@ def forward( return output_BA, hidden_BH -class Critic(nn.Module): +class DiscreteCritic(ModuleWithVectorOutput): """Simple critic network for discrete action spaces. - :param preprocess_net: a self-defined preprocess_net. Typically, an instance of - :class:`~tianshou.utils.net.common.Net`. + :param preprocess_net: the preprocessing network, which outputs a vector of a known dimension; + typically an instance of :class:`~tianshou.utils.net.common.Net`. :param hidden_sizes: a sequence of int for constructing the MLP after preprocess_net. Default to empty sequence (where the MLP now contains only a single linear layer). :param last_size: the output dimension of Critic network. Default to 1. - :param preprocess_net_output_dim: the output dimension of - `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`. For advanced usage (how to customize the network), please refer to :ref:`build_the_network`.. @@ -100,24 +105,22 @@ class Critic(nn.Module): def __init__( self, - preprocess_net: nn.Module | Net, + *, + preprocess_net: ModuleWithVectorOutput, hidden_sizes: Sequence[int] = (), last_size: int = 1, - preprocess_net_output_dim: int | None = None, - device: str | int | torch.device = "cpu", ) -> None: - super().__init__() - self.device = device + super().__init__(output_dim=last_size) self.preprocess = preprocess_net - self.output_dim = last_size - input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) - self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device) + input_dim = preprocess_net.get_output_dim() + self.last = MLP(input_dim=input_dim, output_dim=last_size, hidden_sizes=hidden_sizes) - # TODO: make a proper interface! - def forward(self, obs: np.ndarray | torch.Tensor, **kwargs: Any) -> torch.Tensor: + def forward( + self, obs: TObs, state: T | None = None, info: dict[str, Any] | None = None + ) -> torch.Tensor: """Mapping: s_B -> V(s)_B.""" # TODO: don't use this mechanism for passing state - logits, _ = self.preprocess(obs, state=kwargs.get("state", None)) + logits, _ = self.preprocess(obs, state=state) return self.last(logits) @@ -158,7 +161,7 @@ def forward(self, taus: torch.Tensor) -> torch.Tensor: return self.net(cosines).view(batch_size, N, self.embedding_dim) -class ImplicitQuantileNetwork(Critic): +class ImplicitQuantileNetwork(DiscreteCritic): """Implicit Quantile Network. :param preprocess_net: a self-defined preprocess_net which output a @@ -169,8 +172,6 @@ class ImplicitQuantileNetwork(Critic): only a single linear layer). :param num_cosines: the number of cosines to use for cosine embedding. Default to 64. - :param preprocess_net_output_dim: the output dimension of - preprocess_net. .. note:: @@ -182,19 +183,20 @@ class ImplicitQuantileNetwork(Critic): def __init__( self, - preprocess_net: nn.Module, + *, + preprocess_net: ModuleWithVectorOutput, action_shape: TActionShape, hidden_sizes: Sequence[int] = (), num_cosines: int = 64, - preprocess_net_output_dim: int | None = None, - device: str | int | torch.device = "cpu", ) -> None: last_size = int(np.prod(action_shape)) - super().__init__(preprocess_net, hidden_sizes, last_size, preprocess_net_output_dim, device) - self.input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) - self.embed_model = CosineEmbeddingNetwork(num_cosines, self.input_dim).to( - device, + super().__init__( + preprocess_net=preprocess_net, + hidden_sizes=hidden_sizes, + last_size=last_size, ) + self.input_dim = preprocess_net.get_output_dim() + self.embed_model = CosineEmbeddingNetwork(num_cosines, self.input_dim) def forward( # type: ignore self, @@ -262,8 +264,6 @@ class FullQuantileFunction(ImplicitQuantileNetwork): only a single linear layer). :param num_cosines: the number of cosines to use for cosine embedding. Default to 64. - :param preprocess_net_output_dim: the output dimension of - preprocess_net. .. note:: @@ -273,20 +273,17 @@ class FullQuantileFunction(ImplicitQuantileNetwork): def __init__( self, - preprocess_net: nn.Module, + *, + preprocess_net: ModuleWithVectorOutput, action_shape: TActionShape, hidden_sizes: Sequence[int] = (), num_cosines: int = 64, - preprocess_net_output_dim: int | None = None, - device: str | int | torch.device = "cpu", ) -> None: super().__init__( - preprocess_net, - action_shape, - hidden_sizes, - num_cosines, - preprocess_net_output_dim, - device, + preprocess_net=preprocess_net, + action_shape=action_shape, + hidden_sizes=hidden_sizes, + num_cosines=num_cosines, ) def _compute_quantiles(self, obs: torch.Tensor, taus: torch.Tensor) -> torch.Tensor: @@ -341,8 +338,8 @@ def __init__(self, in_features: int, out_features: int, noisy_std: float = 0.5) self.sigma_bias = nn.Parameter(torch.FloatTensor(out_features)) # Factorized noise parameters. - self.register_buffer("eps_p", torch.FloatTensor(in_features)) - self.register_buffer("eps_q", torch.FloatTensor(out_features)) + self.eps_p = nn.Parameter(torch.FloatTensor(in_features), requires_grad=False) + self.eps_q = nn.Parameter(torch.FloatTensor(out_features), requires_grad=False) self.in_features = in_features self.out_features = out_features @@ -386,34 +383,30 @@ class IntrinsicCuriosityModule(nn.Module): :param feature_dim: input dimension of the feature net. :param action_dim: dimension of the action space. :param hidden_sizes: hidden layer sizes for forward and inverse models. - :param device: device for the module. """ def __init__( self, + *, feature_net: nn.Module, feature_dim: int, action_dim: int, hidden_sizes: Sequence[int] = (), - device: str | torch.device = "cpu", ) -> None: super().__init__() self.feature_net = feature_net self.forward_model = MLP( - feature_dim + action_dim, + input_dim=feature_dim + action_dim, output_dim=feature_dim, hidden_sizes=hidden_sizes, - device=device, ) self.inverse_model = MLP( - feature_dim * 2, + input_dim=feature_dim * 2, output_dim=action_dim, hidden_sizes=hidden_sizes, - device=device, ) self.feature_dim = feature_dim self.action_dim = action_dim - self.device = device def forward( self, @@ -423,10 +416,11 @@ def forward( **kwargs: Any, ) -> tuple[torch.Tensor, torch.Tensor]: r"""Mapping: s1, act, s2 -> mse_loss, act_hat.""" - s1 = to_torch(s1, dtype=torch.float32, device=self.device) - s2 = to_torch(s2, dtype=torch.float32, device=self.device) + device = torch_device(self) + s1 = to_torch(s1, dtype=torch.float32, device=device) + s2 = to_torch(s2, dtype=torch.float32, device=device) phi1, phi2 = self.feature_net(s1), self.feature_net(s2) - act = to_torch(act, dtype=torch.long, device=self.device) + act = to_torch(act, dtype=torch.long, device=device) phi2_hat = self.forward_model( torch.cat([phi1, F.one_hot(act, num_classes=self.action_dim)], dim=1), ) diff --git a/tianshou/utils/optim.py b/tianshou/utils/optim.py deleted file mode 100644 index c69ef71db..000000000 --- a/tianshou/utils/optim.py +++ /dev/null @@ -1,69 +0,0 @@ -from collections.abc import Iterator -from typing import TypeVar - -import torch -from torch import nn - - -def optim_step( - loss: torch.Tensor, - optim: torch.optim.Optimizer, - module: nn.Module | None = None, - max_grad_norm: float | None = None, -) -> None: - """Perform a single optimization step: zero_grad -> backward (-> clip_grad_norm) -> step. - - :param loss: - :param optim: - :param module: the module to optimize, required if max_grad_norm is passed - :param max_grad_norm: if passed, will clip gradients using this - """ - optim.zero_grad() - loss.backward() - if max_grad_norm: - if not module: - raise ValueError( - "module must be passed if max_grad_norm is passed. " - "Note: often the module will be the policy, i.e.`self`", - ) - nn.utils.clip_grad_norm_(module.parameters(), max_norm=max_grad_norm) - optim.step() - - -_STANDARD_TORCH_OPTIMIZERS = [ - torch.optim.Adam, - torch.optim.SGD, - torch.optim.RMSprop, - torch.optim.Adadelta, - torch.optim.AdamW, - torch.optim.Adamax, - torch.optim.NAdam, - torch.optim.SparseAdam, - torch.optim.LBFGS, -] - -TOptim = TypeVar("TOptim", bound=torch.optim.Optimizer) - - -def clone_optimizer( - optim: TOptim, - new_params: nn.Parameter | Iterator[nn.Parameter], -) -> TOptim: - """Clone an optimizer to get a new optim instance with new parameters. - - **WARNING**: This is a temporary measure, and should not be used in downstream code! - Once tianshou interfaces have moved to optimizer factories instead of optimizers, - this will be removed. - - :param optim: the optimizer to clone - :param new_params: the new parameters to use - :return: a new optimizer with the same configuration as the old one - """ - optim_class = type(optim) - # custom optimizers may not behave as expected - if optim_class not in _STANDARD_TORCH_OPTIMIZERS: - raise ValueError( - f"Cannot clone optimizer {optim} of type {optim_class}" - f"Currently, only standard torch optimizers are supported.", - ) - return optim_class(new_params, **optim.defaults) diff --git a/tianshou/utils/torch_utils.py b/tianshou/utils/torch_utils.py index 1ffb9fcd8..6273a41b4 100644 --- a/tianshou/utils/torch_utils.py +++ b/tianshou/utils/torch_utils.py @@ -8,7 +8,7 @@ from torch import nn if TYPE_CHECKING: - from tianshou.policy import BasePolicy + from tianshou.algorithm import algorithm_base @contextmanager @@ -23,7 +23,9 @@ def torch_train_mode(module: nn.Module, enabled: bool = True) -> Iterator[None]: @contextmanager -def policy_within_training_step(policy: "BasePolicy", enabled: bool = True) -> Iterator[None]: +def policy_within_training_step( + policy: "algorithm_base.Policy", enabled: bool = True +) -> Iterator[None]: """Temporarily switch to `policy.is_within_training_step=enabled`. Enabling this ensures that the policy is able to adapt its behavior, @@ -61,7 +63,7 @@ def create_uniform_action_dist( ) -> dist.Uniform | dist.Categorical: """Create a Distribution such that sampling from it is equivalent to sampling a batch with `action_space.sample()`. - :param action_space: The action space of the environment. + :param action_space: the environment's action_space. :param batch_size: The number of environments or batch size for sampling. :return: A PyTorch distribution for sampling actions. """ @@ -75,3 +77,14 @@ def create_uniform_action_dist( else: raise ValueError(f"Unsupported action space type: {type(action_space)}") + + +def torch_device(module: torch.nn.Module) -> torch.device: + """Gets the device of a torch module by retrieving the device of the parameters. + + If parameters are empty, it returns the CPU device as a fallback. + """ + try: + return next(module.parameters()).device + except StopIteration: + return torch.device("cpu")