diff --git a/README.md b/README.md index 69c152859..d6a9300b4 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,13 @@ [![PyPI](https://img.shields.io/pypi/v/tianshou)](https://pypi.org/project/tianshou/) [![Conda](https://img.shields.io/conda/vn/conda-forge/tianshou)](https://github.com/conda-forge/tianshou-feedstock) [![Read the Docs](https://img.shields.io/readthedocs/tianshou)](https://tianshou.readthedocs.io/en/master) [![Read the Docs](https://img.shields.io/readthedocs/tianshou-docs-zh-cn?label=%E4%B8%AD%E6%96%87%E6%96%87%E6%A1%A3)](https://tianshou.readthedocs.io/zh/master/) [![Unittest](https://github.com/thu-ml/tianshou/workflows/Unittest/badge.svg?branch=master)](https://github.com/thu-ml/tianshou/actions) [![codecov](https://img.shields.io/codecov/c/gh/thu-ml/tianshou)](https://codecov.io/gh/thu-ml/tianshou) [![GitHub issues](https://img.shields.io/github/issues/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/issues) [![GitHub stars](https://img.shields.io/github/stars/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/stargazers) [![GitHub forks](https://img.shields.io/github/forks/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/network) [![GitHub license](https://img.shields.io/github/license/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/blob/master/LICENSE) +> ⚠️️ **Transition to Gymnasium**: The maintainers of OpenAI Gym have recently released [Gymnasium](http://github.com/Farama-Foundation/Gymnasium), +> which is where future maintenance of OpenAI Gym will be taking place. +> Tianshou has transitioned to internally using Gymnasium environments. You can still use OpenAI Gym environments with +> Tianshou vector environments, but they will be wrapped in a compatibility layer, which could be a source of issues. +> We recommend that you update your environment code to Gymnasium. If you want to continue using OpenAI Gym with +> Tianshou, you need to manually install Gym and [Shimmy](https://github.com/Farama-Foundation/Shimmy) (the compatibility layer). + **Tianshou** ([天授](https://baike.baidu.com/item/%E5%A4%A9%E6%8E%88)) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed modularized framework and pythonic API for building the deep reinforcement learning agent with the least number of lines of code. The supported interface algorithms currently include: - [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf) @@ -105,21 +112,21 @@ The example scripts are under [test/](https://github.com/thu-ml/tianshou/blob/ma ### Comprehensive Functionality -| RL Platform | GitHub Stars | # of Alg. (1) | Custom Env | Batch Training | RNN Support | Nested Observation | Backend | -| ------------------------------------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------ | --------------------------- | --------------------------------- | ------------------ | ------------------ | ---------- | -| [Baselines](https://github.com/openai/baselines) | [![GitHub stars](https://img.shields.io/github/stars/openai/baselines)](https://github.com/openai/baselines/stargazers) | 9 | :heavy_check_mark: (gym) | :heavy_minus_sign: (2) | :heavy_check_mark: | :x: | TF1 | -| [Stable-Baselines](https://github.com/hill-a/stable-baselines) | [![GitHub stars](https://img.shields.io/github/stars/hill-a/stable-baselines)](https://github.com/hill-a/stable-baselines/stargazers) | 11 | :heavy_check_mark: (gym) | :heavy_minus_sign: (2) | :heavy_check_mark: | :x: | TF1 | -| [Stable-Baselines3](https://github.com/DLR-RM/stable-baselines3) | [![GitHub stars](https://img.shields.io/github/stars/DLR-RM/stable-baselines3)](https://github.com/DLR-RM/stable-baselines3/stargazers) | 7 (3) | :heavy_check_mark: (gym) | :heavy_minus_sign: (2) | :x: | :heavy_check_mark: | PyTorch | -| [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [![GitHub stars](https://img.shields.io/github/stars/ray-project/ray)](https://github.com/ray-project/ray/stargazers) | 16 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | TF/PyTorch | -| [SpinningUp](https://github.com/openai/spinningup) | [![GitHub stars](https://img.shields.io/github/stars/openai/spinningup)](https://github.com/openai/spinningupstargazers) | 6 | :heavy_check_mark: (gym) | :heavy_minus_sign: (2) | :x: | :x: | PyTorch | -| [Dopamine](https://github.com/google/dopamine) | [![GitHub stars](https://img.shields.io/github/stars/google/dopamine)](https://github.com/google/dopamine/stargazers) | 7 | :x: | :x: | :x: | :x: | TF/JAX | -| [ACME](https://github.com/deepmind/acme) | [![GitHub stars](https://img.shields.io/github/stars/deepmind/acme)](https://github.com/deepmind/acme/stargazers) | 14 | :heavy_check_mark: (dm_env) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | TF/JAX | -| [keras-rl](https://github.com/keras-rl/keras-rl) | [![GitHub stars](https://img.shields.io/github/stars/keras-rl/keras-rl)](https://github.com/keras-rl/keras-rlstargazers) | 7 | :heavy_check_mark: (gym) | :x: | :x: | :x: | Keras | -| [rlpyt](https://github.com/astooke/rlpyt) | [![GitHub stars](https://img.shields.io/github/stars/astooke/rlpyt)](https://github.com/astooke/rlpyt/stargazers) | 11 | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch | -| [ChainerRL](https://github.com/chainer/chainerrl) | [![GitHub stars](https://img.shields.io/github/stars/chainer/chainerrl)](https://github.com/chainer/chainerrl/stargazers) | 18 | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :x: | Chainer | -| [Sample Factory](https://github.com/alex-petrenko/sample-factory) | [![GitHub stars](https://img.shields.io/github/stars/alex-petrenko/sample-factory)](https://github.com/alex-petrenko/sample-factory/stargazers) | 1 (4) | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch | -| | | | | | | | | -| [Tianshou](https://github.com/thu-ml/tianshou) | [![GitHub stars](https://img.shields.io/github/stars/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/stargazers) | 20 | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch | +| RL Platform | GitHub Stars | # of Alg. (1) | Custom Env | Batch Training | RNN Support | Nested Observation | Backend | +| ------------------------------------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------ |--------------------------------| --------------------------------- | ------------------ | ------------------ | ---------- | +| [Baselines](https://github.com/openai/baselines) | [![GitHub stars](https://img.shields.io/github/stars/openai/baselines)](https://github.com/openai/baselines/stargazers) | 9 | :heavy_check_mark: (gym) | :heavy_minus_sign: (2) | :heavy_check_mark: | :x: | TF1 | +| [Stable-Baselines](https://github.com/hill-a/stable-baselines) | [![GitHub stars](https://img.shields.io/github/stars/hill-a/stable-baselines)](https://github.com/hill-a/stable-baselines/stargazers) | 11 | :heavy_check_mark: (gym) | :heavy_minus_sign: (2) | :heavy_check_mark: | :x: | TF1 | +| [Stable-Baselines3](https://github.com/DLR-RM/stable-baselines3) | [![GitHub stars](https://img.shields.io/github/stars/DLR-RM/stable-baselines3)](https://github.com/DLR-RM/stable-baselines3/stargazers) | 7 (3) | :heavy_check_mark: (gym) | :heavy_minus_sign: (2) | :x: | :heavy_check_mark: | PyTorch | +| [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [![GitHub stars](https://img.shields.io/github/stars/ray-project/ray)](https://github.com/ray-project/ray/stargazers) | 16 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | TF/PyTorch | +| [SpinningUp](https://github.com/openai/spinningup) | [![GitHub stars](https://img.shields.io/github/stars/openai/spinningup)](https://github.com/openai/spinningupstargazers) | 6 | :heavy_check_mark: (gym) | :heavy_minus_sign: (2) | :x: | :x: | PyTorch | +| [Dopamine](https://github.com/google/dopamine) | [![GitHub stars](https://img.shields.io/github/stars/google/dopamine)](https://github.com/google/dopamine/stargazers) | 7 | :x: | :x: | :x: | :x: | TF/JAX | +| [ACME](https://github.com/deepmind/acme) | [![GitHub stars](https://img.shields.io/github/stars/deepmind/acme)](https://github.com/deepmind/acme/stargazers) | 14 | :heavy_check_mark: (dm_env) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | TF/JAX | +| [keras-rl](https://github.com/keras-rl/keras-rl) | [![GitHub stars](https://img.shields.io/github/stars/keras-rl/keras-rl)](https://github.com/keras-rl/keras-rlstargazers) | 7 | :heavy_check_mark: (gym) | :x: | :x: | :x: | Keras | +| [rlpyt](https://github.com/astooke/rlpyt) | [![GitHub stars](https://img.shields.io/github/stars/astooke/rlpyt)](https://github.com/astooke/rlpyt/stargazers) | 11 | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch | +| [ChainerRL](https://github.com/chainer/chainerrl) | [![GitHub stars](https://img.shields.io/github/stars/chainer/chainerrl)](https://github.com/chainer/chainerrl/stargazers) | 18 | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :x: | Chainer | +| [Sample Factory](https://github.com/alex-petrenko/sample-factory) | [![GitHub stars](https://img.shields.io/github/stars/alex-petrenko/sample-factory)](https://github.com/alex-petrenko/sample-factory/stargazers) | 1 (4) | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch | +| | | | | | | | | +| [Tianshou](https://github.com/thu-ml/tianshou) | [![GitHub stars](https://img.shields.io/github/stars/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/stargazers) | 20 | :heavy_check_mark: (Gymnasium) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch | (1): access date: 2021-08-08 @@ -175,7 +182,8 @@ This is an example of Deep Q Network. You can also run the full script at [test/ First, import some relevant packages: ```python -import gym, torch, numpy as np, torch.nn as nn +import gymnasium as gym +import torch, numpy as np, torch.nn as nn from torch.utils.tensorboard import SummaryWriter import tianshou as ts ``` diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index c63486213..7eaf69004 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -165,3 +165,4 @@ subprocesses isort yapf pydocstyle +Args diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst index 4c99aa6e7..7a02f2b72 100644 --- a/docs/tutorials/cheatsheet.rst +++ b/docs/tutorials/cheatsheet.rst @@ -4,7 +4,7 @@ Cheat Sheet This page shows some code snippets of how to use Tianshou to develop new algorithms / apply algorithms to new scenarios. -By the way, some of these issues can be resolved by using a ``gym.Wrapper``. +By the way, some of these issues can be resolved by using a ``gymnasium.Wrapper``. It could be a universal solution in the policy-environment interaction. But you can also use the batch processor :ref:`preprocess_fn` or vectorized environment wrapper :class:`~tianshou.env.VectorEnvWrapper`. @@ -159,7 +159,7 @@ toy_text and classic_control environments. For more information, please refer to # install envpool: pip3 install envpool import envpool - envs = envpool.make_gym("CartPole-v0", num_envs=10) + envs = envpool.make_gymnasium("CartPole-v0", num_envs=10) collector = Collector(policy, envs, buffer) Here are some other `examples `_. diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index cb6d616fe..79422bb6d 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -55,18 +55,22 @@ Buffer :class:`~tianshou.data.ReplayBuffer` stores data generated from interaction between the policy and environment. ReplayBuffer can be considered as a specialized form (or management) of :class:`~tianshou.data.Batch`. It stores all the data in a batch with circular-queue style. -The current implementation of Tianshou typically use 7 reserved keys in +The current implementation of Tianshou typically use the following reserved keys in :class:`~tianshou.data.Batch`: * ``obs`` the observation of step :math:`t` ; * ``act`` the action of step :math:`t` ; * ``rew`` the reward of step :math:`t` ; -* ``done`` the done flag of step :math:`t` ; +* ``terminated`` the terminated flag of step :math:`t` ; +* ``truncated`` the truncated flag of step :math:`t` ; +* ``done`` the done flag of step :math:`t` (can be inferred as ``terminated or truncated``); * ``obs_next`` the observation of step :math:`t+1` ; * ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()`` function returns 4 arguments, and the last one is ``info``); * ``policy`` the data computed by policy in step :math:`t`; -The following code snippet illustrates its usage, including: +When adding data to a replay buffer, the done flag will be inferred automatically from ``terminated``and ``truncated``. + +The following code snippet illustrates the usage, including: - the basic data storage: ``add()``; - get attribute, get slicing data, ...; @@ -80,7 +84,7 @@ The following code snippet illustrates its usage, including: >>> from tianshou.data import Batch, ReplayBuffer >>> buf = ReplayBuffer(size=20) >>> for i in range(3): - ... buf.add(Batch(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={})) + ... buf.add(Batch(obs=i, act=i, rew=i, terminated=0, truncated=0, obs_next=i + 1, info={})) >>> buf.obs # since we set size = 20, len(buf.obs) == 20. @@ -95,8 +99,8 @@ The following code snippet illustrates its usage, including: >>> buf2 = ReplayBuffer(size=10) >>> for i in range(15): - ... done = i % 4 == 0 - ... buf2.add(Batch(obs=i, act=i, rew=i, done=done, obs_next=i + 1, info={})) + ... terminated = i % 4 == 0 + ... buf2.add(Batch(obs=i, act=i, rew=i, terminated=terminated, truncated=False, obs_next=i + 1, info={})) >>> len(buf2) 10 >>> buf2.obs @@ -146,10 +150,10 @@ The following code snippet illustrates its usage, including: >>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True) >>> for i in range(16): - ... done = i % 5 == 0 + ... terminated = i % 5 == 0 ... ptr, ep_rew, ep_len, ep_idx = buf.add( ... Batch(obs={'id': i}, act=i, rew=i, - ... done=done, obs_next={'id': i + 1})) + ... terminated=terminated, truncated=False, obs_next={'id': i + 1})) ... print(i, ep_len, ep_rew) 0 [1] [0.] 1 [0] [0.] diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index aac8c98b7..b2c5844e2 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -35,10 +35,10 @@ Here is the overall system: Make an Environment ------------------- -First of all, you have to make an environment for your agent to interact with. You can use ``gym.make(environment_name)`` to make an environment for your agent. For environment interfaces, we follow the convention of `OpenAI Gym `_. In your Python code, simply import Tianshou and make the environment: +First of all, you have to make an environment for your agent to interact with. You can use ``gym.make(environment_name)`` to make an environment for your agent. For environment interfaces, we follow the convention of `Gymnasium `_. In your Python code, simply import Tianshou and make the environment: :: - import gym + import gymnasium as gym import tianshou as ts env = gym.make('CartPole-v0') @@ -84,8 +84,8 @@ You can also try the super-fast vectorized environment `EnvPool >> import numpy as np >>> action = 0 # action is either an integer, or an np.ndarray with one element - >>> obs, reward, done, info = env.step(action) # the env.step follows the api of OpenAI Gym + >>> obs, reward, done, info = env.step(action) # the env.step follows the api of Gymnasium >>> print(obs) # notice the change in the observation {'agent_id': 'player_2', 'obs': array([[[0, 1], [0, 0], @@ -185,7 +185,7 @@ So let's start to train our Tic-Tac-Toe agent! First, import some required modul from copy import deepcopy from typing import Optional, Tuple - import gym + import gymnasium as gym import numpy as np import torch from pettingzoo.classic import tictactoe_v3 diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index 901b166ef..fddbfb044 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -5,7 +5,7 @@ from collections import deque import cv2 -import gym +import gymnasium as gym import numpy as np from tianshou.env import ShmemVectorEnv @@ -324,7 +324,7 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs): "please set `x = x / 255.0` inside CNN network's forward function." ) # parameters convertion - train_envs = env = envpool.make_gym( + train_envs = env = envpool.make_gymnasium( task.replace("NoFrameskip-v4", "-v5"), num_envs=training_num, seed=seed, @@ -332,7 +332,7 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs): reward_clip=True, stack_num=kwargs.get("frame_stack", 4), ) - test_envs = envpool.make_gym( + test_envs = envpool.make_gymnasium( task.replace("NoFrameskip-v4", "-v5"), num_envs=test_num, seed=seed, diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 8d3bd46b0..76aec9e33 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -2,7 +2,7 @@ import os import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index bb9d4d213..5d518cd77 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -3,7 +3,7 @@ import os import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index f440c8000..8105964a9 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -2,7 +2,7 @@ import os import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index cd1b2c2c5..561060b19 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -2,7 +2,7 @@ import os import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 48436bf67..fb62d2a8f 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -2,7 +2,7 @@ import os import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index 1b9d2da73..c2f3dab45 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -6,7 +6,7 @@ import pprint import d4rl -import gym +import gymnasium as gym import numpy as np import torch from torch import nn diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index 289008545..1a2332512 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -6,7 +6,7 @@ import os import pprint -import gym +import gymnasium as gym import numpy as np import torch import wandb diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py index 0c567574e..412b60e62 100644 --- a/examples/mujoco/mujoco_env.py +++ b/examples/mujoco/mujoco_env.py @@ -1,6 +1,6 @@ import warnings -import gym +import gymnasium as gym from tianshou.env import ShmemVectorEnv, VectorEnvNormObs @@ -18,8 +18,10 @@ def make_mujoco_env(task, seed, training_num, test_num, obs_norm): :return: a tuple of (single env, training envs, test envs). """ if envpool is not None: - train_envs = env = envpool.make_gym(task, num_envs=training_num, seed=seed) - test_envs = envpool.make_gym(task, num_envs=test_num, seed=seed) + train_envs = env = envpool.make_gymnasium( + task, num_envs=training_num, seed=seed + ) + test_envs = envpool.make_gymnasium(task, num_envs=test_num, seed=seed) else: warnings.warn( "Recommend using envpool (pip install envpool) " diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py index eecf03f34..34c3f01bb 100644 --- a/examples/offline/d4rl_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -5,7 +5,7 @@ import os import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index af8b78710..23cd215ed 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -5,7 +5,7 @@ import os import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index 54dde85ce..9aeddf3db 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -5,7 +5,7 @@ import os import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index 43ee21b31..9899b4d2a 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -5,7 +5,7 @@ import os import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter diff --git a/examples/offline/utils.py b/examples/offline/utils.py index 07c693cdc..f12374280 100644 --- a/examples/offline/utils.py +++ b/examples/offline/utils.py @@ -1,7 +1,7 @@ from typing import Tuple import d4rl -import gym +import gymnasium as gym import h5py import numpy as np diff --git a/examples/vizdoom/env.py b/examples/vizdoom/env.py index 63555f733..011cb2c28 100644 --- a/examples/vizdoom/env.py +++ b/examples/vizdoom/env.py @@ -1,7 +1,7 @@ import os import cv2 -import gym +import gymnasium as gym import numpy as np import vizdoom as vzd @@ -131,7 +131,7 @@ def make_vizdoom_env(task, frame_skip, res, save_lmp, seed, training_num, test_n } if "battle" in task: reward_config["HEALTH"] = [1.0, -1.0] - env = train_envs = envpool.make_gym( + env = train_envs = envpool.make_gymnasium( task_id, frame_skip=frame_skip, stack_num=res[0], @@ -142,7 +142,7 @@ def make_vizdoom_env(task, frame_skip, res, save_lmp, seed, training_num, test_n max_episode_steps=2625, use_inter_area_resize=False, ) - test_envs = envpool.make_gym( + test_envs = envpool.make_gymnasium( task_id, frame_skip=frame_skip, stack_num=res[0], diff --git a/setup.py b/setup.py index 3d96d75a5..5912cd172 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ def get_version() -> str: def get_install_requires() -> str: return [ - "gym>=0.23.1", + "gymnasium>=0.26.0", "tqdm", "numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793 "tensorboard>=2.5.0", @@ -48,18 +48,20 @@ def get_extras_require() -> str: "doc8", "scipy", "pillow", - "pettingzoo>=1.17", + "pettingzoo>=1.22", "pygame>=2.1.0", # pettingzoo test cases pistonball "pymunk>=6.2.1", # pettingzoo test cases pistonball "nni>=2.3,<3.0", # expect breaking changes at next major version "pytorch_lightning", + "gym>=0.22.0", + "shimmy", ], "atari": ["atari_py", "opencv-python"], "mujoco": ["mujoco_py"], "pybullet": ["pybullet"], } if sys.platform == "linux": - req["dev"].append("envpool>=0.5.3") + req["dev"].append("envpool>=0.7.0") return req diff --git a/test/3rd_party/test_nni.py b/test/3rd_party/test_nni.py index ddeeff6f7..e088bf560 100644 --- a/test/3rd_party/test_nni.py +++ b/test/3rd_party/test_nni.py @@ -8,6 +8,7 @@ import nni.nas.execution.api import nni.nas.nn.pytorch as nn import nni.nas.strategy as strategy +import pytest import torch import torch.nn.functional as F from nni.nas.execution import wait_models @@ -107,6 +108,9 @@ def _get_model_and_mutators(**kwargs): return base_model_ir, mutators +@pytest.mark.skip( + reason="NNI currently uses OpenAI Gym" +) # TODO: Remove once NNI transitions to Gymnasium def test_rl(): rl = strategy.PolicyBasedRL(max_collect=2, trial_per_collect=10) engine = MockExecutionEngine(failure_prob=0.2) diff --git a/test/base/env.py b/test/base/env.py index 8c6333d0b..1dd3aab1c 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -2,10 +2,10 @@ import time from copy import deepcopy -import gym +import gymnasium as gym import networkx as nx import numpy as np -from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple +from gymnasium.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple class MyTestEnv(gym.Env): @@ -74,11 +74,13 @@ def __init__( self.terminated = False self.index = 0 - def reset(self, state=0, seed=None): + def reset(self, seed=None, options=None): + if options is None: + options = {"state": 0} super().reset(seed=seed) self.terminated = False self.do_sleep() - self.index = state + self.index = options["state"] return self._get_state(), {'key': 1, 'env': self} def _get_reward(self): @@ -174,7 +176,7 @@ def __init__(self, *args, **kwargs): assert kwargs.get("dict_state", 0) + kwargs.get("recurse_state", 0) == 0, \ "dict_state / recurse_state not supported" super().__init__(*args, **kwargs) - obs, _ = super().reset(state=0) + obs, _ = super().reset(options={"state": 0}) obs, _, _, _, _ = super().step(1) self._goal = obs * self.size super_obsv = self.observation_space diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 02d140d1e..cf011b7f1 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -196,7 +196,7 @@ def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3): buf = ReplayBuffer(bufsize, stack_num=stack_num) buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True) - obs, info = env.reset(1) + obs, info = env.reset(options={"state": 1}) for _ in range(16): obs_next, rew, terminated, truncated, info = env.step(1) done = terminated or truncated @@ -233,7 +233,7 @@ def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3): ) obs = obs_next if done: - obs, info = env.reset(1) + obs, info = env.reset(options={"state": 1}) indices = np.arange(len(buf)) assert np.allclose( buf.get(indices, 'obs')[..., 0], [ @@ -1017,7 +1017,7 @@ def test_multibuf_stack(): bufsize, stack_num=stack_num, ignore_obs_next=True, sample_avail=True ), cached_num, size ) - obs, info = env.reset(1) + obs, info = env.reset(options={"state": 1}) for i in range(18): obs_next, rew, terminated, truncated, info = env.step(1) done = terminated or truncated @@ -1045,7 +1045,7 @@ def test_multibuf_stack(): assert np.all(buf4.truncated == buf5.truncated) obs = obs_next if done: - obs, info = env.reset(1) + obs, info = env.reset(options={"state": 1}) # check the `add` order is correct assert np.allclose( buf4.obs.reshape(-1), diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 6843a394a..4b5304c1e 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -684,7 +684,9 @@ def test_collector_with_atari_setting(): @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") def test_collector_envpool_gym_reset_return_info(): - envs = envpool.make_gym("Pendulum-v1", num_envs=4, gym_reset_return_info=True) + envs = envpool.make_gymnasium( + "Pendulum-v1", num_envs=4, gym_reset_return_info=True + ) policy = MyPolicy(action_shape=(len(envs), 1)) c0 = Collector( diff --git a/test/base/test_env.py b/test/base/test_env.py index 1c91c0fa1..ec2b91415 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -1,10 +1,10 @@ import sys import time -import gym +import gymnasium as gym import numpy as np import pytest -from gym.spaces.discrete import Discrete +from gymnasium.spaces.discrete import Discrete from tianshou.data import Batch from tianshou.env import ( @@ -391,10 +391,10 @@ def step(self, act): @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") def test_venv_wrapper_envpool(): - raw = envpool.make_gym("Ant-v3", num_envs=4) - train = VectorEnvNormObs(envpool.make_gym("Ant-v3", num_envs=4)) + raw = envpool.make_gymnasium("Ant-v3", num_envs=4) + train = VectorEnvNormObs(envpool.make_gymnasium("Ant-v3", num_envs=4)) test = VectorEnvNormObs( - envpool.make_gym("Ant-v3", num_envs=4), update_obs_rms=False + envpool.make_gymnasium("Ant-v3", num_envs=4), update_obs_rms=False ) test.set_obs_rms(train.get_obs_rms()) actions = [ @@ -407,7 +407,9 @@ def test_venv_wrapper_envpool(): def test_venv_wrapper_envpool_gym_reset_return_info(): num_envs = 4 env = VectorEnvNormObs( - envpool.make_gym("Ant-v3", num_envs=num_envs, gym_reset_return_info=True) + envpool.make_gymnasium( + "Ant-v3", num_envs=num_envs, gym_reset_return_info=True + ) ) obs, info = env.reset() assert obs.shape[0] == num_envs diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py index 54b438507..938a358c7 100644 --- a/test/base/test_env_finite.py +++ b/test/base/test_env_finite.py @@ -3,7 +3,7 @@ import copy from collections import Counter -import gym +import gymnasium as gym import numpy as np from torch.utils.data import DataLoader, Dataset, DistributedSampler @@ -45,15 +45,15 @@ def reset(self): try: self.current_sample, self.step_count = next(self.iterator) self.current_step = 0 - return self.current_sample + return self.current_sample, {} except StopIteration: self.iterator = None - return None + return None, {} def step(self, action): self.current_step += 1 assert self.current_step <= self.step_count - return 0, 1.0, self.current_step >= self.step_count, \ + return 0, 1.0, self.current_step >= self.step_count, False, \ {'sample': self.current_sample, 'action': action, 'metric': 2.0} @@ -94,10 +94,12 @@ def reset(self, id=None): # ask super to reset alive envs and remap to current index request_id = list(filter(lambda i: i in self._alive_env_ids, id)) obs = [None] * len(id) + infos = [None] * len(id) id2idx = {i: k for k, i in enumerate(id)} if request_id: - for i, o in zip(request_id, super().reset(request_id)): - obs[id2idx[i]] = o + for k, o, info in zip(request_id, *super().reset(request_id)): + obs[id2idx[k]] = o + infos[id2idx[k]] = info for i, o in zip(id, obs): if o is None and i in self._alive_env_ids: self._alive_env_ids.remove(i) @@ -105,21 +107,24 @@ def reset(self, id=None): # fill empty observation with default(fake) observation for o in obs: self._set_default_obs(o) + for i in range(len(obs)): if obs[i] is None: obs[i] = self._get_default_obs() + if infos[i] is None: + infos[i] = self._get_default_info() if not self._alive_env_ids: self.reset() raise StopIteration - return np.stack(obs) + return np.stack(obs), infos def step(self, action, id=None): id = self._wrap_id(id) id2idx = {i: k for k, i in enumerate(id)} request_id = list(filter(lambda i: i in self._alive_env_ids, id)) - result = [[None, 0., False, None] for _ in range(len(id))] + result = [[None, 0., False, False, None] for _ in range(len(id))] # ask super to step alive envs and remap to current index if request_id: @@ -133,13 +138,13 @@ def step(self, action, id=None): self.tracker.log(*r) # fill empty observation/info with default(fake) - for _, __, ___, i in result: + for _, __, ___, ____, i in result: self._set_default_info(i) for i in range(len(result)): if result[i][0] is None: result[i][0] = self._get_default_obs() - if result[i][3] is None: - result[i][3] = self._get_default_info() + if result[i][-1] is None: + result[i][-1] = self._get_default_info() return list(map(np.stack, zip(*result))) @@ -171,8 +176,9 @@ def __init__(self): self.counter = Counter() self.finished = set() - def log(self, obs, rew, done, info): + def log(self, obs, rew, terminated, truncated, info): assert rew == 1. + done = terminated or truncated index = info['sample'] if done: assert index not in self.finished diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index cc9aafe5d..517ec1f4d 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -2,7 +2,7 @@ import os import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index b7ec50c82..aab35fe44 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -2,7 +2,7 @@ import os import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch import nn diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 80cf6e330..323b85d5b 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -2,7 +2,7 @@ import os import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch.distributions import Independent, Normal diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index 8b649a2be..b6bd5269d 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -2,7 +2,7 @@ import os import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 6d775e1eb..599c47020 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -59,10 +59,12 @@ def get_args(): @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") def test_sac_with_il(args=get_args()): # if you want to use python vector env, please refer to other test scripts - train_envs = env = envpool.make_gym( + train_envs = env = envpool.make_gymnasium( args.task, num_envs=args.training_num, seed=args.seed ) - test_envs = envpool.make_gym(args.task, num_envs=args.test_num, seed=args.seed) + test_envs = envpool.make_gymnasium( + args.task, num_envs=args.test_num, seed=args.seed + ) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n args.max_action = env.action_space.high[0] @@ -185,7 +187,7 @@ def stop_fn(mean_rewards): ) il_test_collector = Collector( il_policy, - envpool.make_gym(args.task, num_envs=args.test_num, seed=args.seed), + envpool.make_gymnasium(args.task, num_envs=args.test_num, seed=args.seed), ) train_collector.reset() result = offpolicy_trainer( diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index ac557ba12..aa9411a4e 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -2,7 +2,7 @@ import os import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index cc8b39f61..9bcc01487 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -2,7 +2,7 @@ import os import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch import nn diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 0a453e1eb..05ddd7ee9 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -2,7 +2,7 @@ import os import pprint -import gym +import gymnasium as gym import numpy as np import pytest import torch @@ -60,10 +60,12 @@ def get_args(): @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") def test_a2c_with_il(args=get_args()): # if you want to use python vector env, please refer to other test scripts - train_envs = env = envpool.make_gym( - args.task, num_envs=args.training_num, seed=args.seed + train_envs = env = envpool.make( + args.task, env_type="gymnasium", num_envs=args.training_num, seed=args.seed + ) + test_envs = envpool.make( + args.task, env_type="gymnasium", num_envs=args.test_num, seed=args.seed ) - test_envs = envpool.make_gym(args.task, num_envs=args.test_num, seed=args.seed) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: @@ -146,7 +148,9 @@ def stop_fn(mean_rewards): il_policy = ImitationPolicy(net, optim, action_space=env.action_space) il_test_collector = Collector( il_policy, - envpool.make_gym(args.task, num_envs=args.test_num, seed=args.seed), + envpool.make( + args.task, env_type="gymnasium", num_envs=args.test_num, seed=args.seed + ), ) train_collector.reset() result = offpolicy_trainer( diff --git a/test/discrete/test_bdq.py b/test/discrete/test_bdq.py index 604fd886b..30a2e11a2 100644 --- a/test/discrete/test_bdq.py +++ b/test/discrete/test_bdq.py @@ -1,7 +1,7 @@ import argparse import pprint -import gym +import gymnasium as gym import numpy as np import torch diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 3912b0740..01f94ab03 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -3,7 +3,7 @@ import pickle import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 2644dc998..e8359616b 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -2,7 +2,7 @@ import os import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter diff --git a/test/discrete/test_dqn_gym.py b/test/discrete/test_dqn_gym.py new file mode 100644 index 000000000..304a88256 --- /dev/null +++ b/test/discrete/test_dqn_gym.py @@ -0,0 +1,172 @@ +"""Same as test_dqn.py, but uses OpenAI Gym environments""" +import argparse +import os +import pprint + +import gym +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer +from tianshou.env import DummyVectorEnv +from tianshou.policy import DQNPolicy +from tianshou.trainer import offpolicy_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='CartPole-v0') + parser.add_argument('--reward-threshold', type=float, default=None) + parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--eps-test', type=float, default=0.05) + parser.add_argument('--eps-train', type=float, default=0.1) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--gamma', type=float, default=0.9) + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument('--target-update-freq', type=int, default=320) + parser.add_argument('--epoch', type=int, default=20) + parser.add_argument('--step-per-epoch', type=int, default=10000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument( + '--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128] + ) + parser.add_argument('--training-num', type=int, default=10) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument('--prioritized-replay', action="store_true", default=False) + parser.add_argument('--alpha', type=float, default=0.6) + parser.add_argument('--beta', type=float, default=0.4) + parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) + args = parser.parse_known_args()[0] + return args + + +def test_dqn(args=get_args()): + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + if args.reward_threshold is None: + default_reward_threshold = {"CartPole-v0": 195} + args.reward_threshold = default_reward_threshold.get( + args.task, env.spec.reward_threshold + ) + # train_envs = gym.make(args.task) + # you can also use tianshou.env.SubprocVectorEnv + train_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) + # test_envs = gym.make(args.task) + test_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # Q_param = V_param = {"hidden_sizes": [128]} + # model + net = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + # dueling=(Q_param, V_param), + ).to(args.device) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + policy = DQNPolicy( + net, + optim, + args.gamma, + args.n_step, + target_update_freq=args.target_update_freq, + ) + # buffer + if args.prioritized_replay: + buf = PrioritizedVectorReplayBuffer( + args.buffer_size, + buffer_num=len(train_envs), + alpha=args.alpha, + beta=args.beta, + ) + else: + buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) + # collector + train_collector = Collector(policy, train_envs, buf, exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) + # policy.set_eps(1) + train_collector.collect(n_step=args.batch_size * args.training_num) + # log + log_path = os.path.join(args.logdir, args.task, 'dqn') + writer = SummaryWriter(log_path) + logger = TensorboardLogger(writer) + + def save_best_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return mean_rewards >= args.reward_threshold + + def train_fn(epoch, env_step): + # eps annnealing, just a demo + if env_step <= 10000: + policy.set_eps(args.eps_train) + elif env_step <= 50000: + eps = args.eps_train - (env_step - 10000) / \ + 40000 * (0.9 * args.eps_train) + policy.set_eps(eps) + else: + policy.set_eps(0.1 * args.eps_train) + + def test_fn(epoch, env_step): + policy.set_eps(args.eps_test) + + # trainer + result = offpolicy_trainer( + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + update_per_step=args.update_per_step, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + ) + assert stop_fn(result['best_reward']) + + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + policy.eval() + policy.set_eps(args.eps_test) + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + + +def test_pdqn(args=get_args()): + args.prioritized_replay = True + args.gamma = .95 + args.seed = 1 + test_dqn(args) + + +if __name__ == '__main__': + test_dqn(get_args()) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 36ff5fa76..c01108c71 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -2,7 +2,7 @@ import os import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index e25c42997..3e5381f4e 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -2,7 +2,7 @@ import os import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index 725c9a9d5..249cdc542 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -2,7 +2,7 @@ import os import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 0ea8f745a..a8a2792ed 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -2,7 +2,7 @@ import os import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index b7dba97c9..20d6d5060 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -2,7 +2,7 @@ import os import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 9c699b4e3..2de7ca9f4 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -2,7 +2,7 @@ import os import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index 78b43dff9..6ecaaaa8c 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -3,7 +3,7 @@ import pickle import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index ac8cf7253..06831058e 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -2,7 +2,7 @@ import os import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index fba0b5523..bb5750875 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -2,7 +2,7 @@ import os import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index 6efd96277..77f0f8300 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -2,7 +2,7 @@ import os import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index 2c5c3fcd3..1332ae506 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -48,10 +48,12 @@ def get_args(): @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") def test_psrl(args=get_args()): # if you want to use python vector env, please refer to other test scripts - train_envs = env = envpool.make_gym( + train_envs = env = envpool.make_gymnasium( args.task, num_envs=args.training_num, seed=args.seed ) - test_envs = envpool.make_gym(args.task, num_envs=args.test_num, seed=args.seed) + test_envs = envpool.make_gymnasium( + args.task, num_envs=args.test_num, seed=args.seed + ) if args.reward_threshold is None: default_reward_threshold = {"NChain-v0": 3400} args.reward_threshold = default_reward_threshold.get( diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index 2e01723f3..bbea01d19 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -2,7 +2,7 @@ import os import pickle -import gym +import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index f60ec3696..b124aad61 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -2,7 +2,7 @@ import os import pickle -import gym +import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index a0fea4f50..687304bbb 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -4,7 +4,7 @@ import pickle import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 6d52bca7a..b5dfafc14 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -4,7 +4,7 @@ import pickle import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 66e42f3e9..1c9fcd591 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -3,7 +3,7 @@ import pickle import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index eca810fb6..99f60933f 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -3,7 +3,7 @@ import pickle import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index 9f47b32e7..c12513c30 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -3,7 +3,7 @@ import pickle import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index aa123ec87..060fc0aab 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -3,7 +3,7 @@ import pickle import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch.distributions import Independent, Normal diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index e34dfc6d2..2adab61cf 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -4,7 +4,7 @@ import pickle import pprint -import gym +import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index 48a5f202a..984d765fd 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -3,7 +3,7 @@ import warnings from typing import List, Optional, Tuple -import gym +import gymnasium as gym import numpy as np import pettingzoo.butterfly.pistonball_v6 as pistonball_v6 import torch diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 8f4818440..6ffb5ec62 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -3,7 +3,7 @@ import warnings from typing import Any, Dict, List, Optional, Tuple, Union -import gym +import gymnasium as gym import numpy as np import pettingzoo.butterfly.pistonball_v6 as pistonball_v6 import torch diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index 7610cee98..0b3b38262 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -4,7 +4,6 @@ from functools import partial from typing import Optional, Tuple -import gym import gymnasium import numpy as np import torch @@ -106,7 +105,7 @@ def get_agents( ) -> Tuple[BasePolicy, torch.optim.Optimizer, list]: env = get_env() observation_space = env.observation_space['observation'] if isinstance( - env.observation_space, (gym.spaces.Dict, gymnasium.spaces.Dict) + env.observation_space, gymnasium.spaces.Dict ) else env.observation_space args.state_shape = observation_space.shape or observation_space.n args.action_shape = env.action_space.shape or env.action_space.n diff --git a/test/throughput/test_buffer_profile.py b/test/throughput/test_buffer_profile.py index b14552eb8..45a1eeb8b 100644 --- a/test/throughput/test_buffer_profile.py +++ b/test/throughput/test_buffer_profile.py @@ -1,7 +1,7 @@ import sys import time -import gym +import gymnasium as gym import numpy as np import tqdm diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 88f6c76ba..3ff1ab26e 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -2,7 +2,7 @@ import warnings from typing import Any, Callable, Dict, List, Optional, Union -import gym +import gymnasium as gym import numpy as np import torch @@ -143,24 +143,14 @@ def reset_buffer(self, keep_statistics: bool = False) -> None: def reset_env(self, gym_reset_kwargs: Optional[Dict[str, Any]] = None) -> None: """Reset all of the environments.""" gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {} - rval = self.env.reset(**gym_reset_kwargs) - returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and ( - isinstance(rval[1], dict) or isinstance(rval[1][0], dict) - ) - if returns_info: - obs, info = rval - if self.preprocess_fn: - processed_data = self.preprocess_fn( - obs=obs, info=info, env_id=np.arange(self.env_num) - ) - obs = processed_data.get("obs", obs) - info = processed_data.get("info", info) - self.data.info = info - else: - obs = rval - if self.preprocess_fn: - obs = self.preprocess_fn(obs=obs, env_id=np.arange(self.env_num - )).get("obs", obs) + obs, info = self.env.reset(**gym_reset_kwargs) + if self.preprocess_fn: + processed_data = self.preprocess_fn( + obs=obs, info=info, env_id=np.arange(self.env_num) + ) + obs = processed_data.get("obs", obs) + info = processed_data.get("info", info) + self.data.info = info self.data.obs = obs def _reset_state(self, id: Union[int, List[int]]) -> None: @@ -181,24 +171,15 @@ def _reset_env_with_ids( gym_reset_kwargs: Optional[Dict[str, Any]] = None, ) -> None: gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {} - rval = self.env.reset(global_ids, **gym_reset_kwargs) - returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and ( - isinstance(rval[1], dict) or isinstance(rval[1][0], dict) - ) - if returns_info: - obs_reset, info = rval - if self.preprocess_fn: - processed_data = self.preprocess_fn( - obs=obs_reset, info=info, env_id=global_ids - ) - obs_reset = processed_data.get("obs", obs_reset) - info = processed_data.get("info", info) - self.data.info[local_ids] = info - else: - obs_reset = rval - if self.preprocess_fn: - obs_reset = self.preprocess_fn(obs=obs_reset, env_id=global_ids - ).get("obs", obs_reset) + obs_reset, info = self.env.reset(global_ids, **gym_reset_kwargs) + if self.preprocess_fn: + processed_data = self.preprocess_fn( + obs=obs_reset, info=info, env_id=global_ids + ) + obs_reset = processed_data.get("obs", obs_reset) + info = processed_data.get("info", info) + self.data.info[local_ids] = info + self.data.obs_next[local_ids] = obs_reset def collect( @@ -311,24 +292,11 @@ def collect( # get bounded and remapped actions first (not saved into buffer) action_remap = self.policy.map_action(self.data.act) # step in env - result = self.env.step(action_remap, ready_env_ids) # type: ignore - if len(result) == 5: - obs_next, rew, terminated, truncated, info = result - done = np.logical_or(terminated, truncated) - elif len(result) == 4: - obs_next, rew, done, info = result - if isinstance(info, dict): - truncated = info["TimeLimit.truncated"] - else: - truncated = np.array( - [ - info_item.get("TimeLimit.truncated", False) - for info_item in info - ] - ) - terminated = np.logical_and(done, ~truncated) - else: - raise ValueError() + obs_next, rew, terminated, truncated, info = self.env.step( + action_remap, # type: ignore + ready_env_ids + ) + done = np.logical_or(terminated, truncated) self.data.update( obs_next=obs_next, @@ -583,25 +551,11 @@ def collect( # get bounded and remapped actions first (not saved into buffer) action_remap = self.policy.map_action(self.data.act) # step in env - result = self.env.step(action_remap, ready_env_ids) # type: ignore - - if len(result) == 5: - obs_next, rew, terminated, truncated, info = result - done = np.logical_or(terminated, truncated) - elif len(result) == 4: - obs_next, rew, done, info = result - if isinstance(info, dict): - truncated = info["TimeLimit.truncated"] - else: - truncated = np.array( - [ - info_item.get("TimeLimit.truncated", False) - for info_item in info - ] - ) - terminated = np.logical_and(done, ~truncated) - else: - raise ValueError() + obs_next, rew, terminated, truncated, info = self.env.step( + action_remap, # type: ignore + ready_env_ids + ) + done = np.logical_or(terminated, truncated) # change self.data here because ready_env_ids has changed try: diff --git a/tianshou/env/gym_wrappers.py b/tianshou/env/gym_wrappers.py index b906b79b9..c9ce66acb 100644 --- a/tianshou/env/gym_wrappers.py +++ b/tianshou/env/gym_wrappers.py @@ -1,6 +1,6 @@ -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, SupportsFloat, Tuple, Union -import gym +import gymnasium as gym import numpy as np from packaging import version @@ -26,7 +26,7 @@ def __init__(self, env: gym.Env, action_per_dim: Union[int, List[int]]) -> None: dtype=object ) - def action(self, act: np.ndarray) -> np.ndarray: + def action(self, act: np.ndarray) -> np.ndarray: # type: ignore # modify act assert len(act.shape) <= 2, f"Unknown action format with shape {act.shape}." if len(act.shape) == 1: @@ -50,7 +50,7 @@ def __init__(self, env: gym.Env) -> None: self.bases[i] = self.bases[i - 1] * nvec[-i] self.action_space = gym.spaces.Discrete(np.prod(nvec)) - def action(self, act: np.ndarray) -> np.ndarray: + def action(self, act: np.ndarray) -> np.ndarray: # type: ignore converted_act = [] for b in np.flip(self.bases): converted_act.append(act // b) @@ -74,7 +74,8 @@ def __init__(self, env: gym.Env): {gym.__version__}" ) - def step(self, act: np.ndarray) -> Tuple[Any, float, bool, bool, Dict[Any, Any]]: + def step(self, + act: np.ndarray) -> Tuple[Any, SupportsFloat, bool, bool, Dict[str, Any]]: observation, reward, terminated, truncated, info = super().step(act) terminated = (terminated or truncated) return observation, reward, terminated, truncated, info diff --git a/tianshou/env/pettingzoo_env.py b/tianshou/env/pettingzoo_env.py index 28173336e..9ab828097 100644 --- a/tianshou/env/pettingzoo_env.py +++ b/tianshou/env/pettingzoo_env.py @@ -1,6 +1,6 @@ import warnings from abc import ABC -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Tuple import pettingzoo from gymnasium import spaces @@ -65,22 +65,11 @@ def __init__(self, env: BaseWrapper): self.reset() - def reset(self, *args: Any, **kwargs: Any) -> Union[dict, Tuple[dict, dict]]: + def reset(self, *args: Any, **kwargs: Any) -> Tuple[dict, dict]: self.env.reset(*args, **kwargs) - # Here, we do not label the return values explicitly to keep compatibility with - # old step API. TODO: Change once PettingZoo>=1.21.0 is required - last_return = self.env.last(self) + observation, reward, terminated, truncated, info = self.env.last(self) - if len(last_return) == 4: - warnings.warn( - "The PettingZoo environment is using the old step API. " - "This API may not be supported in future versions of tianshou. " - "We recommend that you update the environment code or apply a " - "compatibility wrapper.", DeprecationWarning - ) - - observation, info = last_return[0], last_return[-1] if isinstance(observation, dict) and 'action_mask' in observation: observation_dict = { 'agent_id': self.env.agent_selection, @@ -101,21 +90,13 @@ def reset(self, *args: Any, **kwargs: Any) -> Union[dict, Tuple[dict, dict]]: 'obs': observation, } - if "return_info" in kwargs and kwargs["return_info"]: - return observation_dict, info - else: - return observation_dict + return observation_dict, info - def step( - self, action: Any - ) -> Union[Tuple[Dict, List[int], bool, Dict], Tuple[Dict, List[int], bool, bool, - Dict]]: + def step(self, action: Any) -> Tuple[Dict, List[int], bool, bool, Dict]: self.env.step(action) - # Here, we do not label the return values explicitly to keep compatibility with - # old step API. TODO: Change once PettingZoo>=1.21.0 is required - last_return = self.env.last() - observation = last_return[0] + observation, rew, term, trunc, info = self.env.last() + if isinstance(observation, dict) and 'action_mask' in observation: obs = { 'agent_id': self.env.agent_selection, @@ -135,7 +116,7 @@ def step( for agent_id, reward in self.env.rewards.items(): self.rewards[self.agent_idx[agent_id]] = reward - return (obs, self.rewards, *last_return[2:]) # type: ignore + return obs, self.rewards, term, trunc, info def close(self) -> None: self.env.close() diff --git a/tianshou/env/utils.py b/tianshou/env/utils.py index 8868edfe8..ac671016e 100644 --- a/tianshou/env/utils.py +++ b/tianshou/env/utils.py @@ -1,9 +1,16 @@ -from typing import Any, Tuple +from typing import TYPE_CHECKING, Any, Tuple, Union import cloudpickle +import gymnasium import numpy as np -gym_old_venv_step_type = Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray] +from tianshou.env.pettingzoo_env import PettingZooEnv + +if TYPE_CHECKING: + import gym + +ENV_TYPE = Union[gymnasium.Env, "gym.Env", PettingZooEnv] + gym_new_venv_step_type = Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray] diff --git a/tianshou/env/venv_wrappers.py b/tianshou/env/venv_wrappers.py index 5e5909af6..66470289c 100644 --- a/tianshou/env/venv_wrappers.py +++ b/tianshou/env/venv_wrappers.py @@ -2,7 +2,7 @@ import numpy as np -from tianshou.env.utils import gym_new_venv_step_type, gym_old_venv_step_type +from tianshou.env.utils import gym_new_venv_step_type from tianshou.env.venvs import GYM_RESERVED_KEYS, BaseVectorEnv from tianshou.utils import RunningMeanStd @@ -42,14 +42,14 @@ def reset( self, id: Optional[Union[int, List[int], np.ndarray]] = None, **kwargs: Any, - ) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]: + ) -> Tuple[np.ndarray, Union[dict, List[dict]]]: return self.venv.reset(id, **kwargs) def step( self, action: np.ndarray, id: Optional[Union[int, List[int], np.ndarray]] = None, - ) -> Union[gym_old_venv_step_type, gym_new_venv_step_type]: + ) -> gym_new_venv_step_type: return self.venv.step(action, id) def seed( @@ -85,17 +85,10 @@ def reset( self, id: Optional[Union[int, List[int], np.ndarray]] = None, **kwargs: Any, - ) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]: - rval = self.venv.reset(id, **kwargs) - returns_info = isinstance(rval, (tuple, list)) and (len(rval) == 2) and ( - isinstance(rval[1], dict) or isinstance(rval[1][0], dict) - ) - if returns_info: - obs, info = rval - else: - obs = rval + ) -> Tuple[np.ndarray, Union[dict, List[dict]]]: + obs, info = self.venv.reset(id, **kwargs) - if isinstance(obs, tuple): + if isinstance(obs, tuple): # type: ignore raise TypeError( "Tuple observation space is not supported. ", "Please change it to array or dict space", @@ -104,20 +97,17 @@ def reset( if self.obs_rms and self.update_obs_rms: self.obs_rms.update(obs) obs = self._norm_obs(obs) - if returns_info: - return obs, info - else: - return obs + return obs, info def step( self, action: np.ndarray, id: Optional[Union[int, List[int], np.ndarray]] = None, - ) -> Union[gym_old_venv_step_type, gym_new_venv_step_type]: + ) -> gym_new_venv_step_type: step_results = self.venv.step(action, id) if self.obs_rms and self.update_obs_rms: self.obs_rms.update(step_results[0]) - return (self._norm_obs(step_results[0]), *step_results[1:]) # type:ignore + return (self._norm_obs(step_results[0]), *step_results[1:]) def _norm_obs(self, obs: np.ndarray) -> np.ndarray: if self.obs_rms: diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index fc5b757b1..a0fb0b538 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -1,9 +1,12 @@ +import warnings from typing import Any, Callable, List, Optional, Tuple, Union -import gym +import gymnasium as gym import numpy as np +import packaging -from tianshou.env.utils import gym_new_venv_step_type, gym_old_venv_step_type +from tianshou.env.pettingzoo_env import PettingZooEnv +from tianshou.env.utils import ENV_TYPE, gym_new_venv_step_type from tianshou.env.worker import ( DummyEnvWorker, EnvWorker, @@ -11,11 +14,76 @@ SubprocEnvWorker, ) +try: + import gym as old_gym + has_old_gym = True +except ImportError: + has_old_gym = False + GYM_RESERVED_KEYS = [ "metadata", "reward_range", "spec", "action_space", "observation_space" ] +def _patch_env_generator(fn: Callable[[], ENV_TYPE]) -> Callable[[], gym.Env]: + """Takes an environment generator and patches it to return Gymnasium envs. + + This function takes the environment generator `fn` and returns a patched + generator, without invoking `fn`. The original generator may return + Gymnasium or OpenAI Gym environments, but the patched generator wraps + the result of `fn` in a shimmy wrapper to convert it to Gymnasium, + if necessary. + """ + + def patched() -> gym.Env: + assert callable( + fn + ), "Env generators that are provided to vector environemnts must be callable." + + env = fn() + if isinstance(env, (gym.Env, PettingZooEnv)): + return env + + if not has_old_gym or not isinstance(env, old_gym.Env): + raise ValueError( + f"Environment generator returned a {type(env)}, not a Gymnasium " + f"environment. In this case, we expect OpenAI Gym to be " + f"installed and the environment to be an OpenAI Gym environment." + ) + + try: + import shimmy + except ImportError as e: + raise ImportError( + "Missing shimmy installation. You provided an environment generator " + "that returned an OpenAI Gym environment. " + "Tianshou has transitioned to using Gymnasium internally. " + "In order to use OpenAI Gym environments with tianshou, you need to " + "install shimmy (`pip install shimmy`)." + ) from e + + warnings.warn( + "You provided an environment generator that returned an OpenAI Gym " + "environment. We strongly recommend transitioning to Gymnasium " + "environments. " + "Tianshou is automatically wrapping your environments in a compatibility " + "layer, which could potentially cause issues." + ) + + gym_version = packaging.version.parse(old_gym.__version__) + if gym_version >= packaging.version.parse("0.26.0"): + return shimmy.GymV26CompatibilityV0(env=env) + elif gym_version >= packaging.version.parse("0.22.0"): + return shimmy.GymV22CompatibilityV0(env=env) + else: + raise Exception( + f"Found OpenAI Gym version {gym.__version__}. " + f"Tianshou only supports OpenAI Gym environments of version>=0.22.0" + ) + + return patched + + class BaseVectorEnv(object): """Base class for vectorized environments. @@ -69,7 +137,7 @@ def seed(self, seed): def __init__( self, - env_fns: List[Callable[[], gym.Env]], + env_fns: List[Callable[[], ENV_TYPE]], worker_fn: Callable[[Callable[[], gym.Env]], EnvWorker], wait_num: Optional[int] = None, timeout: Optional[float] = None, @@ -77,7 +145,7 @@ def __init__( self._env_fns = env_fns # A VectorEnv contains a pool of EnvWorkers, which corresponds to # interact with the given envs (one worker <-> one env). - self.workers = [worker_fn(fn) for fn in env_fns] + self.workers = [worker_fn(_patch_env_generator(fn)) for fn in env_fns] self.worker_class = type(self.workers[0]) assert issubclass(self.worker_class, EnvWorker) assert all([isinstance(w, self.worker_class) for w in self.workers]) @@ -186,7 +254,7 @@ def reset( self, id: Optional[Union[int, List[int], np.ndarray]] = None, **kwargs: Any, - ) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]: + ) -> Tuple[np.ndarray, Union[dict, List[dict]]]: """Reset the state of some envs and return initial observations. If id is None, reset the state of all the environments and return @@ -203,15 +271,13 @@ def reset( self.workers[i].send(None, **kwargs) ret_list = [self.workers[i].recv() for i in id] - reset_returns_info = isinstance(ret_list[0], (tuple, list)) and len( + assert isinstance(ret_list[0], (tuple, list)) and len( ret_list[0] ) == 2 and isinstance(ret_list[0][1], dict) - if reset_returns_info: - obs_list = [r[0] for r in ret_list] - else: - obs_list = ret_list - if isinstance(obs_list[0], tuple): + obs_list = [r[0] for r in ret_list] + + if isinstance(obs_list[0], tuple): # type: ignore raise TypeError( "Tuple observation space is not supported. ", "Please change it to array or dict space", @@ -221,17 +287,14 @@ def reset( except ValueError: # different len(obs) obs = np.array(obs_list, dtype=object) - if reset_returns_info: - infos = [r[1] for r in ret_list] - return obs, infos # type: ignore - else: - return obs + infos = [r[1] for r in ret_list] + return obs, infos # type: ignore def step( self, action: np.ndarray, id: Optional[Union[int, List[int], np.ndarray]] = None, - ) -> Union[gym_old_venv_step_type, gym_new_venv_step_type]: + ) -> gym_new_venv_step_type: """Run one timestep of some environments' dynamics. If id is None, run one timestep of all the environments’ dynamics; @@ -246,16 +309,6 @@ def step( :return: A tuple consisting of either: - * ``obs`` a numpy.ndarray, the agent's observation of current environments - * ``rew`` a numpy.ndarray, the amount of rewards returned after \ - previous actions - * ``done`` a numpy.ndarray, whether these episodes have ended, in \ - which case further step() calls will return undefined results - * ``info`` a numpy.ndarray, contains auxiliary diagnostic \ - information (helpful for debugging, and sometimes learning) - - or: - * ``obs`` a numpy.ndarray, the agent's observation of current environments * ``rew`` a numpy.ndarray, the amount of rewards returned after \ previous actions @@ -265,9 +318,6 @@ def step( * ``info`` a numpy.ndarray, contains auxiliary diagnostic \ information (helpful for debugging, and sometimes learning) - The case distinction is made based on whether the underlying environment - uses the old step API (first case) or the new step API (second case). - For the async simulation: Provide the given action to the environments. The action sequence @@ -312,14 +362,14 @@ def step( env_return[-1]["env_id"] = env_id # Add `env_id` to info result.append(env_return) self.ready_id.append(env_id) - return_lists = tuple(zip(*result)) - obs_list = return_lists[0] + obs_list, rew_list, term_list, trunc_list, info_list = tuple(zip(*result)) try: obs_stack = np.stack(obs_list) except ValueError: # different len(obs) obs_stack = np.array(obs_list, dtype=object) - other_stacks = map(np.stack, return_lists[1:]) - return (obs_stack, *other_stacks) # type: ignore + return obs_stack, np.stack(rew_list), np.stack(term_list), np.stack( + trunc_list + ), np.stack(info_list) def seed( self, @@ -374,7 +424,7 @@ class DummyVectorEnv(BaseVectorEnv): Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. """ - def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None: + def __init__(self, env_fns: List[Callable[[], ENV_TYPE]], **kwargs: Any) -> None: super().__init__(env_fns, DummyEnvWorker, **kwargs) @@ -386,7 +436,7 @@ class SubprocVectorEnv(BaseVectorEnv): Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. """ - def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None: + def __init__(self, env_fns: List[Callable[[], ENV_TYPE]], **kwargs: Any) -> None: def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker: return SubprocEnvWorker(fn, share_memory=False) @@ -404,7 +454,7 @@ class ShmemVectorEnv(BaseVectorEnv): Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. """ - def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None: + def __init__(self, env_fns: List[Callable[[], ENV_TYPE]], **kwargs: Any) -> None: def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker: return SubprocEnvWorker(fn, share_memory=True) @@ -422,7 +472,7 @@ class RayVectorEnv(BaseVectorEnv): Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. """ - def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None: + def __init__(self, env_fns: List[Callable[[], ENV_TYPE]], **kwargs: Any) -> None: try: import ray except ImportError as exception: diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/base.py index 779a8229a..773d56bce 100644 --- a/tianshou/env/worker/base.py +++ b/tianshou/env/worker/base.py @@ -1,10 +1,10 @@ from abc import ABC, abstractmethod from typing import Any, Callable, List, Optional, Tuple, Union -import gym +import gymnasium as gym import numpy as np -from tianshou.env.utils import gym_new_venv_step_type, gym_old_venv_step_type +from tianshou.env.utils import gym_new_venv_step_type from tianshou.utils import deprecation @@ -14,8 +14,7 @@ class EnvWorker(ABC): def __init__(self, env_fn: Callable[[], gym.Env]) -> None: self._env_fn = env_fn self.is_closed = False - self.result: Union[gym_old_venv_step_type, gym_new_venv_step_type, - Tuple[np.ndarray, dict], np.ndarray] + self.result: Union[gym_new_venv_step_type, Tuple[np.ndarray, dict]] self.action_space = self.get_env_attr("action_space") # noqa: B009 self.is_reset = False @@ -48,8 +47,7 @@ def send(self, action: Optional[np.ndarray]) -> None: def recv( self - ) -> Union[gym_old_venv_step_type, gym_new_venv_step_type, Tuple[np.ndarray, dict], - np.ndarray]: # noqa:E125 + ) -> Union[gym_new_venv_step_type, Tuple[np.ndarray, dict], ]: # noqa:E125 """Receive result from low-level worker. If the last "send" function sends a NULL action, it only returns a @@ -67,12 +65,10 @@ def recv( return self.result @abstractmethod - def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]: + def reset(self, **kwargs: Any) -> Tuple[np.ndarray, dict]: pass - def step( - self, action: np.ndarray - ) -> Union[gym_old_venv_step_type, gym_new_venv_step_type]: + def step(self, action: np.ndarray) -> gym_new_venv_step_type: """Perform one timestep of the environment's dynamic. "send" and "recv" are coupled in sync simulation, so users only call diff --git a/tianshou/env/worker/dummy.py b/tianshou/env/worker/dummy.py index 7613e6784..4eec4e0fa 100644 --- a/tianshou/env/worker/dummy.py +++ b/tianshou/env/worker/dummy.py @@ -1,6 +1,6 @@ -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple -import gym +import gymnasium as gym import numpy as np from tianshou.env.worker import EnvWorker @@ -19,7 +19,7 @@ def get_env_attr(self, key: str) -> Any: def set_env_attr(self, key: str, value: Any) -> None: setattr(self.env.unwrapped, key, value) - def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]: + def reset(self, **kwargs: Any) -> Tuple[np.ndarray, dict]: if "seed" in kwargs: super().seed(kwargs["seed"]) return self.env.reset(**kwargs) diff --git a/tianshou/env/worker/ray.py b/tianshou/env/worker/ray.py index e73d3b7f3..fe2b8fe8d 100644 --- a/tianshou/env/worker/ray.py +++ b/tianshou/env/worker/ray.py @@ -1,9 +1,9 @@ -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, List, Optional -import gym +import gymnasium as gym import numpy as np -from tianshou.env.utils import gym_new_venv_step_type, gym_old_venv_step_type +from tianshou.env.utils import gym_new_venv_step_type from tianshou.env.worker import EnvWorker try: @@ -14,22 +14,6 @@ class _SetAttrWrapper(gym.Wrapper): - def __init__(self, env: gym.Env) -> None: - """Constructor of this wrapper. - - For Gym 0.25, wrappers will automatically - change to the old step API. We need to check - which API ``env`` follows and adjust the - wrapper accordingly. - """ - env.reset() - step_result = env.step(env.action_space.sample()) - new_step_api = len(step_result) == 5 - try: - super().__init__(env, new_step_api=new_step_api) # type: ignore - except TypeError: # The kwarg `new_step_api` was removed in Gym 0.26 - super().__init__(env) - def set_env_attr(self, key: str, value: Any) -> None: setattr(self.env.unwrapped, key, value) @@ -72,7 +56,7 @@ def send(self, action: Optional[np.ndarray], **kwargs: Any) -> None: else: self.result = self.env.step.remote(action) - def recv(self) -> Union[gym_old_venv_step_type, gym_new_venv_step_type]: + def recv(self) -> gym_new_venv_step_type: return ray.get(self.result) # type: ignore def seed(self, seed: Optional[int] = None) -> Optional[List[int]]: diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py index 0c943f8bb..68f34e687 100644 --- a/tianshou/env/worker/subproc.py +++ b/tianshou/env/worker/subproc.py @@ -5,14 +5,10 @@ from multiprocessing.context import Process from typing import Any, Callable, List, Optional, Tuple, Union -import gym +import gymnasium as gym import numpy as np -from tianshou.env.utils import ( - CloudpickleWrapper, - gym_new_venv_step_type, - gym_old_venv_step_type, -) +from tianshou.env.utils import CloudpickleWrapper, gym_new_venv_step_type from tianshou.env.worker import EnvWorker _NP_TO_CT = { @@ -97,21 +93,11 @@ def _encode_obs( env_return = (None, *env_return[1:]) p.send(env_return) elif cmd == "reset": - retval = env.reset(**data) - reset_returns_info = isinstance( - retval, (tuple, list) - ) and len(retval) == 2 and isinstance(retval[1], dict) - if reset_returns_info: - obs, info = retval - else: - obs = retval + obs, info = env.reset(**data) if obs_bufs is not None: _encode_obs(obs, obs_bufs) obs = None - if reset_returns_info: - p.send((obs, info)) - else: - p.send(obs) + p.send((obs, info)) elif cmd == "close": p.send(env.close()) p.close() @@ -214,8 +200,7 @@ def send(self, action: Optional[np.ndarray], **kwargs: Any) -> None: def recv( self - ) -> Union[gym_old_venv_step_type, gym_new_venv_step_type, Tuple[np.ndarray, dict], - np.ndarray]: # noqa:E125 + ) -> Union[gym_new_venv_step_type, Tuple[np.ndarray, dict]]: # noqa:E125 result = self.parent_remote.recv() if isinstance(result, tuple): if len(result) == 2: @@ -233,7 +218,7 @@ def recv( obs = self._decode_obs() return obs - def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]: + def reset(self, **kwargs: Any) -> Tuple[np.ndarray, dict]: if "seed" in kwargs: super().seed(kwargs["seed"]) self.parent_remote.send(["reset", kwargs]) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 5f8a85664..584bb0928 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -1,10 +1,10 @@ from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import gym +import gymnasium as gym import numpy as np import torch -from gym.spaces import Box, Discrete, MultiBinary, MultiDiscrete +from gymnasium.spaces import Box, Discrete, MultiBinary, MultiDiscrete from numba import njit from torch import nn