diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 7396e5cd3..af114c9cc 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -15,53 +15,94 @@ on: required: false default: false + +# This job runs the test suite in two environments: +# - py_pinned: uses Python 3.11 with the existing poetry.lock file (our stable, pinned dev environment) +# - py_latest: latest Python version we want to support, without the lock file to furthermore install the newest dependency versions +# +# This ensures compatibility with both our controlled dev setup and the latest upstream packages, +# helping catch issues introduced by dependency updates or newer Python versions. jobs: cpu: runs-on: ubuntu-latest if: "!contains(github.event.head_commit.message, 'ci skip')" strategy: matrix: - python-version: ["3.11"] + include: + - env_name: py_pinned + python-version: "3.11" + use_lock: true + - env_name: py_latest + python-version: "3.13" + use_lock: false + steps: - name: Setup tmate session uses: mxschmitt/action-tmate@v3 if: ${{ github.event_name == 'workflow_dispatch' && inputs.debug_enabled }} + - name: Cancel previous run uses: styfle/cancel-workflow-action@0.11.0 with: access_token: ${{ github.token }} + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - # use poetry and cache installed packages, see https://github.com/marketplace/actions/python-poetry-action + - name: Install poetry uses: abatilo/actions-poetry@v2 + - name: Setup a local virtual environment (if no poetry.toml file) run: | poetry config virtualenvs.create true --local poetry config virtualenvs.in-project true --local - - uses: actions/cache@v3 - name: Define a cache for the virtual environment based on the dependencies lock file + + - name: Remove poetry.lock for latest dependency test + if: ${{ !matrix.use_lock }} + run: rm -f poetry.lock + + - name: Define a cache for the virtual environment based on the dependencies lock file + if: matrix.use_lock + uses: actions/cache@v3 with: path: ./.venv - key: venv-${{ hashFiles('poetry.lock') }} + key: venv-${{ matrix.env_name }}-${{ hashFiles('poetry.lock') }} + restore-keys: | + venv-${{ matrix.env_name }}- + - name: Install the project dependencies run: | - poetry install --with dev --extras "envpool eval" + if [ "${{ matrix.env_name }}" = "py_latest" ]; then + poetry install --with dev --extras "eval" + else + poetry install --with dev --extras "envpool eval" + fi + + - name: List installed packages + run: | + poetry run pip list + - name: wandb login run: | poetry run wandb login e2366d661b89f2bee877c40bee15502d67b7abef + - name: Test with pytest - # ignore test/throughput which only profiles the code run: | - poetry run poe test + if [ "${{ matrix.env_name }}" = "py_pinned" ]; then + poetry run poe test + else + poetry run poe test-nocov + fi - name: Upload coverage to Codecov + if: matrix.env_name == 'py_pinned' uses: codecov/codecov-action@v1 with: token: ${{ secrets.CODECOV }} file: ./coverage.xml - flags: unittests - name: codecov-umbrella + flags: ${{ matrix.env_name }} + name: codecov-${{ matrix.env_name }} fail_ci_if_error: false diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e4f8b703..daba7ed9a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,14 @@ Developers: * Dr. Dominik Jain (@opcode81) * Michael Panchenko (@MischaPanch) +### Runtime Environment Compatibility + +Tianshou v2 is now compatible with + * Python 3.12 and Python 3.13 #1274 + * newer versions of gymnasium (v1+) and numpy (v2+) + +Our main test environment remains Python 3.11-based for the time being (see `poetry.lock`). + ### Trainer Abstraction * The trainer logic and configuration is now properly separated between the three cases of on-policy, off-policy @@ -228,6 +236,9 @@ Developers: contain parameter `repeat_per_collect`). * All parameter names have been aligned with the new names used by `TrainerParams` (see above). +* Add option to customize the factory for the collector (`ExperimentBuilder.with_collector_factory`), + adding the abstraction `CollectorFactory`. #1256 + ### Peripheral Changes * The `Actor` classes have been renamed for clarity (#1091): diff --git a/docs/01_tutorials/00_dqn.rst b/docs/01_tutorials/00_dqn.rst deleted file mode 100644 index 3c28e7163..000000000 --- a/docs/01_tutorials/00_dqn.rst +++ /dev/null @@ -1,337 +0,0 @@ -Deep Q Network -============== - -Deep reinforcement learning has achieved significant successes in various applications. -**Deep Q Network** (DQN) :cite:`DQN` is the pioneer one. -In this tutorial, we will show how to train a DQN agent on CartPole with Tianshou step by step. -The full script is at `test/discrete/test_dqn.py `_. - -Contrary to existing Deep RL libraries such as `RLlib `_, which could only accept a config specification of hyperparameters, network, and others, Tianshou provides an easy way of construction through the code-level. - - -Overview --------- - -In reinforcement learning, the agent interacts with environments to improve itself. - -.. image:: /_static/images/rl-loop.jpg - :align: center - :height: 200 - -There are three types of data flow in RL training pipeline: - -1. Agent to environment: ``action`` will be generated by agent and sent to environment; -2. Environment to agent: ``env.step`` takes action, and returns a tuple of ``(observation, reward, done, info)``; -3. Agent-environment interaction to agent training: the data generated by interaction will be stored and sent to the learner of agent. - -In the following sections, we will set up (vectorized) environments, policy (with neural network), collector (with buffer), and trainer to successfully run the RL training and evaluation pipeline. -Here is the overall system: - -.. image:: /_static/images/pipeline.png - :align: center - :height: 300 - - -Make an Environment -------------------- - -First of all, you have to make an environment for your agent to interact with. You can use ``gym.make(environment_name)`` to make an environment for your agent. For environment interfaces, we follow the convention of `Gymnasium `_. In your Python code, simply import Tianshou and make the environment: -:: - - import gymnasium as gym - import tianshou as ts - - env = gym.make('CartPole-v1') - -CartPole-v1 includes a cart carrying a pole moving on a track. This is a simple environment with a discrete action space, for which DQN applies. You have to identify whether the action space is continuous or discrete and apply eligible algorithms. DDPG :cite:`DDPG`, for example, could only be applied to continuous action spaces, while almost all other policy gradient methods could be applied to both. - -Here is the detail of useful fields of CartPole-v1: - -- ``state``: the position of the cart, the velocity of the cart, the angle of the pole and the velocity of the tip of the pole; -- ``action``: can only be one of ``[0, 1, 2]``, for moving the cart left, no move, and right; -- ``reward``: each timestep you last, you will receive a +1 ``reward``; -- ``done``: if CartPole is out-of-range or timeout (the pole is more than 15 degrees from vertical, or the cart moves more than 2.4 units from the center, or you last over 200 timesteps); -- ``info``: extra info from environment simulation. - -The goal is to train a good policy that can get the highest reward in this environment. - - -Setup Vectorized Environment ----------------------------- - -If you want to use the original ``gym.Env``: -:: - - train_envs = gym.make('CartPole-v1') - test_envs = gym.make('CartPole-v1') - -Tianshou supports vectorized environment for all algorithms. It provides four types of vectorized environment wrapper: - -- :class:`~tianshou.env.DummyVectorEnv`: the sequential version, using a single-thread for-loop; -- :class:`~tianshou.env.SubprocVectorEnv`: use python multiprocessing and pipe for concurrent execution; -- :class:`~tianshou.env.ShmemVectorEnv`: use share memory instead of pipe based on SubprocVectorEnv; -- :class:`~tianshou.env.RayVectorEnv`: use Ray for concurrent activities and is currently the only choice for parallel simulation in a cluster with multiple machines. It can be used as follows: (more explanation can be found at :ref:`parallel_sampling`) - -:: - - train_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v1') for _ in range(10)]) - test_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v1') for _ in range(100)]) - -Here, we set up 10 environments in ``train_envs`` and 100 environments in ``test_envs``. - -You can also try the super-fast vectorized environment `EnvPool `_ by - -:: - - import envpool - train_envs = envpool.make_gymnasium("CartPole-v1", num_envs=10) - test_envs = envpool.make_gymnasium("CartPole-v1", num_envs=100) - -For the demonstration, here we use the second code-block. - -.. warning:: - - If you use your own environment, please make sure the ``seed`` method is set up properly, e.g., - - :: - - def seed(self, seed): - np.random.seed(seed) - - Otherwise, the outputs of these envs may be the same with each other. - - -.. _build_the_network: - -Build the Network ------------------ - -Tianshou supports any user-defined PyTorch networks and optimizers. Yet, of course, the inputs and outputs must comply with Tianshou's API. Here is an example: -:: - - import torch, numpy as np - from torch import nn - - class MLPActor(nn.Module): - def __init__(self, state_shape, action_shape): - super().__init__() - self.model = nn.Sequential( - nn.Linear(np.prod(state_shape), 128), nn.ReLU(inplace=True), - nn.Linear(128, 128), nn.ReLU(inplace=True), - nn.Linear(128, 128), nn.ReLU(inplace=True), - nn.Linear(128, np.prod(action_shape)), - ) - - def forward(self, obs, state=None, info={}): - if not isinstance(obs, torch.Tensor): - obs = torch.tensor(obs, dtype=torch.float) - batch = obs.shape[0] - logits = self.model(obs.view(batch, -1)) - return logits, state - - from tianshou.utils.net.common import Net - from tianshou.utils.space_info import SpaceInfo - from tianshou.algorithm.optim import AdamOptimizerFactory - - space_info = SpaceInfo.from_env(env) - state_shape = space_info.observation_info.obs_shape - action_shape = space_info.action_info.action_shape - net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128, 128]) - optim = AdamOptimizerFactory(lr=1e-3) - -You can also use pre-defined MLP networks in :mod:`~tianshou.utils.net.common`, :mod:`~tianshou.utils.net.discrete`, and :mod:`~tianshou.utils.net.continuous`. The rules of self-defined networks are: - -1. Input: observation ``obs`` (may be a ``numpy.ndarray``, ``torch.Tensor``, dict, or self-defined class), hidden state ``state`` (for RNN usage), and other information ``info`` provided by the environment. -2. Output: some ``logits``, the next hidden state ``state``. The logits could be a tuple instead of a ``torch.Tensor``, or some other useful variables or results during the policy forwarding procedure. It depends on how the policy class process the network output. For example, in PPO :cite:`PPO`, the return of the network might be ``(mu, sigma), state`` for Gaussian policy. - -.. note:: - - The logits here indicates the raw output of the network. In supervised learning, the raw output of prediction/classification model is called logits, and here we extend this definition to any raw output of the neural network. - - -Setup Policy ------------- - -We use the defined ``net`` and ``optim`` above, with extra policy hyper-parameters, to define a policy. Here we define a DQN policy with a target network: -:: - - from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy - from tianshou.algorithm import DQN - - policy = DiscreteQLearningPolicy( - model=net, - action_space=env.action_space, - observation_space=env.observation_space, - eps_training=0.1, - eps_inference=0.05, - ) - algorithm = DQN( - policy=policy, - optim=optim, - gamma=0.9, - n_step_return_horizon=3, - target_update_freq=320, - ) - - -Setup Collector ---------------- - -The collector is a key concept in Tianshou. It allows the policy to interact with different types of environments conveniently. -In each step, the collector will let the policy perform (at least) a specified number of steps or episodes and store the data in a replay buffer. - -The following code shows how to set up a collector in practice. It is worth noticing that VectorReplayBuffer is to be used in vectorized environment scenarios, and the number of buffers, in the following case 10, is preferred to be set as the number of environments. - -:: - - from tianshou.data import Collector, CollectStats, VectorReplayBuffer - - buf = VectorReplayBuffer(20000, buffer_num=len(train_envs)) - train_collector = Collector[CollectStats](algorithm, train_envs, buf, exploration_noise=True) - test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) - -The main function of collector is the collect function, which can be summarized in the following lines: - -:: - - result = self.policy(self.data, last_state) # the agent predicts the batch action from batch observation - act = to_numpy(result.act) - self.data.update(act=act) # update the data with new action/policy - result = self.env.step(act, ready_env_ids) # apply action to environment - obs_next, rew, done, info = result - self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) # update the data with new state/reward/done/info - - -Train Policy with a Trainer ---------------------------- - -Tianshou provides :class:`~tianshou.trainer.OnPolicyTrainer`, :class:`~tianshou.trainer.OffpolicyTrainer`, -and :class:`~tianshou.trainer.OfflineTrainer`. The trainer will automatically stop training when the policy -reaches the stop condition ``stop_fn`` on test collector. Since DQN is an off-policy algorithm, we use the -:class:`~tianshou.trainer.OffpolicyTrainer` as follows: -:: - - from tianshou.trainer import OffPolicyTrainerParams - - def train_fn(epoch, env_step): - policy.set_eps_training(0.1) - - def stop_fn(mean_rewards): - return mean_rewards >= env.spec.reward_threshold - - result = algorithm.run_training( - OffPolicyTrainerParams( - train_collector=train_collector, - test_collector=test_collector, - max_epochs=10, - epoch_num_steps=10000, - collection_step_num_env_steps=10, - test_step_num_episodes=100, - batch_size=64, - update_step_num_gradient_steps_per_sample=0.1, - train_fn=train_fn, - stop_fn=stop_fn, - ) - ) - print(f'Finished training! Use {result.duration}') - -The meaning of each parameter is as follows (full description can be found at :class:`~tianshou.trainer.OffpolicyTrainer`): - -* ``max_epoch``: The maximum of epochs for training. The training process might be finished before reaching the ``max_epoch``; -* ``epoch_num_steps``: The number of environment step (a.k.a. transition) collected per epoch; -* ``collection_step_num_env_steps``: The number of transition the collector would collect before the network update. For example, the code above means "collect 10 transitions and do one policy network update"; -* ``episode_per_test``: The number of episodes for one policy evaluation. -* ``batch_size``: The batch size of sample data, which is going to feed in the policy network. -* ``train_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training". -* ``test_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of testing in this epoch. For example, the code above means "reset the epsilon to 0.05 in DQN before testing". -* ``stop_fn``: A function receives the average undiscounted returns of the testing result, return a boolean which indicates whether reaching the goal. -* ``logger``: See below. - -The trainer supports `TensorBoard `_ for logging. It can be used as: -:: - - from torch.utils.tensorboard import SummaryWriter - from tianshou.utils import TensorboardLogger - writer = SummaryWriter('log/dqn') - logger = TensorboardLogger(writer) - -Pass the logger into the trainer, and the training result will be recorded into the TensorBoard. - -The returned result is a dictionary as follows: -:: - - { - TrainingResult object with attributes like: - best_reward: 199.03 - duration: 4.01s - And other training statistics - -It shows that within approximately 4 seconds, we finished training a DQN agent on CartPole. The mean returns over 100 consecutive episodes is 199.03. - - -Save/Load Policy ----------------- - -Since the policy inherits the class ``torch.nn.Module``, saving and loading the policy are exactly the same as a torch module: -:: - - torch.save(policy.state_dict(), 'dqn.pth') - policy.load_state_dict(torch.load('dqn.pth')) - - -Watch the Agent's Performance ------------------------------ - -:class:`~tianshou.data.Collector` supports rendering. Here is the example of watching the agent's performance in 35 FPS: -:: - - policy.eval() - policy.set_eps_inference(0.05) - collector = ts.data.Collector(algorithm, env, exploration_noise=True) - collector.collect(n_episode=1, render=1 / 35) - -If you'd like to manually see the action generated by a well-trained agent: -:: - - # assume obs is a single environment observation - action = policy(Batch(obs=np.array([obs]))).act[0] - - -.. _customized_trainer: - -Train a Policy with Customized Codes ------------------------------------- - -"I don't want to use your provided trainer. I want to customize it!" - -Tianshou supports user-defined training code. Here is the code snippet: -:: - - # pre-collect at least 5000 transitions with random action before training - train_collector.collect(n_step=5000, random=True) - - policy.set_eps_training(0.1) - for i in range(int(1e6)): # total step - collect_result = train_collector.collect(n_step=10) - - # once if the collected episodes' mean returns reach the threshold, - # or every 1000 steps, we test it on test_collector - if collect_result['rews'].mean() >= env.spec.reward_threshold or i % 1000 == 0: - policy.set_eps_inference(0.05) - result = test_collector.collect(n_episode=100) - if result['rews'].mean() >= env.spec.reward_threshold: - print(f'Finished training! Test mean returns: {result["rews"].mean()}') - break - else: - # back to training eps - policy.set_eps_training(0.1) - - # train policy with a sampled batch data from buffer - losses = algorithm.update(64, train_collector.buffer) - -For further usage, you can refer to the :doc:`/01_tutorials/07_cheatsheet`. - -.. rubric:: References - -.. bibliography:: /refs.bib - :style: unsrtalpha diff --git a/docs/01_tutorials/00_training_process.md b/docs/01_tutorials/00_training_process.md new file mode 100644 index 000000000..5fe77b17c --- /dev/null +++ b/docs/01_tutorials/00_training_process.md @@ -0,0 +1,100 @@ +# Understanding the Reinforcement Learning Loop + +The following diagram illustrates the key mechanisms underlying the learning process in model-free reinforcement learning algorithms. +It shows how the agent interacts with the environment, collects experiences, and periodically updates its policy based on those experiences. + + + + + +Accordingly, the key entities involved in the learning process are: + * The **environment**: This is the system the agent interacts with. + It provides the agent with observable states and rewards based on the actions taken by the agent. + * The agent's **policy**: This is the strategy used by the agent to decide which action to take in a given state. + The policy can be deterministic or stochastic and is typically represented by a neural network in deep reinforcement learning. + * The **replay buffer**: This is a data structure used to store the agent's experiences, which consist of state transitions, + actions taken, and rewards received. + The agent learns from past experience by sampling mini-batches from the buffer during the policy update phase. + * The **learning algorithm**: This defines how the agent updates its policy based on the experiences stored in the replay buffer. + Different algorithms have different update mechanisms, which can significantly affect the learning performance. + In some cases, the algorithm may also involve additional components (specifically neural networks), such as target networks or value + functions. + +These entities have direct correspondences in Tianshou's codebase: + * The environment is represented by an instance of a class that inherits from `gymnasium.Env`, which is a standard interface for + reinforcement learning environments. + In practice, environments are typically vectorized to enable parallel interactions, increasing efficiency. + * The policy is encapsulated in the `Policy` class, which provides methods for action selection. + * The replay buffer is implemented in the `ReplayBuffer` class. + A `Collector` instance is used to manage the addition of new experiences to the replay buffer as the agent interacts with the + environment. + During the learning phase, the replay buffer may be sampled, providing an instance of `Batch` for the policy update. + * The abstraction for learning algorithms is given by the `Algorithm` class, which defines how to update the policy using data from the + replay buffer. + +## The Training Process + +The learning process itself is reified in Tianshou's `Trainer` class, which orchestrates the interaction between the agent and the +environment, manages the replay buffer, and coordinates the policy updates according to the specified learning algorithm. + +In general, the process can be described as executing a number of epochs as follows: + +* **Epoch**: + * Repeat until a sufficient number of steps is reached (for online learning, typically environment step count) + * **Training Step**: + * For online learning algorithms … + * **Collection Step**: collect state transitions in the environment by running the agent + * (Optionally) conduct a test step if collected data indicates promising behaviour + * **Update Step**: Apply gradient updates using the algorithm’s update logic. + The update is based on … + * data from the preceding collection step only (on-policy learning) + * data from the collection step and previous data (off-policy learning) + * data from a user-provided replay buffer (offline learning) + * **Test Step** + * Collect test episodes from dedicated test environments and evaluate agent performance + * (Optionally) stop training early if performance is sufficiently high + +```{admonition} Glossary +:class: note +The above introduces some of the key terms used throughout Tianshou. +``` + +Note that the above description encompasses several modes of model-free reinforcement learning, including: + * online learning (where the agent continuously interacts with the environment in order to collect new experiences) + * on-policy learning (where the policy is updated based on data collected using the current policy only) + * off-policy learning (where the policy is updated based on data collected using the current and previous policies) + * offline learning (where the replay buffer is pre-filled and not updated during training) + +In Tianshou, the `Trainer` and `Algorithm` classes are specialised to handle these different modes accordingly. diff --git a/docs/01_tutorials/01_apis.md b/docs/01_tutorials/01_apis.md new file mode 100644 index 000000000..c9f9d670a --- /dev/null +++ b/docs/01_tutorials/01_apis.md @@ -0,0 +1,376 @@ +# Tianshou's Dual API Architecture + +Tianshou provides two distinct APIs to serve different use cases and user preferences: + +1. **High-Level API**: A declarative, configuration-based interface designed for ease of use +2. **Procedural API**: A flexible, imperative interface providing maximum control + +Both APIs access the same underlying algorithm implementations, allowing you to choose the level +of abstraction that best fits your needs without sacrificing functionality. + +## Overview + +### High-Level API + +The high-level API is built around the **builder pattern** and **declarative semantics**. +Instead of writing procedural code that sequentially constructs and connects components, +you declare _what_ you want through configuration objects and let Tianshou handle _how_ to +build and execute the experiment. + +**Key characteristics:** +- Centered around `ExperimentBuilder` classes (e.g., `DQNExperimentBuilder`, `PPOExperimentBuilder`, etc.) +- Uses configuration dataclasses and factories for all relevant parameters +- Automatically handles component creation and "wiring" +- Provides sensible defaults that adapt to the nature of your environment +- Includes built-in persistence, logging, and experiment management +- Excellent IDE support with auto-completion + +### Procedural API + +The procedural API provides explicit control over every component in the RL pipeline. +You manually create environments, networks, policies, algorithms, collectors, and +trainers, then wire them together. + +**Key characteristics:** +- Direct instantiation of all components +- Explicit control over the training loop +- Lower-level access to internal mechanisms +- Minimal abstraction (closer to the implementation) +- Ideal for algorithm development and research + +## When to Use Which API + +### Use the High-Level API when: + +- **You're applying existing algorithms** to new problems +- **You want to get started quickly** with minimal boilerplate +- **You need experiment management** with persistence, logging, and reproducibility +- **You prefer declarative code** that focuses on configuration +- **You're building applications** rather than developing new algorithms +- **You want strong IDE support** with auto-completion and type hints + +### Use the Procedural API when: + +- **You're developing new algorithms** or modifying existing ones +- **You need fine-grained control** over the training process +- **You want to understand** the internal workings of Tianshou +- **You're implementing custom components** not supported by the high-level API +- **You prefer imperative programming** where each step is explicit +- **You need maximum flexibility** for experimental research + +## Comparison by Example + +Let's compare both APIs by implementing the same DQN learning task on the CartPole environment. + +### High-Level API Example + +```python +from tianshou.highlevel.config import OffPolicyTrainingConfig +from tianshou.highlevel.env import EnvFactoryRegistered, VectorEnvType +from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig +from tianshou.highlevel.params.algorithm_params import DQNParams +from tianshou.highlevel.trainer import EpochStopCallbackRewardThreshold + +# Build the experiment through configuration +experiment = ( + DQNExperimentBuilder( + # Environment configuration + EnvFactoryRegistered( + task="CartPole-v1", + venv_type=VectorEnvType.DUMMY, + train_seed=0, + test_seed=10, + ), + # Experiment settings + ExperimentConfig( + persistence_enabled=False, + watch=True, + watch_render=1 / 35, + watch_num_episodes=100, + ), + # Training configuration + OffPolicyTrainingConfig( + max_epochs=10, + epoch_num_steps=10000, + batch_size=64, + num_train_envs=10, + num_test_envs=100, + buffer_size=20000, + collection_step_num_env_steps=10, + update_step_num_gradient_steps_per_sample=1 / 10, + ), + ) + # Algorithm-specific parameters + .with_dqn_params( + DQNParams( + lr=1e-3, + gamma=0.9, + n_step_return_horizon=3, + target_update_freq=320, + eps_training=0.3, + eps_inference=0.0, + ), + ) + # Network architecture + .with_model_factory_default(hidden_sizes=(64, 64)) + # Stop condition + .with_epoch_stop_callback(EpochStopCallbackRewardThreshold(195)) + .build() +) + +# Run the experiment +experiment.run() +``` + +**What's happening here:** +1. We create an `ExperimentBuilder` with three main configuration objects +2. We chain builder methods to specify algorithm parameters, model architecture, and callbacks +3. We call `.build()` to construct the experiment +4. We call `.run()` to execute the entire training pipeline + +The high-level API handles: +- Creating and configuring environments +- Building the neural network +- Instantiating the policy and algorithm +- Setting up collectors and replay buffer +- Managing the training loop +- Watching the trained agent + +### Procedural API Example + +```python +import gymnasium as gym +import tianshou as ts +from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import CollectStats +from tianshou.trainer import OffPolicyTrainerParams +from tianshou.utils.net.common import Net +from tianshou.utils.space_info import SpaceInfo +from torch.utils.tensorboard import SummaryWriter + +# Define hyperparameters +task = "CartPole-v1" +lr, epoch, batch_size = 1e-3, 10, 64 +num_train_envs, num_test_envs = 10, 100 +gamma, n_step, target_freq = 0.9, 3, 320 +buffer_size = 20000 +eps_train, eps_test = 0.1, 0.05 +epoch_num_steps, collection_step_num_env_steps = 10000, 10 + +# Set up logging +logger = ts.utils.TensorboardLogger(SummaryWriter("log/dqn")) + +# Create environments +train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(num_train_envs)]) +test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)]) + +# Build the network +env = gym.make(task, render_mode="human") +space_info = SpaceInfo.from_env(env) +state_shape = space_info.observation_info.obs_shape +action_shape = space_info.action_info.action_shape +net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128]) + +# Create policy and algorithm +policy = DiscreteQLearningPolicy( + model=net, + action_space=env.action_space, + eps_training=eps_train, + eps_inference=eps_test, +) +algorithm = ts.algorithm.DQN( + policy=policy, + optim=AdamOptimizerFactory(lr=lr), + gamma=gamma, + n_step_return_horizon=n_step, + target_update_freq=target_freq, +) + +# Set up collectors +train_collector = ts.data.Collector[CollectStats]( + algorithm, + train_envs, + ts.data.VectorReplayBuffer(buffer_size, num_train_envs), + exploration_noise=True, +) +test_collector = ts.data.Collector[CollectStats]( + algorithm, + test_envs, + exploration_noise=True, +) + +# Define stop condition +def stop_fn(mean_rewards: float) -> bool: + if env.spec and env.spec.reward_threshold: + return mean_rewards >= env.spec.reward_threshold + return False + +# Train the algorithm +result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=epoch, + epoch_num_steps=epoch_num_steps, + collection_step_num_env_steps=collection_step_num_env_steps, + test_step_num_episodes=num_test_envs, + batch_size=batch_size, + update_step_num_gradient_steps_per_sample=1 / collection_step_num_env_steps, + stop_fn=stop_fn, + logger=logger, + test_in_train=True, + ) +) +print(f"Finished training in {result.timing.total_time} seconds") + +# Watch the trained agent +collector = ts.data.Collector[CollectStats](algorithm, env, exploration_noise=True) +collector.collect(n_episode=100, render=1 / 35) +``` + +**What's happening here:** +1. We explicitly define all hyperparameters as variables +2. We manually create the logger +3. We construct training and test environments +4. We build the neural network by extracting space information from the environment +5. We create the policy and algorithm objects +6. We set up collectors with a replay buffer +7. We define callback functions +8. We call `algorithm.run_training()` with explicit parameters +9. We manually set up and run the evaluation collector + +The procedural API requires: +- Explicit creation of every component +- Manual extraction of environment properties +- Direct specification of all connections +- Custom callback function definitions + +## Key Concepts in the High-Level API + +### ExperimentBuilder + +The `ExperimentBuilder` is the core abstraction. +Each algorithm has its own builder (e.g., `DQNExperimentBuilder`, `PPOExperimentBuilder`, `SACExperimentBuilder`). + +**Some methods you will find in experiment builders:** +- `.with__params()` - Set algorithm-specific parameters +- `.with_model_factory()`, `.with_model_factory_default()` - Configure network architecture +- `.with_critic_factory()` - Configure critic network (for actor-critic methods) +- `.with_epoch_train_callback()` - Add function to be called at the beginning of the training step in each epoch +- `.with_epoch_test_callback()` - Add function to be called at the beginning of the test step in each epoch +- `.with_epoch_stop_callback()` - Define stopping conditions +- `.with_algorithm_wrapper_factory()` - Add algorithm wrappers (e.g., ICM) + +### Configuration Objects + +Three main configuration objects are required when constructing an experiment builder: + +1. **Environment Configuration** (`EnvFactory` subclasses) + - Defines how to create and configure environments + - Existing factories: + - `EnvFactoryRegistered` - For the creation of environments registered in Gymnasium + - `AtariEnvFactory` - For Atari environments with preprocessing + - Custom factories for your own environments can be created by subclassing `EnvFactory` + +2. **Experiment Configuration** (`ExperimentConfig`): + General settings for the experiment, particularly related to + - logging + - randomization + - persistence + - watching the trained agent's performance after training + +3. **Training Configuration** (`OffPolicyTrainingConfig`, `OnPolicyTrainingConfig`): + Defines all parameters related to the training process + +### Parameter Classes + +Algorithm parameters are defined in dataclasses specific to each algorithm (e.g., `DQNParams`, `PPOParams`). +The parameters are extensively documented. + +```{note} +Make sure to use a modern IDE to take advantage of auto-completion and inline documentation! +``` + +### Factories + +The high-level API uses factories extensively: +- **Model Factories**: Create neural networks (e.g., `IntermediateModuleFactoryAtariDQN()`) +- **Environment Factories**: Create and configure environments +- **Optimizer Factories**: Create optimizers with specific configurations + +### Extensibility + +The high-level API is designed to be extensible. +You can create custom factories (e.g. for your own models or your own environments) by subclassing the appropriate base classes +and then use them in the experiment builder. + +If we have created a torch module in `CustomNetwork`, which we want to use within our policy, +we simply need to define a factory for it in order to apply it in the high-level API: + +```python +from tianshou.highlevel.env import Environments +from tianshou.highlevel.module.core import TDevice +from tianshou.highlevel.module.intermediate import IntermediateModuleFactory, IntermediateModule + +class CustomNetFactory(IntermediateModuleFactory): + def __init__(self, hidden_sizes: tuple[int, ...] = (128, 128)): + self.hidden_sizes = hidden_sizes + + def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule: + obs_shape = envs.get_observation_shape() + action_shape = envs.get_action_shape() + + # Your custom network creation logic + net = CustomNetwork( + obs_shape=obs_shape, + action_shape=action_shape, + hidden_sizes=self.hidden_sizes, + ).to(device) + + return IntermediateModule(net, net.output_dim) + +experiment = ( + DQNExperimentBuilder(...) + .with_model_factory(CustomNetFactory(hidden_sizes=(256, 256))) + .build() +) +``` + +## Key Concepts in the Procedural API + +### Core Components + +You manually create and connect: + +1. **Environments**: Using `gym.make()` and vectorization (`DummyVectorEnv`, `SubprocVectorEnv`) +2. **Networks**: Using `Net` or custom PyTorch modules +3. **Policies**: Using algorithm-specific policy classes (e.g., `DiscreteQLearningPolicy`) +4. **Algorithms**: Using algorithm classes (e.g., `DQN`, `PPO`, `SAC`) +5. **Collectors**: Using `Collector` to gather experience +6. **Buffers**: Using `VectorReplayBuffer` or `ReplayBuffer` +7. **Trainers**: Using the respective trainer class and corresponding parameter class (e.g., `OffPolicyTrainer` and `OffPolicyTrainerParams`) + +### Training Loop + +The training is executed via `algorithm.run_training()`, which takes a trainer parameter object. +You can alternatively implement custom training loops (or even your own trainer class) for maximum flexibility. + + +## Choosing Your Path + +**Use the high-level API** if ... +- you are new to Tianshou, +- you are focused on applying RL to problems, +- you prefer declarative code. + +**Use the procedural API** if ... +- you are developing new algorithms, +- you need maximum flexibility, +- you are comfortable with RL internals, +- you prefer imperative code. + +## Additional Resources + +- **High-Level API Examples**: See `examples/` directory (scripts ending in `_hl.py`) +- **Procedural API Examples**: See `examples/` directory (scripts without suffix) diff --git a/docs/01_tutorials/01_concepts.rst b/docs/01_tutorials/02_internals.rst similarity index 99% rename from docs/01_tutorials/01_concepts.rst rename to docs/01_tutorials/02_internals.rst index 28b0dc276..b9a3b7ac2 100644 --- a/docs/01_tutorials/01_concepts.rst +++ b/docs/01_tutorials/02_internals.rst @@ -1,5 +1,5 @@ -Basic concepts in Tianshou -========================== +Understanding Tianshou Internals +================================ Tianshou splits a Reinforcement Learning agent training procedure into these parts: algorithm, trainer, collector, policy, a data buffer and batches from the buffer. The algorithm encapsulates the specific RL learning method (e.g., DQN, PPO), which contains a policy and defines how to update it. diff --git a/docs/01_tutorials/07_cheatsheet.rst b/docs/01_tutorials/07_cheatsheet.rst index fc747d66f..c9f3e425d 100644 --- a/docs/01_tutorials/07_cheatsheet.rst +++ b/docs/01_tutorials/07_cheatsheet.rst @@ -12,22 +12,6 @@ you can also use the batch processor :ref:`preprocess_fn` or vectorized environment wrapper :class:`~tianshou.env.VectorEnvWrapper`. -.. _network_api: - -Build Policy Network --------------------- - -See :ref:`build_the_network`. - - -.. _new_policy: - -Build New Policy ----------------- - -See :class:`~tianshou.algorithm.BasePolicy`. - - .. _eval_policy: Manually Evaluate Policy @@ -40,14 +24,6 @@ If you'd like to manually see the action generated by a well-trained agent: action = policy(Batch(obs=np.array([obs]))).act[0] -.. _customize_training: - -Customize Training Process --------------------------- - -See :ref:`customized_trainer`. - - .. _resume_training: Resume Training Process diff --git a/docs/04_contributing/05_contributors.rst b/docs/04_contributing/05_contributors.rst index 715c24ab3..e8028c922 100644 --- a/docs/04_contributing/05_contributors.rst +++ b/docs/04_contributing/05_contributors.rst @@ -5,24 +5,22 @@ We always welcome contributions to help make Tianshou better! Tianshou was originally created by the `THU-ML Group `_ at Tsinghua University. Today, it is backed by the `appliedAI Institute for Europe `_, -which is committed to making Tianshou the go-to resource for reinforcement learning research and development, -and guaranteeing its long-term maintenance and support. +a non-profit organization committed to making Tianshou the go-to resource for reinforcement learning research and development, +guaranteeing its long-term maintenance and support. The original creator Jiayi Weng (`Trinkle23897 `_) continues -to be a key contributor to the project. +to be involved in Tianshou development. -The current tianshou maintainers from the appliedAI Institute for Europe are: +The current core developers, who are behind the v1.0 and v2.0 releases of Tianshou, are: -* Michael Panchenko (`MischaPanch `_) * Dominik Jain (`opcode81 `_) +* Michael Panchenko (`MischaPanch `_) - -An incomplete list of the early contributors is: +An incomplete list of early contributors is: * Alexis Duburcq (`duburcqa `_) * Kaichao You (`youkaichao `_) * Huayu Chen (`ChenDRAG `_) * Yi Su (`nuance1979 `_) - You can find more information about contributors `here `_. diff --git a/docs/_config.yml b/docs/_config.yml index fce609211..0f110fb33 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -37,7 +37,7 @@ parse: - colon_fence # - deflist - dollarmath - # - html_admonition + - html_admonition # - html_image - linkify # - replacements @@ -151,4 +151,4 @@ sphinx: R: "{\\mathbb{R}}" abs: ["{\\left| #1 \\right|}", 1] simpl: ["{\\Delta^{#1} }", 1] - amax: "{\\text{argmax}}" \ No newline at end of file + amax: "{\\text{argmax}}" diff --git a/docs/_static/images/agent-env-step1.png b/docs/_static/images/agent-env-step1.png new file mode 100644 index 000000000..27dcc4a97 Binary files /dev/null and b/docs/_static/images/agent-env-step1.png differ diff --git a/docs/_static/images/agent-env-step2.png b/docs/_static/images/agent-env-step2.png new file mode 100644 index 000000000..d14ce0b84 Binary files /dev/null and b/docs/_static/images/agent-env-step2.png differ diff --git a/docs/_static/images/agent-env-step3.png b/docs/_static/images/agent-env-step3.png new file mode 100644 index 000000000..0869ea972 Binary files /dev/null and b/docs/_static/images/agent-env-step3.png differ diff --git a/docs/_static/images/agent-env-step4.png b/docs/_static/images/agent-env-step4.png new file mode 100644 index 000000000..7b70b6e1f Binary files /dev/null and b/docs/_static/images/agent-env-step4.png differ diff --git a/docs/index.rst b/docs/index.rst index 4bfbbdd17..f304f3cf1 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,81 +1,56 @@ -.. Tianshou documentation master file, created by - sphinx-quickstart on Sat Mar 28 15:58:19 2020. - You can adapt this file completely to your liking, but it should at least - contain the root `toctree` directive. - Welcome to Tianshou! ==================== **Tianshou** (`天授 `_) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed framework and pythonic API for building the deep reinforcement learning agent. The supported interface algorithms include: -* :class:`~tianshou.algorithm.DQNPolicy` `Deep Q-Network `_ -* :class:`~tianshou.algorithm.DQNPolicy` `Double DQN `_ -* :class:`~tianshou.algorithm.DQNPolicy` `Dueling DQN `_ -* :class:`~tianshou.algorithm.BranchingDQNPolicy` `Branching DQN `_ -* :class:`~tianshou.algorithm.C51Policy` `Categorical DQN `_ -* :class:`~tianshou.algorithm.RainbowPolicy` `Rainbow DQN `_ -* :class:`~tianshou.algorithm.QRDQNPolicy` `Quantile Regression DQN `_ -* :class:`~tianshou.algorithm.IQNPolicy` `Implicit Quantile Network `_ -* :class:`~tianshou.algorithm.FQFPolicy` `Fully-parameterized Quantile Function `_ -* :class:`~tianshou.algorithm.PGPolicy` `Policy Gradient `_ -* :class:`~tianshou.algorithm.NPGPolicy` `Natural Policy Gradient `_ -* :class:`~tianshou.algorithm.A2CPolicy` `Advantage Actor-Critic `_ -* :class:`~tianshou.algorithm.TRPOPolicy` `Trust Region Policy Optimization `_ -* :class:`~tianshou.algorithm.PPOPolicy` `Proximal Policy Optimization `_ -* :class:`~tianshou.algorithm.DDPGPolicy` `Deep Deterministic Policy Gradient `_ -* :class:`~tianshou.algorithm.TD3Policy` `Twin Delayed DDPG `_ -* :class:`~tianshou.algorithm.SACPolicy` `Soft Actor-Critic `_ -* :class:`~tianshou.algorithm.REDQPolicy` `Randomized Ensembled Double Q-Learning `_ -* :class:`~tianshou.algorithm.DiscreteSACPolicy` `Discrete Soft Actor-Critic `_ -* :class:`~tianshou.algorithm.ImitationPolicy` Imitation Learning -* :class:`~tianshou.algorithm.BCQPolicy` `Batch-Constrained deep Q-Learning `_ -* :class:`~tianshou.algorithm.CQLPolicy` `Conservative Q-Learning `_ -* :class:`~tianshou.algorithm.TD3BCPolicy` `Twin Delayed DDPG with Behavior Cloning `_ -* :class:`~tianshou.algorithm.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning `_ -* :class:`~tianshou.algorithm.DiscreteCQLPolicy` `Discrete Conservative Q-Learning `_ -* :class:`~tianshou.algorithm.DiscreteCRRPolicy` `Critic Regularized Regression `_ -* :class:`~tianshou.algorithm.GAILPolicy` `Generative Adversarial Imitation Learning `_ -* :class:`~tianshou.algorithm.PSRLPolicy` `Posterior Sampling Reinforcement Learning `_ -* :class:`~tianshou.algorithm.ICMPolicy` `Intrinsic Curiosity Module `_ -* :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay `_ -* :meth:`~tianshou.algorithm.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator `_ -* :class:`~tianshou.data.HERReplayBuffer` `Hindsight Experience Replay `_ - -Here is Tianshou's other features: - -* Elegant framework, using only ~3000 lines of code -* State-of-the-art `MuJoCo benchmark `_ -* Support vectorized environment (synchronous or asynchronous) for all algorithms: :ref:`parallel_sampling` -* Support super-fast vectorized environment `EnvPool `_ for all algorithms: :ref:`envpool_integration` -* Support recurrent state representation in actor network and critic network (RNN-style training for POMDP): :ref:`rnn_training` -* Support any type of environment state/action (e.g. a dict, a self-defined class, ...): :ref:`self_defined_env` -* Support :ref:`customize_training` -* Support n-step returns estimation :meth:`~tianshou.algorithm.BasePolicy.compute_nstep_return` and prioritized experience replay :class:`~tianshou.data.PrioritizedReplayBuffer` for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation -* Support :doc:`/01_tutorials/04_tictactoe` -* Support both `TensorBoard `_ and `W&B `_ log tools -* Support multi-GPU training :ref:`multi_gpu` -* Comprehensive `unit tests `_, including functional checking, RL pipeline checking, documentation checking, PEP8 code-style checking, and type checking +* :class:`~tianshou.algorithm.modelfree.dqn.DQN` `Deep Q-Network `_ +* :class:`~tianshou.algorithm.modelfree.dqn.DQN` `Double DQN `_ +* :class:`~tianshou.algorithm.modelfree.dqn.DQN` `Dueling DQN `_ +* :class:`~tianshou.algorithm.modelfree.bdqn.BDQN` `Branching DQN `_ +* :class:`~tianshou.algorithm.modelfree.c51.C51` `Categorical DQN `_ +* :class:`~tianshou.algorithm.modelfree.rainbow.RainbowDQN` `Rainbow DQN `_ +* :class:`~tianshou.algorithm.modelfree.qrdqn.QRDQN` `Quantile Regression DQN `_ +* :class:`~tianshou.algorithm.modelfree.iqn.IQN` `Implicit Quantile Network `_ +* :class:`~tianshou.algorithm.modelfree.fqf.FQF` `Fully-parameterized Quantile Function `_ +* :class:`~tianshou.algorithm.modelfree.reinforce.Reinforce` `Reinforce/Vanilla Policy Gradients `_ +* :class:`~tianshou.algorithm.modelfree.npg.NPG` `Natural Policy Gradient `_ +* :class:`~tianshou.algorithm.modelfree.a2c.A2C` `Advantage Actor-Critic `_ +* :class:`~tianshou.algorithm.modelfree.trpo.TRPO` `Trust Region Policy Optimization `_ +* :class:`~tianshou.algorithm.modelfree.ppo.PPO` `Proximal Policy Optimization `_ +* :class:`~tianshou.algorithm.modelfree.ddpg.DDPG` `Deep Deterministic Policy Gradient `_ +* :class:`~tianshou.algorithm.modelfree.td3.TD3` `Twin Delayed DDPG `_ +* :class:`~tianshou.algorithm.modelfree.sac.SAC` `Soft Actor-Critic `_ +* :class:`~tianshou.algorithm.modelfree.redq.REDQ` `Randomized Ensembled Double Q-Learning `_ +* :class:`~tianshou.algorithm.modelfree.discrete_sac.DiscreteSAC` `Discrete Soft Actor-Critic `_ +* :class:`~tianshou.algorithm.imitation.imitation_base.ImitationPolicy` Imitation Learning +* :class:`~tianshou.algorithm.imitation.bcq.BCQ` `Batch-Constrained deep Q-Learning `_ +* :class:`~tianshou.algorithm.imitation.cql.CQL` `Conservative Q-Learning `_ +* :class:`~tianshou.algorithm.imitation.td3_bc.TD3BC` `Twin Delayed DDPG with Behavior Cloning `_ +* :class:`~tianshou.algorithm.imitation.discrete_cql.DiscreteCQL` `Discrete Conservative Q-Learning `_ +* :class:`~tianshou.algorithm.imitation.discrete_bcq.DiscreteBCQ` `Discrete Batch-Constrained deep Q-Learning `_ +* :class:`~tianshou.algorithm.imitation.discrete_crr.DiscreteCRR` `Critic Regularized Regression `_ +* :class:`~tianshou.algorithm.imitation.gail.GAIL` `Generative Adversarial Imitation Learning `_ +* :class:`~tianshou.algorithm.modelbased.psrl.PSRLPolicy` `Posterior Sampling Reinforcement Learning `_ +* :class:`~tianshou.algorithm.modelbased.icm.ICMOffPolicyWrapper`, :class:`~tianshou.algorithm.modelbased.icm.ICMOnPolicyWrapper` `Intrinsic Curiosity Module `_ +* :class:`~tianshou.data.buffer.prio.PrioritizedReplayBuffer` `Prioritized Experience Replay `_ +* :meth:`~tianshou.algorithm.algorithm_base.Algorithm.compute_episodic_return` `Generalized Advantage Estimator `_ +* :class:`~tianshou.data.buffer.her.HERReplayBuffer` `Hindsight Experience Replay `_ + Installation ------------ -Tianshou is currently hosted on `PyPI `_ and `conda-forge `_. New releases -(and the current state of the master branch) will require Python >= 3.11. +Tianshou is available through `PyPI `_. +New releases require Python >= 3.11. -You can simply install Tianshou from PyPI with the following command: +Install Tianshou with the following command: .. code-block:: bash $ pip install tianshou -If you use Anaconda or Miniconda, you can install Tianshou from conda-forge through the following command: - -.. code-block:: bash - - $ conda install tianshou -c conda-forge - -You can also install with the newest version through GitHub: +Alternatively, install the current version on GitHub: .. code-block:: bash @@ -89,7 +64,6 @@ After installation, open your python console and type If no error occurs, you have successfully installed Tianshou. -Tianshou is still under development, you can also check out the documents in stable version through `tianshou.readthedocs.io/en/stable/ `_. Indices and tables ------------------ diff --git a/docs/nbstripout.py b/docs/nbstripout.py index 95d4193ad..1302a35fb 100644 --- a/docs/nbstripout.py +++ b/docs/nbstripout.py @@ -1,4 +1,5 @@ """Implements a platform-independent way of calling nbstripout (used in pyproject.toml).""" + import glob import os from pathlib import Path diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index 22d35c4b7..f25c768de 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -247,7 +247,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: test_collector = Collector[CollectStats](algorithm, test_envs) # log t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") - log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_gail' + log_file = f"seed_{args.seed}_{t0}-{args.task.replace('-', '_')}_gail" log_path = os.path.join(args.logdir, args.task, "gail", log_file) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) diff --git a/examples/offline/convert_rl_unplugged_atari.py b/examples/offline/convert_rl_unplugged_atari.py index d999ad330..04e5e706a 100755 --- a/examples/offline/convert_rl_unplugged_atari.py +++ b/examples/offline/convert_rl_unplugged_atari.py @@ -27,6 +27,7 @@ * episode_return: Total episode return computed using per-step [-1, 1] clipping. """ + import os from argparse import ArgumentParser, Namespace @@ -179,13 +180,16 @@ def download(url: str, fname: str, chunk_size: int | None = 1024) -> None: if os.path.exists(fname): print(f"Found cached file at {fname}.") return - with open(fname, "wb") as ofile, tqdm( - desc=fname, - total=total, - unit="iB", - unit_scale=True, - unit_divisor=1024, - ) as bar: + with ( + open(fname, "wb") as ofile, + tqdm( + desc=fname, + total=total, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as bar, + ): for data in resp.iter_content(chunk_size=chunk_size): size = ofile.write(data) bar.update(size) diff --git a/poetry.lock b/poetry.lock index c009dd5ed..09adb4a4b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -27,22 +27,6 @@ files = [ [package.dependencies] pygments = ">=1.5" -[[package]] -name = "aiosignal" -version = "1.3.1" -description = "aiosignal: a list of registered asynchronous callbacks" -optional = false -python-versions = ">=3.7" -groups = ["dev"] -markers = "sys_platform != \"win32\"" -files = [ - {file = "aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17"}, - {file = "aiosignal-1.3.1.tar.gz", hash = "sha256:54cd96e15e1649b75d6c87526a6ff0b6c1b0dd3459f43d9ca11d48c339b68cfc"}, -] - -[package.dependencies] -frozenlist = ">=1.1.0" - [[package]] name = "alabaster" version = "0.7.13" @@ -92,6 +76,18 @@ numpy = "*" [package.extras] test = ["gym (>=0.23,<1.0)", "pytest (>=7.0)"] +[[package]] +name = "annotated-types" +version = "0.7.0" +description = "Reusable constraint types to use with typing.Annotated" +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"}, + {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, +] + [[package]] name = "anyio" version = "4.0.0" @@ -442,49 +438,6 @@ soupsieve = ">1.2" html5lib = ["html5lib"] lxml = ["lxml"] -[[package]] -name = "black" -version = "23.11.0" -description = "The uncompromising code formatter." -optional = false -python-versions = ">=3.8" -groups = ["dev"] -files = [ - {file = "black-23.11.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dbea0bb8575c6b6303cc65017b46351dc5953eea5c0a59d7b7e3a2d2f433a911"}, - {file = "black-23.11.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:412f56bab20ac85927f3a959230331de5614aecda1ede14b373083f62ec24e6f"}, - {file = "black-23.11.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d136ef5b418c81660ad847efe0e55c58c8208b77a57a28a503a5f345ccf01394"}, - {file = "black-23.11.0-cp310-cp310-win_amd64.whl", hash = "sha256:6c1cac07e64433f646a9a838cdc00c9768b3c362805afc3fce341af0e6a9ae9f"}, - {file = "black-23.11.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cf57719e581cfd48c4efe28543fea3d139c6b6f1238b3f0102a9c73992cbb479"}, - {file = "black-23.11.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:698c1e0d5c43354ec5d6f4d914d0d553a9ada56c85415700b81dc90125aac244"}, - {file = "black-23.11.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:760415ccc20f9e8747084169110ef75d545f3b0932ee21368f63ac0fee86b221"}, - {file = "black-23.11.0-cp311-cp311-win_amd64.whl", hash = "sha256:58e5f4d08a205b11800332920e285bd25e1a75c54953e05502052738fe16b3b5"}, - {file = "black-23.11.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:45aa1d4675964946e53ab81aeec7a37613c1cb71647b5394779e6efb79d6d187"}, - {file = "black-23.11.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4c44b7211a3a0570cc097e81135faa5f261264f4dfaa22bd5ee2875a4e773bd6"}, - {file = "black-23.11.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2a9acad1451632021ee0d146c8765782a0c3846e0e0ea46659d7c4f89d9b212b"}, - {file = "black-23.11.0-cp38-cp38-win_amd64.whl", hash = "sha256:fc7f6a44d52747e65a02558e1d807c82df1d66ffa80a601862040a43ec2e3142"}, - {file = "black-23.11.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7f622b6822f02bfaf2a5cd31fdb7cd86fcf33dab6ced5185c35f5db98260b055"}, - {file = "black-23.11.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:250d7e60f323fcfc8ea6c800d5eba12f7967400eb6c2d21ae85ad31c204fb1f4"}, - {file = "black-23.11.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5133f5507007ba08d8b7b263c7aa0f931af5ba88a29beacc4b2dc23fcefe9c06"}, - {file = "black-23.11.0-cp39-cp39-win_amd64.whl", hash = "sha256:421f3e44aa67138ab1b9bfbc22ee3780b22fa5b291e4db8ab7eee95200726b07"}, - {file = "black-23.11.0-py3-none-any.whl", hash = "sha256:54caaa703227c6e0c87b76326d0862184729a69b73d3b7305b6288e1d830067e"}, - {file = "black-23.11.0.tar.gz", hash = "sha256:4c68855825ff432d197229846f971bc4d6666ce90492e5b02013bcaca4d9ab05"}, -] - -[package.dependencies] -click = ">=8.0.0" -ipython = {version = ">=7.8.0", optional = true, markers = "extra == \"jupyter\""} -mypy-extensions = ">=0.4.3" -packaging = ">=22.0" -pathspec = ">=0.9.0" -platformdirs = ">=2" -tokenize-rt = {version = ">=3.2.0", optional = true, markers = "extra == \"jupyter\""} - -[package.extras] -colorama = ["colorama (>=0.4.3)"] -d = ["aiohttp (>=3.7.4)"] -jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] -uvloop = ["uvloop (>=0.15.2)"] - [[package]] name = "bleach" version = "6.1.0" @@ -1205,21 +1158,6 @@ files = [ {file = "dm_tree-0.1.8-cp39-cp39-win_amd64.whl", hash = "sha256:8ed3564abed97c806db122c2d3e1a2b64c74a63debe9903aad795167cc301368"}, ] -[[package]] -name = "docker-pycreds" -version = "0.4.0" -description = "Python bindings for the docker credentials store API" -optional = false -python-versions = "*" -groups = ["dev"] -files = [ - {file = "docker-pycreds-0.4.0.tar.gz", hash = "sha256:6ce3270bcaf404cc4c3e27e4b6c70d3521deae82fb508767870fdbf772d584d4"}, - {file = "docker_pycreds-0.4.0-py2.py3-none-any.whl", hash = "sha256:7266112468627868005106ec19cd0d722702d2b7d5912a28e19b826c3d37af49"}, -] - -[package.dependencies] -six = ">=1.4.0" - [[package]] name = "docstring-parser" version = "0.15" @@ -1428,78 +1366,6 @@ files = [ {file = "fqdn-1.5.1.tar.gz", hash = "sha256:105ed3677e767fb5ca086a0c1f4bb66ebc3c100be518f0e0d755d9eae164d89f"}, ] -[[package]] -name = "frozenlist" -version = "1.4.0" -description = "A list-like structure which implements collections.abc.MutableSequence" -optional = false -python-versions = ">=3.8" -groups = ["dev"] -markers = "sys_platform != \"win32\"" -files = [ - {file = "frozenlist-1.4.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:764226ceef3125e53ea2cb275000e309c0aa5464d43bd72abd661e27fffc26ab"}, - {file = "frozenlist-1.4.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d6484756b12f40003c6128bfcc3fa9f0d49a687e171186c2d85ec82e3758c559"}, - {file = "frozenlist-1.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9ac08e601308e41eb533f232dbf6b7e4cea762f9f84f6357136eed926c15d12c"}, - {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d081f13b095d74b67d550de04df1c756831f3b83dc9881c38985834387487f1b"}, - {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:71932b597f9895f011f47f17d6428252fc728ba2ae6024e13c3398a087c2cdea"}, - {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:981b9ab5a0a3178ff413bca62526bb784249421c24ad7381e39d67981be2c326"}, - {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e41f3de4df3e80de75845d3e743b3f1c4c8613c3997a912dbf0229fc61a8b963"}, - {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6918d49b1f90821e93069682c06ffde41829c346c66b721e65a5c62b4bab0300"}, - {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0e5c8764c7829343d919cc2dfc587a8db01c4f70a4ebbc49abde5d4b158b007b"}, - {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:8d0edd6b1c7fb94922bf569c9b092ee187a83f03fb1a63076e7774b60f9481a8"}, - {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e29cda763f752553fa14c68fb2195150bfab22b352572cb36c43c47bedba70eb"}, - {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:0c7c1b47859ee2cac3846fde1c1dc0f15da6cec5a0e5c72d101e0f83dcb67ff9"}, - {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:901289d524fdd571be1c7be054f48b1f88ce8dddcbdf1ec698b27d4b8b9e5d62"}, - {file = "frozenlist-1.4.0-cp310-cp310-win32.whl", hash = "sha256:1a0848b52815006ea6596c395f87449f693dc419061cc21e970f139d466dc0a0"}, - {file = "frozenlist-1.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:b206646d176a007466358aa21d85cd8600a415c67c9bd15403336c331a10d956"}, - {file = "frozenlist-1.4.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:de343e75f40e972bae1ef6090267f8260c1446a1695e77096db6cfa25e759a95"}, - {file = "frozenlist-1.4.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ad2a9eb6d9839ae241701d0918f54c51365a51407fd80f6b8289e2dfca977cc3"}, - {file = "frozenlist-1.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bd7bd3b3830247580de99c99ea2a01416dfc3c34471ca1298bccabf86d0ff4dc"}, - {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bdf1847068c362f16b353163391210269e4f0569a3c166bc6a9f74ccbfc7e839"}, - {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:38461d02d66de17455072c9ba981d35f1d2a73024bee7790ac2f9e361ef1cd0c"}, - {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5a32087d720c608f42caed0ef36d2b3ea61a9d09ee59a5142d6070da9041b8f"}, - {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dd65632acaf0d47608190a71bfe46b209719bf2beb59507db08ccdbe712f969b"}, - {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:261b9f5d17cac914531331ff1b1d452125bf5daa05faf73b71d935485b0c510b"}, - {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:b89ac9768b82205936771f8d2eb3ce88503b1556324c9f903e7156669f521472"}, - {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:008eb8b31b3ea6896da16c38c1b136cb9fec9e249e77f6211d479db79a4eaf01"}, - {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e74b0506fa5aa5598ac6a975a12aa8928cbb58e1f5ac8360792ef15de1aa848f"}, - {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:490132667476f6781b4c9458298b0c1cddf237488abd228b0b3650e5ecba7467"}, - {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:76d4711f6f6d08551a7e9ef28c722f4a50dd0fc204c56b4bcd95c6cc05ce6fbb"}, - {file = "frozenlist-1.4.0-cp311-cp311-win32.whl", hash = "sha256:a02eb8ab2b8f200179b5f62b59757685ae9987996ae549ccf30f983f40602431"}, - {file = "frozenlist-1.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:515e1abc578dd3b275d6a5114030b1330ba044ffba03f94091842852f806f1c1"}, - {file = "frozenlist-1.4.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:f0ed05f5079c708fe74bf9027e95125334b6978bf07fd5ab923e9e55e5fbb9d3"}, - {file = "frozenlist-1.4.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ca265542ca427bf97aed183c1676e2a9c66942e822b14dc6e5f42e038f92a503"}, - {file = "frozenlist-1.4.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:491e014f5c43656da08958808588cc6c016847b4360e327a62cb308c791bd2d9"}, - {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:17ae5cd0f333f94f2e03aaf140bb762c64783935cc764ff9c82dff626089bebf"}, - {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1e78fb68cf9c1a6aa4a9a12e960a5c9dfbdb89b3695197aa7064705662515de2"}, - {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5655a942f5f5d2c9ed93d72148226d75369b4f6952680211972a33e59b1dfdc"}, - {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c11b0746f5d946fecf750428a95f3e9ebe792c1ee3b1e96eeba145dc631a9672"}, - {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e66d2a64d44d50d2543405fb183a21f76b3b5fd16f130f5c99187c3fb4e64919"}, - {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:88f7bc0fcca81f985f78dd0fa68d2c75abf8272b1f5c323ea4a01a4d7a614efc"}, - {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:5833593c25ac59ede40ed4de6d67eb42928cca97f26feea219f21d0ed0959b79"}, - {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:fec520865f42e5c7f050c2a79038897b1c7d1595e907a9e08e3353293ffc948e"}, - {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:b826d97e4276750beca7c8f0f1a4938892697a6bcd8ec8217b3312dad6982781"}, - {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:ceb6ec0a10c65540421e20ebd29083c50e6d1143278746a4ef6bcf6153171eb8"}, - {file = "frozenlist-1.4.0-cp38-cp38-win32.whl", hash = "sha256:2b8bcf994563466db019fab287ff390fffbfdb4f905fc77bc1c1d604b1c689cc"}, - {file = "frozenlist-1.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:a6c8097e01886188e5be3e6b14e94ab365f384736aa1fca6a0b9e35bd4a30bc7"}, - {file = "frozenlist-1.4.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:6c38721585f285203e4b4132a352eb3daa19121a035f3182e08e437cface44bf"}, - {file = "frozenlist-1.4.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a0c6da9aee33ff0b1a451e867da0c1f47408112b3391dd43133838339e410963"}, - {file = "frozenlist-1.4.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:93ea75c050c5bb3d98016b4ba2497851eadf0ac154d88a67d7a6816206f6fa7f"}, - {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f61e2dc5ad442c52b4887f1fdc112f97caeff4d9e6ebe78879364ac59f1663e1"}, - {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa384489fefeb62321b238e64c07ef48398fe80f9e1e6afeff22e140e0850eef"}, - {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:10ff5faaa22786315ef57097a279b833ecab1a0bfb07d604c9cbb1c4cdc2ed87"}, - {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:007df07a6e3eb3e33e9a1fe6a9db7af152bbd8a185f9aaa6ece10a3529e3e1c6"}, - {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f4f399d28478d1f604c2ff9119907af9726aed73680e5ed1ca634d377abb087"}, - {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c5374b80521d3d3f2ec5572e05adc94601985cc526fb276d0c8574a6d749f1b3"}, - {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:ce31ae3e19f3c902de379cf1323d90c649425b86de7bbdf82871b8a2a0615f3d"}, - {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:7211ef110a9194b6042449431e08c4d80c0481e5891e58d429df5899690511c2"}, - {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:556de4430ce324c836789fa4560ca62d1591d2538b8ceb0b4f68fb7b2384a27a"}, - {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:7645a8e814a3ee34a89c4a372011dcd817964ce8cb273c8ed6119d706e9613e3"}, - {file = "frozenlist-1.4.0-cp39-cp39-win32.whl", hash = "sha256:19488c57c12d4e8095a922f328df3f179c820c212940a498623ed39160bc3c2f"}, - {file = "frozenlist-1.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:6221d84d463fb110bdd7619b69cb43878a11d51cbb9394ae3105d082d5199167"}, - {file = "frozenlist-1.4.0.tar.gz", hash = "sha256:09163bdf0b2907454042edb19f887c6d33806adc71fbd54afc14908bfdc22251"}, -] - [[package]] name = "fsspec" version = "2023.10.0" @@ -4079,29 +3945,6 @@ files = [ {file = "pastel-0.2.1.tar.gz", hash = "sha256:e6581ac04e973cac858828c6202c1e1e81fee1dc7de7683f3e1ffe0bfd8a573d"}, ] -[[package]] -name = "pathspec" -version = "0.11.2" -description = "Utility library for gitignore style pattern matching of file paths." -optional = false -python-versions = ">=3.7" -groups = ["dev"] -files = [ - {file = "pathspec-0.11.2-py3-none-any.whl", hash = "sha256:1d6ed233af05e679efb96b1851550ea95bbb64b7c490b0f5aa52996c11e92a20"}, - {file = "pathspec-0.11.2.tar.gz", hash = "sha256:e0d8d0ac2f12da61956eb2306b69f9469b42f4deb0f3cb6ed47b9cce9996ced3"}, -] - -[[package]] -name = "pathtools" -version = "0.1.2" -description = "File system general utilities" -optional = false -python-versions = "*" -groups = ["dev"] -files = [ - {file = "pathtools-0.1.2.tar.gz", hash = "sha256:7c35c5421a39bb82e58018febd90e3b6e5db34c5443aaaf742b3f33d4655f1c0"}, -] - [[package]] name = "patsy" version = "0.5.6" @@ -4353,23 +4196,6 @@ files = [ [package.extras] twisted = ["twisted"] -[[package]] -name = "promise" -version = "2.3" -description = "Promises/A+ implementation for Python" -optional = false -python-versions = "*" -groups = ["dev"] -files = [ - {file = "promise-2.3.tar.gz", hash = "sha256:dfd18337c523ba4b6a58801c164c1904a9d4d1b1747c7d5dbf45b693a49d93d0"}, -] - -[package.dependencies] -six = "*" - -[package.extras] -test = ["coveralls", "futures", "mock", "pytest (>=2.7.3)", "pytest-benchmark", "pytest-cov"] - [[package]] name = "prompt-toolkit" version = "3.0.41" @@ -4593,6 +4419,132 @@ files = [ {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"}, ] +[[package]] +name = "pydantic" +version = "2.9.2" +description = "Data validation using Python type hints" +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "pydantic-2.9.2-py3-none-any.whl", hash = "sha256:f048cec7b26778210e28a0459867920654d48e5e62db0958433636cde4254f12"}, + {file = "pydantic-2.9.2.tar.gz", hash = "sha256:d155cef71265d1e9807ed1c32b4c8deec042a44a50a4188b25ac67ecd81a9c0f"}, +] + +[package.dependencies] +annotated-types = ">=0.6.0" +pydantic-core = "2.23.4" +typing-extensions = [ + {version = ">=4.12.2", markers = "python_version >= \"3.13\""}, + {version = ">=4.6.1", markers = "python_version < \"3.13\""}, +] + +[package.extras] +email = ["email-validator (>=2.0.0)"] +timezone = ["tzdata ; python_version >= \"3.9\" and sys_platform == \"win32\""] + +[[package]] +name = "pydantic-core" +version = "2.23.4" +description = "Core functionality for Pydantic validation and serialization" +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "pydantic_core-2.23.4-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:b10bd51f823d891193d4717448fab065733958bdb6a6b351967bd349d48d5c9b"}, + {file = "pydantic_core-2.23.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4fc714bdbfb534f94034efaa6eadd74e5b93c8fa6315565a222f7b6f42ca1166"}, + {file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63e46b3169866bd62849936de036f901a9356e36376079b05efa83caeaa02ceb"}, + {file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ed1a53de42fbe34853ba90513cea21673481cd81ed1be739f7f2efb931b24916"}, + {file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cfdd16ab5e59fc31b5e906d1a3f666571abc367598e3e02c83403acabc092e07"}, + {file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:255a8ef062cbf6674450e668482456abac99a5583bbafb73f9ad469540a3a232"}, + {file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a7cd62e831afe623fbb7aabbb4fe583212115b3ef38a9f6b71869ba644624a2"}, + {file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f09e2ff1f17c2b51f2bc76d1cc33da96298f0a036a137f5440ab3ec5360b624f"}, + {file = "pydantic_core-2.23.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e38e63e6f3d1cec5a27e0afe90a085af8b6806ee208b33030e65b6516353f1a3"}, + {file = "pydantic_core-2.23.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0dbd8dbed2085ed23b5c04afa29d8fd2771674223135dc9bc937f3c09284d071"}, + {file = "pydantic_core-2.23.4-cp310-none-win32.whl", hash = "sha256:6531b7ca5f951d663c339002e91aaebda765ec7d61b7d1e3991051906ddde119"}, + {file = "pydantic_core-2.23.4-cp310-none-win_amd64.whl", hash = "sha256:7c9129eb40958b3d4500fa2467e6a83356b3b61bfff1b414c7361d9220f9ae8f"}, + {file = "pydantic_core-2.23.4-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:77733e3892bb0a7fa797826361ce8a9184d25c8dffaec60b7ffe928153680ba8"}, + {file = "pydantic_core-2.23.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1b84d168f6c48fabd1f2027a3d1bdfe62f92cade1fb273a5d68e621da0e44e6d"}, + {file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:df49e7a0861a8c36d089c1ed57d308623d60416dab2647a4a17fe050ba85de0e"}, + {file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ff02b6d461a6de369f07ec15e465a88895f3223eb75073ffea56b84d9331f607"}, + {file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:996a38a83508c54c78a5f41456b0103c30508fed9abcad0a59b876d7398f25fd"}, + {file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d97683ddee4723ae8c95d1eddac7c192e8c552da0c73a925a89fa8649bf13eea"}, + {file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:216f9b2d7713eb98cb83c80b9c794de1f6b7e3145eef40400c62e86cee5f4e1e"}, + {file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6f783e0ec4803c787bcea93e13e9932edab72068f68ecffdf86a99fd5918878b"}, + {file = "pydantic_core-2.23.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d0776dea117cf5272382634bd2a5c1b6eb16767c223c6a5317cd3e2a757c61a0"}, + {file = "pydantic_core-2.23.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d5f7a395a8cf1621939692dba2a6b6a830efa6b3cee787d82c7de1ad2930de64"}, + {file = "pydantic_core-2.23.4-cp311-none-win32.whl", hash = "sha256:74b9127ffea03643e998e0c5ad9bd3811d3dac8c676e47db17b0ee7c3c3bf35f"}, + {file = "pydantic_core-2.23.4-cp311-none-win_amd64.whl", hash = "sha256:98d134c954828488b153d88ba1f34e14259284f256180ce659e8d83e9c05eaa3"}, + {file = "pydantic_core-2.23.4-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:f3e0da4ebaef65158d4dfd7d3678aad692f7666877df0002b8a522cdf088f231"}, + {file = "pydantic_core-2.23.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f69a8e0b033b747bb3e36a44e7732f0c99f7edd5cea723d45bc0d6e95377ffee"}, + {file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:723314c1d51722ab28bfcd5240d858512ffd3116449c557a1336cbe3919beb87"}, + {file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bb2802e667b7051a1bebbfe93684841cc9351004e2badbd6411bf357ab8d5ac8"}, + {file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d18ca8148bebe1b0a382a27a8ee60350091a6ddaf475fa05ef50dc35b5df6327"}, + {file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:33e3d65a85a2a4a0dc3b092b938a4062b1a05f3a9abde65ea93b233bca0e03f2"}, + {file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:128585782e5bfa515c590ccee4b727fb76925dd04a98864182b22e89a4e6ed36"}, + {file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:68665f4c17edcceecc112dfed5dbe6f92261fb9d6054b47d01bf6371a6196126"}, + {file = "pydantic_core-2.23.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:20152074317d9bed6b7a95ade3b7d6054845d70584216160860425f4fbd5ee9e"}, + {file = "pydantic_core-2.23.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:9261d3ce84fa1d38ed649c3638feefeae23d32ba9182963e465d58d62203bd24"}, + {file = "pydantic_core-2.23.4-cp312-none-win32.whl", hash = "sha256:4ba762ed58e8d68657fc1281e9bb72e1c3e79cc5d464be146e260c541ec12d84"}, + {file = "pydantic_core-2.23.4-cp312-none-win_amd64.whl", hash = "sha256:97df63000f4fea395b2824da80e169731088656d1818a11b95f3b173747b6cd9"}, + {file = "pydantic_core-2.23.4-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:7530e201d10d7d14abce4fb54cfe5b94a0aefc87da539d0346a484ead376c3cc"}, + {file = "pydantic_core-2.23.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:df933278128ea1cd77772673c73954e53a1c95a4fdf41eef97c2b779271bd0bd"}, + {file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0cb3da3fd1b6a5d0279a01877713dbda118a2a4fc6f0d821a57da2e464793f05"}, + {file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:42c6dcb030aefb668a2b7009c85b27f90e51e6a3b4d5c9bc4c57631292015b0d"}, + {file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:696dd8d674d6ce621ab9d45b205df149399e4bb9aa34102c970b721554828510"}, + {file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2971bb5ffe72cc0f555c13e19b23c85b654dd2a8f7ab493c262071377bfce9f6"}, + {file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8394d940e5d400d04cad4f75c0598665cbb81aecefaca82ca85bd28264af7f9b"}, + {file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0dff76e0602ca7d4cdaacc1ac4c005e0ce0dcfe095d5b5259163a80d3a10d327"}, + {file = "pydantic_core-2.23.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:7d32706badfe136888bdea71c0def994644e09fff0bfe47441deaed8e96fdbc6"}, + {file = "pydantic_core-2.23.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:ed541d70698978a20eb63d8c5d72f2cc6d7079d9d90f6b50bad07826f1320f5f"}, + {file = "pydantic_core-2.23.4-cp313-none-win32.whl", hash = "sha256:3d5639516376dce1940ea36edf408c554475369f5da2abd45d44621cb616f769"}, + {file = "pydantic_core-2.23.4-cp313-none-win_amd64.whl", hash = "sha256:5a1504ad17ba4210df3a045132a7baeeba5a200e930f57512ee02909fc5c4cb5"}, + {file = "pydantic_core-2.23.4-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:d4488a93b071c04dc20f5cecc3631fc78b9789dd72483ba15d423b5b3689b555"}, + {file = "pydantic_core-2.23.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:81965a16b675b35e1d09dd14df53f190f9129c0202356ed44ab2728b1c905658"}, + {file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ffa2ebd4c8530079140dd2d7f794a9d9a73cbb8e9d59ffe24c63436efa8f271"}, + {file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:61817945f2fe7d166e75fbfb28004034b48e44878177fc54d81688e7b85a3665"}, + {file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:29d2c342c4bc01b88402d60189f3df065fb0dda3654744d5a165a5288a657368"}, + {file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5e11661ce0fd30a6790e8bcdf263b9ec5988e95e63cf901972107efc49218b13"}, + {file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d18368b137c6295db49ce7218b1a9ba15c5bc254c96d7c9f9e924a9bc7825ad"}, + {file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ec4e55f79b1c4ffb2eecd8a0cfba9955a2588497d96851f4c8f99aa4a1d39b12"}, + {file = "pydantic_core-2.23.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:374a5e5049eda9e0a44c696c7ade3ff355f06b1fe0bb945ea3cac2bc336478a2"}, + {file = "pydantic_core-2.23.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5c364564d17da23db1106787675fc7af45f2f7b58b4173bfdd105564e132e6fb"}, + {file = "pydantic_core-2.23.4-cp38-none-win32.whl", hash = "sha256:d7a80d21d613eec45e3d41eb22f8f94ddc758a6c4720842dc74c0581f54993d6"}, + {file = "pydantic_core-2.23.4-cp38-none-win_amd64.whl", hash = "sha256:5f5ff8d839f4566a474a969508fe1c5e59c31c80d9e140566f9a37bba7b8d556"}, + {file = "pydantic_core-2.23.4-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:a4fa4fc04dff799089689f4fd502ce7d59de529fc2f40a2c8836886c03e0175a"}, + {file = "pydantic_core-2.23.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0a7df63886be5e270da67e0966cf4afbae86069501d35c8c1b3b6c168f42cb36"}, + {file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dcedcd19a557e182628afa1d553c3895a9f825b936415d0dbd3cd0bbcfd29b4b"}, + {file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5f54b118ce5de9ac21c363d9b3caa6c800341e8c47a508787e5868c6b79c9323"}, + {file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:86d2f57d3e1379a9525c5ab067b27dbb8a0642fb5d454e17a9ac434f9ce523e3"}, + {file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:de6d1d1b9e5101508cb37ab0d972357cac5235f5c6533d1071964c47139257df"}, + {file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1278e0d324f6908e872730c9102b0112477a7f7cf88b308e4fc36ce1bdb6d58c"}, + {file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9a6b5099eeec78827553827f4c6b8615978bb4b6a88e5d9b93eddf8bb6790f55"}, + {file = "pydantic_core-2.23.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:e55541f756f9b3ee346b840103f32779c695a19826a4c442b7954550a0972040"}, + {file = "pydantic_core-2.23.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a5c7ba8ffb6d6f8f2ab08743be203654bb1aaa8c9dcb09f82ddd34eadb695605"}, + {file = "pydantic_core-2.23.4-cp39-none-win32.whl", hash = "sha256:37b0fe330e4a58d3c58b24d91d1eb102aeec675a3db4c292ec3928ecd892a9a6"}, + {file = "pydantic_core-2.23.4-cp39-none-win_amd64.whl", hash = "sha256:1498bec4c05c9c787bde9125cfdcc63a41004ff167f495063191b863399b1a29"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f455ee30a9d61d3e1a15abd5068827773d6e4dc513e795f380cdd59932c782d5"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:1e90d2e3bd2c3863d48525d297cd143fe541be8bbf6f579504b9712cb6b643ec"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e203fdf807ac7e12ab59ca2bfcabb38c7cf0b33c41efeb00f8e5da1d86af480"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e08277a400de01bc72436a0ccd02bdf596631411f592ad985dcee21445bd0068"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f220b0eea5965dec25480b6333c788fb72ce5f9129e8759ef876a1d805d00801"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:d06b0c8da4f16d1d1e352134427cb194a0a6e19ad5db9161bf32b2113409e728"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:ba1a0996f6c2773bd83e63f18914c1de3c9dd26d55f4ac302a7efe93fb8e7433"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:9a5bce9d23aac8f0cf0836ecfc033896aa8443b501c58d0602dbfd5bd5b37753"}, + {file = "pydantic_core-2.23.4-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:78ddaaa81421a29574a682b3179d4cf9e6d405a09b99d93ddcf7e5239c742e21"}, + {file = "pydantic_core-2.23.4-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:883a91b5dd7d26492ff2f04f40fbb652de40fcc0afe07e8129e8ae779c2110eb"}, + {file = "pydantic_core-2.23.4-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88ad334a15b32a791ea935af224b9de1bf99bcd62fabf745d5f3442199d86d59"}, + {file = "pydantic_core-2.23.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:233710f069d251feb12a56da21e14cca67994eab08362207785cf8c598e74577"}, + {file = "pydantic_core-2.23.4-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:19442362866a753485ba5e4be408964644dd6a09123d9416c54cd49171f50744"}, + {file = "pydantic_core-2.23.4-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:624e278a7d29b6445e4e813af92af37820fafb6dcc55c012c834f9e26f9aaaef"}, + {file = "pydantic_core-2.23.4-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f5ef8f42bec47f21d07668a043f077d507e5bf4e668d5c6dfe6aaba89de1a5b8"}, + {file = "pydantic_core-2.23.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:aea443fffa9fbe3af1a9ba721a87f926fe548d32cab71d188a6ede77d0ff244e"}, + {file = "pydantic_core-2.23.4.tar.gz", hash = "sha256:2584f7cf844ac4d970fba483a717dbe10c1c1c96a969bf65d61ffe94df1b2863"}, +] + +[package.dependencies] +typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" + [[package]] name = "pydata-sphinx-theme" version = "0.14.3" @@ -5153,61 +5105,62 @@ test = ["pytest (>=6,!=7.0.0,!=7.0.1)", "pytest-cov (>=3.0.0)", "pytest-qt"] [[package]] name = "ray" -version = "2.8.0" +version = "2.50.1" description = "Ray provides a simple, universal API for building distributed applications." optional = false -python-versions = "*" +python-versions = ">=3.9" groups = ["dev"] markers = "sys_platform != \"win32\"" files = [ - {file = "ray-2.8.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:34e0676a0dfa277efa688bccd83ecb7a799bc03078e5b1f1aa747fe9263175a8"}, - {file = "ray-2.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:72c696c1b784c55f0ad107d55bb58ecef5d368176765cf44fed87e714538d708"}, - {file = "ray-2.8.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:44dc3179d90f5c70954ac4a34760bab472efe2854add4905e6d5809e4d37d1f8"}, - {file = "ray-2.8.0-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:d3005c7c308624aaf191be8cdcc9056cd88921fe9aa2d84d7d579c193f87d2af"}, - {file = "ray-2.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:21189b49b7c74c2f98f0ecbd5cd01ef664fb05bd71784a9538680d2f26213b32"}, - {file = "ray-2.8.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:490768aeb9bfe137ea4e3605ef0f0dbe6e77fff78c8a65bafc89b784fb2839a4"}, - {file = "ray-2.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:370a2e2d9e49eab588a6bf26e20e5859a8a0cfcf0e4633bad7f2a0231a7431cf"}, - {file = "ray-2.8.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:c1ba046faf4f3bb7d58d99923ec042a82d86b243227642e864349d5ad5639a1e"}, - {file = "ray-2.8.0-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:bd9b4ed49cb89a6715e6ba4ec8bdebef7664b64fd05fb48fd0a913832784b746"}, - {file = "ray-2.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:60fd55aa5c11550bbbe7fd97cf51ec33e6b8cde7ee136e0182bd4298a1732421"}, - {file = "ray-2.8.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:19311606bc1eccbf07a4f420f2a9eb6548c463506868df1e1e81bb568add3c14"}, - {file = "ray-2.8.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1ed34be0ec7a290d5ceae651c3e49c7e42b46c0f5e755afca28b15df385f774d"}, - {file = "ray-2.8.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:df504921c4bea3588d45ac9aef13d974d6e1a23b107a89a60fa1e2f6005e0e08"}, - {file = "ray-2.8.0-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:cca3a73323328bf8d72c72a358efe7199bd5ef0fa8bb9a5b0565727e5fade141"}, - {file = "ray-2.8.0-cp38-cp38-win_amd64.whl", hash = "sha256:300595421c7eecd98aa251fa8a0d6487c902fcde260d8e5dcb0ddd843050773b"}, - {file = "ray-2.8.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:ea544d77e5870c17d61ce0a8a86b1a69a64ee95c2618121c7db15525fcb8a46b"}, - {file = "ray-2.8.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fef1ad1bdd7b168f5fc667ba671b8e3bcc19cda1b979faa27cd2985cda76402f"}, - {file = "ray-2.8.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:9cc6176013909d3cd9facf05d12222c2b7503e862b29738070a552e0eea0b48a"}, - {file = "ray-2.8.0-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:7058bfd14f8c6458721225f0d30983b41e6142fb9f22d8b46f0b8b1776d99593"}, - {file = "ray-2.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:89280325b01455fad14b67a7061e97a84eb8b7d0774daba3e5101d05d75f5022"}, -] - -[package.dependencies] -aiosignal = "*" -click = ">=7.0" + {file = "ray-2.50.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:0ee61b69b06acb7754f6cd08716084ce495fbbd963aeb72cfb4d14525d0e0969"}, + {file = "ray-2.50.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:b061816f8aed4bca0b81174dbffe717c69fb7bca20efd8f75560d3a6f8ccb280"}, + {file = "ray-2.50.1-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:dea9cc60e92dd089689156756e6d99c99fe3e49134180f715436b2e352b5df86"}, + {file = "ray-2.50.1-cp310-cp310-win_amd64.whl", hash = "sha256:fa0e2f4021ea5cefebf742ce0b7abff0c062863f082b03330a2f40d40591eb1f"}, + {file = "ray-2.50.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:bb33fd81684fead4aab706bd88b0aa27b94684906bfb511e59cc5756885a3f6f"}, + {file = "ray-2.50.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:76e1eaa627e19b706fa21e489cba692c7143d7d35843373ba71941def07ab58b"}, + {file = "ray-2.50.1-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:85f476bb4e667daad65318f29a35b13d6faa8e0530079c667d548c00c2d925e8"}, + {file = "ray-2.50.1-cp311-cp311-win_amd64.whl", hash = "sha256:4b32bc93aa67399bde65a220160c20f7959354d9fc6824b46a99ba3611381337"}, + {file = "ray-2.50.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:723e56c8193f8adde3ec18817ab437ad1cc9d4e72df1263e85697be282cfc526"}, + {file = "ray-2.50.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:a8424fd3a4a1ef314a85f80c361f22e9bd949c7a63e238cf4172dc1955b12c7c"}, + {file = "ray-2.50.1-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:75c884e31d4dc0c384d4a4b68e9611175b6acba8622352bcabb73190cb9f8c3f"}, + {file = "ray-2.50.1-cp312-cp312-win_amd64.whl", hash = "sha256:a571529b74e959e1e088f6e0f320a612f351cdd309e17696f41327d9c9d42ce7"}, + {file = "ray-2.50.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:254a257dc2ba4349a4784af1f204c4d8169908ea779a2e5d4de87311ab5f525f"}, + {file = "ray-2.50.1-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:40cb56cb82a2779d5b2676b7bcd911d0f0a78d2234a15abb4f982415b651cfca"}, + {file = "ray-2.50.1-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:7a52554bd55f2a6188af56ffe5c7bd977e40eb97b7b6282d827a8d3a73f0789a"}, + {file = "ray-2.50.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:1c1d766ff8ebcb9c22c6afaf84bcebaafe1a3ba87b86d6ed3219aa8c2fdb1046"}, + {file = "ray-2.50.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:f2e1bd5e2b2b00e46b1bc0a02626f33fe30c042213d358149de535767dd61c39"}, + {file = "ray-2.50.1-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:b91594e97a94780c6aaa570e154e6f9f74d3afe2e6c26964f673f71c7c0a8f9a"}, + {file = "ray-2.50.1-cp39-cp39-win_amd64.whl", hash = "sha256:3b96077511aef9efa6210682bc6df0fec2db7353952297789bd13360d3baf939"}, +] + +[package.dependencies] +click = ">=7.0,<8.3.0 || >8.3.0" filelock = "*" -frozenlist = "*" jsonschema = "*" msgpack = ">=1.0.0,<2.0.0" -numpy = {version = ">=1.19.3", markers = "python_version >= \"3.9\""} packaging = "*" -protobuf = ">=3.15.3,<3.19.5 || >3.19.5" +protobuf = ">=3.20.3" pyyaml = "*" requests = "*" [package.extras] -air = ["aiohttp (>=3.7)", "aiohttp-cors", "aiorwlock", "colorful", "fastapi", "fsspec", "gpustat (>=1.0.0)", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "numpy (>=1.20)", "opencensus", "pandas", "pandas (>=1.3)", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pyarrow (>=6.0.1)", "pydantic (<2)", "requests", "smart-open", "starlette", "tensorboardX (>=1.9)", "uvicorn[standard]", "virtualenv (>=20.0.24,<20.21.1)", "watchfiles"] -all = ["aiohttp (>=3.7)", "aiohttp-cors", "aiorwlock", "colorful", "dm-tree", "fastapi", "fsspec", "gpustat (>=1.0.0)", "grpcio (!=1.56.0)", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "gymnasium (==0.28.1)", "lz4", "numpy (>=1.20)", "opencensus", "opentelemetry-api", "opentelemetry-exporter-otlp", "opentelemetry-sdk", "pandas", "pandas (>=1.3)", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pyarrow (>=6.0.1)", "pydantic (<2)", "pyyaml", "ray-cpp (==2.8.0)", "requests", "rich", "scikit-image", "scipy", "smart-open", "starlette", "tensorboardX (>=1.9)", "typer", "uvicorn[standard]", "virtualenv (>=20.0.24,<20.21.1)", "watchfiles"] -client = ["grpcio (!=1.56.0)"] -cpp = ["ray-cpp (==2.8.0)"] -data = ["fsspec", "numpy (>=1.20)", "pandas (>=1.3)", "pyarrow (>=6.0.1)"] -default = ["aiohttp (>=3.7)", "aiohttp-cors", "colorful", "gpustat (>=1.0.0)", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "opencensus", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pydantic (<2)", "requests", "smart-open", "virtualenv (>=20.0.24,<20.21.1)"] -observability = ["opentelemetry-api", "opentelemetry-exporter-otlp", "opentelemetry-sdk"] -rllib = ["dm-tree", "fsspec", "gymnasium (==0.28.1)", "lz4", "pandas", "pyarrow (>=6.0.1)", "pyyaml", "requests", "rich", "scikit-image", "scipy", "tensorboardX (>=1.9)", "typer"] -serve = ["aiohttp (>=3.7)", "aiohttp-cors", "aiorwlock", "colorful", "fastapi", "gpustat (>=1.0.0)", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "opencensus", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pydantic (<2)", "requests", "smart-open", "starlette", "uvicorn[standard]", "virtualenv (>=20.0.24,<20.21.1)", "watchfiles"] -serve-grpc = ["aiohttp (>=3.7)", "aiohttp-cors", "aiorwlock", "colorful", "fastapi", "gpustat (>=1.0.0)", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "opencensus", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pydantic (<2)", "requests", "smart-open", "starlette", "uvicorn[standard]", "virtualenv (>=20.0.24,<20.21.1)", "watchfiles"] -train = ["fsspec", "pandas", "pyarrow (>=6.0.1)", "requests", "tensorboardX (>=1.9)"] -tune = ["fsspec", "pandas", "pyarrow (>=6.0.1)", "requests", "tensorboardX (>=1.9)"] +adag = ["cupy-cuda12x ; sys_platform != \"darwin\""] +air = ["aiohttp (>=3.7)", "aiohttp_cors", "colorful", "fastapi", "fsspec", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "numpy (>=1.20)", "opencensus", "opentelemetry-exporter-prometheus", "opentelemetry-proto", "opentelemetry-sdk (>=1.30.0)", "pandas", "pandas (>=1.3)", "prometheus_client (>=0.7.1)", "py-spy (>=0.2.0) ; python_version < \"3.12\"", "py-spy (>=0.4.0) ; python_version >= \"3.12\"", "pyarrow (>=9.0.0)", "pydantic (<2.0.dev0 || >=2.5.dev0,<3)", "requests", "smart_open", "starlette", "tensorboardX (>=1.9)", "uvicorn[standard]", "virtualenv (>=20.0.24,!=20.21.1)", "watchfiles"] +all = ["aiohttp (>=3.7)", "aiohttp_cors", "celery", "colorful", "cupy-cuda12x ; sys_platform != \"darwin\"", "dm_tree", "fastapi", "fsspec", "grpcio", "grpcio (!=1.56.0) ; sys_platform == \"darwin\"", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "gymnasium (==1.1.1)", "lz4", "memray ; sys_platform != \"win32\"", "numpy (>=1.20)", "opencensus", "opentelemetry-exporter-prometheus", "opentelemetry-proto", "opentelemetry-sdk (>=1.30.0)", "ormsgpack (==1.7.0)", "pandas", "pandas (>=1.3)", "prometheus_client (>=0.7.1)", "py-spy (>=0.2.0) ; python_version < \"3.12\"", "py-spy (>=0.4.0) ; python_version >= \"3.12\"", "pyOpenSSL", "pyarrow (>=9.0.0)", "pydantic (<2.0.dev0 || >=2.5.dev0,<3)", "pyyaml", "requests", "scipy", "smart_open", "starlette", "tensorboardX (>=1.9)", "uvicorn[standard]", "virtualenv (>=20.0.24,!=20.21.1)", "watchfiles"] +all-cpp = ["aiohttp (>=3.7)", "aiohttp_cors", "celery", "colorful", "cupy-cuda12x ; sys_platform != \"darwin\"", "dm_tree", "fastapi", "fsspec", "grpcio", "grpcio (!=1.56.0) ; sys_platform == \"darwin\"", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "gymnasium (==1.1.1)", "lz4", "memray ; sys_platform != \"win32\"", "numpy (>=1.20)", "opencensus", "opentelemetry-exporter-prometheus", "opentelemetry-proto", "opentelemetry-sdk (>=1.30.0)", "ormsgpack (==1.7.0)", "pandas", "pandas (>=1.3)", "prometheus_client (>=0.7.1)", "py-spy (>=0.2.0) ; python_version < \"3.12\"", "py-spy (>=0.4.0) ; python_version >= \"3.12\"", "pyOpenSSL", "pyarrow (>=9.0.0)", "pydantic (<2.0.dev0 || >=2.5.dev0,<3)", "pyyaml", "ray-cpp (==2.50.1)", "requests", "scipy", "smart_open", "starlette", "tensorboardX (>=1.9)", "uvicorn[standard]", "virtualenv (>=20.0.24,!=20.21.1)", "watchfiles"] +cgraph = ["cupy-cuda12x ; sys_platform != \"darwin\""] +client = ["grpcio", "grpcio (!=1.56.0) ; sys_platform == \"darwin\""] +cpp = ["ray-cpp (==2.50.1)"] +data = ["fsspec", "numpy (>=1.20)", "pandas (>=1.3)", "pyarrow (>=9.0.0)"] +default = ["aiohttp (>=3.7)", "aiohttp_cors", "colorful", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "opencensus", "opentelemetry-exporter-prometheus", "opentelemetry-proto", "opentelemetry-sdk (>=1.30.0)", "prometheus_client (>=0.7.1)", "py-spy (>=0.2.0) ; python_version < \"3.12\"", "py-spy (>=0.4.0) ; python_version >= \"3.12\"", "pydantic (<2.0.dev0 || >=2.5.dev0,<3)", "requests", "smart_open", "virtualenv (>=20.0.24,!=20.21.1)"] +llm = ["aiohttp (>=3.7)", "aiohttp_cors", "async-timeout ; python_version < \"3.11\"", "colorful", "fastapi", "fsspec", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "hf_transfer", "jsonref (>=1.1.0)", "jsonschema", "ninja", "numpy (>=1.20)", "opencensus", "opentelemetry-exporter-prometheus", "opentelemetry-proto", "opentelemetry-sdk (>=1.30.0)", "pandas (>=1.3)", "prometheus_client (>=0.7.1)", "py-spy (>=0.2.0) ; python_version < \"3.12\"", "py-spy (>=0.4.0) ; python_version >= \"3.12\"", "pyarrow (>=9.0.0)", "pydantic (<2.0.dev0 || >=2.5.dev0,<3)", "requests", "smart_open", "starlette", "typer", "uvicorn[standard]", "virtualenv (>=20.0.24,!=20.21.1)", "vllm (>=0.10.2)", "watchfiles"] +observability = ["memray ; sys_platform != \"win32\""] +rllib = ["dm_tree", "fsspec", "gymnasium (==1.1.1)", "lz4", "ormsgpack (==1.7.0)", "pandas", "pyarrow (>=9.0.0)", "pyyaml", "requests", "scipy", "tensorboardX (>=1.9)"] +serve = ["aiohttp (>=3.7)", "aiohttp_cors", "colorful", "fastapi", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "opencensus", "opentelemetry-exporter-prometheus", "opentelemetry-proto", "opentelemetry-sdk (>=1.30.0)", "prometheus_client (>=0.7.1)", "py-spy (>=0.2.0) ; python_version < \"3.12\"", "py-spy (>=0.4.0) ; python_version >= \"3.12\"", "pydantic (<2.0.dev0 || >=2.5.dev0,<3)", "requests", "smart_open", "starlette", "uvicorn[standard]", "virtualenv (>=20.0.24,!=20.21.1)", "watchfiles"] +serve-async-inference = ["aiohttp (>=3.7)", "aiohttp_cors", "celery", "colorful", "fastapi", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "opencensus", "opentelemetry-exporter-prometheus", "opentelemetry-proto", "opentelemetry-sdk (>=1.30.0)", "prometheus_client (>=0.7.1)", "py-spy (>=0.2.0) ; python_version < \"3.12\"", "py-spy (>=0.4.0) ; python_version >= \"3.12\"", "pydantic (<2.0.dev0 || >=2.5.dev0,<3)", "requests", "smart_open", "starlette", "uvicorn[standard]", "virtualenv (>=20.0.24,!=20.21.1)", "watchfiles"] +serve-grpc = ["aiohttp (>=3.7)", "aiohttp_cors", "colorful", "fastapi", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "opencensus", "opentelemetry-exporter-prometheus", "opentelemetry-proto", "opentelemetry-sdk (>=1.30.0)", "prometheus_client (>=0.7.1)", "py-spy (>=0.2.0) ; python_version < \"3.12\"", "py-spy (>=0.4.0) ; python_version >= \"3.12\"", "pyOpenSSL", "pydantic (<2.0.dev0 || >=2.5.dev0,<3)", "requests", "smart_open", "starlette", "uvicorn[standard]", "virtualenv (>=20.0.24,!=20.21.1)", "watchfiles"] +train = ["fsspec", "pandas", "pyarrow (>=9.0.0)", "pydantic (<2.0.dev0 || >=2.5.dev0,<3)", "requests", "tensorboardX (>=1.9)"] +tune = ["fsspec", "pandas", "pyarrow (>=9.0.0)", "requests", "tensorboardX (>=1.9)"] [[package]] name = "referencing" @@ -5463,7 +5416,7 @@ description = "C version of reader, parser and emitter for ruamel.yaml derived f optional = false python-versions = ">=3.6" groups = ["dev"] -markers = "platform_python_implementation == \"CPython\" and python_version < \"3.13\"" +markers = "python_version < \"3.13\" and platform_python_implementation == \"CPython\"" files = [ {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:b42169467c42b692c19cf539c38d4602069d8c1505e97b86387fcf7afb766e1d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:07238db9cbdf8fc1e9de2489a4f68474e70dffcb32232db7c08fa61ca0c7c462"}, @@ -5519,29 +5472,31 @@ files = [ [[package]] name = "ruff" -version = "0.0.285" -description = "An extremely fast Python linter, written in Rust." +version = "0.14.1" +description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" groups = ["dev"] files = [ - {file = "ruff-0.0.285-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:72a3a0936369b986b0e959f9090206ed3c18f9e5e439ea5b8e6867c6707aded5"}, - {file = "ruff-0.0.285-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:0d9ab6ad16742eb78919e0fba09f914f042409df40ad63423c34bb20d350162a"}, - {file = "ruff-0.0.285-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c48926156288b8ac005eb1db5e77c15e8a37309ae49d9fb6771d5cf5f777590"}, - {file = "ruff-0.0.285-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1d2a60c102e7a5e147b58fc2cbea12a563c565383effc527c987ea2086a05742"}, - {file = "ruff-0.0.285-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b02aae62f922d088bb01943e1dbd861688ada13d735b78b8348a7d90121fd292"}, - {file = "ruff-0.0.285-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:f572c4296d8c7ddd22c3204de4031965be524fdd1fdaaef273945932912b28c5"}, - {file = "ruff-0.0.285-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:80effdf4fe69763d69eb4ab9443e186fd09e668b59fe70ba4b49f4c077d15a1b"}, - {file = "ruff-0.0.285-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5977ce304da35c263f5e082901bd7ac0bd2be845a8fcfd1a29e4d6680cddb307"}, - {file = "ruff-0.0.285-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:72a087712d474fa17b915d7cb9ef807e1256182b12ddfafb105eb00aeee48d1a"}, - {file = "ruff-0.0.285-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:7ce67736cd8dfe97162d1e7adfc2d9a1bac0efb9aaaff32e4042c7cde079f54b"}, - {file = "ruff-0.0.285-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:5473a4c6cac34f583bff08c5f63b8def5599a0ea4dc96c0302fbd2cc0b3ecbad"}, - {file = "ruff-0.0.285-py3-none-musllinux_1_2_i686.whl", hash = "sha256:e6b1c961d608d373a032f047a20bf3c55ad05f56c32e7b96dcca0830a2a72348"}, - {file = "ruff-0.0.285-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:2933cc9631f453305399c7b8fb72b113ad76b49ae1d7103cc4afd3a423bed164"}, - {file = "ruff-0.0.285-py3-none-win32.whl", hash = "sha256:770c5eb6376de024111443022cda534fb28980a9dd3b4abc83992a8770167ba6"}, - {file = "ruff-0.0.285-py3-none-win_amd64.whl", hash = "sha256:a8c6ad6b9cd77489bf6d1510950cbbe47a843aa234adff0960bae64bd06c3b6d"}, - {file = "ruff-0.0.285-py3-none-win_arm64.whl", hash = "sha256:de44fbc6c3b25fccee473ddf851416fd4e246fc6027b2197c395b1b3b3897921"}, - {file = "ruff-0.0.285.tar.gz", hash = "sha256:45866048d1dcdcc80855998cb26c4b2b05881f9e043d2e3bfe1aa36d9a2e8f28"}, + {file = "ruff-0.14.1-py3-none-linux_armv6l.whl", hash = "sha256:083bfc1f30f4a391ae09c6f4f99d83074416b471775b59288956f5bc18e82f8b"}, + {file = "ruff-0.14.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:f6fa757cd717f791009f7669fefb09121cc5f7d9bd0ef211371fad68c2b8b224"}, + {file = "ruff-0.14.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d6191903d39ac156921398e9c86b7354d15e3c93772e7dbf26c9fcae59ceccd5"}, + {file = "ruff-0.14.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ed04f0e04f7a4587244e5c9d7df50e6b5bf2705d75059f409a6421c593a35896"}, + {file = "ruff-0.14.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5c9e6cf6cd4acae0febbce29497accd3632fe2025c0c583c8b87e8dbdeae5f61"}, + {file = "ruff-0.14.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a6fa2458527794ecdfbe45f654e42c61f2503a230545a91af839653a0a93dbc6"}, + {file = "ruff-0.14.1-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:39f1c392244e338b21d42ab29b8a6392a722c5090032eb49bb4d6defcdb34345"}, + {file = "ruff-0.14.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7382fa12a26cce1f95070ce450946bec357727aaa428983036362579eadcc5cf"}, + {file = "ruff-0.14.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dd0bf2be3ae8521e1093a487c4aa3b455882f139787770698530d28ed3fbb37c"}, + {file = "ruff-0.14.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cabcaa9ccf8089fb4fdb78d17cc0e28241520f50f4c2e88cb6261ed083d85151"}, + {file = "ruff-0.14.1-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:747d583400f6125ec11a4c14d1c8474bf75d8b419ad22a111a537ec1a952d192"}, + {file = "ruff-0.14.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5a6e74c0efd78515a1d13acbfe6c90f0f5bd822aa56b4a6d43a9ffb2ae6e56cd"}, + {file = "ruff-0.14.1-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:0ea6a864d2fb41a4b6d5b456ed164302a0d96f4daac630aeba829abfb059d020"}, + {file = "ruff-0.14.1-py3-none-musllinux_1_2_i686.whl", hash = "sha256:0826b8764f94229604fa255918d1cc45e583e38c21c203248b0bfc9a0e930be5"}, + {file = "ruff-0.14.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:cbc52160465913a1a3f424c81c62ac8096b6a491468e7d872cb9444a860bc33d"}, + {file = "ruff-0.14.1-py3-none-win32.whl", hash = "sha256:e037ea374aaaff4103240ae79168c0945ae3d5ae8db190603de3b4012bd1def6"}, + {file = "ruff-0.14.1-py3-none-win_amd64.whl", hash = "sha256:59d599cdff9c7f925a017f6f2c256c908b094e55967f93f2821b1439928746a1"}, + {file = "ruff-0.14.1-py3-none-win_arm64.whl", hash = "sha256:e3b443c4c9f16ae850906b8d0a707b2a4c16f8d2f0a7fe65c475c5886665ce44"}, + {file = "ruff-0.14.1.tar.gz", hash = "sha256:1dd86253060c4772867c61791588627320abcb6ed1577a90ef432ee319729b69"}, ] [[package]] @@ -5629,14 +5584,14 @@ win32 = ["pywin32 ; sys_platform == \"win32\""] [[package]] name = "sensai-utils" -version = "1.4.0" +version = "1.6.0" description = "Utilities from sensAI, the Python library for sensible AI" optional = false python-versions = "*" groups = ["main"] files = [ - {file = "sensai_utils-1.4.0-py3-none-any.whl", hash = "sha256:ed6fc57552620e43b33cf364ea0bc0fd7df39391069dd7b621b113ef55547507"}, - {file = "sensai_utils-1.4.0.tar.gz", hash = "sha256:2d32bdcc91fd1428c5cae0181e98623142d2d5f7e115e23d585a842dd9dc59ba"}, + {file = "sensai_utils-1.6.0-py3-none-any.whl", hash = "sha256:3298d4d21bdf7a1b91873213614c748ceb17014c7d72c9492bd5bf82ba893f9d"}, + {file = "sensai_utils-1.6.0.tar.gz", hash = "sha256:e50ae6bbd7c62a961f25b98e55b29029450efd66444678931b3b9c43e9bf9e95"}, ] [package.dependencies] @@ -5693,107 +5648,6 @@ starlette = ["starlette (>=0.19.1)"] starlite = ["starlite (>=1.48)"] tornado = ["tornado (>=6)"] -[[package]] -name = "setproctitle" -version = "1.3.3" -description = "A Python module to customize the process title" -optional = false -python-versions = ">=3.7" -groups = ["dev"] -files = [ - {file = "setproctitle-1.3.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:897a73208da48db41e687225f355ce993167079eda1260ba5e13c4e53be7f754"}, - {file = "setproctitle-1.3.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8c331e91a14ba4076f88c29c777ad6b58639530ed5b24b5564b5ed2fd7a95452"}, - {file = "setproctitle-1.3.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bbbd6c7de0771c84b4aa30e70b409565eb1fc13627a723ca6be774ed6b9d9fa3"}, - {file = "setproctitle-1.3.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c05ac48ef16ee013b8a326c63e4610e2430dbec037ec5c5b58fcced550382b74"}, - {file = "setproctitle-1.3.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1342f4fdb37f89d3e3c1c0a59d6ddbedbde838fff5c51178a7982993d238fe4f"}, - {file = "setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc74e84fdfa96821580fb5e9c0b0777c1c4779434ce16d3d62a9c4d8c710df39"}, - {file = "setproctitle-1.3.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9617b676b95adb412bb69645d5b077d664b6882bb0d37bfdafbbb1b999568d85"}, - {file = "setproctitle-1.3.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6a249415f5bb88b5e9e8c4db47f609e0bf0e20a75e8d744ea787f3092ba1f2d0"}, - {file = "setproctitle-1.3.3-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:38da436a0aaace9add67b999eb6abe4b84397edf4a78ec28f264e5b4c9d53cd5"}, - {file = "setproctitle-1.3.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:da0d57edd4c95bf221b2ebbaa061e65b1788f1544977288bdf95831b6e44e44d"}, - {file = "setproctitle-1.3.3-cp310-cp310-win32.whl", hash = "sha256:a1fcac43918b836ace25f69b1dca8c9395253ad8152b625064415b1d2f9be4fb"}, - {file = "setproctitle-1.3.3-cp310-cp310-win_amd64.whl", hash = "sha256:200620c3b15388d7f3f97e0ae26599c0c378fdf07ae9ac5a13616e933cbd2086"}, - {file = "setproctitle-1.3.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:334f7ed39895d692f753a443102dd5fed180c571eb6a48b2a5b7f5b3564908c8"}, - {file = "setproctitle-1.3.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:950f6476d56ff7817a8fed4ab207727fc5260af83481b2a4b125f32844df513a"}, - {file = "setproctitle-1.3.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:195c961f54a09eb2acabbfc90c413955cf16c6e2f8caa2adbf2237d1019c7dd8"}, - {file = "setproctitle-1.3.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f05e66746bf9fe6a3397ec246fe481096664a9c97eb3fea6004735a4daf867fd"}, - {file = "setproctitle-1.3.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b5901a31012a40ec913265b64e48c2a4059278d9f4e6be628441482dd13fb8b5"}, - {file = "setproctitle-1.3.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64286f8a995f2cd934082b398fc63fca7d5ffe31f0e27e75b3ca6b4efda4e353"}, - {file = "setproctitle-1.3.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:184239903bbc6b813b1a8fc86394dc6ca7d20e2ebe6f69f716bec301e4b0199d"}, - {file = "setproctitle-1.3.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:664698ae0013f986118064b6676d7dcd28fefd0d7d5a5ae9497cbc10cba48fa5"}, - {file = "setproctitle-1.3.3-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e5119a211c2e98ff18b9908ba62a3bd0e3fabb02a29277a7232a6fb4b2560aa0"}, - {file = "setproctitle-1.3.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:417de6b2e214e837827067048f61841f5d7fc27926f2e43954567094051aff18"}, - {file = "setproctitle-1.3.3-cp311-cp311-win32.whl", hash = "sha256:6a143b31d758296dc2f440175f6c8e0b5301ced3b0f477b84ca43cdcf7f2f476"}, - {file = "setproctitle-1.3.3-cp311-cp311-win_amd64.whl", hash = "sha256:a680d62c399fa4b44899094027ec9a1bdaf6f31c650e44183b50d4c4d0ccc085"}, - {file = "setproctitle-1.3.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:d4460795a8a7a391e3567b902ec5bdf6c60a47d791c3b1d27080fc203d11c9dc"}, - {file = "setproctitle-1.3.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:bdfd7254745bb737ca1384dee57e6523651892f0ea2a7344490e9caefcc35e64"}, - {file = "setproctitle-1.3.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:477d3da48e216d7fc04bddab67b0dcde633e19f484a146fd2a34bb0e9dbb4a1e"}, - {file = "setproctitle-1.3.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ab2900d111e93aff5df9fddc64cf51ca4ef2c9f98702ce26524f1acc5a786ae7"}, - {file = "setproctitle-1.3.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:088b9efc62d5aa5d6edf6cba1cf0c81f4488b5ce1c0342a8b67ae39d64001120"}, - {file = "setproctitle-1.3.3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a6d50252377db62d6a0bb82cc898089916457f2db2041e1d03ce7fadd4a07381"}, - {file = "setproctitle-1.3.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:87e668f9561fd3a457ba189edfc9e37709261287b52293c115ae3487a24b92f6"}, - {file = "setproctitle-1.3.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:287490eb90e7a0ddd22e74c89a92cc922389daa95babc833c08cf80c84c4df0a"}, - {file = "setproctitle-1.3.3-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:4fe1c49486109f72d502f8be569972e27f385fe632bd8895f4730df3c87d5ac8"}, - {file = "setproctitle-1.3.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4a6ba2494a6449b1f477bd3e67935c2b7b0274f2f6dcd0f7c6aceae10c6c6ba3"}, - {file = "setproctitle-1.3.3-cp312-cp312-win32.whl", hash = "sha256:2df2b67e4b1d7498632e18c56722851ba4db5d6a0c91aaf0fd395111e51cdcf4"}, - {file = "setproctitle-1.3.3-cp312-cp312-win_amd64.whl", hash = "sha256:f38d48abc121263f3b62943f84cbaede05749047e428409c2c199664feb6abc7"}, - {file = "setproctitle-1.3.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:816330675e3504ae4d9a2185c46b573105d2310c20b19ea2b4596a9460a4f674"}, - {file = "setproctitle-1.3.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:68f960bc22d8d8e4ac886d1e2e21ccbd283adcf3c43136161c1ba0fa509088e0"}, - {file = "setproctitle-1.3.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:00e6e7adff74796ef12753ff399491b8827f84f6c77659d71bd0b35870a17d8f"}, - {file = "setproctitle-1.3.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:53bc0d2358507596c22b02db079618451f3bd720755d88e3cccd840bafb4c41c"}, - {file = "setproctitle-1.3.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad6d20f9541f5f6ac63df553b6d7a04f313947f550eab6a61aa758b45f0d5657"}, - {file = "setproctitle-1.3.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:c1c84beab776b0becaa368254801e57692ed749d935469ac10e2b9b825dbdd8e"}, - {file = "setproctitle-1.3.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:507e8dc2891021350eaea40a44ddd887c9f006e6b599af8d64a505c0f718f170"}, - {file = "setproctitle-1.3.3-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:b1067647ac7aba0b44b591936118a22847bda3c507b0a42d74272256a7a798e9"}, - {file = "setproctitle-1.3.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:2e71f6365744bf53714e8bd2522b3c9c1d83f52ffa6324bd7cbb4da707312cd8"}, - {file = "setproctitle-1.3.3-cp37-cp37m-win32.whl", hash = "sha256:7f1d36a1e15a46e8ede4e953abb104fdbc0845a266ec0e99cc0492a4364f8c44"}, - {file = "setproctitle-1.3.3-cp37-cp37m-win_amd64.whl", hash = "sha256:c9a402881ec269d0cc9c354b149fc29f9ec1a1939a777f1c858cdb09c7a261df"}, - {file = "setproctitle-1.3.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ff814dea1e5c492a4980e3e7d094286077054e7ea116cbeda138819db194b2cd"}, - {file = "setproctitle-1.3.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:accb66d7b3ccb00d5cd11d8c6e07055a4568a24c95cf86109894dcc0c134cc89"}, - {file = "setproctitle-1.3.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:554eae5a5b28f02705b83a230e9d163d645c9a08914c0ad921df363a07cf39b1"}, - {file = "setproctitle-1.3.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a911b26264dbe9e8066c7531c0591cfab27b464459c74385b276fe487ca91c12"}, - {file = "setproctitle-1.3.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2982efe7640c4835f7355fdb4da313ad37fb3b40f5c69069912f8048f77b28c8"}, - {file = "setproctitle-1.3.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df3f4274b80709d8bcab2f9a862973d453b308b97a0b423a501bcd93582852e3"}, - {file = "setproctitle-1.3.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:af2c67ae4c795d1674a8d3ac1988676fa306bcfa1e23fddb5e0bd5f5635309ca"}, - {file = "setproctitle-1.3.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:af4061f67fd7ec01624c5e3c21f6b7af2ef0e6bab7fbb43f209e6506c9ce0092"}, - {file = "setproctitle-1.3.3-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:37a62cbe16d4c6294e84670b59cf7adcc73faafe6af07f8cb9adaf1f0e775b19"}, - {file = "setproctitle-1.3.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:a83ca086fbb017f0d87f240a8f9bbcf0809f3b754ee01cec928fff926542c450"}, - {file = "setproctitle-1.3.3-cp38-cp38-win32.whl", hash = "sha256:059f4ce86f8cc92e5860abfc43a1dceb21137b26a02373618d88f6b4b86ba9b2"}, - {file = "setproctitle-1.3.3-cp38-cp38-win_amd64.whl", hash = "sha256:ab92e51cd4a218208efee4c6d37db7368fdf182f6e7ff148fb295ecddf264287"}, - {file = "setproctitle-1.3.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c7951820b77abe03d88b114b998867c0f99da03859e5ab2623d94690848d3e45"}, - {file = "setproctitle-1.3.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5bc94cf128676e8fac6503b37763adb378e2b6be1249d207630f83fc325d9b11"}, - {file = "setproctitle-1.3.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f5d9027eeda64d353cf21a3ceb74bb1760bd534526c9214e19f052424b37e42"}, - {file = "setproctitle-1.3.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e4a8104db15d3462e29d9946f26bed817a5b1d7a47eabca2d9dc2b995991503"}, - {file = "setproctitle-1.3.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c32c41ace41f344d317399efff4cffb133e709cec2ef09c99e7a13e9f3b9483c"}, - {file = "setproctitle-1.3.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cbf16381c7bf7f963b58fb4daaa65684e10966ee14d26f5cc90f07049bfd8c1e"}, - {file = "setproctitle-1.3.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:e18b7bd0898398cc97ce2dfc83bb192a13a087ef6b2d5a8a36460311cb09e775"}, - {file = "setproctitle-1.3.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:69d565d20efe527bd8a9b92e7f299ae5e73b6c0470f3719bd66f3cd821e0d5bd"}, - {file = "setproctitle-1.3.3-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:ddedd300cd690a3b06e7eac90ed4452348b1348635777ce23d460d913b5b63c3"}, - {file = "setproctitle-1.3.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:415bfcfd01d1fbf5cbd75004599ef167a533395955305f42220a585f64036081"}, - {file = "setproctitle-1.3.3-cp39-cp39-win32.whl", hash = "sha256:21112fcd2195d48f25760f0eafa7a76510871bbb3b750219310cf88b04456ae3"}, - {file = "setproctitle-1.3.3-cp39-cp39-win_amd64.whl", hash = "sha256:5a740f05d0968a5a17da3d676ce6afefebeeeb5ce137510901bf6306ba8ee002"}, - {file = "setproctitle-1.3.3-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:6b9e62ddb3db4b5205c0321dd69a406d8af9ee1693529d144e86bd43bcb4b6c0"}, - {file = "setproctitle-1.3.3-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9e3b99b338598de0bd6b2643bf8c343cf5ff70db3627af3ca427a5e1a1a90dd9"}, - {file = "setproctitle-1.3.3-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:38ae9a02766dad331deb06855fb7a6ca15daea333b3967e214de12cfae8f0ef5"}, - {file = "setproctitle-1.3.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:200ede6fd11233085ba9b764eb055a2a191fb4ffb950c68675ac53c874c22e20"}, - {file = "setproctitle-1.3.3-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:0d3a953c50776751e80fe755a380a64cb14d61e8762bd43041ab3f8cc436092f"}, - {file = "setproctitle-1.3.3-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e5e08e232b78ba3ac6bc0d23ce9e2bee8fad2be391b7e2da834fc9a45129eb87"}, - {file = "setproctitle-1.3.3-pp37-pypy37_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f1da82c3e11284da4fcbf54957dafbf0655d2389cd3d54e4eaba636faf6d117a"}, - {file = "setproctitle-1.3.3-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:aeaa71fb9568ebe9b911ddb490c644fbd2006e8c940f21cb9a1e9425bd709574"}, - {file = "setproctitle-1.3.3-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:59335d000c6250c35989394661eb6287187854e94ac79ea22315469ee4f4c244"}, - {file = "setproctitle-1.3.3-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c3ba57029c9c50ecaf0c92bb127224cc2ea9fda057b5d99d3f348c9ec2855ad3"}, - {file = "setproctitle-1.3.3-pp38-pypy38_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d876d355c53d975c2ef9c4f2487c8f83dad6aeaaee1b6571453cb0ee992f55f6"}, - {file = "setproctitle-1.3.3-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:224602f0939e6fb9d5dd881be1229d485f3257b540f8a900d4271a2c2aa4e5f4"}, - {file = "setproctitle-1.3.3-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:d7f27e0268af2d7503386e0e6be87fb9b6657afd96f5726b733837121146750d"}, - {file = "setproctitle-1.3.3-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f5e7266498cd31a4572378c61920af9f6b4676a73c299fce8ba93afd694f8ae7"}, - {file = "setproctitle-1.3.3-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:33c5609ad51cd99d388e55651b19148ea99727516132fb44680e1f28dd0d1de9"}, - {file = "setproctitle-1.3.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:eae8988e78192fd1a3245a6f4f382390b61bce6cfcc93f3809726e4c885fa68d"}, - {file = "setproctitle-1.3.3.tar.gz", hash = "sha256:c913e151e7ea01567837ff037a23ca8740192880198b7fbb90b16d181607caae"}, -] - -[package.extras] -test = ["pytest"] - [[package]] name = "setuptools" version = "68.2.2" @@ -5837,18 +5691,6 @@ gym = ["gym (>=0.21)"] openspiel = ["open-spiel (>=1.2)", "pettingzoo (>=1.22)"] testing = ["autorom[accept-rom-license] (>=0.4.2,<0.5.0)", "pillow (>=9.3.0)", "pytest (==7.1.3)"] -[[package]] -name = "shortuuid" -version = "1.0.11" -description = "A generator library for concise, unambiguous and URL-safe UUIDs." -optional = false -python-versions = ">=3.5" -groups = ["dev"] -files = [ - {file = "shortuuid-1.0.11-py3-none-any.whl", hash = "sha256:27ea8f28b1bd0bf8f15057a3ece57275d2059d2b0bb02854f02189962c13b6aa"}, - {file = "shortuuid-1.0.11.tar.gz", hash = "sha256:fc75f2615914815a8e4cb1501b3a513745cb66ef0fd5fc6fb9f8c3fa3481f789"}, -] - [[package]] name = "six" version = "1.16.0" @@ -6899,14 +6741,14 @@ files = [ [[package]] name = "typing-extensions" -version = "4.8.0" -description = "Backported and Experimental Type Hints for Python 3.8+" +version = "4.15.0" +description = "Backported and Experimental Type Hints for Python 3.9+" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" groups = ["main", "dev"] files = [ - {file = "typing_extensions-4.8.0-py3-none-any.whl", hash = "sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0"}, - {file = "typing_extensions-4.8.0.tar.gz", hash = "sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef"}, + {file = "typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548"}, + {file = "typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466"}, ] [[package]] @@ -7052,42 +6894,48 @@ test = ["psutil", "pytest"] [[package]] name = "wandb" -version = "0.12.21" -description = "A CLI and library for interacting with the Weights and Biases API." +version = "0.22.2" +description = "A CLI and library for interacting with the Weights & Biases API." optional = false -python-versions = ">=3.6" +python-versions = ">=3.8" groups = ["dev"] files = [ - {file = "wandb-0.12.21-py2.py3-none-any.whl", hash = "sha256:150842447d355d90dc7f368b824951a625e5b2d1be355a00e99b11b73728bc1f"}, - {file = "wandb-0.12.21.tar.gz", hash = "sha256:1975ff88c5024923c3321c93cfefb8d9b871543c0b009f34001bf0f31e444b04"}, + {file = "wandb-0.22.2-py3-none-macosx_12_0_arm64.whl", hash = "sha256:2e29c9fa4462b5411b2cd2175ae33eff4309c91de7c426bca6bc8e7abc7e5dec"}, + {file = "wandb-0.22.2-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:c42d594cd7a9da4fd39ecdb0abbc081b61f304123277b2b6c4ba84283956fd21"}, + {file = "wandb-0.22.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5188d84e66d3fd584f3b3ae4d2a70e78f29403c0528e6aecaa4188a1fcf54d8"}, + {file = "wandb-0.22.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:88ccd484af9f21cfc127976793c3cf66cfe1acd75bd8cd650086a64e88bac4bf"}, + {file = "wandb-0.22.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:abf0ed175e791af64110e0a0b99ce02bbbbd1017722bc32d3bc328efb86450cd"}, + {file = "wandb-0.22.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:44e77c56403b90bf3473a7ca3bfc4d42c636b7c0e31a5fb9cd0382f08302f74b"}, + {file = "wandb-0.22.2-py3-none-win32.whl", hash = "sha256:44d12bd379dbe15be5ceed6bdf23803d42f648ba0dd111297b4c47a3c7be6dbd"}, + {file = "wandb-0.22.2-py3-none-win_amd64.whl", hash = "sha256:c95eb221bf316c0872f7ac55071856b9f25f95a2de983ada48acf653ce259386"}, + {file = "wandb-0.22.2-py3-none-win_arm64.whl", hash = "sha256:20d2ab9aa10445aab3d60914a980f002a4f66566e28b0cd156b1e462f0080a0d"}, + {file = "wandb-0.22.2.tar.gz", hash = "sha256:510f5a1ac30d16921c36c3b932da852f046641d4aee98a86a7f5ec03a6e95bda"}, ] [package.dependencies] -Click = ">=7.0,<8.0.0 || >8.0.0" -docker-pycreds = ">=0.4.0" -GitPython = ">=1.0.0" -pathtools = "*" -promise = ">=2.0,<3" -protobuf = ">=3.12.0,<4.0dev" -psutil = ">=5.0.0" -PyYAML = "*" +click = ">=8.0.1" +gitpython = ">=1.0.0,<3.1.29 || >3.1.29" +packaging = "*" +platformdirs = "*" +protobuf = {version = ">=3.19.0,<4.21.0 || >4.21.0,<5.28.0 || >5.28.0,<7", markers = "python_version > \"3.9\" or sys_platform != \"linux\""} +pydantic = "<3" +pyyaml = "*" requests = ">=2.0.0,<3" -sentry-sdk = ">=1.0.0" -setproctitle = "*" -setuptools = "*" -shortuuid = ">=0.5.0" -six = ">=1.13.0" +sentry-sdk = ">=2.0.0" +typing-extensions = ">=4.8,<5" [package.extras] -aws = ["boto3"] -azure = ["azure-storage-blob"] +aws = ["boto3", "botocore (>=1.5.76)"] +azure = ["azure-identity", "azure-storage-blob"] gcp = ["google-cloud-storage"] -grpc = ["grpcio (>=1.27.2)"] +importers = ["filelock", "mlflow", "polars (<=1.2.1)", "rich", "tenacity"] kubeflow = ["google-cloud-storage", "kubernetes", "minio", "sh"] -launch = ["boto3", "chardet", "google-cloud-storage", "iso8601", "kubernetes", "nbconvert", "nbformat", "typing-extensions"] -media = ["bokeh", "moviepy", "numpy", "pillow", "plotly", "rdkit-pypi", "soundfile"] +launch = ["awscli", "azure-containerregistry", "azure-identity", "azure-storage-blob", "boto3", "botocore (>=1.5.76)", "chardet", "google-auth", "google-cloud-aiplatform", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "jsonschema", "kubernetes", "kubernetes-asyncio", "nbconvert", "nbformat", "optuna", "pydantic", "pyyaml (>=6.0.0)", "tomli", "tornado (>=6.5.0) ; python_version >= \"3.9\"", "typing-extensions"] +media = ["bokeh", "imageio (>=2.28.1)", "moviepy (>=1.0.0)", "numpy", "pillow", "plotly (>=5.18.0)", "rdkit", "soundfile"] models = ["cloudpickle"] -sweeps = ["sweeps (>=0.1.0)"] +perf = ["orjson"] +sweeps = ["sweeps (>=0.2.0)"] +workspaces = ["wandb-workspaces"] [[package]] name = "wcwidth" @@ -7222,4 +7070,4 @@ vizdoom = ["vizdoom"] [metadata] lock-version = "2.1" python-versions = "^3.11" -content-hash = "575f58bac92d215908d074f946b8593cbefaf83f965beed396253d8d3f38eea7" +content-hash = "6a5ae8b5b701f0daee90e241187c1628477b6ac96394a3cb15f2921659e80e34" diff --git a/pyproject.toml b/pyproject.toml index 2b76846b1..371159e92 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,16 +26,16 @@ exclude = ["test/*", "examples/*", "docs/*"] [tool.poetry.dependencies] python = "^3.11" deepdiff = "^7.0.1" -gymnasium = "^0.28.0" +gymnasium = ">=0.28.0" h5py = "^3.9.0" matplotlib = ">=3.0.0" numba = ">=0.60.0" -numpy = "^1" +numpy = ">=1.24.4" overrides = "^7.4.0" packaging = "*" pandas = ">=2.0.0" pettingzoo = "^1.22" -sensai-utils = "^1.4.0" +sensai-utils = ">=1.6.0" tensorboard = "^2.5.0" # Torch 2.0.1 causes problems, see https://github.com/pytorch/pytorch/issues/100974 torch = "^2.0.0, !=2.0.1, !=2.1.0" @@ -87,8 +87,8 @@ eval = ["rliable", "joblib", "scipy", "jsonargparse", "docstring-parser"] [tool.poetry.group.dev] optional = true + [tool.poetry.group.dev.dependencies] -black = { version = ">=23.7,<25.0", extras = ["jupyter"] } docutils = "0.20.1" jinja2 = "*" jupyter = "^1.0.0" @@ -106,8 +106,8 @@ pytest = "*" pytest-cov = "*" # Ray currently causes issues when installed on windows server 2022 in CI # If users want to use ray, they should install it manually. -ray = { version = "^2", markers = "sys_platform != 'win32'" } -ruff = "^0.0.285" +ray = { version = ">=2.10, <3", markers = "sys_platform != 'win32'" } +ruff = "0.14.1" scipy = "*" sphinx = "^7" sphinx-book-theme = "^1.0.1" @@ -120,7 +120,9 @@ sphinxcontrib-bibtex = "*" sphinxcontrib-spelling = "^8.0.0" types-requests = "^2.31.0.20240311" types-tabulate = "^0.9.0.20240106" -wandb = "^0.12.0" +# this is needed for wandb only (undisclosed dependency) +typing-extensions = ">=4.10" +wandb = ">=0.16.0" [tool.mypy] allow_redefinition = true @@ -145,21 +147,22 @@ exclude = "^build/|^docs/" [tool.doc8] max-line-length = 1000 -[tool.black] -line-length = 100 -target-version = ["py311"] [tool.nbqa.exclude] ruff = "\\.jupyter_cache|jupyter_execute" mypy = "\\.jupyter_cache|jupyter_execute" [tool.ruff] +target-version = "py311" +line-length = 100 + +[tool.ruff.lint] select = [ "ASYNC", "B", "C4", "C90", "COM", "D", "DTZ", "E", "F", "FLY", "G", "I", "ISC", "PIE", "PLC", "PLE", "PLW", "RET", "RUF", "RSE", "SIM", "TID", "UP", "W", "YTT", ] ignore = [ "SIM118", # Needed b/c iter(batch) != iter(batch.keys()). See https://github.com/thu-ml/tianshou/issues/922 - "E501", # line too long. black does a good enough job + "E501", # line too long. ruff does a good enough job "E741", # variable names like "l". this isn't a huge problem "B008", # do not perform function calls in argument defaults. we do this sometimes "B011", # assert false. we don't use python -O @@ -182,6 +185,14 @@ ignore = [ "D407", "D408", "D409", # Ruff rules for underlines under 'Example:' and so clash with Sphinx "COM812", # missing trailing comma: With this enabled, re-application of "poe format" chain can cause additional commas and subsequent reformatting "B023", # forbids function using loop variable without explicit binding + "RUF059", # unused name after unpacking + "RUF005", # concatenation + "PLC0415", # local imports + "SIM108", # if else is fine instead of ternary + "PLW1641", # weird thing requiring __hash__ for Protocol + "PLC0206", # extracting value from dictionary without calling `.items()` + "SIM103", # forces returning of conditions instead of booleans + "E721", # forbids use of equality for type checks ] unfixable = [ "F841", # unused variable. ruff keeps the call, but mostly we want to get rid of it all @@ -194,12 +205,10 @@ extend-fixable = [ "B905", # bugbear ] -target-version = "py311" - -[tool.ruff.mccabe] +[tool.ruff.lint.mccabe] max-complexity = 20 -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "test/**" = ["D103"] "docs/**" = ["D103"] "examples/**" = ["D103"] @@ -213,18 +222,15 @@ PYDEVD_DISABLE_FILE_VALIDATION="1" # keep relevant parts in sync with pre-commit [tool.poe.tasks] # https://github.com/nat-n/poethepoet test = "pytest test" +test-nocov = "pytest -p no:cov test" test-reduced = "pytest test/base test/continuous --cov=tianshou --durations=0 -v --color=yes" -_black_check = "black --check ." -_ruff_check = "ruff check ." -_ruff_check_nb = "nbqa ruff docs" -_black_format = "black ." -_ruff_format = "ruff --fix ." -_ruff_format_nb = "nbqa ruff --fix docs" -lint = ["_black_check", "_ruff_check", "_ruff_check_nb"] -_poetry_install_sort_plugin = "poetry self add poetry-plugin-sort" -_poetry_sort = "poetry sort" +_ruff_fix = "ruff check --fix ." +_ruff_fix_check = "ruff check ." +_ruff_format = "ruff format ." +_ruff_format_check = "ruff format --check ." +lint = ["_ruff_format_check", "_ruff_fix_check"] clean-nbs = "python docs/nbstripout.py" -format = ["_ruff_format", "_ruff_format_nb", "_black_format", "_poetry_install_sort_plugin", "_poetry_sort"] +format = ["_ruff_fix", "_ruff_format"] _autogen_rst = "python docs/autogen_rst.py" _sphinx_build = "sphinx-build -b html docs docs/_build -W --keep-going" _jb_generate_toc = "python docs/create_toc.py" diff --git a/test/base/env.py b/test/base/env.py index 618252eb0..89c27ccf6 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -34,9 +34,9 @@ def __init__( random_sleep: bool = False, array_state: bool = False, ) -> None: - assert ( - dict_state + recurse_state + array_state <= 1 - ), "dict_state / recurse_state / array_state can be only one true" + assert dict_state + recurse_state + array_state <= 1, ( + "dict_state / recurse_state / array_state can be only one true" + ) self.size = size self.sleep = sleep self.random_sleep = random_sleep @@ -208,9 +208,9 @@ def step( class MyGoalEnv(MoveToRightEnv): def __init__(self, *args: Any, **kwargs: Any) -> None: - assert ( - kwargs.get("dict_state", 0) + kwargs.get("recurse_state", 0) == 0 - ), "dict_state / recurse_state not supported" + assert kwargs.get("dict_state", 0) + kwargs.get("recurse_state", 0) == 0, ( + "dict_state / recurse_state not supported" + ) super().__init__(*args, **kwargs) super().reset(options={"state": 0}) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index f8af8c521..4d4b56338 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -1,6 +1,5 @@ import copy import pickle -import sys from itertools import starmap from typing import Any, cast @@ -402,10 +401,7 @@ def test_utils_to_torch_numpy() -> None: assert to_numpy(to_numpy).item() == to_numpy # additional test for to_torch, for code-coverage assert isinstance(to_torch(1), torch.Tensor) - if sys.platform in ["win32", "cygwin"]: # windows - assert to_torch(1).dtype == torch.int32 - else: - assert to_torch(1).dtype == torch.int64 + assert to_torch(1).dtype in (torch.int64, torch.int32) assert to_torch(1.0).dtype == torch.float64 assert isinstance(to_torch({"a": [1]})["a"], torch.Tensor) with pytest.raises(TypeError): @@ -753,10 +749,7 @@ def test_batch_over_batch_to_torch() -> None: assert batch.a.dtype == torch.float64 assert batch.b.c.dtype == torch.float32 assert batch.b.d.dtype == torch.float64 - if sys.platform in ["win32", "cygwin"]: # windows - assert batch.b.e.dtype == torch.int32 - else: - assert batch.b.e.dtype == torch.int64 + assert batch.b.e.dtype in (torch.int64, torch.int32) batch.to_torch_(dtype=torch.float32) assert batch.a.dtype == torch.float32 assert batch.b.c.dtype == torch.float32 diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 2d101fac9..0373bd932 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -1,7 +1,6 @@ import os import pickle import tempfile -from test.base.env import MoveToRightEnv, MyGoalEnv from typing import cast import h5py @@ -10,6 +9,7 @@ import pytest import torch +from test.base.env import MoveToRightEnv, MyGoalEnv from tianshou.data import ( Batch, CachedReplayBuffer, @@ -1451,10 +1451,7 @@ def test_custom_key() -> None: # Check if they have the same keys assert set(batch.get_keys()) == set( sampled_batch.get_keys(), - ), "Batches have different keys: {} and {}".format( - set(batch.get_keys()), - set(sampled_batch.get_keys()), - ) + ), f"Batches have different keys: {set(batch.get_keys())} and {set(sampled_batch.get_keys())}" # Compare the values for each key for key in batch.get_keys(): if isinstance(batch.__dict__[key], np.ndarray) and isinstance( diff --git a/test/base/test_collector.py b/test/base/test_collector.py index b88aa1fca..44264621f 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -1,5 +1,4 @@ from collections.abc import Callable, Sequence -from test.base.env import MoveToRightEnv, NXEnv from typing import Any import gymnasium as gym @@ -7,6 +6,7 @@ import pytest import tqdm +from test.base.env import MoveToRightEnv, NXEnv from tianshou.algorithm.algorithm_base import Policy, episode_mc_return_to_go from tianshou.data import ( AsyncCollector, @@ -410,11 +410,8 @@ def test_collector_with_dict_state() -> None: result = c1.collect(n_episode=8) assert result.n_collected_episodes == 8 lens = np.bincount(result.lens) - assert ( - result.n_collected_steps == 21 - and np.all(lens == [0, 0, 2, 2, 2, 2]) - or result.n_collected_steps == 20 - and np.all(lens == [0, 0, 3, 1, 2, 2]) + assert (result.n_collected_steps == 21 and np.all(lens == [0, 0, 2, 2, 2, 2])) or ( + result.n_collected_steps == 20 and np.all(lens == [0, 0, 3, 1, 2, 2]) ) batch, _ = c1.buffer.sample(10) c0.buffer.update(c1.buffer) diff --git a/test/base/test_env.py b/test/base/test_env.py index 1a33e861c..2fe13b9b5 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -1,7 +1,6 @@ import sys import time from collections.abc import Callable -from test.base.env import MoveToRightEnv, NXEnv from typing import Any, Literal import gymnasium as gym @@ -9,6 +8,7 @@ import pytest from gymnasium.spaces.discrete import Discrete +from test.base.env import MoveToRightEnv, NXEnv from tianshou.data import Batch from tianshou.env import ( ContinuousToDiscrete, @@ -213,8 +213,6 @@ def test_attr_unwrapped() -> None: train_envs = DummyVectorEnv([lambda: gym.make("CartPole-v1")]) train_envs.set_env_attr("test_attribute", 1337) assert train_envs.get_env_attr("test_attribute") == [1337] - # mypy doesn't know but BaseVectorEnv takes the reserved keys in gym.Env (one of which is env) - assert hasattr(train_envs.workers[0].env, "test_attribute") # type: ignore assert hasattr(train_envs.workers[0].env.unwrapped, "test_attribute") # type: ignore diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 325b4ce23..e6367f96c 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -1,12 +1,12 @@ import argparse import os -from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter +from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import DDPG from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index e2ce35cd0..dd85c0cab 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -1,6 +1,5 @@ import argparse import os -from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -9,6 +8,7 @@ from torch.distributions import Distribution, Independent, Normal from torch.utils.tensorboard import SummaryWriter +from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import NPG from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 686fa86d7..fcdd150aa 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -1,6 +1,5 @@ import argparse import os -from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -8,6 +7,7 @@ from torch.distributions import Distribution, Independent, Normal from torch.utils.tensorboard import SummaryWriter +from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import PPO from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index 5a2538033..9e60e5d43 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -1,6 +1,5 @@ import argparse import os -from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -8,6 +7,7 @@ import torch.nn as nn from torch.utils.tensorboard import SummaryWriter +from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import REDQ from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.redq import REDQPolicy diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 5f68630fc..3609f9d4e 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -1,12 +1,12 @@ import argparse import os -from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter +from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import SAC, OffPolicyImitationLearning from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.imitation.imitation_base import ImitationPolicy diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index df532aed4..8153c8cf0 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -1,12 +1,12 @@ import argparse import os -from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter +from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import TD3 from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index b5e24ad30..b06d761ea 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -1,6 +1,5 @@ import argparse import os -from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -9,6 +8,7 @@ from torch.distributions import Distribution, Independent, Normal from torch.utils.tensorboard import SummaryWriter +from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import TRPO from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index e865b0c7c..a6a69183b 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -1,6 +1,5 @@ import argparse import os -from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -8,6 +7,7 @@ from gymnasium.spaces import Box from torch.utils.tensorboard import SummaryWriter +from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import A2C, Algorithm, OffPolicyImitationLearning from tianshou.algorithm.imitation.imitation_base import ImitationPolicy from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy diff --git a/test/discrete/test_bdqn.py b/test/discrete/test_bdqn.py index ef766707c..ffed70d43 100644 --- a/test/discrete/test_bdqn.py +++ b/test/discrete/test_bdqn.py @@ -1,10 +1,10 @@ import argparse -from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np import torch +from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import BDQN from tianshou.algorithm.modelfree.bdqn import BDQNPolicy from tianshou.algorithm.optim import AdamOptimizerFactory diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 5e0977e6d..6d3fba184 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -1,13 +1,13 @@ import argparse import os import pickle -from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter +from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import C51 from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.c51 import C51Policy diff --git a/test/discrete/test_discrete_sac.py b/test/discrete/test_discrete_sac.py index ffd59afa5..6f5ddc5f2 100644 --- a/test/discrete/test_discrete_sac.py +++ b/test/discrete/test_discrete_sac.py @@ -1,12 +1,12 @@ import argparse import os -from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter +from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import DiscreteSAC from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.discrete_sac import ( diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 092763e5b..e58c52ac8 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -1,12 +1,12 @@ import argparse import os -from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter +from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import DQN from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 89b6185f8..48d91bb7a 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -1,12 +1,12 @@ import argparse import os -from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter +from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import DQN from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 19fb5768b..20192a06b 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -1,12 +1,12 @@ import argparse import os -from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter +from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import FQF from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.fqf import FQFPolicy diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index 0dadeb6bd..e760de975 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -1,12 +1,12 @@ import argparse import os -from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter +from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import IQN from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.iqn import IQNPolicy diff --git a/test/discrete/test_ppo_discrete.py b/test/discrete/test_ppo_discrete.py index cb1e31c9f..634c1bd09 100644 --- a/test/discrete/test_ppo_discrete.py +++ b/test/discrete/test_ppo_discrete.py @@ -1,12 +1,12 @@ import argparse import os -from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter +from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import PPO from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.reinforce import DiscreteActorPolicy diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index f44562b2b..fd829cd88 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -1,12 +1,12 @@ import argparse import os -from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter +from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import QRDQN from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index 4666d2299..079e594fe 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -1,13 +1,13 @@ import argparse import os import pickle -from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter +from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import RainbowDQN from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.c51 import C51Policy diff --git a/test/discrete/test_reinforce.py b/test/discrete/test_reinforce.py index df0e1cf53..93981a557 100644 --- a/test/discrete/test_reinforce.py +++ b/test/discrete/test_reinforce.py @@ -1,6 +1,5 @@ import argparse import os -from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -8,6 +7,7 @@ from gymnasium.spaces import Box from torch.utils.tensorboard import SummaryWriter +from test.determinism_test import AlgorithmDeterminismTest from tianshou.algorithm import Reinforce from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy diff --git a/test/highlevel/test_experiment_builder.py b/test/highlevel/test_experiment_builder.py index b02438787..9690c69c3 100644 --- a/test/highlevel/test_experiment_builder.py +++ b/test/highlevel/test_experiment_builder.py @@ -1,7 +1,6 @@ -from test.highlevel.env_factory import ContinuousTestEnvFactory, DiscreteTestEnvFactory - import pytest +from test.highlevel.env_factory import ContinuousTestEnvFactory, DiscreteTestEnvFactory from tianshou.highlevel.config import ( OffPolicyTrainingConfig, OnPolicyTrainingConfig, diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index a33442a49..136b7f376 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -2,14 +2,14 @@ import datetime import os import pickle -from test.determinism_test import AlgorithmDeterminismTest -from test.offline.gather_pendulum_data import expert_file_name, gather_data import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter +from test.determinism_test import AlgorithmDeterminismTest +from test.offline.gather_pendulum_data import expert_file_name, gather_data from tianshou.algorithm import BCQ, Algorithm from tianshou.algorithm.imitation.bcq import BCQPolicy from tianshou.algorithm.optim import AdamOptimizerFactory @@ -170,7 +170,7 @@ def test_bcq(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr # logger t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") - log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_bcq' + log_file = f"seed_{args.seed}_{t0}-{args.task.replace('-', '_')}_bcq" log_path = os.path.join(args.logdir, args.task, "bcq", log_file) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index d320b2bfb..18c05ef6d 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -2,14 +2,14 @@ import datetime import os import pickle -from test.determinism_test import AlgorithmDeterminismTest -from test.offline.gather_pendulum_data import expert_file_name, gather_data import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter +from test.determinism_test import AlgorithmDeterminismTest +from test.offline.gather_pendulum_data import expert_file_name, gather_data from tianshou.algorithm import CQL, Algorithm from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy from tianshou.algorithm.optim import AdamOptimizerFactory @@ -169,7 +169,7 @@ def test_cql(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr test_collector = Collector[CollectStats](algorithm, test_envs) # log t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") - log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_cql' + log_file = f"seed_{args.seed}_{t0}-{args.task.replace('-', '_')}_cql" log_path = os.path.join(args.logdir, args.task, "cql", log_file) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 5c48ea017..64eeb4247 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -1,14 +1,14 @@ import argparse import os import pickle -from test.determinism_test import AlgorithmDeterminismTest -from test.offline.gather_cartpole_data import expert_file_name, gather_data import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter +from test.determinism_test import AlgorithmDeterminismTest +from test.offline.gather_cartpole_data import expert_file_name, gather_data from tianshou.algorithm import Algorithm, DiscreteBCQ from tianshou.algorithm.imitation.discrete_bcq import DiscreteBCQPolicy from tianshou.algorithm.optim import AdamOptimizerFactory diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index 12c0af017..f13cfdb3a 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -1,14 +1,14 @@ import argparse import os import pickle -from test.determinism_test import AlgorithmDeterminismTest -from test.offline.gather_cartpole_data import expert_file_name, gather_data import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter +from test.determinism_test import AlgorithmDeterminismTest +from test.offline.gather_cartpole_data import expert_file_name, gather_data from tianshou.algorithm import Algorithm, DiscreteCQL from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy from tianshou.algorithm.optim import AdamOptimizerFactory diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index b547cc3d5..cdc618cc0 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -1,14 +1,14 @@ import argparse import os import pickle -from test.determinism_test import AlgorithmDeterminismTest -from test.offline.gather_cartpole_data import expert_file_name, gather_data import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter +from test.determinism_test import AlgorithmDeterminismTest +from test.offline.gather_cartpole_data import expert_file_name, gather_data from tianshou.algorithm import Algorithm, DiscreteCRR from tianshou.algorithm.modelfree.reinforce import DiscreteActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index a54fb5d3d..296067826 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -1,8 +1,6 @@ import argparse import os import pickle -from test.determinism_test import AlgorithmDeterminismTest -from test.offline.gather_pendulum_data import expert_file_name, gather_data import gymnasium as gym import numpy as np @@ -10,6 +8,8 @@ from torch.distributions import Distribution, Independent, Normal from torch.utils.tensorboard import SummaryWriter +from test.determinism_test import AlgorithmDeterminismTest +from test.offline.gather_pendulum_data import expert_file_name, gather_data from tianshou.algorithm import GAIL, Algorithm from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index dfe4d6b70..528c71b91 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -2,14 +2,14 @@ import datetime import os import pickle -from test.determinism_test import AlgorithmDeterminismTest -from test.offline.gather_pendulum_data import expert_file_name, gather_data import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter +from test.determinism_test import AlgorithmDeterminismTest +from test.offline.gather_pendulum_data import expert_file_name, gather_data from tianshou.algorithm import TD3BC from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy @@ -159,7 +159,7 @@ def test_td3_bc(args: argparse.Namespace = get_args(), enable_assertions: bool = # logger t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") - log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_td3_bc' + log_file = f"seed_{args.seed}_{t0}-{args.task.replace('-', '_')}_td3_bc" log_path = os.path.join(args.logdir, args.task, "td3_bc", log_file) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 5f0fb5839..b468c49ab 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -27,11 +27,7 @@ class DQNet(ModuleWithVectorOutput): - """Reference: Human-level control through deep reinforcement learning. - - For advanced usage (how to customize the network), please refer to - :ref:`build_the_network`. - """ + """Reference: Human-level control through deep reinforcement learning.""" def __init__( self, diff --git a/tianshou/__init__.py b/tianshou/__init__.py index 3c719b1eb..6870c5a5f 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -18,10 +18,10 @@ def configure() -> None: __all__ = [ - "env", - "data", - "utils", "algorithm", - "trainer", + "data", + "env", "exploration", + "trainer", + "utils", ] diff --git a/tianshou/algorithm/algorithm_base.py b/tianshou/algorithm/algorithm_base.py index 50884d646..efeafbbc6 100644 --- a/tianshou/algorithm/algorithm_base.py +++ b/tianshou/algorithm/algorithm_base.py @@ -206,8 +206,7 @@ def __init__( ) if action_scaling and not isinstance(action_space, Box): raise ValueError( - f"action_scaling can only be True when action_space is Box but " - f"got: {action_space}", + f"action_scaling can only be True when action_space is Box but got: {action_space}", ) super().__init__() self.observation_space = observation_space @@ -280,9 +279,9 @@ def map_action( elif self.action_bound_method == "tanh": act = np.tanh(act) if self.action_scaling: - assert ( - np.min(act) >= -1.0 and np.max(act) <= 1.0 - ), f"action scaling only accepts raw action range = [-1, 1], but got: {act}" + assert np.min(act) >= -1.0 and np.max(act) <= 1.0, ( + f"action scaling only accepts raw action range = [-1, 1], but got: {act}" + ) low, high = self.action_space.low, self.action_space.high act = low + (high - low) * (act + 1.0) / 2.0 return act @@ -452,7 +451,7 @@ def __init__( super().__init__() self.policy: TPolicy = policy self.lr_schedulers: list[LRScheduler] = [] - self._optimizers: list["Algorithm.Optimizer"] = [] + self._optimizers: list[Algorithm.Optimizer] = [] """ list of optimizers associated with the algorithm (created via `_create_optimizer`), whose states will be returned when calling `state_dict` and which will be restored diff --git a/tianshou/algorithm/imitation/discrete_bcq.py b/tianshou/algorithm/imitation/discrete_bcq.py index 4a8824e81..7e9698dd0 100644 --- a/tianshou/algorithm/imitation/discrete_bcq.py +++ b/tianshou/algorithm/imitation/discrete_bcq.py @@ -90,12 +90,12 @@ def __init__( eps_inference=eps_inference, ) self.imitator = imitator - assert ( - target_update_freq > 0 - ), f"BCQ needs target_update_freq>0 but got: {target_update_freq}." - assert ( - 0.0 <= unlikely_action_threshold < 1.0 - ), f"unlikely_action_threshold should be in [0, 1) but got: {unlikely_action_threshold}" + assert target_update_freq > 0, ( + f"BCQ needs target_update_freq>0 but got: {target_update_freq}." + ) + assert 0.0 <= unlikely_action_threshold < 1.0, ( + f"unlikely_action_threshold should be in [0, 1) but got: {unlikely_action_threshold}" + ) if unlikely_action_threshold > 0: self._log_tau = math.log(unlikely_action_threshold) else: @@ -199,9 +199,9 @@ def __init__( self.optim = self._create_optimizer(self.policy, optim) assert 0.0 <= gamma <= 1.0, f"discount factor should be in [0, 1] but got: {gamma}" self.gamma = gamma - assert ( - n_step_return_horizon > 0 - ), f"n_step_return_horizon should be greater than 0 but got: {n_step_return_horizon}" + assert n_step_return_horizon > 0, ( + f"n_step_return_horizon should be greater than 0 but got: {n_step_return_horizon}" + ) self.n_step = n_step_return_horizon self._target = target_update_freq > 0 self.freq = target_update_freq diff --git a/tianshou/algorithm/modelfree/dqn.py b/tianshou/algorithm/modelfree/dqn.py index 530de6ef0..57ef0645f 100644 --- a/tianshou/algorithm/modelfree/dqn.py +++ b/tianshou/algorithm/modelfree/dqn.py @@ -232,9 +232,9 @@ def __init__( LaggedNetworkFullUpdateAlgorithmMixin.__init__(self) assert 0.0 <= gamma <= 1.0, f"discount factor should be in [0, 1] but got: {gamma}" self.gamma = gamma - assert ( - n_step_return_horizon > 0 - ), f"n_step_return_horizon should be greater than 0 but got: {n_step_return_horizon}" + assert n_step_return_horizon > 0, ( + f"n_step_return_horizon should be greater than 0 but got: {n_step_return_horizon}" + ) self.n_step = n_step_return_horizon self.target_update_freq = target_update_freq # TODO: 1 would be a more reasonable initialization given how it is incremented diff --git a/tianshou/algorithm/modelfree/iqn.py b/tianshou/algorithm/modelfree/iqn.py index 4cd69b80e..0ca8c263b 100644 --- a/tianshou/algorithm/modelfree/iqn.py +++ b/tianshou/algorithm/modelfree/iqn.py @@ -52,12 +52,12 @@ def __init__( """ assert isinstance(action_space, gym.spaces.Discrete) assert sample_size > 1, f"sample_size should be greater than 1 but got: {sample_size}" - assert ( - online_sample_size > 1 - ), f"online_sample_size should be greater than 1 but got: {online_sample_size}" - assert ( - target_sample_size > 1 - ), f"target_sample_size should be greater than 1 but got: {target_sample_size}" + assert online_sample_size > 1, ( + f"online_sample_size should be greater than 1 but got: {online_sample_size}" + ) + assert target_sample_size > 1, ( + f"target_sample_size should be greater than 1 but got: {target_sample_size}" + ) super().__init__( model=model, action_space=action_space, diff --git a/tianshou/algorithm/modelfree/ppo.py b/tianshou/algorithm/modelfree/ppo.py index ede1d3418..8749a86a0 100644 --- a/tianshou/algorithm/modelfree/ppo.py +++ b/tianshou/algorithm/modelfree/ppo.py @@ -121,9 +121,9 @@ def __init__( Best used in environments where the relative ordering of actions is more important than the absolute scale of returns. """ - assert ( - dual_clip is None or dual_clip > 1.0 - ), f"Dual-clip PPO parameter should greater than 1.0 but got {dual_clip}" + assert dual_clip is None or dual_clip > 1.0, ( + f"Dual-clip PPO parameter should greater than 1.0 but got {dual_clip}" + ) super().__init__( policy=policy, diff --git a/tianshou/algorithm/multiagent/marl.py b/tianshou/algorithm/multiagent/marl.py index 1c30a1cbc..1913a1e76 100644 --- a/tianshou/algorithm/multiagent/marl.py +++ b/tianshou/algorithm/multiagent/marl.py @@ -63,15 +63,12 @@ class MAPRolloutBatchProtocol(RolloutBatchProtocol, Protocol): # TODO: this might not be entirely correct. # The whole MAP data processing pipeline needs more documentation and possibly some refactoring @overload - def __getitem__(self, index: str) -> RolloutBatchProtocol: - ... + def __getitem__(self, index: str) -> RolloutBatchProtocol: ... @overload - def __getitem__(self, index: IndexType) -> Self: - ... + def __getitem__(self, index: IndexType) -> Self: ... - def __getitem__(self, index: str | IndexType) -> Any: - ... + def __getitem__(self, index: str | IndexType) -> Any: ... class MultiAgentPolicy(Policy): diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index 7e1d5298d..b19f3a7e7 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -33,28 +33,28 @@ ) __all__ = [ + "AsyncCollector", + "BaseCollector", "Batch", - "to_numpy", - "to_torch", - "to_torch_as", - "SegmentTree", - "ReplayBuffer", - "PrioritizedReplayBuffer", - "HERReplayBuffer", - "ReplayBufferManager", - "PrioritizedReplayBufferManager", - "HERReplayBufferManager", - "VectorReplayBuffer", - "PrioritizedVectorReplayBuffer", - "HERVectorReplayBuffer", "CachedReplayBuffer", - "Collector", "CollectStats", "CollectStatsBase", - "AsyncCollector", + "Collector", "EpochStats", + "HERReplayBuffer", + "HERReplayBufferManager", + "HERVectorReplayBuffer", "InfoStats", + "PrioritizedReplayBuffer", + "PrioritizedReplayBufferManager", + "PrioritizedVectorReplayBuffer", + "ReplayBuffer", + "ReplayBufferManager", + "SegmentTree", "SequenceSummaryStats", "TimingStats", - "BaseCollector", + "VectorReplayBuffer", + "to_numpy", + "to_torch", + "to_torch_as", ] diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 8eb88e939..ffdcf0efe 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -582,24 +582,21 @@ def dropnull(self) -> Self: def apply_values_transform( self, values_transform: Callable[[np.ndarray | torch.Tensor], Any], - ) -> Self: - ... + ) -> Self: ... @overload def apply_values_transform( self, values_transform: Callable, inplace: Literal[True], - ) -> None: - ... + ) -> None: ... @overload def apply_values_transform( self, values_transform: Callable[[np.ndarray | torch.Tensor], Any], inplace: Literal[False], - ) -> Self: - ... + ) -> Self: ... def apply_values_transform( self, @@ -713,12 +710,10 @@ def __setstate__(self, state: dict[str, Any]) -> None: self.__init__(**state) # type: ignore @overload - def __getitem__(self, index: str) -> Any: - ... + def __getitem__(self, index: str) -> Any: ... @overload - def __getitem__(self, index: IndexType) -> Self: - ... + def __getitem__(self, index: IndexType) -> Self: ... def __getitem__(self, index: str | IndexType) -> Any: """Returns either the value of a key or a sliced Batch object.""" @@ -902,8 +897,7 @@ def arr_to_torch(arr: TArr) -> TArr: # TODO: simplify if ( - dtype is not None - and arr.dtype != dtype + (dtype is not None and arr.dtype != dtype) or arr.device.type != device.type or device.index != arr.device.index ): @@ -1228,24 +1222,21 @@ def split( def apply_values_transform( self, values_transform: Callable, - ) -> Self: - ... + ) -> Self: ... @overload def apply_values_transform( self, values_transform: Callable, inplace: Literal[True], - ) -> None: - ... + ) -> None: ... @overload def apply_values_transform( self, values_transform: Callable, inplace: Literal[False], - ) -> Self: - ... + ) -> Self: ... def apply_values_transform( self, diff --git a/tianshou/data/buffer/buffer_base.py b/tianshou/data/buffer/buffer_base.py index 72c7af5bb..fe07f7834 100644 --- a/tianshou/data/buffer/buffer_base.py +++ b/tianshou/data/buffer/buffer_base.py @@ -28,8 +28,8 @@ class ReplayBuffer: ReplayBuffer can be considered as a specialized form (or management) of Batch. It stores all the data in a batch with circular-queue style. - For the example usage of ReplayBuffer, please check out Section Buffer in - :doc:`/01_tutorials/01_concepts`. + For the example usage of ReplayBuffer, please check out Section "Buffer" in + :doc:`/01_tutorials/02_internals`. :param size: the maximum size of replay buffer. :param stack_num: the frame-stack sampling argument, should be greater than or @@ -141,6 +141,7 @@ def _get_start_stop_tuples_for_edge_crossing_interval( The buffer sliced from 4 to 5 and then from 0 to 2 will contain the transitions corresponding to the provided start and stop values. + """ if stop >= start: raise ValueError( @@ -204,6 +205,7 @@ def get_buffer_indices(self, start: int, stop: int) -> np.ndarray: :param start: The start index of the interval. :param stop: The stop index of the interval. :return: The indices of the transitions in the buffer between start and stop. + """ start_left_edge = np.searchsorted(self.subbuffer_edges, start, side="right") - 1 stop_left_edge = np.searchsorted(self.subbuffer_edges, stop - 1, side="right") - 1 @@ -215,9 +217,12 @@ def get_buffer_indices(self, start: int, stop: int) -> np.ndarray: if stop >= start: return np.arange(start, stop, dtype=int) else: - (start, upper_edge), ( - lower_edge, - stop, + ( + (start, upper_edge), + ( + lower_edge, + stop, + ), ) = self._get_start_stop_tuples_for_edge_crossing_interval( start, stop, diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 640025e25..de3144786 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -1277,9 +1277,9 @@ def _collect( np.ndarray | Batch, self._current_hidden_state_in_all_envs_EH, ) - self._current_hidden_state_in_all_envs_EH[ - ready_env_ids_R - ] = collect_batch_R.hidden_state + self._current_hidden_state_in_all_envs_EH[ready_env_ids_R] = ( + collect_batch_R.hidden_state + ) else: self._current_hidden_state_in_all_envs_EH = collect_batch_R.hidden_state @@ -1420,8 +1420,7 @@ def __call__( self, action_batch: CollectActionBatchProtocol, rollout_batch: RolloutBatchProtocol, - ) -> None: - ... + ) -> None: ... class StepHookAddActionDistribution(StepHook): @@ -1473,8 +1472,7 @@ class EpisodeRolloutHook(EpisodeRolloutHookProtocol, ABC): """ @abstractmethod - def __call__(self, episode_batch: EpisodeBatchProtocol) -> dict[str, np.ndarray] | None: - ... + def __call__(self, episode_batch: EpisodeBatchProtocol) -> dict[str, np.ndarray] | None: ... class EpisodeRolloutHookMCReturn(EpisodeRolloutHook): diff --git a/tianshou/env/__init__.py b/tianshou/env/__init__.py index 049ccf439..698715cd4 100644 --- a/tianshou/env/__init__.py +++ b/tianshou/env/__init__.py @@ -17,14 +17,14 @@ __all__ = [ "BaseVectorEnv", - "DummyVectorEnv", - "SubprocVectorEnv", - "ShmemVectorEnv", - "RayVectorEnv", - "VectorEnvWrapper", - "VectorEnvNormObs", - "PettingZooEnv", "ContinuousToDiscrete", + "DummyVectorEnv", "MultiDiscreteToDiscrete", + "PettingZooEnv", + "RayVectorEnv", + "ShmemVectorEnv", + "SubprocVectorEnv", "TruncatedAsTerminated", + "VectorEnvNormObs", + "VectorEnvWrapper", ] diff --git a/tianshou/env/atari/atari_network.py b/tianshou/env/atari/atari_network.py index 79b862c4c..df27fc5d8 100644 --- a/tianshou/env/atari/atari_network.py +++ b/tianshou/env/atari/atari_network.py @@ -58,11 +58,7 @@ def forward( class DQNet(ActionReprNetWithVectorOutput[Any]): - """Reference: Human-level control through deep reinforcement learning. - - For advanced usage (how to customize the network), please refer to - :ref:`build_the_network`. - """ + """Reference: Human-level control through deep reinforcement learning.""" def __init__( self, @@ -127,11 +123,7 @@ def forward( class C51Net(DQNet): - """Reference: A distributional perspective on reinforcement learning. - - For advanced usage (how to customize the network), please refer to - :ref:`build_the_network`. - """ + """Reference: A distributional perspective on reinforcement learning.""" def __init__( self, @@ -160,11 +152,7 @@ def forward( class RainbowNet(DQNet): - """Reference: Rainbow: Combining Improvements in Deep Reinforcement Learning. - - For advanced usage (how to customize the network), please refer to - :ref:`build_the_network`. - """ + """Reference: Rainbow: Combining Improvements in Deep Reinforcement Learning.""" def __init__( self, @@ -221,11 +209,7 @@ def forward( class QRDQNet(DQNet): - """Reference: Distributional Reinforcement Learning with Quantile Regression. - - For advanced usage (how to customize the network), please refer to - :ref:`build_the_network`. - """ + """Reference: Distributional Reinforcement Learning with Quantile Regression.""" def __init__( self, diff --git a/tianshou/env/gym_wrappers.py b/tianshou/env/gym_wrappers.py index 1db4a286c..26e478c1f 100644 --- a/tianshou/env/gym_wrappers.py +++ b/tianshou/env/gym_wrappers.py @@ -26,7 +26,7 @@ def __init__(self, env: gym.Env, action_per_dim: int | list[int]) -> None: dtype=object, ) - def action(self, act: np.ndarray) -> np.ndarray: # type: ignore + def action(self, act: np.ndarray) -> np.ndarray: # modify act assert len(act.shape) <= 2, f"Unknown action format with shape {act.shape}." if len(act.shape) == 1: @@ -50,7 +50,7 @@ def __init__(self, env: gym.Env) -> None: self.bases[i] = self.bases[i - 1] * nvec[-i] self.action_space = gym.spaces.Discrete(np.prod(nvec)) - def action(self, act: np.ndarray) -> np.ndarray: # type: ignore + def action(self, act: np.ndarray) -> np.ndarray: converted_act = [] for b in np.flip(self.bases): converted_act.append(act // b) diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index e9309f9ec..cf07705ab 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -90,13 +90,13 @@ def __init__( self.env_num = len(env_fns) self.wait_num = wait_num or len(env_fns) - assert ( - 1 <= self.wait_num <= len(env_fns) - ), f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}" + assert 1 <= self.wait_num <= len(env_fns), ( + f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}" + ) self.timeout = timeout - assert ( - self.timeout is None or self.timeout > 0 - ), f"timeout is {timeout}, it should be positive if provided!" + assert self.timeout is None or self.timeout > 0, ( + f"timeout is {timeout}, it should be positive if provided!" + ) self.is_async = self.wait_num != len(env_fns) or timeout is not None self.waiting_conn: list[EnvWorker] = [] # environments in self.ready_id is actually ready @@ -109,9 +109,9 @@ def __init__( self.is_closed = False def _assert_is_not_closed(self) -> None: - assert ( - not self.is_closed - ), f"Methods of {self.__class__.__name__} cannot be called after close." + assert not self.is_closed, ( + f"Methods of {self.__class__.__name__} cannot be called after close." + ) def __len__(self) -> int: """Return len(self), which is the number of environments.""" @@ -185,9 +185,9 @@ def _wrap_id( def _assert_id(self, id: list[int] | np.ndarray) -> None: for i in id: - assert ( - i not in self.waiting_id - ), f"Cannot interact with environment {i} which is stepping now." + assert i not in self.waiting_id, ( + f"Cannot interact with environment {i} which is stepping now." + ) assert i in self.ready_id, f"Can only interact with ready environments {self.ready_id}." # TODO: for now, has to be kept in sync with reset in EnvPoolMixin diff --git a/tianshou/env/worker/__init__.py b/tianshou/env/worker/__init__.py index 5e3d21235..dd34b7f04 100644 --- a/tianshou/env/worker/__init__.py +++ b/tianshou/env/worker/__init__.py @@ -6,8 +6,8 @@ from tianshou.env.worker.subproc import SubprocEnvWorker __all__ = [ - "EnvWorker", "DummyEnvWorker", - "SubprocEnvWorker", + "EnvWorker", "RayEnvWorker", + "SubprocEnvWorker", ] diff --git a/tianshou/env/worker/dummy.py b/tianshou/env/worker/dummy.py index 0ed0b0319..35127b714 100644 --- a/tianshou/env/worker/dummy.py +++ b/tianshou/env/worker/dummy.py @@ -15,7 +15,7 @@ def __init__(self, env_fn: Callable[[], gym.Env]) -> None: super().__init__(env_fn) def get_env_attr(self, key: str) -> Any: - return getattr(self.env, key) + return getattr(self.env.unwrapped, key) def set_env_attr(self, key: str, value: Any) -> None: setattr(self.env.unwrapped, key, value) diff --git a/tianshou/env/worker/ray.py b/tianshou/env/worker/ray.py index 76b842220..f6ba0dbb0 100644 --- a/tianshou/env/worker/ray.py +++ b/tianshou/env/worker/ray.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code=unused-ignore import contextlib from collections.abc import Callable from typing import Any @@ -12,9 +13,6 @@ import ray -# mypy: disable-error-code="unused-ignore" - - class _SetAttrWrapper(gym.Wrapper): def set_env_attr(self, key: str, value: Any) -> None: setattr(self.env.unwrapped, key, value) @@ -34,15 +32,15 @@ def __init__( super().__init__(env_fn) def get_env_attr(self, key: str) -> Any: - return ray.get(self.env.get_env_attr.remote(key)) + return ray.get(self.env.get_env_attr.remote(key)) # type: ignore def set_env_attr(self, key: str, value: Any) -> None: - ray.get(self.env.set_env_attr.remote(key, value)) + ray.get(self.env.set_env_attr.remote(key, value)) # type: ignore def reset(self, **kwargs: Any) -> Any: if "seed" in kwargs: super().seed(kwargs["seed"]) - return ray.get(self.env.reset.remote(**kwargs)) + return ray.get(self.env.reset.remote(**kwargs)) # type: ignore @staticmethod def wait( # type: ignore @@ -51,15 +49,15 @@ def wait( # type: ignore timeout: float | None = None, ) -> list["RayEnvWorker"]: results = [x.result for x in workers] - ready_results, _ = ray.wait(results, num_returns=wait_num, timeout=timeout) + ready_results, _ = ray.wait(results, num_returns=wait_num, timeout=timeout) # type: ignore return [workers[results.index(result)] for result in ready_results] def send(self, action: np.ndarray | None, **kwargs: Any) -> None: # self.result is actually a handle if action is None: - self.result = self.env.reset.remote(**kwargs) + self.result = self.env.reset.remote(**kwargs) # type: ignore else: - self.result = self.env.step.remote(action) + self.result = self.env.step.remote(action) # type: ignore def recv(self) -> gym_new_venv_step_type: return ray.get(self.result) # type: ignore @@ -67,13 +65,13 @@ def recv(self) -> gym_new_venv_step_type: def seed(self, seed: int | None = None) -> list[int] | None: super().seed(seed) try: - return ray.get(self.env.seed.remote(seed)) + return ray.get(self.env.seed.remote(seed)) # type: ignore except (AttributeError, NotImplementedError): - self.env.reset.remote(seed=seed) + self.env.reset.remote(seed=seed) # type: ignore return None def render(self, **kwargs: Any) -> Any: - return ray.get(self.env.render.remote(**kwargs)) + return ray.get(self.env.render.remote(**kwargs)) # type: ignore def close_env(self) -> None: - ray.get(self.env.close.remote()) + ray.get(self.env.close.remote()) # type: ignore diff --git a/tianshou/exploration/random.py b/tianshou/exploration/random.py index a797a9049..7857037bc 100644 --- a/tianshou/exploration/random.py +++ b/tianshou/exploration/random.py @@ -73,7 +73,7 @@ def __call__(self, size: Sequence[int], mu: float | None = None) -> np.ndarray: Return an numpy array which size is equal to ``size``. """ - if self._x is None or isinstance(self._x, np.ndarray) and self._x.shape != size: + if self._x is None or (isinstance(self._x, np.ndarray) and self._x.shape != size): self._x = 0.0 if mu is None: mu = self._mu diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index deaa2cd35..109b685b5 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -34,8 +34,8 @@ from tianshou.algorithm.modelfree.redq import REDQPolicy from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.modelfree.sac import SACPolicy -from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer -from tianshou.data.collector import BaseCollector, CollectStats +from tianshou.data import ReplayBuffer, VectorReplayBuffer +from tianshou.data.collector import BaseCollector from tianshou.highlevel.config import ( OffPolicyTrainingConfig, OnPolicyTrainingConfig, @@ -69,6 +69,10 @@ TRPOParams, ) from tianshou.highlevel.params.algorithm_wrapper import AlgorithmWrapperFactory +from tianshou.highlevel.params.collector import ( + CollectorFactory, + CollectorFactoryDefault, +) from tianshou.highlevel.params.optim import OptimizerFactoryFactory from tianshou.highlevel.persistence import PolicyPersistence from tianshou.highlevel.trainer import TrainerCallbacks, TrainingContext @@ -111,18 +115,25 @@ def __init__(self, training_config: TTrainingConfig, optim_factory: OptimizerFac self.optim_factory = optim_factory self.algorithm_wrapper_factory: AlgorithmWrapperFactory | None = None self.trainer_callbacks: TrainerCallbacks = TrainerCallbacks() + self.collector_factory: CollectorFactory = CollectorFactoryDefault() + + def set_collector_factory(self, collector_factory: CollectorFactory) -> None: + self.collector_factory = collector_factory - def create_train_test_collector( + def create_train_test_collectors( self, - policy: Algorithm, + algorithm: Algorithm, envs: Environments, reset_collectors: bool = True, ) -> tuple[BaseCollector, BaseCollector]: - """:param policy: - :param envs: + """ + Creates the collectors for training and test environments. + + :param algorithm: the algorithm + :param envs: the environments wrapper :param reset_collectors: Whether to reset the collectors before returning them. Setting to True means that the envs will be reset as well. - :return: + :return: a tuple of (train_collector, test_collector) """ buffer_size = self.training_config.buffer_size train_envs = envs.train_envs @@ -142,13 +153,13 @@ def create_train_test_collector( save_only_last_obs=self.training_config.replay_buffer_save_only_last_obs, ignore_obs_next=self.training_config.replay_buffer_ignore_obs_next, ) - train_collector = Collector[CollectStats]( - policy, + train_collector = self.collector_factory.create_collector( + algorithm, train_envs, buffer, exploration_noise=True, ) - test_collector = Collector[CollectStats](policy, envs.test_envs) + test_collector = self.collector_factory.create_collector(algorithm, envs.test_envs) if reset_collectors: train_collector.reset() test_collector.reset() @@ -483,7 +494,6 @@ def _create_policy( action_space: gymnasium.spaces.Discrete, observation_space: gymnasium.spaces.Space, ) -> Policy: - pass return self._create_policy_from_args( IQNPolicy, params, diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 8a0d3d7e9..6f6b7c694 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -27,6 +27,8 @@ from pprint import pformat from typing import TYPE_CHECKING, Any, Generic, Self, Union, cast +from tianshou.highlevel.params.collector import CollectorFactory + if TYPE_CHECKING: from tianshou.evaluation.launcher import ExpLauncher, RegisteredExpLauncher @@ -340,7 +342,7 @@ def create_experiment_world( ( train_collector, test_collector, - ) = self.algorithm_factory.create_train_test_collector( + ) = self.algorithm_factory.create_train_test_collectors( policy, envs, reset_collectors=reset_collectors, @@ -519,6 +521,7 @@ def __init__( self._logger_factory: LoggerFactory | None = None self._optim_factory: OptimizerFactoryFactory | None = None self._algorithm_wrapper_factory: AlgorithmWrapperFactory | None = None + self._collector_factory: CollectorFactory | None = None self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks() self._name: str = self.__class__.__name__.replace("Builder", "") + "_" + datetime_tag() @@ -623,6 +626,15 @@ def with_name( self._name = name return self + def with_collector_factory(self, collector_factory: CollectorFactory) -> Self: + """Allows customizing the collector factory to use. + + :param collector_factory: the factory to use for the creation of collectors + :return: the builder + """ + self._collector_factory = collector_factory + return self + @abstractmethod def _create_algorithm_factory(self) -> AlgorithmFactory: pass @@ -642,6 +654,8 @@ def build(self) -> Experiment: algorithm_factory.set_trainer_callbacks(self._trainer_callbacks) if self._algorithm_wrapper_factory: algorithm_factory.set_policy_wrapper_factory(self._algorithm_wrapper_factory) + if self._collector_factory: + algorithm_factory.set_collector_factory(self._collector_factory) experiment: Experiment = Experiment( config=self._config, env_factory=self._env_factory, diff --git a/tianshou/highlevel/module/actor.py b/tianshou/highlevel/module/actor.py index 83ae08878..3c917b1fb 100644 --- a/tianshou/highlevel/module/actor.py +++ b/tianshou/highlevel/module/actor.py @@ -103,7 +103,11 @@ def __init__( def _create_factory(self, envs: Environments) -> ActorFactory: env_type = envs.get_type() - factory: ActorFactoryContinuousDeterministicNet | ActorFactoryContinuousGaussianNet | ActorFactoryDiscreteNet + factory: ( + ActorFactoryContinuousDeterministicNet + | ActorFactoryContinuousGaussianNet + | ActorFactoryDiscreteNet + ) if env_type == EnvType.CONTINUOUS: match self.continuous_actor_type: case ContinuousActorType.GAUSSIAN: @@ -274,7 +278,7 @@ def __init__(self, actor_factory: ActorFactory): def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule: actor = self.actor_factory.create_module(envs, device) - assert isinstance( - actor, ModuleWithVectorOutput - ), "Actor factory must produce an actor with known vector output dimension" + assert isinstance(actor, ModuleWithVectorOutput), ( + "Actor factory must produce an actor with known vector output dimension" + ) return IntermediateModule(actor, actor.get_output_dim()) diff --git a/tianshou/highlevel/params/collector.py b/tianshou/highlevel/params/collector.py new file mode 100644 index 000000000..a9bf667f7 --- /dev/null +++ b/tianshou/highlevel/params/collector.py @@ -0,0 +1,39 @@ +from abc import ABC, abstractmethod + +from tianshou.algorithm import Algorithm +from tianshou.data import BaseCollector, Collector, ReplayBuffer +from tianshou.env import BaseVectorEnv + + +class CollectorFactory(ABC): + @abstractmethod + def create_collector( + self, + algorithm: Algorithm, + vector_env: BaseVectorEnv, + buffer: ReplayBuffer | None = None, + exploration_noise: bool = False, + ) -> BaseCollector: + """ + Creates a collector for the given algorithm and vectorized environment. + + :param algorithm: the algorithm + :param vector_env: the vectorized environment + :param buffer: the replay buffer to be used by the collector; + if None, a new buffer will be created with default parameters + :param exploration_noise: whether action shall be modified using the policy's exploration noise + :return: the collector + """ + + +class CollectorFactoryDefault(CollectorFactory): + def create_collector( + self, + algorithm: Algorithm, + vector_env: BaseVectorEnv, + buffer: ReplayBuffer | None = None, + exploration_noise: bool = False, + ) -> BaseCollector: + return Collector( + algorithm.policy, vector_env, buffer=buffer, exploration_noise=exploration_noise + ) diff --git a/tianshou/highlevel/params/env_param.py b/tianshou/highlevel/params/env_param.py index 2b444bbd6..143164ad6 100644 --- a/tianshou/highlevel/params/env_param.py +++ b/tianshou/highlevel/params/env_param.py @@ -1,4 +1,5 @@ """Factories for the generation of environment-dependent parameters.""" + from abc import ABC, abstractmethod from typing import Generic, TypeVar diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tianshou/utils/__init__.py b/tianshou/utils/__init__.py index a23841b36..4f79f7557 100644 --- a/tianshou/utils/__init__.py +++ b/tianshou/utils/__init__.py @@ -8,13 +8,13 @@ from tianshou.utils.warning import deprecation __all__ = [ + "BaseLogger", + "DummyTqdm", + "LazyLogger", "MovAvg", "RunningMeanStd", - "tqdm_config", - "deprecation", - "DummyTqdm", - "BaseLogger", "TensorboardLogger", - "LazyLogger", "WandbLogger", + "deprecation", + "tqdm_config", ] diff --git a/tianshou/utils/conversion.py b/tianshou/utils/conversion.py index bae2db331..16f532cb3 100644 --- a/tianshou/utils/conversion.py +++ b/tianshou/utils/conversion.py @@ -4,18 +4,15 @@ @overload -def to_optional_float(x: torch.Tensor) -> float: - ... +def to_optional_float(x: torch.Tensor) -> float: ... @overload -def to_optional_float(x: float) -> float: - ... +def to_optional_float(x: float) -> float: ... @overload -def to_optional_float(x: None) -> None: - ... +def to_optional_float(x: None) -> None: ... def to_optional_float(x: torch.Tensor | float | None) -> float | None: diff --git a/tianshou/utils/logger/tensorboard.py b/tianshou/utils/logger/tensorboard.py index dba11b555..1fc5e9d04 100644 --- a/tianshou/utils/logger/tensorboard.py +++ b/tianshou/utils/logger/tensorboard.py @@ -170,6 +170,7 @@ def add_value_to_innermost_nested_dict( >>> add_value_to_innermost_nested_dict(data_dict, "a/b/c", 1) >>> data_dict {"a": {"b": {"c": 1}}} + """ keys = key_string.split("/") diff --git a/tianshou/utils/logger/wandb.py b/tianshou/utils/logger/wandb.py index f92c3fd2c..f5a6486eb 100644 --- a/tianshou/utils/logger/wandb.py +++ b/tianshou/utils/logger/wandb.py @@ -1,5 +1,4 @@ import argparse -import contextlib import logging import os from collections.abc import Callable @@ -9,9 +8,6 @@ from tianshou.utils import BaseLogger, TensorboardLogger from tianshou.utils.logger.logger_base import VALID_LOG_VALS_TYPE, TRestoredData -with contextlib.suppress(ImportError): - import wandb - log = logging.getLogger(__name__) @@ -62,6 +58,8 @@ def __init__( disable_stats: bool = False, log_dir: str | None = None, ) -> None: + import wandb + super().__init__(train_interval, test_interval, update_interval, info_interval) self.last_save_step = -1 self.save_interval = save_interval @@ -83,13 +81,12 @@ def __init__( # monitor_gym=monitor_gym, # currently disabled until gymnasium version is bumped to >1.0.0 https://github.com/wandb/wandb/issues/7047 dir=log_dir, config=config, # type: ignore - settings=wandb.Settings(_disable_stats=disable_stats), + settings=wandb.Settings(x_disable_stats=disable_stats), ) if not wandb.run else wandb.run ) - # TODO: don't access private attribute! - self.wandb_run._label(repo="tianshou") # type: ignore + self.wandb_run._label(repo="tianshou") self.tensorboard_logger: TensorboardLogger | None = None self.writer: SummaryWriter | None = None @@ -141,12 +138,14 @@ def save_data( :param function save_checkpoint_fn: a hook defined by user, see trainer documentation for detail. """ + import wandb + if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval: self.last_save_step = epoch checkpoint_path = save_checkpoint_fn(epoch, env_step, update_step) checkpoint_artifact = wandb.Artifact( - "run_" + self.wandb_run.id + "_checkpoint", # type: ignore + "run_" + self.wandb_run.id + "_checkpoint", type="model", metadata={ "save/epoch": epoch, @@ -156,11 +155,11 @@ def save_data( }, ) checkpoint_artifact.add_file(str(checkpoint_path)) - self.wandb_run.log_artifact(checkpoint_artifact) # type: ignore + self.wandb_run.log_artifact(checkpoint_artifact) def restore_data(self) -> tuple[int, int, int]: - checkpoint_artifact = self.wandb_run.use_artifact( # type: ignore - f"run_{self.wandb_run.id}_checkpoint:latest", # type: ignore + checkpoint_artifact = self.wandb_run.use_artifact( + f"run_{self.wandb_run.id}_checkpoint:latest", ) assert checkpoint_artifact is not None, "W&B dataset artifact doesn't exist" diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index b6da01b5f..9d0277d92 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -370,11 +370,7 @@ def forward( class Recurrent(ActionReprNetWithVectorOutput[RecurrentStateBatch]): - """Simple Recurrent network based on LSTM. - - For advanced usage (how to customize the network), please refer to - :ref:`build_the_network`. - """ + """Simple Recurrent network based on LSTM.""" def __init__( self, diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index 83cddd049..e6bfe94a0 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -40,9 +40,6 @@ class ContinuousActorDeterministic(AbstractContinuousActorDeterministic): :param hidden_sizes: a sequence of int for constructing the MLP after preprocess_net. :param max_action: the scale for the final action. - - For advanced usage (how to customize the network), please refer to - :ref:`build_the_network`. """ def __init__( @@ -113,9 +110,6 @@ class ContinuousCritic(AbstractContinuousCritic): :param apply_preprocess_net_to_obs_only: whether to apply `preprocess_net` to the observations only (before concatenating with the action) - and without the observations being modified in any way beforehand. This allows the actor's preprocessing network to be reused for the critic. - - For advanced usage (how to customize the network), please refer to - :ref:`build_the_network`. """ def __init__( @@ -188,9 +182,6 @@ class ContinuousActorProbabilistic(AbstractContinuousActorProbabilistic): :param unbounded: whether to apply tanh activation on final logits. :param conditioned_sigma: True when sigma is calculated from the input, False when sigma is an independent parameter. - - For advanced usage (how to customize the network), please refer to - :ref:`build_the_network`. """ def __init__( @@ -248,11 +239,7 @@ def forward( class RecurrentActorProb(nn.Module): - """Recurrent version of ActorProb. - - For advanced usage (how to customize the network), please refer to - :ref:`build_the_network`. - """ + """Recurrent version of ActorProb.""" def __init__( self, @@ -336,11 +323,7 @@ def forward( class RecurrentCritic(nn.Module): - """Recurrent version of Critic. - - For advanced usage (how to customize the network), please refer to - :ref:`build_the_network`. - """ + """Recurrent version of Critic.""" def __init__( self, @@ -403,9 +386,6 @@ class Perturbation(nn.Module): :param device: which device to create this model on. :param phi: max perturbation parameter for BCQ. - For advanced usage (how to customize the network), please refer to - :ref:`build_the_network`. - .. seealso:: You can refer to `examples/offline/offline_bcq.py` to see how to use it. @@ -447,9 +427,6 @@ class VAE(nn.Module): :param max_action: the maximum value of each dimension of action. :param device: which device to create this model on. - For advanced usage (how to customize the network), please refer to - :ref:`build_the_network`. - .. seealso:: You can refer to `examples/offline/offline_bcq.py` to see how to use it. diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index 8f2e11f1e..53f68b033 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -98,9 +98,6 @@ class DiscreteCritic(ModuleWithVectorOutput): preprocess_net. Default to empty sequence (where the MLP now contains only a single linear layer). :param last_size: the output dimension of Critic network. Default to 1. - - For advanced usage (how to customize the network), please refer to - :ref:`build_the_network`.. """ def __init__( @@ -205,7 +202,7 @@ def forward( # type: ignore **kwargs: Any, ) -> tuple[Any, torch.Tensor]: r"""Mapping: s -> Q(s, \*).""" - logits, hidden = self.preprocess(obs, state=kwargs.get("state", None)) + logits, hidden = self.preprocess(obs, state=kwargs.get("state")) # Sample fractions. batch_size = logits.size(0) taus = torch.rand(batch_size, sample_size, dtype=logits.dtype, device=logits.device) @@ -299,7 +296,7 @@ def forward( # type: ignore **kwargs: Any, ) -> tuple[Any, torch.Tensor]: r"""Mapping: s -> Q(s, \*).""" - logits, hidden = self.preprocess(obs, state=kwargs.get("state", None)) + logits, hidden = self.preprocess(obs, state=kwargs.get("state")) # Propose fractions if fractions is None: taus, tau_hats, entropies = propose_model(logits.detach()) diff --git a/tianshou/utils/torch_utils.py b/tianshou/utils/torch_utils.py index 6273a41b4..2e6c279da 100644 --- a/tianshou/utils/torch_utils.py +++ b/tianshou/utils/torch_utils.py @@ -45,16 +45,14 @@ def policy_within_training_step( @overload -def create_uniform_action_dist(action_space: spaces.Box, batch_size: int = 1) -> dist.Uniform: - ... +def create_uniform_action_dist(action_space: spaces.Box, batch_size: int = 1) -> dist.Uniform: ... @overload def create_uniform_action_dist( action_space: spaces.Discrete, batch_size: int = 1, -) -> dist.Categorical: - ... +) -> dist.Categorical: ... def create_uniform_action_dist(