diff --git a/README.md b/README.md
index 69c152859..d6a9300b4 100644
--- a/README.md
+++ b/README.md
@@ -6,6 +6,13 @@
[](https://pypi.org/project/tianshou/) [](https://github.com/conda-forge/tianshou-feedstock) [](https://tianshou.readthedocs.io/en/master) [](https://tianshou.readthedocs.io/zh/master/) [](https://github.com/thu-ml/tianshou/actions) [](https://codecov.io/gh/thu-ml/tianshou) [](https://github.com/thu-ml/tianshou/issues) [](https://github.com/thu-ml/tianshou/stargazers) [](https://github.com/thu-ml/tianshou/network) [](https://github.com/thu-ml/tianshou/blob/master/LICENSE)
+> ⚠️️ **Transition to Gymnasium**: The maintainers of OpenAI Gym have recently released [Gymnasium](http://github.com/Farama-Foundation/Gymnasium),
+> which is where future maintenance of OpenAI Gym will be taking place.
+> Tianshou has transitioned to internally using Gymnasium environments. You can still use OpenAI Gym environments with
+> Tianshou vector environments, but they will be wrapped in a compatibility layer, which could be a source of issues.
+> We recommend that you update your environment code to Gymnasium. If you want to continue using OpenAI Gym with
+> Tianshou, you need to manually install Gym and [Shimmy](https://github.com/Farama-Foundation/Shimmy) (the compatibility layer).
+
**Tianshou** ([天授](https://baike.baidu.com/item/%E5%A4%A9%E6%8E%88)) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed modularized framework and pythonic API for building the deep reinforcement learning agent with the least number of lines of code. The supported interface algorithms currently include:
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
@@ -105,21 +112,21 @@ The example scripts are under [test/](https://github.com/thu-ml/tianshou/blob/ma
### Comprehensive Functionality
-| RL Platform | GitHub Stars | # of Alg. (1) | Custom Env | Batch Training | RNN Support | Nested Observation | Backend |
-| ------------------------------------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------ | --------------------------- | --------------------------------- | ------------------ | ------------------ | ---------- |
-| [Baselines](https://github.com/openai/baselines) | [](https://github.com/openai/baselines/stargazers) | 9 | :heavy_check_mark: (gym) | :heavy_minus_sign: (2) | :heavy_check_mark: | :x: | TF1 |
-| [Stable-Baselines](https://github.com/hill-a/stable-baselines) | [](https://github.com/hill-a/stable-baselines/stargazers) | 11 | :heavy_check_mark: (gym) | :heavy_minus_sign: (2) | :heavy_check_mark: | :x: | TF1 |
-| [Stable-Baselines3](https://github.com/DLR-RM/stable-baselines3) | [](https://github.com/DLR-RM/stable-baselines3/stargazers) | 7 (3) | :heavy_check_mark: (gym) | :heavy_minus_sign: (2) | :x: | :heavy_check_mark: | PyTorch |
-| [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [](https://github.com/ray-project/ray/stargazers) | 16 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | TF/PyTorch |
-| [SpinningUp](https://github.com/openai/spinningup) | [](https://github.com/openai/spinningupstargazers) | 6 | :heavy_check_mark: (gym) | :heavy_minus_sign: (2) | :x: | :x: | PyTorch |
-| [Dopamine](https://github.com/google/dopamine) | [](https://github.com/google/dopamine/stargazers) | 7 | :x: | :x: | :x: | :x: | TF/JAX |
-| [ACME](https://github.com/deepmind/acme) | [](https://github.com/deepmind/acme/stargazers) | 14 | :heavy_check_mark: (dm_env) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | TF/JAX |
-| [keras-rl](https://github.com/keras-rl/keras-rl) | [](https://github.com/keras-rl/keras-rlstargazers) | 7 | :heavy_check_mark: (gym) | :x: | :x: | :x: | Keras |
-| [rlpyt](https://github.com/astooke/rlpyt) | [](https://github.com/astooke/rlpyt/stargazers) | 11 | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch |
-| [ChainerRL](https://github.com/chainer/chainerrl) | [](https://github.com/chainer/chainerrl/stargazers) | 18 | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :x: | Chainer |
-| [Sample Factory](https://github.com/alex-petrenko/sample-factory) | [](https://github.com/alex-petrenko/sample-factory/stargazers) | 1 (4) | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch |
-| | | | | | | | |
-| [Tianshou](https://github.com/thu-ml/tianshou) | [](https://github.com/thu-ml/tianshou/stargazers) | 20 | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch |
+| RL Platform | GitHub Stars | # of Alg. (1) | Custom Env | Batch Training | RNN Support | Nested Observation | Backend |
+| ------------------------------------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------ |--------------------------------| --------------------------------- | ------------------ | ------------------ | ---------- |
+| [Baselines](https://github.com/openai/baselines) | [](https://github.com/openai/baselines/stargazers) | 9 | :heavy_check_mark: (gym) | :heavy_minus_sign: (2) | :heavy_check_mark: | :x: | TF1 |
+| [Stable-Baselines](https://github.com/hill-a/stable-baselines) | [](https://github.com/hill-a/stable-baselines/stargazers) | 11 | :heavy_check_mark: (gym) | :heavy_minus_sign: (2) | :heavy_check_mark: | :x: | TF1 |
+| [Stable-Baselines3](https://github.com/DLR-RM/stable-baselines3) | [](https://github.com/DLR-RM/stable-baselines3/stargazers) | 7 (3) | :heavy_check_mark: (gym) | :heavy_minus_sign: (2) | :x: | :heavy_check_mark: | PyTorch |
+| [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [](https://github.com/ray-project/ray/stargazers) | 16 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | TF/PyTorch |
+| [SpinningUp](https://github.com/openai/spinningup) | [](https://github.com/openai/spinningupstargazers) | 6 | :heavy_check_mark: (gym) | :heavy_minus_sign: (2) | :x: | :x: | PyTorch |
+| [Dopamine](https://github.com/google/dopamine) | [](https://github.com/google/dopamine/stargazers) | 7 | :x: | :x: | :x: | :x: | TF/JAX |
+| [ACME](https://github.com/deepmind/acme) | [](https://github.com/deepmind/acme/stargazers) | 14 | :heavy_check_mark: (dm_env) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | TF/JAX |
+| [keras-rl](https://github.com/keras-rl/keras-rl) | [](https://github.com/keras-rl/keras-rlstargazers) | 7 | :heavy_check_mark: (gym) | :x: | :x: | :x: | Keras |
+| [rlpyt](https://github.com/astooke/rlpyt) | [](https://github.com/astooke/rlpyt/stargazers) | 11 | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch |
+| [ChainerRL](https://github.com/chainer/chainerrl) | [](https://github.com/chainer/chainerrl/stargazers) | 18 | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :x: | Chainer |
+| [Sample Factory](https://github.com/alex-petrenko/sample-factory) | [](https://github.com/alex-petrenko/sample-factory/stargazers) | 1 (4) | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch |
+| | | | | | | | |
+| [Tianshou](https://github.com/thu-ml/tianshou) | [](https://github.com/thu-ml/tianshou/stargazers) | 20 | :heavy_check_mark: (Gymnasium) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch |
(1): access date: 2021-08-08
@@ -175,7 +182,8 @@ This is an example of Deep Q Network. You can also run the full script at [test/
First, import some relevant packages:
```python
-import gym, torch, numpy as np, torch.nn as nn
+import gymnasium as gym
+import torch, numpy as np, torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import tianshou as ts
```
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index c63486213..7eaf69004 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -165,3 +165,4 @@ subprocesses
isort
yapf
pydocstyle
+Args
diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst
index 4c99aa6e7..7a02f2b72 100644
--- a/docs/tutorials/cheatsheet.rst
+++ b/docs/tutorials/cheatsheet.rst
@@ -4,7 +4,7 @@ Cheat Sheet
This page shows some code snippets of how to use Tianshou to develop new
algorithms / apply algorithms to new scenarios.
-By the way, some of these issues can be resolved by using a ``gym.Wrapper``.
+By the way, some of these issues can be resolved by using a ``gymnasium.Wrapper``.
It could be a universal solution in the policy-environment interaction. But
you can also use the batch processor :ref:`preprocess_fn` or vectorized
environment wrapper :class:`~tianshou.env.VectorEnvWrapper`.
@@ -159,7 +159,7 @@ toy_text and classic_control environments. For more information, please refer to
# install envpool: pip3 install envpool
import envpool
- envs = envpool.make_gym("CartPole-v0", num_envs=10)
+ envs = envpool.make_gymnasium("CartPole-v0", num_envs=10)
collector = Collector(policy, envs, buffer)
Here are some other `examples `_.
diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst
index cb6d616fe..79422bb6d 100644
--- a/docs/tutorials/concepts.rst
+++ b/docs/tutorials/concepts.rst
@@ -55,18 +55,22 @@ Buffer
:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction between the policy and environment. ReplayBuffer can be considered as a specialized form (or management) of :class:`~tianshou.data.Batch`. It stores all the data in a batch with circular-queue style.
-The current implementation of Tianshou typically use 7 reserved keys in
+The current implementation of Tianshou typically use the following reserved keys in
:class:`~tianshou.data.Batch`:
* ``obs`` the observation of step :math:`t` ;
* ``act`` the action of step :math:`t` ;
* ``rew`` the reward of step :math:`t` ;
-* ``done`` the done flag of step :math:`t` ;
+* ``terminated`` the terminated flag of step :math:`t` ;
+* ``truncated`` the truncated flag of step :math:`t` ;
+* ``done`` the done flag of step :math:`t` (can be inferred as ``terminated or truncated``);
* ``obs_next`` the observation of step :math:`t+1` ;
* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()`` function returns 4 arguments, and the last one is ``info``);
* ``policy`` the data computed by policy in step :math:`t`;
-The following code snippet illustrates its usage, including:
+When adding data to a replay buffer, the done flag will be inferred automatically from ``terminated``and ``truncated``.
+
+The following code snippet illustrates the usage, including:
- the basic data storage: ``add()``;
- get attribute, get slicing data, ...;
@@ -80,7 +84,7 @@ The following code snippet illustrates its usage, including:
>>> from tianshou.data import Batch, ReplayBuffer
>>> buf = ReplayBuffer(size=20)
>>> for i in range(3):
- ... buf.add(Batch(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={}))
+ ... buf.add(Batch(obs=i, act=i, rew=i, terminated=0, truncated=0, obs_next=i + 1, info={}))
>>> buf.obs
# since we set size = 20, len(buf.obs) == 20.
@@ -95,8 +99,8 @@ The following code snippet illustrates its usage, including:
>>> buf2 = ReplayBuffer(size=10)
>>> for i in range(15):
- ... done = i % 4 == 0
- ... buf2.add(Batch(obs=i, act=i, rew=i, done=done, obs_next=i + 1, info={}))
+ ... terminated = i % 4 == 0
+ ... buf2.add(Batch(obs=i, act=i, rew=i, terminated=terminated, truncated=False, obs_next=i + 1, info={}))
>>> len(buf2)
10
>>> buf2.obs
@@ -146,10 +150,10 @@ The following code snippet illustrates its usage, including:
>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
>>> for i in range(16):
- ... done = i % 5 == 0
+ ... terminated = i % 5 == 0
... ptr, ep_rew, ep_len, ep_idx = buf.add(
... Batch(obs={'id': i}, act=i, rew=i,
- ... done=done, obs_next={'id': i + 1}))
+ ... terminated=terminated, truncated=False, obs_next={'id': i + 1}))
... print(i, ep_len, ep_rew)
0 [1] [0.]
1 [0] [0.]
diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst
index aac8c98b7..b2c5844e2 100644
--- a/docs/tutorials/dqn.rst
+++ b/docs/tutorials/dqn.rst
@@ -35,10 +35,10 @@ Here is the overall system:
Make an Environment
-------------------
-First of all, you have to make an environment for your agent to interact with. You can use ``gym.make(environment_name)`` to make an environment for your agent. For environment interfaces, we follow the convention of `OpenAI Gym `_. In your Python code, simply import Tianshou and make the environment:
+First of all, you have to make an environment for your agent to interact with. You can use ``gym.make(environment_name)`` to make an environment for your agent. For environment interfaces, we follow the convention of `Gymnasium `_. In your Python code, simply import Tianshou and make the environment:
::
- import gym
+ import gymnasium as gym
import tianshou as ts
env = gym.make('CartPole-v0')
@@ -84,8 +84,8 @@ You can also try the super-fast vectorized environment `EnvPool >> import numpy as np
>>> action = 0 # action is either an integer, or an np.ndarray with one element
- >>> obs, reward, done, info = env.step(action) # the env.step follows the api of OpenAI Gym
+ >>> obs, reward, done, info = env.step(action) # the env.step follows the api of Gymnasium
>>> print(obs) # notice the change in the observation
{'agent_id': 'player_2', 'obs': array([[[0, 1],
[0, 0],
@@ -185,7 +185,7 @@ So let's start to train our Tic-Tac-Toe agent! First, import some required modul
from copy import deepcopy
from typing import Optional, Tuple
- import gym
+ import gymnasium as gym
import numpy as np
import torch
from pettingzoo.classic import tictactoe_v3
diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py
index 901b166ef..fddbfb044 100644
--- a/examples/atari/atari_wrapper.py
+++ b/examples/atari/atari_wrapper.py
@@ -5,7 +5,7 @@
from collections import deque
import cv2
-import gym
+import gymnasium as gym
import numpy as np
from tianshou.env import ShmemVectorEnv
@@ -324,7 +324,7 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs):
"please set `x = x / 255.0` inside CNN network's forward function."
)
# parameters convertion
- train_envs = env = envpool.make_gym(
+ train_envs = env = envpool.make_gymnasium(
task.replace("NoFrameskip-v4", "-v5"),
num_envs=training_num,
seed=seed,
@@ -332,7 +332,7 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs):
reward_clip=True,
stack_num=kwargs.get("frame_stack", 4),
)
- test_envs = envpool.make_gym(
+ test_envs = envpool.make_gymnasium(
task.replace("NoFrameskip-v4", "-v5"),
num_envs=test_num,
seed=seed,
diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py
index 8d3bd46b0..76aec9e33 100644
--- a/examples/box2d/acrobot_dualdqn.py
+++ b/examples/box2d/acrobot_dualdqn.py
@@ -2,7 +2,7 @@
import os
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py
index bb9d4d213..5d518cd77 100644
--- a/examples/box2d/bipedal_bdq.py
+++ b/examples/box2d/bipedal_bdq.py
@@ -3,7 +3,7 @@
import os
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py
index f440c8000..8105964a9 100644
--- a/examples/box2d/bipedal_hardcore_sac.py
+++ b/examples/box2d/bipedal_hardcore_sac.py
@@ -2,7 +2,7 @@
import os
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py
index cd1b2c2c5..561060b19 100644
--- a/examples/box2d/lunarlander_dqn.py
+++ b/examples/box2d/lunarlander_dqn.py
@@ -2,7 +2,7 @@
import os
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py
index 48436bf67..fb62d2a8f 100644
--- a/examples/box2d/mcc_sac.py
+++ b/examples/box2d/mcc_sac.py
@@ -2,7 +2,7 @@
import os
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py
index 1b9d2da73..c2f3dab45 100644
--- a/examples/inverse/irl_gail.py
+++ b/examples/inverse/irl_gail.py
@@ -6,7 +6,7 @@
import pprint
import d4rl
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch import nn
diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py
index 289008545..1a2332512 100644
--- a/examples/mujoco/fetch_her_ddpg.py
+++ b/examples/mujoco/fetch_her_ddpg.py
@@ -6,7 +6,7 @@
import os
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
import wandb
diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py
index 0c567574e..412b60e62 100644
--- a/examples/mujoco/mujoco_env.py
+++ b/examples/mujoco/mujoco_env.py
@@ -1,6 +1,6 @@
import warnings
-import gym
+import gymnasium as gym
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
@@ -18,8 +18,10 @@ def make_mujoco_env(task, seed, training_num, test_num, obs_norm):
:return: a tuple of (single env, training envs, test envs).
"""
if envpool is not None:
- train_envs = env = envpool.make_gym(task, num_envs=training_num, seed=seed)
- test_envs = envpool.make_gym(task, num_envs=test_num, seed=seed)
+ train_envs = env = envpool.make_gymnasium(
+ task, num_envs=training_num, seed=seed
+ )
+ test_envs = envpool.make_gymnasium(task, num_envs=test_num, seed=seed)
else:
warnings.warn(
"Recommend using envpool (pip install envpool) "
diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py
index eecf03f34..34c3f01bb 100644
--- a/examples/offline/d4rl_bcq.py
+++ b/examples/offline/d4rl_bcq.py
@@ -5,7 +5,7 @@
import os
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py
index af8b78710..23cd215ed 100644
--- a/examples/offline/d4rl_cql.py
+++ b/examples/offline/d4rl_cql.py
@@ -5,7 +5,7 @@
import os
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py
index 54dde85ce..9aeddf3db 100644
--- a/examples/offline/d4rl_il.py
+++ b/examples/offline/d4rl_il.py
@@ -5,7 +5,7 @@
import os
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py
index 43ee21b31..9899b4d2a 100644
--- a/examples/offline/d4rl_td3_bc.py
+++ b/examples/offline/d4rl_td3_bc.py
@@ -5,7 +5,7 @@
import os
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
diff --git a/examples/offline/utils.py b/examples/offline/utils.py
index 07c693cdc..f12374280 100644
--- a/examples/offline/utils.py
+++ b/examples/offline/utils.py
@@ -1,7 +1,7 @@
from typing import Tuple
import d4rl
-import gym
+import gymnasium as gym
import h5py
import numpy as np
diff --git a/examples/vizdoom/env.py b/examples/vizdoom/env.py
index 63555f733..011cb2c28 100644
--- a/examples/vizdoom/env.py
+++ b/examples/vizdoom/env.py
@@ -1,7 +1,7 @@
import os
import cv2
-import gym
+import gymnasium as gym
import numpy as np
import vizdoom as vzd
@@ -131,7 +131,7 @@ def make_vizdoom_env(task, frame_skip, res, save_lmp, seed, training_num, test_n
}
if "battle" in task:
reward_config["HEALTH"] = [1.0, -1.0]
- env = train_envs = envpool.make_gym(
+ env = train_envs = envpool.make_gymnasium(
task_id,
frame_skip=frame_skip,
stack_num=res[0],
@@ -142,7 +142,7 @@ def make_vizdoom_env(task, frame_skip, res, save_lmp, seed, training_num, test_n
max_episode_steps=2625,
use_inter_area_resize=False,
)
- test_envs = envpool.make_gym(
+ test_envs = envpool.make_gymnasium(
task_id,
frame_skip=frame_skip,
stack_num=res[0],
diff --git a/setup.py b/setup.py
index 3d96d75a5..5912cd172 100644
--- a/setup.py
+++ b/setup.py
@@ -15,7 +15,7 @@ def get_version() -> str:
def get_install_requires() -> str:
return [
- "gym>=0.23.1",
+ "gymnasium>=0.26.0",
"tqdm",
"numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793
"tensorboard>=2.5.0",
@@ -48,18 +48,20 @@ def get_extras_require() -> str:
"doc8",
"scipy",
"pillow",
- "pettingzoo>=1.17",
+ "pettingzoo>=1.22",
"pygame>=2.1.0", # pettingzoo test cases pistonball
"pymunk>=6.2.1", # pettingzoo test cases pistonball
"nni>=2.3,<3.0", # expect breaking changes at next major version
"pytorch_lightning",
+ "gym>=0.22.0",
+ "shimmy",
],
"atari": ["atari_py", "opencv-python"],
"mujoco": ["mujoco_py"],
"pybullet": ["pybullet"],
}
if sys.platform == "linux":
- req["dev"].append("envpool>=0.5.3")
+ req["dev"].append("envpool>=0.7.0")
return req
diff --git a/test/3rd_party/test_nni.py b/test/3rd_party/test_nni.py
index ddeeff6f7..e088bf560 100644
--- a/test/3rd_party/test_nni.py
+++ b/test/3rd_party/test_nni.py
@@ -8,6 +8,7 @@
import nni.nas.execution.api
import nni.nas.nn.pytorch as nn
import nni.nas.strategy as strategy
+import pytest
import torch
import torch.nn.functional as F
from nni.nas.execution import wait_models
@@ -107,6 +108,9 @@ def _get_model_and_mutators(**kwargs):
return base_model_ir, mutators
+@pytest.mark.skip(
+ reason="NNI currently uses OpenAI Gym"
+) # TODO: Remove once NNI transitions to Gymnasium
def test_rl():
rl = strategy.PolicyBasedRL(max_collect=2, trial_per_collect=10)
engine = MockExecutionEngine(failure_prob=0.2)
diff --git a/test/base/env.py b/test/base/env.py
index 8c6333d0b..1dd3aab1c 100644
--- a/test/base/env.py
+++ b/test/base/env.py
@@ -2,10 +2,10 @@
import time
from copy import deepcopy
-import gym
+import gymnasium as gym
import networkx as nx
import numpy as np
-from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple
+from gymnasium.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple
class MyTestEnv(gym.Env):
@@ -74,11 +74,13 @@ def __init__(
self.terminated = False
self.index = 0
- def reset(self, state=0, seed=None):
+ def reset(self, seed=None, options=None):
+ if options is None:
+ options = {"state": 0}
super().reset(seed=seed)
self.terminated = False
self.do_sleep()
- self.index = state
+ self.index = options["state"]
return self._get_state(), {'key': 1, 'env': self}
def _get_reward(self):
@@ -174,7 +176,7 @@ def __init__(self, *args, **kwargs):
assert kwargs.get("dict_state", 0) + kwargs.get("recurse_state", 0) == 0, \
"dict_state / recurse_state not supported"
super().__init__(*args, **kwargs)
- obs, _ = super().reset(state=0)
+ obs, _ = super().reset(options={"state": 0})
obs, _, _, _, _ = super().step(1)
self._goal = obs * self.size
super_obsv = self.observation_space
diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py
index 02d140d1e..cf011b7f1 100644
--- a/test/base/test_buffer.py
+++ b/test/base/test_buffer.py
@@ -196,7 +196,7 @@ def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3):
buf = ReplayBuffer(bufsize, stack_num=stack_num)
buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True)
- obs, info = env.reset(1)
+ obs, info = env.reset(options={"state": 1})
for _ in range(16):
obs_next, rew, terminated, truncated, info = env.step(1)
done = terminated or truncated
@@ -233,7 +233,7 @@ def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3):
)
obs = obs_next
if done:
- obs, info = env.reset(1)
+ obs, info = env.reset(options={"state": 1})
indices = np.arange(len(buf))
assert np.allclose(
buf.get(indices, 'obs')[..., 0], [
@@ -1017,7 +1017,7 @@ def test_multibuf_stack():
bufsize, stack_num=stack_num, ignore_obs_next=True, sample_avail=True
), cached_num, size
)
- obs, info = env.reset(1)
+ obs, info = env.reset(options={"state": 1})
for i in range(18):
obs_next, rew, terminated, truncated, info = env.step(1)
done = terminated or truncated
@@ -1045,7 +1045,7 @@ def test_multibuf_stack():
assert np.all(buf4.truncated == buf5.truncated)
obs = obs_next
if done:
- obs, info = env.reset(1)
+ obs, info = env.reset(options={"state": 1})
# check the `add` order is correct
assert np.allclose(
buf4.obs.reshape(-1),
diff --git a/test/base/test_collector.py b/test/base/test_collector.py
index 6843a394a..4b5304c1e 100644
--- a/test/base/test_collector.py
+++ b/test/base/test_collector.py
@@ -684,7 +684,9 @@ def test_collector_with_atari_setting():
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_collector_envpool_gym_reset_return_info():
- envs = envpool.make_gym("Pendulum-v1", num_envs=4, gym_reset_return_info=True)
+ envs = envpool.make_gymnasium(
+ "Pendulum-v1", num_envs=4, gym_reset_return_info=True
+ )
policy = MyPolicy(action_shape=(len(envs), 1))
c0 = Collector(
diff --git a/test/base/test_env.py b/test/base/test_env.py
index 1c91c0fa1..ec2b91415 100644
--- a/test/base/test_env.py
+++ b/test/base/test_env.py
@@ -1,10 +1,10 @@
import sys
import time
-import gym
+import gymnasium as gym
import numpy as np
import pytest
-from gym.spaces.discrete import Discrete
+from gymnasium.spaces.discrete import Discrete
from tianshou.data import Batch
from tianshou.env import (
@@ -391,10 +391,10 @@ def step(self, act):
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_venv_wrapper_envpool():
- raw = envpool.make_gym("Ant-v3", num_envs=4)
- train = VectorEnvNormObs(envpool.make_gym("Ant-v3", num_envs=4))
+ raw = envpool.make_gymnasium("Ant-v3", num_envs=4)
+ train = VectorEnvNormObs(envpool.make_gymnasium("Ant-v3", num_envs=4))
test = VectorEnvNormObs(
- envpool.make_gym("Ant-v3", num_envs=4), update_obs_rms=False
+ envpool.make_gymnasium("Ant-v3", num_envs=4), update_obs_rms=False
)
test.set_obs_rms(train.get_obs_rms())
actions = [
@@ -407,7 +407,9 @@ def test_venv_wrapper_envpool():
def test_venv_wrapper_envpool_gym_reset_return_info():
num_envs = 4
env = VectorEnvNormObs(
- envpool.make_gym("Ant-v3", num_envs=num_envs, gym_reset_return_info=True)
+ envpool.make_gymnasium(
+ "Ant-v3", num_envs=num_envs, gym_reset_return_info=True
+ )
)
obs, info = env.reset()
assert obs.shape[0] == num_envs
diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py
index 54b438507..938a358c7 100644
--- a/test/base/test_env_finite.py
+++ b/test/base/test_env_finite.py
@@ -3,7 +3,7 @@
import copy
from collections import Counter
-import gym
+import gymnasium as gym
import numpy as np
from torch.utils.data import DataLoader, Dataset, DistributedSampler
@@ -45,15 +45,15 @@ def reset(self):
try:
self.current_sample, self.step_count = next(self.iterator)
self.current_step = 0
- return self.current_sample
+ return self.current_sample, {}
except StopIteration:
self.iterator = None
- return None
+ return None, {}
def step(self, action):
self.current_step += 1
assert self.current_step <= self.step_count
- return 0, 1.0, self.current_step >= self.step_count, \
+ return 0, 1.0, self.current_step >= self.step_count, False, \
{'sample': self.current_sample, 'action': action, 'metric': 2.0}
@@ -94,10 +94,12 @@ def reset(self, id=None):
# ask super to reset alive envs and remap to current index
request_id = list(filter(lambda i: i in self._alive_env_ids, id))
obs = [None] * len(id)
+ infos = [None] * len(id)
id2idx = {i: k for k, i in enumerate(id)}
if request_id:
- for i, o in zip(request_id, super().reset(request_id)):
- obs[id2idx[i]] = o
+ for k, o, info in zip(request_id, *super().reset(request_id)):
+ obs[id2idx[k]] = o
+ infos[id2idx[k]] = info
for i, o in zip(id, obs):
if o is None and i in self._alive_env_ids:
self._alive_env_ids.remove(i)
@@ -105,21 +107,24 @@ def reset(self, id=None):
# fill empty observation with default(fake) observation
for o in obs:
self._set_default_obs(o)
+
for i in range(len(obs)):
if obs[i] is None:
obs[i] = self._get_default_obs()
+ if infos[i] is None:
+ infos[i] = self._get_default_info()
if not self._alive_env_ids:
self.reset()
raise StopIteration
- return np.stack(obs)
+ return np.stack(obs), infos
def step(self, action, id=None):
id = self._wrap_id(id)
id2idx = {i: k for k, i in enumerate(id)}
request_id = list(filter(lambda i: i in self._alive_env_ids, id))
- result = [[None, 0., False, None] for _ in range(len(id))]
+ result = [[None, 0., False, False, None] for _ in range(len(id))]
# ask super to step alive envs and remap to current index
if request_id:
@@ -133,13 +138,13 @@ def step(self, action, id=None):
self.tracker.log(*r)
# fill empty observation/info with default(fake)
- for _, __, ___, i in result:
+ for _, __, ___, ____, i in result:
self._set_default_info(i)
for i in range(len(result)):
if result[i][0] is None:
result[i][0] = self._get_default_obs()
- if result[i][3] is None:
- result[i][3] = self._get_default_info()
+ if result[i][-1] is None:
+ result[i][-1] = self._get_default_info()
return list(map(np.stack, zip(*result)))
@@ -171,8 +176,9 @@ def __init__(self):
self.counter = Counter()
self.finished = set()
- def log(self, obs, rew, done, info):
+ def log(self, obs, rew, terminated, truncated, info):
assert rew == 1.
+ done = terminated or truncated
index = info['sample']
if done:
assert index not in self.finished
diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py
index cc9aafe5d..517ec1f4d 100644
--- a/test/continuous/test_ddpg.py
+++ b/test/continuous/test_ddpg.py
@@ -2,7 +2,7 @@
import os
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py
index b7ec50c82..aab35fe44 100644
--- a/test/continuous/test_npg.py
+++ b/test/continuous/test_npg.py
@@ -2,7 +2,7 @@
import os
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch import nn
diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py
index 80cf6e330..323b85d5b 100644
--- a/test/continuous/test_ppo.py
+++ b/test/continuous/test_ppo.py
@@ -2,7 +2,7 @@
import os
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.distributions import Independent, Normal
diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py
index 8b649a2be..b6bd5269d 100644
--- a/test/continuous/test_redq.py
+++ b/test/continuous/test_redq.py
@@ -2,7 +2,7 @@
import os
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py
index 6d775e1eb..599c47020 100644
--- a/test/continuous/test_sac_with_il.py
+++ b/test/continuous/test_sac_with_il.py
@@ -59,10 +59,12 @@ def get_args():
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_sac_with_il(args=get_args()):
# if you want to use python vector env, please refer to other test scripts
- train_envs = env = envpool.make_gym(
+ train_envs = env = envpool.make_gymnasium(
args.task, num_envs=args.training_num, seed=args.seed
)
- test_envs = envpool.make_gym(args.task, num_envs=args.test_num, seed=args.seed)
+ test_envs = envpool.make_gymnasium(
+ args.task, num_envs=args.test_num, seed=args.seed
+ )
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
@@ -185,7 +187,7 @@ def stop_fn(mean_rewards):
)
il_test_collector = Collector(
il_policy,
- envpool.make_gym(args.task, num_envs=args.test_num, seed=args.seed),
+ envpool.make_gymnasium(args.task, num_envs=args.test_num, seed=args.seed),
)
train_collector.reset()
result = offpolicy_trainer(
diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py
index ac557ba12..aa9411a4e 100644
--- a/test/continuous/test_td3.py
+++ b/test/continuous/test_td3.py
@@ -2,7 +2,7 @@
import os
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py
index cc8b39f61..9bcc01487 100644
--- a/test/continuous/test_trpo.py
+++ b/test/continuous/test_trpo.py
@@ -2,7 +2,7 @@
import os
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch import nn
diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py
index 0a453e1eb..05ddd7ee9 100644
--- a/test/discrete/test_a2c_with_il.py
+++ b/test/discrete/test_a2c_with_il.py
@@ -2,7 +2,7 @@
import os
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import pytest
import torch
@@ -60,10 +60,12 @@ def get_args():
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_a2c_with_il(args=get_args()):
# if you want to use python vector env, please refer to other test scripts
- train_envs = env = envpool.make_gym(
- args.task, num_envs=args.training_num, seed=args.seed
+ train_envs = env = envpool.make(
+ args.task, env_type="gymnasium", num_envs=args.training_num, seed=args.seed
+ )
+ test_envs = envpool.make(
+ args.task, env_type="gymnasium", num_envs=args.test_num, seed=args.seed
)
- test_envs = envpool.make_gym(args.task, num_envs=args.test_num, seed=args.seed)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
if args.reward_threshold is None:
@@ -146,7 +148,9 @@ def stop_fn(mean_rewards):
il_policy = ImitationPolicy(net, optim, action_space=env.action_space)
il_test_collector = Collector(
il_policy,
- envpool.make_gym(args.task, num_envs=args.test_num, seed=args.seed),
+ envpool.make(
+ args.task, env_type="gymnasium", num_envs=args.test_num, seed=args.seed
+ ),
)
train_collector.reset()
result = offpolicy_trainer(
diff --git a/test/discrete/test_bdq.py b/test/discrete/test_bdq.py
index 604fd886b..30a2e11a2 100644
--- a/test/discrete/test_bdq.py
+++ b/test/discrete/test_bdq.py
@@ -1,7 +1,7 @@
import argparse
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py
index 3912b0740..01f94ab03 100644
--- a/test/discrete/test_c51.py
+++ b/test/discrete/test_c51.py
@@ -3,7 +3,7 @@
import pickle
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py
index 2644dc998..e8359616b 100644
--- a/test/discrete/test_dqn.py
+++ b/test/discrete/test_dqn.py
@@ -2,7 +2,7 @@
import os
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
diff --git a/test/discrete/test_dqn_gym.py b/test/discrete/test_dqn_gym.py
new file mode 100644
index 000000000..304a88256
--- /dev/null
+++ b/test/discrete/test_dqn_gym.py
@@ -0,0 +1,172 @@
+"""Same as test_dqn.py, but uses OpenAI Gym environments"""
+import argparse
+import os
+import pprint
+
+import gym
+import numpy as np
+import torch
+from torch.utils.tensorboard import SummaryWriter
+
+from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
+from tianshou.env import DummyVectorEnv
+from tianshou.policy import DQNPolicy
+from tianshou.trainer import offpolicy_trainer
+from tianshou.utils import TensorboardLogger
+from tianshou.utils.net.common import Net
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--task', type=str, default='CartPole-v0')
+ parser.add_argument('--reward-threshold', type=float, default=None)
+ parser.add_argument('--seed', type=int, default=1626)
+ parser.add_argument('--eps-test', type=float, default=0.05)
+ parser.add_argument('--eps-train', type=float, default=0.1)
+ parser.add_argument('--buffer-size', type=int, default=20000)
+ parser.add_argument('--lr', type=float, default=1e-3)
+ parser.add_argument('--gamma', type=float, default=0.9)
+ parser.add_argument('--n-step', type=int, default=3)
+ parser.add_argument('--target-update-freq', type=int, default=320)
+ parser.add_argument('--epoch', type=int, default=20)
+ parser.add_argument('--step-per-epoch', type=int, default=10000)
+ parser.add_argument('--step-per-collect', type=int, default=10)
+ parser.add_argument('--update-per-step', type=float, default=0.1)
+ parser.add_argument('--batch-size', type=int, default=64)
+ parser.add_argument(
+ '--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]
+ )
+ parser.add_argument('--training-num', type=int, default=10)
+ parser.add_argument('--test-num', type=int, default=100)
+ parser.add_argument('--logdir', type=str, default='log')
+ parser.add_argument('--render', type=float, default=0.)
+ parser.add_argument('--prioritized-replay', action="store_true", default=False)
+ parser.add_argument('--alpha', type=float, default=0.6)
+ parser.add_argument('--beta', type=float, default=0.4)
+ parser.add_argument(
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
+ args = parser.parse_known_args()[0]
+ return args
+
+
+def test_dqn(args=get_args()):
+ env = gym.make(args.task)
+ args.state_shape = env.observation_space.shape or env.observation_space.n
+ args.action_shape = env.action_space.shape or env.action_space.n
+ if args.reward_threshold is None:
+ default_reward_threshold = {"CartPole-v0": 195}
+ args.reward_threshold = default_reward_threshold.get(
+ args.task, env.spec.reward_threshold
+ )
+ # train_envs = gym.make(args.task)
+ # you can also use tianshou.env.SubprocVectorEnv
+ train_envs = DummyVectorEnv(
+ [lambda: gym.make(args.task) for _ in range(args.training_num)]
+ )
+ # test_envs = gym.make(args.task)
+ test_envs = DummyVectorEnv(
+ [lambda: gym.make(args.task) for _ in range(args.test_num)]
+ )
+ # seed
+ np.random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ train_envs.seed(args.seed)
+ test_envs.seed(args.seed)
+ # Q_param = V_param = {"hidden_sizes": [128]}
+ # model
+ net = Net(
+ args.state_shape,
+ args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ device=args.device,
+ # dueling=(Q_param, V_param),
+ ).to(args.device)
+ optim = torch.optim.Adam(net.parameters(), lr=args.lr)
+ policy = DQNPolicy(
+ net,
+ optim,
+ args.gamma,
+ args.n_step,
+ target_update_freq=args.target_update_freq,
+ )
+ # buffer
+ if args.prioritized_replay:
+ buf = PrioritizedVectorReplayBuffer(
+ args.buffer_size,
+ buffer_num=len(train_envs),
+ alpha=args.alpha,
+ beta=args.beta,
+ )
+ else:
+ buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
+ # collector
+ train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
+ test_collector = Collector(policy, test_envs, exploration_noise=True)
+ # policy.set_eps(1)
+ train_collector.collect(n_step=args.batch_size * args.training_num)
+ # log
+ log_path = os.path.join(args.logdir, args.task, 'dqn')
+ writer = SummaryWriter(log_path)
+ logger = TensorboardLogger(writer)
+
+ def save_best_fn(policy):
+ torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
+
+ def stop_fn(mean_rewards):
+ return mean_rewards >= args.reward_threshold
+
+ def train_fn(epoch, env_step):
+ # eps annnealing, just a demo
+ if env_step <= 10000:
+ policy.set_eps(args.eps_train)
+ elif env_step <= 50000:
+ eps = args.eps_train - (env_step - 10000) / \
+ 40000 * (0.9 * args.eps_train)
+ policy.set_eps(eps)
+ else:
+ policy.set_eps(0.1 * args.eps_train)
+
+ def test_fn(epoch, env_step):
+ policy.set_eps(args.eps_test)
+
+ # trainer
+ result = offpolicy_trainer(
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.step_per_collect,
+ args.test_num,
+ args.batch_size,
+ update_per_step=args.update_per_step,
+ train_fn=train_fn,
+ test_fn=test_fn,
+ stop_fn=stop_fn,
+ save_best_fn=save_best_fn,
+ logger=logger,
+ )
+ assert stop_fn(result['best_reward'])
+
+ if __name__ == '__main__':
+ pprint.pprint(result)
+ # Let's watch its performance!
+ env = gym.make(args.task)
+ policy.eval()
+ policy.set_eps(args.eps_test)
+ collector = Collector(policy, env)
+ result = collector.collect(n_episode=1, render=args.render)
+ rews, lens = result["rews"], result["lens"]
+ print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
+
+
+def test_pdqn(args=get_args()):
+ args.prioritized_replay = True
+ args.gamma = .95
+ args.seed = 1
+ test_dqn(args)
+
+
+if __name__ == '__main__':
+ test_dqn(get_args())
diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py
index 36ff5fa76..c01108c71 100644
--- a/test/discrete/test_drqn.py
+++ b/test/discrete/test_drqn.py
@@ -2,7 +2,7 @@
import os
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py
index e25c42997..3e5381f4e 100644
--- a/test/discrete/test_fqf.py
+++ b/test/discrete/test_fqf.py
@@ -2,7 +2,7 @@
import os
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py
index 725c9a9d5..249cdc542 100644
--- a/test/discrete/test_iqn.py
+++ b/test/discrete/test_iqn.py
@@ -2,7 +2,7 @@
import os
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py
index 0ea8f745a..a8a2792ed 100644
--- a/test/discrete/test_pg.py
+++ b/test/discrete/test_pg.py
@@ -2,7 +2,7 @@
import os
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py
index b7dba97c9..20d6d5060 100644
--- a/test/discrete/test_ppo.py
+++ b/test/discrete/test_ppo.py
@@ -2,7 +2,7 @@
import os
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py
index 9c699b4e3..2de7ca9f4 100644
--- a/test/discrete/test_qrdqn.py
+++ b/test/discrete/test_qrdqn.py
@@ -2,7 +2,7 @@
import os
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py
index 78b43dff9..6ecaaaa8c 100644
--- a/test/discrete/test_rainbow.py
+++ b/test/discrete/test_rainbow.py
@@ -3,7 +3,7 @@
import pickle
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py
index ac8cf7253..06831058e 100644
--- a/test/discrete/test_sac.py
+++ b/test/discrete/test_sac.py
@@ -2,7 +2,7 @@
import os
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py
index fba0b5523..bb5750875 100644
--- a/test/modelbased/test_dqn_icm.py
+++ b/test/modelbased/test_dqn_icm.py
@@ -2,7 +2,7 @@
import os
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py
index 6efd96277..77f0f8300 100644
--- a/test/modelbased/test_ppo_icm.py
+++ b/test/modelbased/test_ppo_icm.py
@@ -2,7 +2,7 @@
import os
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py
index 2c5c3fcd3..1332ae506 100644
--- a/test/modelbased/test_psrl.py
+++ b/test/modelbased/test_psrl.py
@@ -48,10 +48,12 @@ def get_args():
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_psrl(args=get_args()):
# if you want to use python vector env, please refer to other test scripts
- train_envs = env = envpool.make_gym(
+ train_envs = env = envpool.make_gymnasium(
args.task, num_envs=args.training_num, seed=args.seed
)
- test_envs = envpool.make_gym(args.task, num_envs=args.test_num, seed=args.seed)
+ test_envs = envpool.make_gymnasium(
+ args.task, num_envs=args.test_num, seed=args.seed
+ )
if args.reward_threshold is None:
default_reward_threshold = {"NChain-v0": 3400}
args.reward_threshold = default_reward_threshold.get(
diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py
index 2e01723f3..bbea01d19 100644
--- a/test/offline/gather_cartpole_data.py
+++ b/test/offline/gather_cartpole_data.py
@@ -2,7 +2,7 @@
import os
import pickle
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py
index f60ec3696..b124aad61 100644
--- a/test/offline/gather_pendulum_data.py
+++ b/test/offline/gather_pendulum_data.py
@@ -2,7 +2,7 @@
import os
import pickle
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py
index a0fea4f50..687304bbb 100644
--- a/test/offline/test_bcq.py
+++ b/test/offline/test_bcq.py
@@ -4,7 +4,7 @@
import pickle
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py
index 6d52bca7a..b5dfafc14 100644
--- a/test/offline/test_cql.py
+++ b/test/offline/test_cql.py
@@ -4,7 +4,7 @@
import pickle
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py
index 66e42f3e9..1c9fcd591 100644
--- a/test/offline/test_discrete_bcq.py
+++ b/test/offline/test_discrete_bcq.py
@@ -3,7 +3,7 @@
import pickle
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py
index eca810fb6..99f60933f 100644
--- a/test/offline/test_discrete_cql.py
+++ b/test/offline/test_discrete_cql.py
@@ -3,7 +3,7 @@
import pickle
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py
index 9f47b32e7..c12513c30 100644
--- a/test/offline/test_discrete_crr.py
+++ b/test/offline/test_discrete_crr.py
@@ -3,7 +3,7 @@
import pickle
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py
index aa123ec87..060fc0aab 100644
--- a/test/offline/test_gail.py
+++ b/test/offline/test_gail.py
@@ -3,7 +3,7 @@
import pickle
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.distributions import Independent, Normal
diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py
index e34dfc6d2..2adab61cf 100644
--- a/test/offline/test_td3_bc.py
+++ b/test/offline/test_td3_bc.py
@@ -4,7 +4,7 @@
import pickle
import pprint
-import gym
+import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py
index 48a5f202a..984d765fd 100644
--- a/test/pettingzoo/pistonball.py
+++ b/test/pettingzoo/pistonball.py
@@ -3,7 +3,7 @@
import warnings
from typing import List, Optional, Tuple
-import gym
+import gymnasium as gym
import numpy as np
import pettingzoo.butterfly.pistonball_v6 as pistonball_v6
import torch
diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py
index 8f4818440..6ffb5ec62 100644
--- a/test/pettingzoo/pistonball_continuous.py
+++ b/test/pettingzoo/pistonball_continuous.py
@@ -3,7 +3,7 @@
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union
-import gym
+import gymnasium as gym
import numpy as np
import pettingzoo.butterfly.pistonball_v6 as pistonball_v6
import torch
diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py
index 7610cee98..0b3b38262 100644
--- a/test/pettingzoo/tic_tac_toe.py
+++ b/test/pettingzoo/tic_tac_toe.py
@@ -4,7 +4,6 @@
from functools import partial
from typing import Optional, Tuple
-import gym
import gymnasium
import numpy as np
import torch
@@ -106,7 +105,7 @@ def get_agents(
) -> Tuple[BasePolicy, torch.optim.Optimizer, list]:
env = get_env()
observation_space = env.observation_space['observation'] if isinstance(
- env.observation_space, (gym.spaces.Dict, gymnasium.spaces.Dict)
+ env.observation_space, gymnasium.spaces.Dict
) else env.observation_space
args.state_shape = observation_space.shape or observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
diff --git a/test/throughput/test_buffer_profile.py b/test/throughput/test_buffer_profile.py
index b14552eb8..45a1eeb8b 100644
--- a/test/throughput/test_buffer_profile.py
+++ b/test/throughput/test_buffer_profile.py
@@ -1,7 +1,7 @@
import sys
import time
-import gym
+import gymnasium as gym
import numpy as np
import tqdm
diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py
index 88f6c76ba..3ff1ab26e 100644
--- a/tianshou/data/collector.py
+++ b/tianshou/data/collector.py
@@ -2,7 +2,7 @@
import warnings
from typing import Any, Callable, Dict, List, Optional, Union
-import gym
+import gymnasium as gym
import numpy as np
import torch
@@ -143,24 +143,14 @@ def reset_buffer(self, keep_statistics: bool = False) -> None:
def reset_env(self, gym_reset_kwargs: Optional[Dict[str, Any]] = None) -> None:
"""Reset all of the environments."""
gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
- rval = self.env.reset(**gym_reset_kwargs)
- returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and (
- isinstance(rval[1], dict) or isinstance(rval[1][0], dict)
- )
- if returns_info:
- obs, info = rval
- if self.preprocess_fn:
- processed_data = self.preprocess_fn(
- obs=obs, info=info, env_id=np.arange(self.env_num)
- )
- obs = processed_data.get("obs", obs)
- info = processed_data.get("info", info)
- self.data.info = info
- else:
- obs = rval
- if self.preprocess_fn:
- obs = self.preprocess_fn(obs=obs, env_id=np.arange(self.env_num
- )).get("obs", obs)
+ obs, info = self.env.reset(**gym_reset_kwargs)
+ if self.preprocess_fn:
+ processed_data = self.preprocess_fn(
+ obs=obs, info=info, env_id=np.arange(self.env_num)
+ )
+ obs = processed_data.get("obs", obs)
+ info = processed_data.get("info", info)
+ self.data.info = info
self.data.obs = obs
def _reset_state(self, id: Union[int, List[int]]) -> None:
@@ -181,24 +171,15 @@ def _reset_env_with_ids(
gym_reset_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
- rval = self.env.reset(global_ids, **gym_reset_kwargs)
- returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and (
- isinstance(rval[1], dict) or isinstance(rval[1][0], dict)
- )
- if returns_info:
- obs_reset, info = rval
- if self.preprocess_fn:
- processed_data = self.preprocess_fn(
- obs=obs_reset, info=info, env_id=global_ids
- )
- obs_reset = processed_data.get("obs", obs_reset)
- info = processed_data.get("info", info)
- self.data.info[local_ids] = info
- else:
- obs_reset = rval
- if self.preprocess_fn:
- obs_reset = self.preprocess_fn(obs=obs_reset, env_id=global_ids
- ).get("obs", obs_reset)
+ obs_reset, info = self.env.reset(global_ids, **gym_reset_kwargs)
+ if self.preprocess_fn:
+ processed_data = self.preprocess_fn(
+ obs=obs_reset, info=info, env_id=global_ids
+ )
+ obs_reset = processed_data.get("obs", obs_reset)
+ info = processed_data.get("info", info)
+ self.data.info[local_ids] = info
+
self.data.obs_next[local_ids] = obs_reset
def collect(
@@ -311,24 +292,11 @@ def collect(
# get bounded and remapped actions first (not saved into buffer)
action_remap = self.policy.map_action(self.data.act)
# step in env
- result = self.env.step(action_remap, ready_env_ids) # type: ignore
- if len(result) == 5:
- obs_next, rew, terminated, truncated, info = result
- done = np.logical_or(terminated, truncated)
- elif len(result) == 4:
- obs_next, rew, done, info = result
- if isinstance(info, dict):
- truncated = info["TimeLimit.truncated"]
- else:
- truncated = np.array(
- [
- info_item.get("TimeLimit.truncated", False)
- for info_item in info
- ]
- )
- terminated = np.logical_and(done, ~truncated)
- else:
- raise ValueError()
+ obs_next, rew, terminated, truncated, info = self.env.step(
+ action_remap, # type: ignore
+ ready_env_ids
+ )
+ done = np.logical_or(terminated, truncated)
self.data.update(
obs_next=obs_next,
@@ -583,25 +551,11 @@ def collect(
# get bounded and remapped actions first (not saved into buffer)
action_remap = self.policy.map_action(self.data.act)
# step in env
- result = self.env.step(action_remap, ready_env_ids) # type: ignore
-
- if len(result) == 5:
- obs_next, rew, terminated, truncated, info = result
- done = np.logical_or(terminated, truncated)
- elif len(result) == 4:
- obs_next, rew, done, info = result
- if isinstance(info, dict):
- truncated = info["TimeLimit.truncated"]
- else:
- truncated = np.array(
- [
- info_item.get("TimeLimit.truncated", False)
- for info_item in info
- ]
- )
- terminated = np.logical_and(done, ~truncated)
- else:
- raise ValueError()
+ obs_next, rew, terminated, truncated, info = self.env.step(
+ action_remap, # type: ignore
+ ready_env_ids
+ )
+ done = np.logical_or(terminated, truncated)
# change self.data here because ready_env_ids has changed
try:
diff --git a/tianshou/env/gym_wrappers.py b/tianshou/env/gym_wrappers.py
index b906b79b9..c9ce66acb 100644
--- a/tianshou/env/gym_wrappers.py
+++ b/tianshou/env/gym_wrappers.py
@@ -1,6 +1,6 @@
-from typing import Any, Dict, List, Tuple, Union
+from typing import Any, Dict, List, SupportsFloat, Tuple, Union
-import gym
+import gymnasium as gym
import numpy as np
from packaging import version
@@ -26,7 +26,7 @@ def __init__(self, env: gym.Env, action_per_dim: Union[int, List[int]]) -> None:
dtype=object
)
- def action(self, act: np.ndarray) -> np.ndarray:
+ def action(self, act: np.ndarray) -> np.ndarray: # type: ignore
# modify act
assert len(act.shape) <= 2, f"Unknown action format with shape {act.shape}."
if len(act.shape) == 1:
@@ -50,7 +50,7 @@ def __init__(self, env: gym.Env) -> None:
self.bases[i] = self.bases[i - 1] * nvec[-i]
self.action_space = gym.spaces.Discrete(np.prod(nvec))
- def action(self, act: np.ndarray) -> np.ndarray:
+ def action(self, act: np.ndarray) -> np.ndarray: # type: ignore
converted_act = []
for b in np.flip(self.bases):
converted_act.append(act // b)
@@ -74,7 +74,8 @@ def __init__(self, env: gym.Env):
{gym.__version__}"
)
- def step(self, act: np.ndarray) -> Tuple[Any, float, bool, bool, Dict[Any, Any]]:
+ def step(self,
+ act: np.ndarray) -> Tuple[Any, SupportsFloat, bool, bool, Dict[str, Any]]:
observation, reward, terminated, truncated, info = super().step(act)
terminated = (terminated or truncated)
return observation, reward, terminated, truncated, info
diff --git a/tianshou/env/pettingzoo_env.py b/tianshou/env/pettingzoo_env.py
index 28173336e..9ab828097 100644
--- a/tianshou/env/pettingzoo_env.py
+++ b/tianshou/env/pettingzoo_env.py
@@ -1,6 +1,6 @@
import warnings
from abc import ABC
-from typing import Any, Dict, List, Tuple, Union
+from typing import Any, Dict, List, Tuple
import pettingzoo
from gymnasium import spaces
@@ -65,22 +65,11 @@ def __init__(self, env: BaseWrapper):
self.reset()
- def reset(self, *args: Any, **kwargs: Any) -> Union[dict, Tuple[dict, dict]]:
+ def reset(self, *args: Any, **kwargs: Any) -> Tuple[dict, dict]:
self.env.reset(*args, **kwargs)
- # Here, we do not label the return values explicitly to keep compatibility with
- # old step API. TODO: Change once PettingZoo>=1.21.0 is required
- last_return = self.env.last(self)
+ observation, reward, terminated, truncated, info = self.env.last(self)
- if len(last_return) == 4:
- warnings.warn(
- "The PettingZoo environment is using the old step API. "
- "This API may not be supported in future versions of tianshou. "
- "We recommend that you update the environment code or apply a "
- "compatibility wrapper.", DeprecationWarning
- )
-
- observation, info = last_return[0], last_return[-1]
if isinstance(observation, dict) and 'action_mask' in observation:
observation_dict = {
'agent_id': self.env.agent_selection,
@@ -101,21 +90,13 @@ def reset(self, *args: Any, **kwargs: Any) -> Union[dict, Tuple[dict, dict]]:
'obs': observation,
}
- if "return_info" in kwargs and kwargs["return_info"]:
- return observation_dict, info
- else:
- return observation_dict
+ return observation_dict, info
- def step(
- self, action: Any
- ) -> Union[Tuple[Dict, List[int], bool, Dict], Tuple[Dict, List[int], bool, bool,
- Dict]]:
+ def step(self, action: Any) -> Tuple[Dict, List[int], bool, bool, Dict]:
self.env.step(action)
- # Here, we do not label the return values explicitly to keep compatibility with
- # old step API. TODO: Change once PettingZoo>=1.21.0 is required
- last_return = self.env.last()
- observation = last_return[0]
+ observation, rew, term, trunc, info = self.env.last()
+
if isinstance(observation, dict) and 'action_mask' in observation:
obs = {
'agent_id': self.env.agent_selection,
@@ -135,7 +116,7 @@ def step(
for agent_id, reward in self.env.rewards.items():
self.rewards[self.agent_idx[agent_id]] = reward
- return (obs, self.rewards, *last_return[2:]) # type: ignore
+ return obs, self.rewards, term, trunc, info
def close(self) -> None:
self.env.close()
diff --git a/tianshou/env/utils.py b/tianshou/env/utils.py
index 8868edfe8..ac671016e 100644
--- a/tianshou/env/utils.py
+++ b/tianshou/env/utils.py
@@ -1,9 +1,16 @@
-from typing import Any, Tuple
+from typing import TYPE_CHECKING, Any, Tuple, Union
import cloudpickle
+import gymnasium
import numpy as np
-gym_old_venv_step_type = Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
+from tianshou.env.pettingzoo_env import PettingZooEnv
+
+if TYPE_CHECKING:
+ import gym
+
+ENV_TYPE = Union[gymnasium.Env, "gym.Env", PettingZooEnv]
+
gym_new_venv_step_type = Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray,
np.ndarray]
diff --git a/tianshou/env/venv_wrappers.py b/tianshou/env/venv_wrappers.py
index 5e5909af6..66470289c 100644
--- a/tianshou/env/venv_wrappers.py
+++ b/tianshou/env/venv_wrappers.py
@@ -2,7 +2,7 @@
import numpy as np
-from tianshou.env.utils import gym_new_venv_step_type, gym_old_venv_step_type
+from tianshou.env.utils import gym_new_venv_step_type
from tianshou.env.venvs import GYM_RESERVED_KEYS, BaseVectorEnv
from tianshou.utils import RunningMeanStd
@@ -42,14 +42,14 @@ def reset(
self,
id: Optional[Union[int, List[int], np.ndarray]] = None,
**kwargs: Any,
- ) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]:
+ ) -> Tuple[np.ndarray, Union[dict, List[dict]]]:
return self.venv.reset(id, **kwargs)
def step(
self,
action: np.ndarray,
id: Optional[Union[int, List[int], np.ndarray]] = None,
- ) -> Union[gym_old_venv_step_type, gym_new_venv_step_type]:
+ ) -> gym_new_venv_step_type:
return self.venv.step(action, id)
def seed(
@@ -85,17 +85,10 @@ def reset(
self,
id: Optional[Union[int, List[int], np.ndarray]] = None,
**kwargs: Any,
- ) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]:
- rval = self.venv.reset(id, **kwargs)
- returns_info = isinstance(rval, (tuple, list)) and (len(rval) == 2) and (
- isinstance(rval[1], dict) or isinstance(rval[1][0], dict)
- )
- if returns_info:
- obs, info = rval
- else:
- obs = rval
+ ) -> Tuple[np.ndarray, Union[dict, List[dict]]]:
+ obs, info = self.venv.reset(id, **kwargs)
- if isinstance(obs, tuple):
+ if isinstance(obs, tuple): # type: ignore
raise TypeError(
"Tuple observation space is not supported. ",
"Please change it to array or dict space",
@@ -104,20 +97,17 @@ def reset(
if self.obs_rms and self.update_obs_rms:
self.obs_rms.update(obs)
obs = self._norm_obs(obs)
- if returns_info:
- return obs, info
- else:
- return obs
+ return obs, info
def step(
self,
action: np.ndarray,
id: Optional[Union[int, List[int], np.ndarray]] = None,
- ) -> Union[gym_old_venv_step_type, gym_new_venv_step_type]:
+ ) -> gym_new_venv_step_type:
step_results = self.venv.step(action, id)
if self.obs_rms and self.update_obs_rms:
self.obs_rms.update(step_results[0])
- return (self._norm_obs(step_results[0]), *step_results[1:]) # type:ignore
+ return (self._norm_obs(step_results[0]), *step_results[1:])
def _norm_obs(self, obs: np.ndarray) -> np.ndarray:
if self.obs_rms:
diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py
index fc5b757b1..a0fb0b538 100644
--- a/tianshou/env/venvs.py
+++ b/tianshou/env/venvs.py
@@ -1,9 +1,12 @@
+import warnings
from typing import Any, Callable, List, Optional, Tuple, Union
-import gym
+import gymnasium as gym
import numpy as np
+import packaging
-from tianshou.env.utils import gym_new_venv_step_type, gym_old_venv_step_type
+from tianshou.env.pettingzoo_env import PettingZooEnv
+from tianshou.env.utils import ENV_TYPE, gym_new_venv_step_type
from tianshou.env.worker import (
DummyEnvWorker,
EnvWorker,
@@ -11,11 +14,76 @@
SubprocEnvWorker,
)
+try:
+ import gym as old_gym
+ has_old_gym = True
+except ImportError:
+ has_old_gym = False
+
GYM_RESERVED_KEYS = [
"metadata", "reward_range", "spec", "action_space", "observation_space"
]
+def _patch_env_generator(fn: Callable[[], ENV_TYPE]) -> Callable[[], gym.Env]:
+ """Takes an environment generator and patches it to return Gymnasium envs.
+
+ This function takes the environment generator `fn` and returns a patched
+ generator, without invoking `fn`. The original generator may return
+ Gymnasium or OpenAI Gym environments, but the patched generator wraps
+ the result of `fn` in a shimmy wrapper to convert it to Gymnasium,
+ if necessary.
+ """
+
+ def patched() -> gym.Env:
+ assert callable(
+ fn
+ ), "Env generators that are provided to vector environemnts must be callable."
+
+ env = fn()
+ if isinstance(env, (gym.Env, PettingZooEnv)):
+ return env
+
+ if not has_old_gym or not isinstance(env, old_gym.Env):
+ raise ValueError(
+ f"Environment generator returned a {type(env)}, not a Gymnasium "
+ f"environment. In this case, we expect OpenAI Gym to be "
+ f"installed and the environment to be an OpenAI Gym environment."
+ )
+
+ try:
+ import shimmy
+ except ImportError as e:
+ raise ImportError(
+ "Missing shimmy installation. You provided an environment generator "
+ "that returned an OpenAI Gym environment. "
+ "Tianshou has transitioned to using Gymnasium internally. "
+ "In order to use OpenAI Gym environments with tianshou, you need to "
+ "install shimmy (`pip install shimmy`)."
+ ) from e
+
+ warnings.warn(
+ "You provided an environment generator that returned an OpenAI Gym "
+ "environment. We strongly recommend transitioning to Gymnasium "
+ "environments. "
+ "Tianshou is automatically wrapping your environments in a compatibility "
+ "layer, which could potentially cause issues."
+ )
+
+ gym_version = packaging.version.parse(old_gym.__version__)
+ if gym_version >= packaging.version.parse("0.26.0"):
+ return shimmy.GymV26CompatibilityV0(env=env)
+ elif gym_version >= packaging.version.parse("0.22.0"):
+ return shimmy.GymV22CompatibilityV0(env=env)
+ else:
+ raise Exception(
+ f"Found OpenAI Gym version {gym.__version__}. "
+ f"Tianshou only supports OpenAI Gym environments of version>=0.22.0"
+ )
+
+ return patched
+
+
class BaseVectorEnv(object):
"""Base class for vectorized environments.
@@ -69,7 +137,7 @@ def seed(self, seed):
def __init__(
self,
- env_fns: List[Callable[[], gym.Env]],
+ env_fns: List[Callable[[], ENV_TYPE]],
worker_fn: Callable[[Callable[[], gym.Env]], EnvWorker],
wait_num: Optional[int] = None,
timeout: Optional[float] = None,
@@ -77,7 +145,7 @@ def __init__(
self._env_fns = env_fns
# A VectorEnv contains a pool of EnvWorkers, which corresponds to
# interact with the given envs (one worker <-> one env).
- self.workers = [worker_fn(fn) for fn in env_fns]
+ self.workers = [worker_fn(_patch_env_generator(fn)) for fn in env_fns]
self.worker_class = type(self.workers[0])
assert issubclass(self.worker_class, EnvWorker)
assert all([isinstance(w, self.worker_class) for w in self.workers])
@@ -186,7 +254,7 @@ def reset(
self,
id: Optional[Union[int, List[int], np.ndarray]] = None,
**kwargs: Any,
- ) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]:
+ ) -> Tuple[np.ndarray, Union[dict, List[dict]]]:
"""Reset the state of some envs and return initial observations.
If id is None, reset the state of all the environments and return
@@ -203,15 +271,13 @@ def reset(
self.workers[i].send(None, **kwargs)
ret_list = [self.workers[i].recv() for i in id]
- reset_returns_info = isinstance(ret_list[0], (tuple, list)) and len(
+ assert isinstance(ret_list[0], (tuple, list)) and len(
ret_list[0]
) == 2 and isinstance(ret_list[0][1], dict)
- if reset_returns_info:
- obs_list = [r[0] for r in ret_list]
- else:
- obs_list = ret_list
- if isinstance(obs_list[0], tuple):
+ obs_list = [r[0] for r in ret_list]
+
+ if isinstance(obs_list[0], tuple): # type: ignore
raise TypeError(
"Tuple observation space is not supported. ",
"Please change it to array or dict space",
@@ -221,17 +287,14 @@ def reset(
except ValueError: # different len(obs)
obs = np.array(obs_list, dtype=object)
- if reset_returns_info:
- infos = [r[1] for r in ret_list]
- return obs, infos # type: ignore
- else:
- return obs
+ infos = [r[1] for r in ret_list]
+ return obs, infos # type: ignore
def step(
self,
action: np.ndarray,
id: Optional[Union[int, List[int], np.ndarray]] = None,
- ) -> Union[gym_old_venv_step_type, gym_new_venv_step_type]:
+ ) -> gym_new_venv_step_type:
"""Run one timestep of some environments' dynamics.
If id is None, run one timestep of all the environments’ dynamics;
@@ -246,16 +309,6 @@ def step(
:return: A tuple consisting of either:
- * ``obs`` a numpy.ndarray, the agent's observation of current environments
- * ``rew`` a numpy.ndarray, the amount of rewards returned after \
- previous actions
- * ``done`` a numpy.ndarray, whether these episodes have ended, in \
- which case further step() calls will return undefined results
- * ``info`` a numpy.ndarray, contains auxiliary diagnostic \
- information (helpful for debugging, and sometimes learning)
-
- or:
-
* ``obs`` a numpy.ndarray, the agent's observation of current environments
* ``rew`` a numpy.ndarray, the amount of rewards returned after \
previous actions
@@ -265,9 +318,6 @@ def step(
* ``info`` a numpy.ndarray, contains auxiliary diagnostic \
information (helpful for debugging, and sometimes learning)
- The case distinction is made based on whether the underlying environment
- uses the old step API (first case) or the new step API (second case).
-
For the async simulation:
Provide the given action to the environments. The action sequence
@@ -312,14 +362,14 @@ def step(
env_return[-1]["env_id"] = env_id # Add `env_id` to info
result.append(env_return)
self.ready_id.append(env_id)
- return_lists = tuple(zip(*result))
- obs_list = return_lists[0]
+ obs_list, rew_list, term_list, trunc_list, info_list = tuple(zip(*result))
try:
obs_stack = np.stack(obs_list)
except ValueError: # different len(obs)
obs_stack = np.array(obs_list, dtype=object)
- other_stacks = map(np.stack, return_lists[1:])
- return (obs_stack, *other_stacks) # type: ignore
+ return obs_stack, np.stack(rew_list), np.stack(term_list), np.stack(
+ trunc_list
+ ), np.stack(info_list)
def seed(
self,
@@ -374,7 +424,7 @@ class DummyVectorEnv(BaseVectorEnv):
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
"""
- def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
+ def __init__(self, env_fns: List[Callable[[], ENV_TYPE]], **kwargs: Any) -> None:
super().__init__(env_fns, DummyEnvWorker, **kwargs)
@@ -386,7 +436,7 @@ class SubprocVectorEnv(BaseVectorEnv):
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
"""
- def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
+ def __init__(self, env_fns: List[Callable[[], ENV_TYPE]], **kwargs: Any) -> None:
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
return SubprocEnvWorker(fn, share_memory=False)
@@ -404,7 +454,7 @@ class ShmemVectorEnv(BaseVectorEnv):
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
"""
- def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
+ def __init__(self, env_fns: List[Callable[[], ENV_TYPE]], **kwargs: Any) -> None:
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
return SubprocEnvWorker(fn, share_memory=True)
@@ -422,7 +472,7 @@ class RayVectorEnv(BaseVectorEnv):
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
"""
- def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
+ def __init__(self, env_fns: List[Callable[[], ENV_TYPE]], **kwargs: Any) -> None:
try:
import ray
except ImportError as exception:
diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/base.py
index 779a8229a..773d56bce 100644
--- a/tianshou/env/worker/base.py
+++ b/tianshou/env/worker/base.py
@@ -1,10 +1,10 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, List, Optional, Tuple, Union
-import gym
+import gymnasium as gym
import numpy as np
-from tianshou.env.utils import gym_new_venv_step_type, gym_old_venv_step_type
+from tianshou.env.utils import gym_new_venv_step_type
from tianshou.utils import deprecation
@@ -14,8 +14,7 @@ class EnvWorker(ABC):
def __init__(self, env_fn: Callable[[], gym.Env]) -> None:
self._env_fn = env_fn
self.is_closed = False
- self.result: Union[gym_old_venv_step_type, gym_new_venv_step_type,
- Tuple[np.ndarray, dict], np.ndarray]
+ self.result: Union[gym_new_venv_step_type, Tuple[np.ndarray, dict]]
self.action_space = self.get_env_attr("action_space") # noqa: B009
self.is_reset = False
@@ -48,8 +47,7 @@ def send(self, action: Optional[np.ndarray]) -> None:
def recv(
self
- ) -> Union[gym_old_venv_step_type, gym_new_venv_step_type, Tuple[np.ndarray, dict],
- np.ndarray]: # noqa:E125
+ ) -> Union[gym_new_venv_step_type, Tuple[np.ndarray, dict], ]: # noqa:E125
"""Receive result from low-level worker.
If the last "send" function sends a NULL action, it only returns a
@@ -67,12 +65,10 @@ def recv(
return self.result
@abstractmethod
- def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]:
+ def reset(self, **kwargs: Any) -> Tuple[np.ndarray, dict]:
pass
- def step(
- self, action: np.ndarray
- ) -> Union[gym_old_venv_step_type, gym_new_venv_step_type]:
+ def step(self, action: np.ndarray) -> gym_new_venv_step_type:
"""Perform one timestep of the environment's dynamic.
"send" and "recv" are coupled in sync simulation, so users only call
diff --git a/tianshou/env/worker/dummy.py b/tianshou/env/worker/dummy.py
index 7613e6784..4eec4e0fa 100644
--- a/tianshou/env/worker/dummy.py
+++ b/tianshou/env/worker/dummy.py
@@ -1,6 +1,6 @@
-from typing import Any, Callable, List, Optional, Tuple, Union
+from typing import Any, Callable, List, Optional, Tuple
-import gym
+import gymnasium as gym
import numpy as np
from tianshou.env.worker import EnvWorker
@@ -19,7 +19,7 @@ def get_env_attr(self, key: str) -> Any:
def set_env_attr(self, key: str, value: Any) -> None:
setattr(self.env.unwrapped, key, value)
- def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]:
+ def reset(self, **kwargs: Any) -> Tuple[np.ndarray, dict]:
if "seed" in kwargs:
super().seed(kwargs["seed"])
return self.env.reset(**kwargs)
diff --git a/tianshou/env/worker/ray.py b/tianshou/env/worker/ray.py
index e73d3b7f3..fe2b8fe8d 100644
--- a/tianshou/env/worker/ray.py
+++ b/tianshou/env/worker/ray.py
@@ -1,9 +1,9 @@
-from typing import Any, Callable, List, Optional, Union
+from typing import Any, Callable, List, Optional
-import gym
+import gymnasium as gym
import numpy as np
-from tianshou.env.utils import gym_new_venv_step_type, gym_old_venv_step_type
+from tianshou.env.utils import gym_new_venv_step_type
from tianshou.env.worker import EnvWorker
try:
@@ -14,22 +14,6 @@
class _SetAttrWrapper(gym.Wrapper):
- def __init__(self, env: gym.Env) -> None:
- """Constructor of this wrapper.
-
- For Gym 0.25, wrappers will automatically
- change to the old step API. We need to check
- which API ``env`` follows and adjust the
- wrapper accordingly.
- """
- env.reset()
- step_result = env.step(env.action_space.sample())
- new_step_api = len(step_result) == 5
- try:
- super().__init__(env, new_step_api=new_step_api) # type: ignore
- except TypeError: # The kwarg `new_step_api` was removed in Gym 0.26
- super().__init__(env)
-
def set_env_attr(self, key: str, value: Any) -> None:
setattr(self.env.unwrapped, key, value)
@@ -72,7 +56,7 @@ def send(self, action: Optional[np.ndarray], **kwargs: Any) -> None:
else:
self.result = self.env.step.remote(action)
- def recv(self) -> Union[gym_old_venv_step_type, gym_new_venv_step_type]:
+ def recv(self) -> gym_new_venv_step_type:
return ray.get(self.result) # type: ignore
def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py
index 0c943f8bb..68f34e687 100644
--- a/tianshou/env/worker/subproc.py
+++ b/tianshou/env/worker/subproc.py
@@ -5,14 +5,10 @@
from multiprocessing.context import Process
from typing import Any, Callable, List, Optional, Tuple, Union
-import gym
+import gymnasium as gym
import numpy as np
-from tianshou.env.utils import (
- CloudpickleWrapper,
- gym_new_venv_step_type,
- gym_old_venv_step_type,
-)
+from tianshou.env.utils import CloudpickleWrapper, gym_new_venv_step_type
from tianshou.env.worker import EnvWorker
_NP_TO_CT = {
@@ -97,21 +93,11 @@ def _encode_obs(
env_return = (None, *env_return[1:])
p.send(env_return)
elif cmd == "reset":
- retval = env.reset(**data)
- reset_returns_info = isinstance(
- retval, (tuple, list)
- ) and len(retval) == 2 and isinstance(retval[1], dict)
- if reset_returns_info:
- obs, info = retval
- else:
- obs = retval
+ obs, info = env.reset(**data)
if obs_bufs is not None:
_encode_obs(obs, obs_bufs)
obs = None
- if reset_returns_info:
- p.send((obs, info))
- else:
- p.send(obs)
+ p.send((obs, info))
elif cmd == "close":
p.send(env.close())
p.close()
@@ -214,8 +200,7 @@ def send(self, action: Optional[np.ndarray], **kwargs: Any) -> None:
def recv(
self
- ) -> Union[gym_old_venv_step_type, gym_new_venv_step_type, Tuple[np.ndarray, dict],
- np.ndarray]: # noqa:E125
+ ) -> Union[gym_new_venv_step_type, Tuple[np.ndarray, dict]]: # noqa:E125
result = self.parent_remote.recv()
if isinstance(result, tuple):
if len(result) == 2:
@@ -233,7 +218,7 @@ def recv(
obs = self._decode_obs()
return obs
- def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]:
+ def reset(self, **kwargs: Any) -> Tuple[np.ndarray, dict]:
if "seed" in kwargs:
super().seed(kwargs["seed"])
self.parent_remote.send(["reset", kwargs])
diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py
index 5f8a85664..584bb0928 100644
--- a/tianshou/policy/base.py
+++ b/tianshou/policy/base.py
@@ -1,10 +1,10 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
-import gym
+import gymnasium as gym
import numpy as np
import torch
-from gym.spaces import Box, Discrete, MultiBinary, MultiDiscrete
+from gymnasium.spaces import Box, Discrete, MultiBinary, MultiDiscrete
from numba import njit
from torch import nn