From 558a5eabf698eadc425c52fd3065f613117b6ce6 Mon Sep 17 00:00:00 2001 From: Yifei Cheng Date: Mon, 25 Apr 2022 23:25:03 -0400 Subject: [PATCH 01/27] upgrade version of cartpole --- README.md | 4 ++-- docs/tutorials/cheatsheet.rst | 2 +- docs/tutorials/concepts.rst | 4 ++-- docs/tutorials/dqn.rst | 18 +++++++++--------- docs/tutorials/get_started.rst | 2 +- test/discrete/test_a2c_with_il.py | 6 +++--- test/discrete/test_c51.py | 4 ++-- test/discrete/test_dqn.py | 4 ++-- test/discrete/test_drqn.py | 4 ++-- test/discrete/test_fqf.py | 4 ++-- test/discrete/test_iqn.py | 4 ++-- test/discrete/test_pg.py | 4 ++-- test/discrete/test_ppo.py | 4 ++-- test/discrete/test_qrdqn.py | 6 +++--- test/discrete/test_rainbow.py | 4 ++-- test/discrete/test_sac.py | 4 ++-- test/modelbased/test_dqn_icm.py | 4 ++-- test/modelbased/test_ppo_icm.py | 4 ++-- test/offline/gather_cartpole_data.py | 6 +++--- test/offline/test_discrete_bcq.py | 4 ++-- test/offline/test_discrete_cql.py | 4 ++-- test/offline/test_discrete_crr.py | 4 ++-- 22 files changed, 52 insertions(+), 52 deletions(-) diff --git a/README.md b/README.md index aaf610549..0849ce75d 100644 --- a/README.md +++ b/README.md @@ -191,7 +191,7 @@ import tianshou as ts Define some hyper-parameters: ```python -task = 'CartPole-v0' +task = 'CartPole-v1' lr, epoch, batch_size = 1e-3, 10, 64 train_num, test_num = 10, 100 gamma, n_step, target_freq = 0.9, 3, 320 @@ -268,7 +268,7 @@ $ tensorboard --logdir log/dqn You can check out the [documentation](https://tianshou.readthedocs.io) for advanced usage. -It's worth a try: here is a test on a laptop (i7-8750H + GTX1060). It only uses **3** seconds for training an agent based on vanilla policy gradient on the CartPole-v0 task: (seed may be different across different platform and device) +It's worth a try: here is a test on a laptop (i7-8750H + GTX1060). It only uses **3** seconds for training an agent based on vanilla policy gradient on the CartPole-v1 task: (seed may be different across different platform and device) ```bash $ python3 test/discrete/test_pg.py --seed 0 --render 0.03 diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst index 6273c06af..2ccbdd2e7 100644 --- a/docs/tutorials/cheatsheet.rst +++ b/docs/tutorials/cheatsheet.rst @@ -130,7 +130,7 @@ Currently it supports Atari, VizDoom, toy_text and classic_control environments. # install envpool: pip3 install envpool import envpool - envs = envpool.make_gym("CartPole-v0", num_envs=10) + envs = envpool.make_gym("CartPole-v1", num_envs=10) collector = Collector(policy, envs, buffer) Here are some examples: https://github.com/sail-sg/envpool/tree/master/examples/tianshou_examples diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index cb6d616fe..14f6175e0 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -349,7 +349,7 @@ The general explanation is listed in :ref:`pseudocode`. Other usages of collecto :: policy = PGPolicy(...) # or other policies if you wish - env = gym.make("CartPole-v0") + env = gym.make("CartPole-v1") replay_buffer = ReplayBuffer(size=10000) @@ -359,7 +359,7 @@ The general explanation is listed in :ref:`pseudocode`. Other usages of collecto # the collector supports vectorized environments as well vec_buffer = VectorReplayBuffer(total_size=10000, buffer_num=3) # buffer_num should be equal to (suggested) or larger than #envs - envs = DummyVectorEnv([lambda: gym.make("CartPole-v0") for _ in range(3)]) + envs = DummyVectorEnv([lambda: gym.make("CartPole-v1") for _ in range(3)]) collector = Collector(policy, envs, buffer=vec_buffer) # collect 3 episodes diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index 5810b2b62..b136a2bf8 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -41,11 +41,11 @@ First of all, you have to make an environment for your agent to interact with. Y import gym import tianshou as ts - env = gym.make('CartPole-v0') + env = gym.make('CartPole-v1') -CartPole-v0 includes a cart carrying a pole moving on a track. This is a simple environment with a discrete action space, for which DQN applies. You have to identify whether the action space is continuous or discrete and apply eligible algorithms. DDPG :cite:`DDPG`, for example, could only be applied to continuous action spaces, while almost all other policy gradient methods could be applied to both. +CartPole-v1 includes a cart carrying a pole moving on a track. This is a simple environment with a discrete action space, for which DQN applies. You have to identify whether the action space is continuous or discrete and apply eligible algorithms. DDPG :cite:`DDPG`, for example, could only be applied to continuous action spaces, while almost all other policy gradient methods could be applied to both. -Here is the detail of useful fields of CartPole-v0: +Here is the detail of useful fields of CartPole-v1: - ``state``: the position of the cart, the velocity of the cart, the angle of the pole and the velocity of the tip of the pole; - ``action``: can only be one of ``[0, 1, 2]``, for moving the cart left, no move, and right; @@ -62,8 +62,8 @@ Setup Vectorized Environment If you want to use the original ``gym.Env``: :: - train_envs = gym.make('CartPole-v0') - test_envs = gym.make('CartPole-v0') + train_envs = gym.make('CartPole-v1') + test_envs = gym.make('CartPole-v1') Tianshou supports vectorized environment for all algorithms. It provides four types of vectorized environment wrapper: @@ -74,8 +74,8 @@ Tianshou supports vectorized environment for all algorithms. It provides four ty :: - train_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(10)]) - test_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(100)]) + train_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v1') for _ in range(10)]) + test_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v1') for _ in range(100)]) Here, we set up 10 environments in ``train_envs`` and 100 environments in ``test_envs``. @@ -84,8 +84,8 @@ You can also try the super-fast vectorized environment `EnvPool `_ - L1: `Batch `_ diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index b777e2e03..20698b083 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -18,7 +18,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='CartPole-v0') + parser.add_argument('--task', type=str, default='CartPole-v1') parser.add_argument('--reward-threshold', type=float, default=None) parser.add_argument('--seed', type=int, default=1) parser.add_argument('--buffer-size', type=int, default=20000) @@ -60,7 +60,7 @@ def test_a2c_with_il(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) @@ -131,7 +131,7 @@ def stop_fn(mean_rewards): policy.eval() # here we define an imitation collector with a trivial policy - # if args.task == 'CartPole-v0': + # if args.task == 'CartPole-v1': # env.spec.reward_threshold = 190 # lower the goal net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) net = Actor(net, args.action_shape, device=args.device).to(args.device) diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 993c4a80e..2f14942d8 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -18,7 +18,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='CartPole-v0') + parser.add_argument('--task', type=str, default='CartPole-v1') parser.add_argument('--reward-threshold', type=float, default=None) parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--eps-test', type=float, default=0.05) @@ -60,7 +60,7 @@ def test_c51(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 2644dc998..30a90d669 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -17,7 +17,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='CartPole-v0') + parser.add_argument('--task', type=str, default='CartPole-v1') parser.add_argument('--reward-threshold', type=float, default=None) parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--eps-test', type=float, default=0.05) @@ -54,7 +54,7 @@ def test_dqn(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 36ff5fa76..47e553741 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -17,7 +17,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='CartPole-v0') + parser.add_argument('--task', type=str, default='CartPole-v1') parser.add_argument('--reward-threshold', type=float, default=None) parser.add_argument('--seed', type=int, default=1) parser.add_argument('--eps-test', type=float, default=0.05) @@ -50,7 +50,7 @@ def test_drqn(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index e25c42997..30b9b860c 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -18,7 +18,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='CartPole-v0') + parser.add_argument('--task', type=str, default='CartPole-v1') parser.add_argument('--reward-threshold', type=float, default=None) parser.add_argument('--seed', type=int, default=1) parser.add_argument('--eps-test', type=float, default=0.05) @@ -57,7 +57,7 @@ def test_fqf(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index 725c9a9d5..b433922a4 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -18,7 +18,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='CartPole-v0') + parser.add_argument('--task', type=str, default='CartPole-v1') parser.add_argument('--reward-threshold', type=float, default=None) parser.add_argument('--seed', type=int, default=0) parser.add_argument('--eps-test', type=float, default=0.05) @@ -57,7 +57,7 @@ def test_iqn(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 1f5007f7a..3eabf498b 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -17,7 +17,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='CartPole-v0') + parser.add_argument('--task', type=str, default='CartPole-v1') parser.add_argument('--reward-threshold', type=float, default=None) parser.add_argument('--seed', type=int, default=1) parser.add_argument('--buffer-size', type=int, default=20000) @@ -46,7 +46,7 @@ def test_pg(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index b7dba97c9..eff4231c2 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -18,7 +18,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='CartPole-v0') + parser.add_argument('--task', type=str, default='CartPole-v1') parser.add_argument('--reward-threshold', type=float, default=None) parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--buffer-size', type=int, default=20000) @@ -57,7 +57,7 @@ def test_ppo(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 9c699b4e3..5301d40f4 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -17,7 +17,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='CartPole-v0') + parser.add_argument('--task', type=str, default='CartPole-v1') parser.add_argument('--reward-threshold', type=float, default=None) parser.add_argument('--seed', type=int, default=1) parser.add_argument('--eps-test', type=float, default=0.05) @@ -52,12 +52,12 @@ def get_args(): def test_qrdqn(args=get_args()): env = gym.make(args.task) - if args.task == 'CartPole-v0': + if args.task == 'CartPole-v1': env.spec.reward_threshold = 190 # lower the goal args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index 5e2345300..357622dee 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -19,7 +19,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='CartPole-v0') + parser.add_argument('--task', type=str, default='CartPole-v1') parser.add_argument('--reward-threshold', type=float, default=None) parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--eps-test', type=float, default=0.05) @@ -63,7 +63,7 @@ def test_rainbow(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index 6593c9864..d972150e8 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -18,7 +18,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='CartPole-v0') + parser.add_argument('--task', type=str, default='CartPole-v1') parser.add_argument('--reward-threshold', type=float, default=None) parser.add_argument('--seed', type=int, default=0) parser.add_argument('--buffer-size', type=int, default=20000) @@ -53,7 +53,7 @@ def test_discrete_sac(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 180} # lower the goal + default_reward_threshold = {"CartPole-v1": 180} # lower the goal args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index fba0b5523..4c177310c 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -18,7 +18,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='CartPole-v0') + parser.add_argument('--task', type=str, default='CartPole-v1') parser.add_argument('--reward-threshold', type=float, default=None) parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--eps-test', type=float, default=0.05) @@ -73,7 +73,7 @@ def test_dqn_icm(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index 6efd96277..e3df9b1ed 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -18,7 +18,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='CartPole-v0') + parser.add_argument('--task', type=str, default='CartPole-v1') parser.add_argument('--reward-threshold', type=float, default=None) parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--buffer-size', type=int, default=20000) @@ -75,7 +75,7 @@ def test_ppo(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index 2e01723f3..0e7783dff 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -16,12 +16,12 @@ def expert_file_name(): - return os.path.join(os.path.dirname(__file__), "expert_QRDQN_CartPole-v0.pkl") + return os.path.join(os.path.dirname(__file__), "expert_QRDQN_CartPole-v1.pkl") def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='CartPole-v0') + parser.add_argument('--task', type=str, default='CartPole-v1') parser.add_argument('--reward-threshold', type=float, default=None) parser.add_argument('--seed', type=int, default=1) parser.add_argument('--eps-test', type=float, default=0.05) @@ -61,7 +61,7 @@ def gather_data(): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 190} + default_reward_threshold = {"CartPole-v1": 190} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 51380fb62..f9fe16e64 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -24,7 +24,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument('--reward-threshold', type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps-test", type=float, default=0.001) @@ -59,7 +59,7 @@ def test_discrete_bcq(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 190} + default_reward_threshold = {"CartPole-v1": 190} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index eca810fb6..4155f799b 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -23,7 +23,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps-test", type=float, default=0.001) @@ -56,7 +56,7 @@ def test_discrete_cql(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 170} + default_reward_threshold = {"CartPole-v1": 170} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index 9f47b32e7..d5aa810c3 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -24,7 +24,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--lr", type=float, default=7e-4) @@ -54,7 +54,7 @@ def test_discrete_crr(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 180} + default_reward_threshold = {"CartPole-v1": 180} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) From bbbcd814d7c3e79ab9b558a79bab8d060d8da8ac Mon Sep 17 00:00:00 2001 From: Yifei Cheng Date: Mon, 25 Apr 2022 23:26:27 -0400 Subject: [PATCH 02/27] np_random.randint --> np_random.integers --- examples/atari/atari_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index 4aca61218..b89c1d46f 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -32,7 +32,7 @@ def __init__(self, env, noop_max=30): def reset(self): self.env.reset() - noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) + noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) for _ in range(noops): obs, _, done, _ = self.env.step(self.noop_action) if done: From c29642753908eabc3f7376b5e2b52d546c4fcda2 Mon Sep 17 00:00:00 2001 From: Yifei Cheng Date: Tue, 26 Apr 2022 23:51:18 -0400 Subject: [PATCH 03/27] try to use env.reset(seed=seed) instead of env.seed(seed) --- setup.py | 2 +- test/base/env.py | 10 ++++------ test/base/test_env.py | 1 - test/discrete/test_sac.py | 2 +- tianshou/env/pettingzoo_env.py | 4 +++- tianshou/env/worker/dummy.py | 4 ++-- tianshou/env/worker/ray.py | 4 ++-- tianshou/env/worker/subproc.py | 5 ++--- 8 files changed, 15 insertions(+), 17 deletions(-) diff --git a/setup.py b/setup.py index 83ac36580..3f4480752 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ def get_version() -> str: def get_install_requires() -> str: return [ - "gym>=0.15.4", + "gym>=0.23.1", "tqdm", "numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793 "tensorboard>=2.5.0", diff --git a/test/base/env.py b/test/base/env.py index 0a649e546..437699d5c 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -73,13 +73,11 @@ def __init__( self.action_space = Discrete(2) self.done = False self.index = 0 - self.seed() + self.reset(seed=0) - def seed(self, seed=0): - self.rng = np.random.RandomState(seed) - return [seed] - - def reset(self, state=0): + def reset(self, state=0, seed=None): + if seed is not None: + self.rng = np.random.RandomState(seed) self.done = False self.do_sleep() self.index = state diff --git a/test/base/test_env.py b/test/base/test_env.py index 7e507c4ed..64476c0a9 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -47,7 +47,6 @@ def test_async_env(size=10000, num=8, sleep=0.1): test_cls += [RayVectorEnv] for cls in test_cls: v = cls(env_fns, wait_num=num // 2, timeout=1e-3) - v.seed(None) v.reset() # for a random variable u ~ U[0, 1], let v = max{u1, u2, ..., un} # P(v <= x) = x^n (0 <= x <= 1), pdf of v is nx^{n-1} diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index d972150e8..0c2a0dffb 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -53,7 +53,7 @@ def test_discrete_sac(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v1": 180} # lower the goal + default_reward_threshold = {"CartPole-v1": 160} # lower the goal args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) diff --git a/tianshou/env/pettingzoo_env.py b/tianshou/env/pettingzoo_env.py index 25c34f994..b65f7cc85 100644 --- a/tianshou/env/pettingzoo_env.py +++ b/tianshou/env/pettingzoo_env.py @@ -55,7 +55,9 @@ def __init__(self, env: BaseWrapper): self.reset() - def reset(self) -> dict: + def reset(self, seed=None) -> dict: + if seed is not None: + self.env.seed(seed) self.env.reset() observation = self.env.observe(self.env.agent_selection) if isinstance(observation, dict) and 'action_mask' in observation: diff --git a/tianshou/env/worker/dummy.py b/tianshou/env/worker/dummy.py index be873861c..5f13c860b 100644 --- a/tianshou/env/worker/dummy.py +++ b/tianshou/env/worker/dummy.py @@ -35,9 +35,9 @@ def send(self, action: Optional[np.ndarray]) -> None: else: self.result = self.env.step(action) # type: ignore - def seed(self, seed: Optional[int] = None) -> List[int]: + def seed(self, seed: Optional[int] = None) -> None: super().seed(seed) - return self.env.seed(seed) + self.env.reset(seed=seed) def render(self, **kwargs: Any) -> Any: return self.env.render(**kwargs) diff --git a/tianshou/env/worker/ray.py b/tianshou/env/worker/ray.py index e094692d6..cf93bde53 100644 --- a/tianshou/env/worker/ray.py +++ b/tianshou/env/worker/ray.py @@ -47,7 +47,7 @@ def wait( # type: ignore return [workers[results.index(result)] for result in ready_results] def send(self, action: Optional[np.ndarray]) -> None: - # self.action is actually a handle + # self.result is actually a handle if action is None: self.result = self.env.reset.remote() else: @@ -60,7 +60,7 @@ def recv( def seed(self, seed: Optional[int] = None) -> List[int]: super().seed(seed) - return ray.get(self.env.seed.remote(seed)) + return ray.get(self.env.reset.remote(seed=seed)) def render(self, **kwargs: Any) -> Any: return ray.get(self.env.render.remote(**kwargs)) diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py index c2119ab50..5753fb47e 100644 --- a/tianshou/env/worker/subproc.py +++ b/tianshou/env/worker/subproc.py @@ -104,7 +104,7 @@ def _encode_obs( elif cmd == "render": p.send(env.render(**data) if hasattr(env, "render") else None) elif cmd == "seed": - p.send(env.seed(data) if hasattr(env, "seed") else None) + env.reset(seed=data) elif cmd == "getattr": p.send(getattr(env, data) if hasattr(env, data) else None) elif cmd == "setattr": @@ -204,10 +204,9 @@ def recv( obs = self._decode_obs() return obs - def seed(self, seed: Optional[int] = None) -> Optional[List[int]]: + def seed(self, seed: Optional[int] = None) -> None: super().seed(seed) self.parent_remote.send(["seed", seed]) - return self.parent_remote.recv() def render(self, **kwargs: Any) -> Any: self.parent_remote.send(["render", kwargs]) From 95d6d87f35b0f32bd5bfdc98d7f8c68ba8cb1ed0 Mon Sep 17 00:00:00 2001 From: Yifei Cheng Date: Wed, 27 Apr 2022 09:23:28 -0400 Subject: [PATCH 04/27] commit checks --- tianshou/env/pettingzoo_env.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tianshou/env/pettingzoo_env.py b/tianshou/env/pettingzoo_env.py index b65f7cc85..632631467 100644 --- a/tianshou/env/pettingzoo_env.py +++ b/tianshou/env/pettingzoo_env.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple import gym.spaces from pettingzoo.utils.env import AECEnv @@ -55,7 +55,7 @@ def __init__(self, env: BaseWrapper): self.reset() - def reset(self, seed=None) -> dict: + def reset(self, seed: Optional[int] = None) -> dict: if seed is not None: self.env.seed(seed) self.env.reset() From b1ceefda57b57a294f5d9b9a3eb493ea7081d283 Mon Sep 17 00:00:00 2001 From: Yifei Cheng Date: Tue, 10 May 2022 22:36:27 -0400 Subject: [PATCH 05/27] Revert "upgrade version of cartpole" This reverts commit 558a5eabf698eadc425c52fd3065f613117b6ce6. --- README.md | 4 ++-- docs/tutorials/cheatsheet.rst | 2 +- docs/tutorials/concepts.rst | 4 ++-- docs/tutorials/dqn.rst | 18 +++++++++--------- docs/tutorials/get_started.rst | 2 +- test/discrete/test_a2c_with_il.py | 6 +++--- test/discrete/test_c51.py | 4 ++-- test/discrete/test_dqn.py | 4 ++-- test/discrete/test_drqn.py | 4 ++-- test/discrete/test_fqf.py | 4 ++-- test/discrete/test_iqn.py | 4 ++-- test/discrete/test_pg.py | 4 ++-- test/discrete/test_ppo.py | 4 ++-- test/discrete/test_qrdqn.py | 6 +++--- test/discrete/test_rainbow.py | 4 ++-- test/discrete/test_sac.py | 4 ++-- test/modelbased/test_dqn_icm.py | 4 ++-- test/modelbased/test_ppo_icm.py | 4 ++-- test/offline/gather_cartpole_data.py | 6 +++--- test/offline/test_discrete_bcq.py | 4 ++-- test/offline/test_discrete_cql.py | 4 ++-- test/offline/test_discrete_crr.py | 4 ++-- 22 files changed, 52 insertions(+), 52 deletions(-) diff --git a/README.md b/README.md index 497f67a6c..0807fd8bb 100644 --- a/README.md +++ b/README.md @@ -183,7 +183,7 @@ import tianshou as ts Define some hyper-parameters: ```python -task = 'CartPole-v1' +task = 'CartPole-v0' lr, epoch, batch_size = 1e-3, 10, 64 train_num, test_num = 10, 100 gamma, n_step, target_freq = 0.9, 3, 320 @@ -260,7 +260,7 @@ $ tensorboard --logdir log/dqn You can check out the [documentation](https://tianshou.readthedocs.io) for advanced usage. -It's worth a try: here is a test on a laptop (i7-8750H + GTX1060). It only uses **3** seconds for training an agent based on vanilla policy gradient on the CartPole-v1 task: (seed may be different across different platform and device) +It's worth a try: here is a test on a laptop (i7-8750H + GTX1060). It only uses **3** seconds for training an agent based on vanilla policy gradient on the CartPole-v0 task: (seed may be different across different platform and device) ```bash $ python3 test/discrete/test_pg.py --seed 0 --render 0.03 diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst index 33d1ecf84..08aac4451 100644 --- a/docs/tutorials/cheatsheet.rst +++ b/docs/tutorials/cheatsheet.rst @@ -134,7 +134,7 @@ toy_text and classic_control environments. For more information, please refer to # install envpool: pip3 install envpool import envpool - envs = envpool.make_gym("CartPole-v1", num_envs=10) + envs = envpool.make_gym("CartPole-v0", num_envs=10) collector = Collector(policy, envs, buffer) Here are some other `examples `_. diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index 14f6175e0..cb6d616fe 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -349,7 +349,7 @@ The general explanation is listed in :ref:`pseudocode`. Other usages of collecto :: policy = PGPolicy(...) # or other policies if you wish - env = gym.make("CartPole-v1") + env = gym.make("CartPole-v0") replay_buffer = ReplayBuffer(size=10000) @@ -359,7 +359,7 @@ The general explanation is listed in :ref:`pseudocode`. Other usages of collecto # the collector supports vectorized environments as well vec_buffer = VectorReplayBuffer(total_size=10000, buffer_num=3) # buffer_num should be equal to (suggested) or larger than #envs - envs = DummyVectorEnv([lambda: gym.make("CartPole-v1") for _ in range(3)]) + envs = DummyVectorEnv([lambda: gym.make("CartPole-v0") for _ in range(3)]) collector = Collector(policy, envs, buffer=vec_buffer) # collect 3 episodes diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index b136a2bf8..5810b2b62 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -41,11 +41,11 @@ First of all, you have to make an environment for your agent to interact with. Y import gym import tianshou as ts - env = gym.make('CartPole-v1') + env = gym.make('CartPole-v0') -CartPole-v1 includes a cart carrying a pole moving on a track. This is a simple environment with a discrete action space, for which DQN applies. You have to identify whether the action space is continuous or discrete and apply eligible algorithms. DDPG :cite:`DDPG`, for example, could only be applied to continuous action spaces, while almost all other policy gradient methods could be applied to both. +CartPole-v0 includes a cart carrying a pole moving on a track. This is a simple environment with a discrete action space, for which DQN applies. You have to identify whether the action space is continuous or discrete and apply eligible algorithms. DDPG :cite:`DDPG`, for example, could only be applied to continuous action spaces, while almost all other policy gradient methods could be applied to both. -Here is the detail of useful fields of CartPole-v1: +Here is the detail of useful fields of CartPole-v0: - ``state``: the position of the cart, the velocity of the cart, the angle of the pole and the velocity of the tip of the pole; - ``action``: can only be one of ``[0, 1, 2]``, for moving the cart left, no move, and right; @@ -62,8 +62,8 @@ Setup Vectorized Environment If you want to use the original ``gym.Env``: :: - train_envs = gym.make('CartPole-v1') - test_envs = gym.make('CartPole-v1') + train_envs = gym.make('CartPole-v0') + test_envs = gym.make('CartPole-v0') Tianshou supports vectorized environment for all algorithms. It provides four types of vectorized environment wrapper: @@ -74,8 +74,8 @@ Tianshou supports vectorized environment for all algorithms. It provides four ty :: - train_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v1') for _ in range(10)]) - test_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v1') for _ in range(100)]) + train_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(10)]) + test_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(100)]) Here, we set up 10 environments in ``train_envs`` and 100 environments in ``test_envs``. @@ -84,8 +84,8 @@ You can also try the super-fast vectorized environment `EnvPool `_ - L1: `Batch `_ diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 20698b083..b777e2e03 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -18,7 +18,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='CartPole-v1') + parser.add_argument('--task', type=str, default='CartPole-v0') parser.add_argument('--reward-threshold', type=float, default=None) parser.add_argument('--seed', type=int, default=1) parser.add_argument('--buffer-size', type=int, default=20000) @@ -60,7 +60,7 @@ def test_a2c_with_il(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v1": 195} + default_reward_threshold = {"CartPole-v0": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) @@ -131,7 +131,7 @@ def stop_fn(mean_rewards): policy.eval() # here we define an imitation collector with a trivial policy - # if args.task == 'CartPole-v1': + # if args.task == 'CartPole-v0': # env.spec.reward_threshold = 190 # lower the goal net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) net = Actor(net, args.action_shape, device=args.device).to(args.device) diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 2f14942d8..993c4a80e 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -18,7 +18,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='CartPole-v1') + parser.add_argument('--task', type=str, default='CartPole-v0') parser.add_argument('--reward-threshold', type=float, default=None) parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--eps-test', type=float, default=0.05) @@ -60,7 +60,7 @@ def test_c51(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v1": 195} + default_reward_threshold = {"CartPole-v0": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 30a90d669..2644dc998 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -17,7 +17,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='CartPole-v1') + parser.add_argument('--task', type=str, default='CartPole-v0') parser.add_argument('--reward-threshold', type=float, default=None) parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--eps-test', type=float, default=0.05) @@ -54,7 +54,7 @@ def test_dqn(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v1": 195} + default_reward_threshold = {"CartPole-v0": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 47e553741..36ff5fa76 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -17,7 +17,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='CartPole-v1') + parser.add_argument('--task', type=str, default='CartPole-v0') parser.add_argument('--reward-threshold', type=float, default=None) parser.add_argument('--seed', type=int, default=1) parser.add_argument('--eps-test', type=float, default=0.05) @@ -50,7 +50,7 @@ def test_drqn(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v1": 195} + default_reward_threshold = {"CartPole-v0": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 30b9b860c..e25c42997 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -18,7 +18,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='CartPole-v1') + parser.add_argument('--task', type=str, default='CartPole-v0') parser.add_argument('--reward-threshold', type=float, default=None) parser.add_argument('--seed', type=int, default=1) parser.add_argument('--eps-test', type=float, default=0.05) @@ -57,7 +57,7 @@ def test_fqf(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v1": 195} + default_reward_threshold = {"CartPole-v0": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index b433922a4..725c9a9d5 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -18,7 +18,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='CartPole-v1') + parser.add_argument('--task', type=str, default='CartPole-v0') parser.add_argument('--reward-threshold', type=float, default=None) parser.add_argument('--seed', type=int, default=0) parser.add_argument('--eps-test', type=float, default=0.05) @@ -57,7 +57,7 @@ def test_iqn(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v1": 195} + default_reward_threshold = {"CartPole-v0": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 3eabf498b..1f5007f7a 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -17,7 +17,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='CartPole-v1') + parser.add_argument('--task', type=str, default='CartPole-v0') parser.add_argument('--reward-threshold', type=float, default=None) parser.add_argument('--seed', type=int, default=1) parser.add_argument('--buffer-size', type=int, default=20000) @@ -46,7 +46,7 @@ def test_pg(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v1": 195} + default_reward_threshold = {"CartPole-v0": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index eff4231c2..b7dba97c9 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -18,7 +18,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='CartPole-v1') + parser.add_argument('--task', type=str, default='CartPole-v0') parser.add_argument('--reward-threshold', type=float, default=None) parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--buffer-size', type=int, default=20000) @@ -57,7 +57,7 @@ def test_ppo(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v1": 195} + default_reward_threshold = {"CartPole-v0": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 5301d40f4..9c699b4e3 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -17,7 +17,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='CartPole-v1') + parser.add_argument('--task', type=str, default='CartPole-v0') parser.add_argument('--reward-threshold', type=float, default=None) parser.add_argument('--seed', type=int, default=1) parser.add_argument('--eps-test', type=float, default=0.05) @@ -52,12 +52,12 @@ def get_args(): def test_qrdqn(args=get_args()): env = gym.make(args.task) - if args.task == 'CartPole-v1': + if args.task == 'CartPole-v0': env.spec.reward_threshold = 190 # lower the goal args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v1": 195} + default_reward_threshold = {"CartPole-v0": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index 357622dee..5e2345300 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -19,7 +19,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='CartPole-v1') + parser.add_argument('--task', type=str, default='CartPole-v0') parser.add_argument('--reward-threshold', type=float, default=None) parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--eps-test', type=float, default=0.05) @@ -63,7 +63,7 @@ def test_rainbow(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v1": 195} + default_reward_threshold = {"CartPole-v0": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index 0c2a0dffb..6593c9864 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -18,7 +18,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='CartPole-v1') + parser.add_argument('--task', type=str, default='CartPole-v0') parser.add_argument('--reward-threshold', type=float, default=None) parser.add_argument('--seed', type=int, default=0) parser.add_argument('--buffer-size', type=int, default=20000) @@ -53,7 +53,7 @@ def test_discrete_sac(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v1": 160} # lower the goal + default_reward_threshold = {"CartPole-v0": 180} # lower the goal args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index 4c177310c..fba0b5523 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -18,7 +18,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='CartPole-v1') + parser.add_argument('--task', type=str, default='CartPole-v0') parser.add_argument('--reward-threshold', type=float, default=None) parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--eps-test', type=float, default=0.05) @@ -73,7 +73,7 @@ def test_dqn_icm(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v1": 195} + default_reward_threshold = {"CartPole-v0": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index e3df9b1ed..6efd96277 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -18,7 +18,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='CartPole-v1') + parser.add_argument('--task', type=str, default='CartPole-v0') parser.add_argument('--reward-threshold', type=float, default=None) parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--buffer-size', type=int, default=20000) @@ -75,7 +75,7 @@ def test_ppo(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v1": 195} + default_reward_threshold = {"CartPole-v0": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index 0e7783dff..2e01723f3 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -16,12 +16,12 @@ def expert_file_name(): - return os.path.join(os.path.dirname(__file__), "expert_QRDQN_CartPole-v1.pkl") + return os.path.join(os.path.dirname(__file__), "expert_QRDQN_CartPole-v0.pkl") def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='CartPole-v1') + parser.add_argument('--task', type=str, default='CartPole-v0') parser.add_argument('--reward-threshold', type=float, default=None) parser.add_argument('--seed', type=int, default=1) parser.add_argument('--eps-test', type=float, default=0.05) @@ -61,7 +61,7 @@ def gather_data(): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v1": 190} + default_reward_threshold = {"CartPole-v0": 190} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index f9fe16e64..51380fb62 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -24,7 +24,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v1") + parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument('--reward-threshold', type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps-test", type=float, default=0.001) @@ -59,7 +59,7 @@ def test_discrete_bcq(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v1": 190} + default_reward_threshold = {"CartPole-v0": 190} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index 4155f799b..eca810fb6 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -23,7 +23,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v1") + parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps-test", type=float, default=0.001) @@ -56,7 +56,7 @@ def test_discrete_cql(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v1": 170} + default_reward_threshold = {"CartPole-v0": 170} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index d5aa810c3..9f47b32e7 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -24,7 +24,7 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v1") + parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--lr", type=float, default=7e-4) @@ -54,7 +54,7 @@ def test_discrete_crr(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v1": 180} + default_reward_threshold = {"CartPole-v0": 180} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold ) From 4eae57c29dd8cce0473346f14bbe74e180160e96 Mon Sep 17 00:00:00 2001 From: Yifei Cheng Date: Tue, 10 May 2022 22:47:12 -0400 Subject: [PATCH 06/27] fix merge error --- tianshou/env/pettingzoo_env.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tianshou/env/pettingzoo_env.py b/tianshou/env/pettingzoo_env.py index bce1e4778..4b362a558 100644 --- a/tianshou/env/pettingzoo_env.py +++ b/tianshou/env/pettingzoo_env.py @@ -55,10 +55,8 @@ def __init__(self, env: BaseWrapper): self.reset() - def reset(self, seed: Optional[int] = None) -> dict: - if seed is not None: - self.env.seed(seed) - self.env.reset() + def reset(self, *args: Any, **kwargs: Any) -> dict: + self.env.reset(*args, **kwargs) observation = self.env.observe(self.env.agent_selection) if isinstance(observation, dict) and 'action_mask' in observation: return { From 61358d5c0f35a6ddfb53321ddfa08bef6dd9ad06 Mon Sep 17 00:00:00 2001 From: Yifei Cheng Date: Wed, 11 May 2022 22:59:15 -0400 Subject: [PATCH 07/27] make venvs and env workers to support reset()->[obs, info] --- test/base/env.py | 4 ++++ tianshou/env/pettingzoo_env.py | 25 +++++++++++++++----- tianshou/env/venv_wrappers.py | 22 +++++++++++------- tianshou/env/venvs.py | 24 ++++++++++++------- tianshou/env/worker/base.py | 6 ++--- tianshou/env/worker/dummy.py | 13 +++++++---- tianshou/env/worker/ray.py | 11 +++++---- tianshou/env/worker/subproc.py | 42 ++++++++++++++++++++++++++-------- 8 files changed, 104 insertions(+), 43 deletions(-) diff --git a/test/base/env.py b/test/base/env.py index 437699d5c..872a7c1f2 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -75,6 +75,10 @@ def __init__( self.index = 0 self.reset(seed=0) + def seed(self, seed=0): + self.rng = np.random.RandomState(seed) + return [seed] + def reset(self, state=0, seed=None): if seed is not None: self.rng = np.random.RandomState(seed) diff --git a/tianshou/env/pettingzoo_env.py b/tianshou/env/pettingzoo_env.py index 4b362a558..146037413 100644 --- a/tianshou/env/pettingzoo_env.py +++ b/tianshou/env/pettingzoo_env.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import gym.spaces from pettingzoo.utils.env import AECEnv @@ -55,11 +55,19 @@ def __init__(self, env: BaseWrapper): self.reset() - def reset(self, *args: Any, **kwargs: Any) -> dict: - self.env.reset(*args, **kwargs) + def reset( + self, + seed: Optional[int] = None, + return_info: bool = False, + *args: Any, + **kwargs: Any, + ) -> Union[dict, Tuple[dict, dict]]: + self.env.reset(seed=seed, *args, **kwargs) observation = self.env.observe(self.env.agent_selection) + observation, _, _, info = self.env.last(self) + if isinstance(observation, dict) and 'action_mask' in observation: - return { + observation_dict = { 'agent_id': self.env.agent_selection, 'obs': observation['observation'], 'mask': @@ -67,13 +75,18 @@ def reset(self, *args: Any, **kwargs: Any) -> dict: } else: if isinstance(self.action_space, gym.spaces.Discrete): - return { + observation_dict = { 'agent_id': self.env.agent_selection, 'obs': observation, 'mask': [True] * self.env.action_space(self.env.agent_selection).n } else: - return {'agent_id': self.env.agent_selection, 'obs': observation} + observation_dict = {'agent_id': self.env.agent_selection, 'obs': observation} + + if return_info: + return observation_dict, info + else: + return observation_dict def step(self, action: Any) -> Tuple[Dict, List[int], bool, Dict]: self.env.step(action) diff --git a/tianshou/env/venv_wrappers.py b/tianshou/env/venv_wrappers.py index 860c390d9..5499028cb 100644 --- a/tianshou/env/venv_wrappers.py +++ b/tianshou/env/venv_wrappers.py @@ -37,11 +37,10 @@ def set_env_attr( ) -> None: return self.venv.set_env_attr(key, value, id) - # TODO: compatible issue with reset -> (obs, info) def reset( - self, id: Optional[Union[int, List[int], np.ndarray]] = None - ) -> np.ndarray: - return self.venv.reset(id) + self, id: Optional[Union[int, List[int], np.ndarray]] = None, **kwargs, + ) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]: + return self.venv.reset(id, **kwargs) def step( self, @@ -86,14 +85,21 @@ def __init__( self.clip_max = clip_obs self.eps = epsilon - # TODO: compatible issue with reset -> (obs, info) def reset( - self, id: Optional[Union[int, List[int], np.ndarray]] = None + self, id: Optional[Union[int, List[int], np.ndarray]] = None, + **kwargs, ) -> np.ndarray: - obs = self.venv.reset(id) + if "return_info" in kwargs and kwargs["return_info"]: + obs, info = self.venv.reset(id, **kwargs) + else: + obs = self.venv.reset(id) if self.obs_rms and self.update_obs_rms: self.obs_rms.update(obs) - return self._norm_obs(obs) + obs = self._norm_obs(obs) + if "return_info" in kwargs and kwargs["return_info"]: + return obs, info + else: + return obs def step( self, diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index 93558d9ef..f284490bc 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -181,10 +181,11 @@ def _assert_id(self, id: Union[List[int], np.ndarray]) -> None: assert i in self.ready_id, \ f"Can only interact with ready environments {self.ready_id}." - # TODO: compatible issue with reset -> (obs, info) def reset( - self, id: Optional[Union[int, List[int], np.ndarray]] = None - ) -> np.ndarray: + self, + id: Optional[Union[int, List[int], np.ndarray]] = None, + **kwargs, + ) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]: """Reset the state of some envs and return initial observations. If id is None, reset the state of all the environments and return @@ -195,15 +196,22 @@ def reset( id = self._wrap_id(id) if self.is_async: self._assert_id(id) - # send(None) == reset() in worker - for i in id: - self.workers[i].send(None) - obs_list = [self.workers[i].recv() for i in id] + + ret = [self.workers[i].reset(**kwargs) for i in id] + if "return_info" in kwargs and kwargs["return_info"]: + obs_list = [r[0] for r in ret] + else: + obs_list = ret try: obs = np.stack(obs_list) except ValueError: # different len(obs) obs = np.array(obs_list, dtype=object) - return obs + + if "return_info" in kwargs and kwargs["return_info"]: + infos = [r[1] for r in ret] + return obs, infos + else: + return obs def step( self, diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/base.py index b861a15d5..2f40f4bb2 100644 --- a/tianshou/env/worker/base.py +++ b/tianshou/env/worker/base.py @@ -63,9 +63,9 @@ def recv( self.result = self.get_result() # type: ignore return self.result - def reset(self) -> np.ndarray: - self.send(None) - return self.recv() # type: ignore + @abstractmethod + def reset(self, **kwargs) -> Union[np.ndarray, Tuple[np.ndarray, dict]]: + pass def step( self, action: np.ndarray diff --git a/tianshou/env/worker/dummy.py b/tianshou/env/worker/dummy.py index 5f13c860b..43754387e 100644 --- a/tianshou/env/worker/dummy.py +++ b/tianshou/env/worker/dummy.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Optional +from typing import Any, Callable, List, Optional, Tuple, Union import gym import numpy as np @@ -19,8 +19,8 @@ def get_env_attr(self, key: str) -> Any: def set_env_attr(self, key: str, value: Any) -> None: setattr(self.env, key, value) - def reset(self) -> Any: - return self.env.reset() + def reset(self, **kwargs) -> Union[np.ndarray, Tuple[np.ndarray, dict]]: + return self.env.reset(**kwargs) @staticmethod def wait( # type: ignore @@ -35,9 +35,12 @@ def send(self, action: Optional[np.ndarray]) -> None: else: self.result = self.env.step(action) # type: ignore - def seed(self, seed: Optional[int] = None) -> None: + def seed(self, seed: Optional[int] = None) -> Optional[List[int]]: super().seed(seed) - self.env.reset(seed=seed) + try: + return self.env.seed(seed) + except NotImplementedError: + self.env.reset(seed=seed) def render(self, **kwargs: Any) -> Any: return self.env.render(**kwargs) diff --git a/tianshou/env/worker/ray.py b/tianshou/env/worker/ray.py index cf93bde53..01a79408d 100644 --- a/tianshou/env/worker/ray.py +++ b/tianshou/env/worker/ray.py @@ -35,8 +35,8 @@ def get_env_attr(self, key: str) -> Any: def set_env_attr(self, key: str, value: Any) -> None: ray.get(self.env.set_env_attr.remote(key, value)) - def reset(self) -> Any: - return ray.get(self.env.reset.remote()) + def reset(self, **kwargs) -> Any: + return ray.get(self.env.reset.remote(**kwargs)) @staticmethod def wait( # type: ignore @@ -58,9 +58,12 @@ def recv( ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]: return ray.get(self.result) # type: ignore - def seed(self, seed: Optional[int] = None) -> List[int]: + def seed(self, seed: Optional[int] = None) -> Optional[List[int]]: super().seed(seed) - return ray.get(self.env.reset.remote(seed=seed)) + try: + return ray.get(self.env.seed.remote(seed)) + except NotImplementedError: + self.env.reset.remote(seed=seed) def render(self, **kwargs: Any) -> Any: return ray.get(self.env.render.remote(**kwargs)) diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py index 5753fb47e..f4a8cf3b2 100644 --- a/tianshou/env/worker/subproc.py +++ b/tianshou/env/worker/subproc.py @@ -86,17 +86,23 @@ def _encode_obs( p.close() break if cmd == "step": - if data is None: # reset - obs = env.reset() + obs, reward, done, info = env.step(data) + if obs_bufs is not None: + _encode_obs(obs, obs_bufs) + obs = None + p.send((obs, reward, done, info)) + elif cmd == "reset": + if "return_info" in data and data["return_info"]: + obs, info = env.reset(**data) else: - obs, reward, done, info = env.step(data) + obs = env.reset(**data) if obs_bufs is not None: _encode_obs(obs, obs_bufs) obs = None - if data is None: - p.send(obs) + if "return_info" in data and data["return_info"]: + p.send((obs, info)) else: - p.send((obs, reward, done, info)) + p.send(obs) elif cmd == "close": p.send(env.close()) p.close() @@ -104,7 +110,11 @@ def _encode_obs( elif cmd == "render": p.send(env.render(**data) if hasattr(env, "render") else None) elif cmd == "seed": - env.reset(seed=data) + if hasattr(env, "seed"): + p.send(env.seed(data)) + else: + env.reset(seed=data) + p.send(None) elif cmd == "getattr": p.send(getattr(env, data) if hasattr(env, data) else None) elif cmd == "setattr": @@ -140,7 +150,6 @@ def __init__( self.process = Process(target=_worker, args=args, daemon=True) self.process.start() self.child_remote.close() - self.is_reset = False super().__init__(env_fn) def get_env_attr(self, key: str) -> Any: @@ -204,9 +213,24 @@ def recv( obs = self._decode_obs() return obs - def seed(self, seed: Optional[int] = None) -> None: + def reset(self, **kwargs) -> Union[np.ndarray, Tuple[np.ndarray, dict]]: + self.parent_remote.send(["reset", kwargs]) + result = self.parent_remote.recv() + if isinstance(result, tuple): + obs, info = result + if self.share_memory: + obs = self._decode_obs() + return obs, info + else: + obs = result + if self.share_memory: + obs = self._decode_obs() + return obs + + def seed(self, seed: Optional[int] = None) -> Optional[List[int]]: super().seed(seed) self.parent_remote.send(["seed", seed]) + return self.parent_remote.recv() def render(self, **kwargs: Any) -> Any: self.parent_remote.send(["render", kwargs]) From fcc2cfc32d74902dd74ca128aa77080b72d0b8c2 Mon Sep 17 00:00:00 2001 From: Yifei Cheng Date: Fri, 13 May 2022 00:06:57 -0400 Subject: [PATCH 08/27] clean up --- setup.py | 2 +- tianshou/data/collector.py | 2 +- tianshou/env/pettingzoo_env.py | 5 ++++- tianshou/env/venv_wrappers.py | 11 +++++++---- tianshou/env/venvs.py | 14 +++++++++----- tianshou/env/worker/base.py | 4 ++-- tianshou/env/worker/dummy.py | 10 +++++----- tianshou/env/worker/ray.py | 7 ++++--- tianshou/env/worker/subproc.py | 19 ++++++++++++++----- 9 files changed, 47 insertions(+), 27 deletions(-) diff --git a/setup.py b/setup.py index 1f207d847..0ae409056 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ def get_version() -> str: def get_install_requires() -> str: return [ - "gym>=0.23.1", + "gym>=0.15.4", "tqdm", "numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793 "tensorboard>=2.5.0", diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 537986cf6..21f9707bb 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -64,7 +64,7 @@ def __init__( super().__init__() if isinstance(env, gym.Env) and not hasattr(env, "__len__"): warnings.warn("Single environment detected, wrap to DummyVectorEnv.") - self.env = DummyVectorEnv([lambda: env]) # type: ignore + self.env = DummyVectorEnv([lambda: env]) else: self.env = env # type: ignore self.env_num = len(self.env) diff --git a/tianshou/env/pettingzoo_env.py b/tianshou/env/pettingzoo_env.py index 146037413..e30d3dfc3 100644 --- a/tianshou/env/pettingzoo_env.py +++ b/tianshou/env/pettingzoo_env.py @@ -81,7 +81,10 @@ def reset( 'mask': [True] * self.env.action_space(self.env.agent_selection).n } else: - observation_dict = {'agent_id': self.env.agent_selection, 'obs': observation} + observation_dict = { + 'agent_id': self.env.agent_selection, + 'obs': observation, + } if return_info: return observation_dict, info diff --git a/tianshou/env/venv_wrappers.py b/tianshou/env/venv_wrappers.py index 5499028cb..b475c2a92 100644 --- a/tianshou/env/venv_wrappers.py +++ b/tianshou/env/venv_wrappers.py @@ -38,7 +38,9 @@ def set_env_attr( return self.venv.set_env_attr(key, value, id) def reset( - self, id: Optional[Union[int, List[int], np.ndarray]] = None, **kwargs, + self, + id: Optional[Union[int, List[int], np.ndarray]] = None, + **kwargs: Any, ) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]: return self.venv.reset(id, **kwargs) @@ -86,9 +88,10 @@ def __init__( self.eps = epsilon def reset( - self, id: Optional[Union[int, List[int], np.ndarray]] = None, - **kwargs, - ) -> np.ndarray: + self, + id: Optional[Union[int, List[int], np.ndarray]] = None, + **kwargs: Any, + ) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]: if "return_info" in kwargs and kwargs["return_info"]: obs, info = self.venv.reset(id, **kwargs) else: diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index f284490bc..4b1b3fa50 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -184,7 +184,7 @@ def _assert_id(self, id: Union[List[int], np.ndarray]) -> None: def reset( self, id: Optional[Union[int, List[int], np.ndarray]] = None, - **kwargs, + **kwargs: Any, ) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]: """Reset the state of some envs and return initial observations. @@ -197,18 +197,22 @@ def reset( if self.is_async: self._assert_id(id) - ret = [self.workers[i].reset(**kwargs) for i in id] + # send(None) == reset() in worker + for i in id: + self.workers[i].send(None, **kwargs) + ret_list = [self.workers[i].recv() for i in id] + if "return_info" in kwargs and kwargs["return_info"]: - obs_list = [r[0] for r in ret] + obs_list = [r[0] for r in ret_list] else: - obs_list = ret + obs_list = ret_list try: obs = np.stack(obs_list) except ValueError: # different len(obs) obs = np.array(obs_list, dtype=object) if "return_info" in kwargs and kwargs["return_info"]: - infos = [r[1] for r in ret] + infos = [r[1] for r in ret_list] return obs, infos else: return obs diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/base.py index 2f40f4bb2..3f905ebe6 100644 --- a/tianshou/env/worker/base.py +++ b/tianshou/env/worker/base.py @@ -40,7 +40,7 @@ def send(self, action: Optional[np.ndarray]) -> None: ) if action is None: self.is_reset = True - self.result = self.reset() + self.result = self.reset() # type: ignore else: self.is_reset = False self.send_action(action) # type: ignore @@ -64,7 +64,7 @@ def recv( return self.result @abstractmethod - def reset(self, **kwargs) -> Union[np.ndarray, Tuple[np.ndarray, dict]]: + def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]: pass def step( diff --git a/tianshou/env/worker/dummy.py b/tianshou/env/worker/dummy.py index 43754387e..6fe5ba61c 100644 --- a/tianshou/env/worker/dummy.py +++ b/tianshou/env/worker/dummy.py @@ -19,7 +19,7 @@ def get_env_attr(self, key: str) -> Any: def set_env_attr(self, key: str, value: Any) -> None: setattr(self.env, key, value) - def reset(self, **kwargs) -> Union[np.ndarray, Tuple[np.ndarray, dict]]: + def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]: return self.env.reset(**kwargs) @staticmethod @@ -29,18 +29,18 @@ def wait( # type: ignore # Sequential EnvWorker objects are always ready return workers - def send(self, action: Optional[np.ndarray]) -> None: + def send(self, action: Optional[np.ndarray], **kwargs: Any) -> None: if action is None: - self.result = self.env.reset() # type: ignore + self.result = self.env.reset(**kwargs) else: - self.result = self.env.step(action) # type: ignore + self.result = self.env.step(action) def seed(self, seed: Optional[int] = None) -> Optional[List[int]]: super().seed(seed) try: return self.env.seed(seed) except NotImplementedError: - self.env.reset(seed=seed) + return self.env.reset(seed=seed) def render(self, **kwargs: Any) -> Any: return self.env.render(**kwargs) diff --git a/tianshou/env/worker/ray.py b/tianshou/env/worker/ray.py index 01a79408d..038a7fcaf 100644 --- a/tianshou/env/worker/ray.py +++ b/tianshou/env/worker/ray.py @@ -35,7 +35,7 @@ def get_env_attr(self, key: str) -> Any: def set_env_attr(self, key: str, value: Any) -> None: ray.get(self.env.set_env_attr.remote(key, value)) - def reset(self, **kwargs) -> Any: + def reset(self, **kwargs: Any) -> Any: return ray.get(self.env.reset.remote(**kwargs)) @staticmethod @@ -46,10 +46,10 @@ def wait( # type: ignore ready_results, _ = ray.wait(results, num_returns=wait_num, timeout=timeout) return [workers[results.index(result)] for result in ready_results] - def send(self, action: Optional[np.ndarray]) -> None: + def send(self, action: Optional[np.ndarray], **kwargs: Any) -> None: # self.result is actually a handle if action is None: - self.result = self.env.reset.remote() + self.result = self.env.reset.remote(**kwargs) else: self.result = self.env.step.remote(action) @@ -64,6 +64,7 @@ def seed(self, seed: Optional[int] = None) -> Optional[List[int]]: return ray.get(self.env.seed.remote(seed)) except NotImplementedError: self.env.reset.remote(seed=seed) + return None def render(self, **kwargs: Any) -> Any: return ray.get(self.env.render.remote(**kwargs)) diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py index f4a8cf3b2..f445f758e 100644 --- a/tianshou/env/worker/subproc.py +++ b/tianshou/env/worker/subproc.py @@ -53,7 +53,7 @@ def _setup_buf(space: gym.Space) -> Union[dict, tuple, ShArray]: assert isinstance(space.spaces, tuple) return tuple([_setup_buf(t) for t in space.spaces]) else: - return ShArray(space.dtype, space.shape) # type: ignore + return ShArray(space.dtype, space.shape) def _worker( @@ -195,14 +195,22 @@ def wait( # type: ignore remain_conns = [conn for conn in remain_conns if conn not in ready_conns] return [workers[conns.index(con)] for con in ready_conns] - def send(self, action: Optional[np.ndarray]) -> None: - self.parent_remote.send(["step", action]) + def send(self, action: Optional[np.ndarray], **kwargs: Any) -> None: + if action is None: + self.parent_remote.send(["reset", kwargs]) + else: + self.parent_remote.send(["step", action]) def recv( self - ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]: + ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], Tuple[np.ndarray, dict], np.ndarray]: result = self.parent_remote.recv() if isinstance(result, tuple): + if len(result) == 2: + obs, info = result + if self.share_memory: + obs = self._decode_obs() + return obs, info obs, rew, done, info = result if self.share_memory: obs = self._decode_obs() @@ -213,8 +221,9 @@ def recv( obs = self._decode_obs() return obs - def reset(self, **kwargs) -> Union[np.ndarray, Tuple[np.ndarray, dict]]: + def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]: self.parent_remote.send(["reset", kwargs]) + result = self.parent_remote.recv() if isinstance(result, tuple): obs, info = result From 2563439cf0ce7a7a9f5a4fa9c76f5d23625cad78 Mon Sep 17 00:00:00 2001 From: Yifei Cheng Date: Fri, 13 May 2022 00:28:38 -0400 Subject: [PATCH 09/27] add test case for reset with optional kwargs --- test/base/env.py | 9 ++++++--- test/base/test_env.py | 16 ++++++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/test/base/env.py b/test/base/env.py index 872a7c1f2..4e97ee1da 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -73,19 +73,22 @@ def __init__( self.action_space = Discrete(2) self.done = False self.index = 0 - self.reset(seed=0) + self.seed() def seed(self, seed=0): self.rng = np.random.RandomState(seed) return [seed] - def reset(self, state=0, seed=None): + def reset(self, state=0, seed=None, return_info=False): if seed is not None: self.rng = np.random.RandomState(seed) self.done = False self.do_sleep() self.index = state - return self._get_state() + if return_info: + return self._get_state(), {'key': 1, 'env': self} + else: + return self._get_state() def _get_reward(self): """Generate a non-scalar reward if ma_rew is True.""" diff --git a/test/base/test_env.py b/test/base/test_env.py index 073a23e2d..d7a408908 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -60,6 +60,7 @@ def test_async_env(size=10000, num=8, sleep=0.1): test_cls += [RayVectorEnv] for cls in test_cls: v = cls(env_fns, wait_num=num // 2, timeout=1e-3) + v.seed(None) v.reset() # for a random variable u ~ U[0, 1], let v = max{u1, u2, ..., un} # P(v <= x) = x^n (0 <= x <= 1), pdf of v is nx^{n-1} @@ -218,6 +219,21 @@ def test_env_obs_dtype(): assert obs.dtype == object +def test_env_reset_optional_kwargs(size=10000, num=8): + env_fns = [ + lambda i=i: MyTestEnv(size=i) + for i in range(size, size + num) + ] + test_cls = [DummyVectorEnv, SubprocVectorEnv, ShmemVectorEnv] + if has_ray(): + test_cls += [RayVectorEnv] + for cls in test_cls: + v = cls(env_fns, wait_num=num // 2, timeout=1e-3) + _, info = v.reset(seed=1, return_info=True) + assert len(info) == len(env_fns) + assert isinstance(info[0], dict) + + def run_align_norm_obs(raw_env, train_env, test_env, action_list): eps = np.finfo(np.float32).eps.item() raw_obs, train_obs = [raw_env.reset()], [train_env.reset()] From b82eb0e25ef5c9432aab8c371c69486239215791 Mon Sep 17 00:00:00 2001 From: Yifei Cheng Date: Fri, 13 May 2022 01:14:25 -0400 Subject: [PATCH 10/27] satisfy checks --- test/base/env.py | 2 +- test/base/test_env.py | 5 +---- tianshou/env/pettingzoo_env.py | 28 ++++++---------------------- tianshou/env/venvs.py | 6 +++--- tianshou/env/worker/base.py | 7 ++++--- tianshou/env/worker/dummy.py | 2 ++ tianshou/env/worker/ray.py | 2 ++ tianshou/env/worker/subproc.py | 7 ++++++- 8 files changed, 25 insertions(+), 34 deletions(-) diff --git a/test/base/env.py b/test/base/env.py index 4e97ee1da..e29f7ffa5 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -86,7 +86,7 @@ def reset(self, state=0, seed=None, return_info=False): self.do_sleep() self.index = state if return_info: - return self._get_state(), {'key': 1, 'env': self} + return self._get_state(), {'key': 1, 'env': self} else: return self._get_state() diff --git a/test/base/test_env.py b/test/base/test_env.py index d7a408908..7d0e93229 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -220,10 +220,7 @@ def test_env_obs_dtype(): def test_env_reset_optional_kwargs(size=10000, num=8): - env_fns = [ - lambda i=i: MyTestEnv(size=i) - for i in range(size, size + num) - ] + env_fns = [lambda i=i: MyTestEnv(size=i) for i in range(size, size + num)] test_cls = [DummyVectorEnv, SubprocVectorEnv, ShmemVectorEnv] if has_ray(): test_cls += [RayVectorEnv] diff --git a/tianshou/env/pettingzoo_env.py b/tianshou/env/pettingzoo_env.py index e30d3dfc3..c406872dc 100644 --- a/tianshou/env/pettingzoo_env.py +++ b/tianshou/env/pettingzoo_env.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Tuple import gym.spaces from pettingzoo.utils.env import AECEnv @@ -55,19 +55,11 @@ def __init__(self, env: BaseWrapper): self.reset() - def reset( - self, - seed: Optional[int] = None, - return_info: bool = False, - *args: Any, - **kwargs: Any, - ) -> Union[dict, Tuple[dict, dict]]: - self.env.reset(seed=seed, *args, **kwargs) + def reset(self, *args: Any, **kwargs: Any) -> dict: + self.env.reset(*args, **kwargs) observation = self.env.observe(self.env.agent_selection) - observation, _, _, info = self.env.last(self) - if isinstance(observation, dict) and 'action_mask' in observation: - observation_dict = { + return { 'agent_id': self.env.agent_selection, 'obs': observation['observation'], 'mask': @@ -75,21 +67,13 @@ def reset( } else: if isinstance(self.action_space, gym.spaces.Discrete): - observation_dict = { + return { 'agent_id': self.env.agent_selection, 'obs': observation, 'mask': [True] * self.env.action_space(self.env.agent_selection).n } else: - observation_dict = { - 'agent_id': self.env.agent_selection, - 'obs': observation, - } - - if return_info: - return observation_dict, info - else: - return observation_dict + return {'agent_id': self.env.agent_selection, 'obs': observation} def step(self, action: Any) -> Tuple[Dict, List[int], bool, Dict]: self.env.step(action) diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index 4b1b3fa50..aba1c4599 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -213,7 +213,7 @@ def reset( if "return_info" in kwargs and kwargs["return_info"]: infos = [r[1] for r in ret_list] - return obs, infos + return obs, infos # type: ignore else: return obs @@ -260,7 +260,7 @@ def step( self.workers[j].send(action[i]) result = [] for j in id: - obs, rew, done, info = self.workers[j].recv() + obs, rew, done, info = self.workers[j].recv() # type: ignore info["env_id"] = j result.append((obs, rew, done, info)) else: @@ -282,7 +282,7 @@ def step( waiting_index = self.waiting_conn.index(conn) self.waiting_conn.pop(waiting_index) env_id = self.waiting_id.pop(waiting_index) - obs, rew, done, info = conn.recv() + obs, rew, done, info = conn.recv() # type: ignore info["env_id"] = env_id result.append((obs, rew, done, info)) self.ready_id.append(env_id) diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/base.py index 3f905ebe6..3ea46d724 100644 --- a/tianshou/env/worker/base.py +++ b/tianshou/env/worker/base.py @@ -14,7 +14,7 @@ def __init__(self, env_fn: Callable[[], gym.Env]) -> None: self._env_fn = env_fn self.is_closed = False self.result: Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], - np.ndarray] + Tuple[np.ndarray, dict], np.ndarray] self.action_space = self.get_env_attr("action_space") # noqa: B009 self.is_reset = False @@ -40,14 +40,15 @@ def send(self, action: Optional[np.ndarray]) -> None: ) if action is None: self.is_reset = True - self.result = self.reset() # type: ignore + self.result = self.reset() else: self.is_reset = False self.send_action(action) # type: ignore def recv( self - ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]: + ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], Tuple[ + np.ndarray, dict], np.ndarray]: # noqa:E125 """Receive result from low-level worker. If the last "send" function sends a NULL action, it only returns a diff --git a/tianshou/env/worker/dummy.py b/tianshou/env/worker/dummy.py index 6fe5ba61c..0961cdc11 100644 --- a/tianshou/env/worker/dummy.py +++ b/tianshou/env/worker/dummy.py @@ -20,6 +20,8 @@ def set_env_attr(self, key: str, value: Any) -> None: setattr(self.env, key, value) def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]: + if "seed" in kwargs: + super().seed(kwargs["seed"]) return self.env.reset(**kwargs) @staticmethod diff --git a/tianshou/env/worker/ray.py b/tianshou/env/worker/ray.py index 038a7fcaf..055fd7a8d 100644 --- a/tianshou/env/worker/ray.py +++ b/tianshou/env/worker/ray.py @@ -36,6 +36,8 @@ def set_env_attr(self, key: str, value: Any) -> None: ray.get(self.env.set_env_attr.remote(key, value)) def reset(self, **kwargs: Any) -> Any: + if "seed" in kwargs: + super().seed(kwargs["seed"]) return ray.get(self.env.reset.remote(**kwargs)) @staticmethod diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py index f445f758e..b5336b632 100644 --- a/tianshou/env/worker/subproc.py +++ b/tianshou/env/worker/subproc.py @@ -197,13 +197,16 @@ def wait( # type: ignore def send(self, action: Optional[np.ndarray], **kwargs: Any) -> None: if action is None: + if "seed" in kwargs: + super().seed(kwargs["seed"]) self.parent_remote.send(["reset", kwargs]) else: self.parent_remote.send(["step", action]) def recv( self - ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], Tuple[np.ndarray, dict], np.ndarray]: + ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], Tuple[ + np.ndarray, dict], np.ndarray]: # noqa:E125 result = self.parent_remote.recv() if isinstance(result, tuple): if len(result) == 2: @@ -222,6 +225,8 @@ def recv( return obs def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]: + if "seed" in kwargs: + super().seed(kwargs["seed"]) self.parent_remote.send(["reset", kwargs]) result = self.parent_remote.recv() From d644280f90920a1060abcc92e476714f1b2ad975 Mon Sep 17 00:00:00 2001 From: Yifei Cheng Date: Fri, 13 May 2022 01:21:50 -0400 Subject: [PATCH 11/27] pettingzoo reset supports return_info --- tianshou/env/pettingzoo_env.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/tianshou/env/pettingzoo_env.py b/tianshou/env/pettingzoo_env.py index c406872dc..1722dc563 100644 --- a/tianshou/env/pettingzoo_env.py +++ b/tianshou/env/pettingzoo_env.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Tuple, Union import gym.spaces from pettingzoo.utils.env import AECEnv @@ -55,11 +55,11 @@ def __init__(self, env: BaseWrapper): self.reset() - def reset(self, *args: Any, **kwargs: Any) -> dict: + def reset(self, *args: Any, **kwargs: Any) -> Union[dict, Tuple[dict, dict]]: self.env.reset(*args, **kwargs) - observation = self.env.observe(self.env.agent_selection) + observation, _, _, info = self.env.last(self) if isinstance(observation, dict) and 'action_mask' in observation: - return { + observation_dict = { 'agent_id': self.env.agent_selection, 'obs': observation['observation'], 'mask': @@ -67,13 +67,21 @@ def reset(self, *args: Any, **kwargs: Any) -> dict: } else: if isinstance(self.action_space, gym.spaces.Discrete): - return { + observation_dict = { 'agent_id': self.env.agent_selection, 'obs': observation, 'mask': [True] * self.env.action_space(self.env.agent_selection).n } else: - return {'agent_id': self.env.agent_selection, 'obs': observation} + observation_dict = { + 'agent_id': self.env.agent_selection, + 'obs': observation, + } + + if "return_info" in kwargs and kwargs["return_info"]: + return observation_dict, info + else: + return observation_dict def step(self, action: Any) -> Tuple[Dict, List[int], bool, Dict]: self.env.step(action) From 8ffd633187862619e6de27e1bf0d5017dc7f8d19 Mon Sep 17 00:00:00 2001 From: Yifei Cheng Date: Fri, 13 May 2022 01:39:51 -0400 Subject: [PATCH 12/27] small addition --- examples/atari/atari_wrapper.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index b89c1d46f..3f75e9e28 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -32,7 +32,10 @@ def __init__(self, env, noop_max=30): def reset(self): self.env.reset() - noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) + if hasattr(self.unwrapped.np_random, "integers"): + noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) + else: + noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) for _ in range(noops): obs, _, done, _ = self.env.step(self.noop_action) if done: From 861a1ba2a47f16eb6482deaf75aabd623866d4b2 Mon Sep 17 00:00:00 2001 From: Yifei Cheng Date: Fri, 13 May 2022 17:38:20 -0400 Subject: [PATCH 13/27] fix mypy --- tianshou/data/collector.py | 2 +- tianshou/env/worker/dummy.py | 5 +++-- tianshou/env/worker/subproc.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 21f9707bb..537986cf6 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -64,7 +64,7 @@ def __init__( super().__init__() if isinstance(env, gym.Env) and not hasattr(env, "__len__"): warnings.warn("Single environment detected, wrap to DummyVectorEnv.") - self.env = DummyVectorEnv([lambda: env]) + self.env = DummyVectorEnv([lambda: env]) # type: ignore else: self.env = env # type: ignore self.env_num = len(self.env) diff --git a/tianshou/env/worker/dummy.py b/tianshou/env/worker/dummy.py index 0961cdc11..58a2fc3c9 100644 --- a/tianshou/env/worker/dummy.py +++ b/tianshou/env/worker/dummy.py @@ -35,14 +35,15 @@ def send(self, action: Optional[np.ndarray], **kwargs: Any) -> None: if action is None: self.result = self.env.reset(**kwargs) else: - self.result = self.env.step(action) + self.result = self.env.step(action) # type: ignore def seed(self, seed: Optional[int] = None) -> Optional[List[int]]: super().seed(seed) try: return self.env.seed(seed) except NotImplementedError: - return self.env.reset(seed=seed) + self.env.reset(seed=seed) + return [seed] # type: ignore def render(self, **kwargs: Any) -> Any: return self.env.render(**kwargs) diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py index b5336b632..b35d89c3c 100644 --- a/tianshou/env/worker/subproc.py +++ b/tianshou/env/worker/subproc.py @@ -53,7 +53,7 @@ def _setup_buf(space: gym.Space) -> Union[dict, tuple, ShArray]: assert isinstance(space.spaces, tuple) return tuple([_setup_buf(t) for t in space.spaces]) else: - return ShArray(space.dtype, space.shape) + return ShArray(space.dtype, space.shape) # type: ignore def _worker( From 364b46e249f70b21dd8c3b46dcdcb24fe1162b88 Mon Sep 17 00:00:00 2001 From: Yifei Cheng Date: Sat, 21 May 2022 09:38:07 -0400 Subject: [PATCH 14/27] return info based on the return type of env.reset --- test/base/test_env.py | 14 ++++++++++++++ tianshou/env/venv_wrappers.py | 19 +++++++++++++++---- tianshou/env/venvs.py | 13 +++++++++++-- tianshou/env/worker/subproc.py | 14 ++++++++++---- 4 files changed, 50 insertions(+), 10 deletions(-) diff --git a/test/base/test_env.py b/test/base/test_env.py index 7d0e93229..59653debe 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -291,6 +291,19 @@ def test_venv_wrapper_envpool(): run_align_norm_obs(raw, train, test, actions) +@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") +def test_venv_wrapper_envpool_gym_reset_return_info(): + num_envs = 4 + env = VectorEnvNormObs( + envpool.make_gym("Ant-v3", num_envs=num_envs, gym_reset_return_info=True) + ) + obs, info = env.reset() + assert obs.shape[0] == num_envs + for _, v in info.items(): + if not isinstance(v, dict): + assert v.shape[0] == num_envs + + if __name__ == '__main__': test_venv_norm_obs() test_venv_wrapper_envpool() @@ -298,3 +311,4 @@ def test_venv_wrapper_envpool(): test_vecenv() test_async_env() test_async_check_id() + test_env_reset_optional_kwargs() diff --git a/tianshou/env/venv_wrappers.py b/tianshou/env/venv_wrappers.py index b475c2a92..3297421c4 100644 --- a/tianshou/env/venv_wrappers.py +++ b/tianshou/env/venv_wrappers.py @@ -92,14 +92,25 @@ def reset( id: Optional[Union[int, List[int], np.ndarray]] = None, **kwargs: Any, ) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]: - if "return_info" in kwargs and kwargs["return_info"]: - obs, info = self.venv.reset(id, **kwargs) + retval = self.venv.reset(id, **kwargs) + has_info = isinstance(retval, + (tuple, list + )) and len(retval) == 2 and isinstance(retval[1], dict) + if has_info: + obs, info = retval else: - obs = self.venv.reset(id) + obs = retval + + if isinstance(obs, tuple): + raise Exception( + "Tuple observation space is not supported. ", + "Please change it to array or dict space", + ) + if self.obs_rms and self.update_obs_rms: self.obs_rms.update(obs) obs = self._norm_obs(obs) - if "return_info" in kwargs and kwargs["return_info"]: + if has_info: return obs, info else: return obs diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index aba1c4599..cdf49584f 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -202,16 +202,25 @@ def reset( self.workers[i].send(None, **kwargs) ret_list = [self.workers[i].recv() for i in id] - if "return_info" in kwargs and kwargs["return_info"]: + has_infos = isinstance(ret_list[0], (tuple, list)) and len( + ret_list[0] + ) == 2 and isinstance(ret_list[0][1], dict) + if has_infos: obs_list = [r[0] for r in ret_list] else: obs_list = ret_list + + if isinstance(obs_list[0], tuple): + raise Exception( + "Tuple observation space is not supported. ", + "Please change it to array or dict space", + ) try: obs = np.stack(obs_list) except ValueError: # different len(obs) obs = np.array(obs_list, dtype=object) - if "return_info" in kwargs and kwargs["return_info"]: + if has_infos: infos = [r[1] for r in ret_list] return obs, infos # type: ignore else: diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py index b35d89c3c..ee0ce833c 100644 --- a/tianshou/env/worker/subproc.py +++ b/tianshou/env/worker/subproc.py @@ -92,14 +92,20 @@ def _encode_obs( obs = None p.send((obs, reward, done, info)) elif cmd == "reset": - if "return_info" in data and data["return_info"]: - obs, info = env.reset(**data) + retval = env.reset(**data) + print(f"type(retval): {type(retval)}") + print(retval) + has_info = isinstance( + retval, (tuple, list) + ) and len(retval) == 2 and isinstance(retval[1], dict) + if has_info: + obs, info = retval else: - obs = env.reset(**data) + obs = retval if obs_bufs is not None: _encode_obs(obs, obs_bufs) obs = None - if "return_info" in data and data["return_info"]: + if has_info: p.send((obs, info)) else: p.send(obs) From d99b5ae48dbfed335681937660c42faf8e3dff1d Mon Sep 17 00:00:00 2001 From: Yifei Cheng Date: Sat, 21 May 2022 09:40:19 -0400 Subject: [PATCH 15/27] switch tuple observation Exception to TypeError --- tianshou/env/venv_wrappers.py | 2 +- tianshou/env/venvs.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tianshou/env/venv_wrappers.py b/tianshou/env/venv_wrappers.py index 3297421c4..b52bf5aca 100644 --- a/tianshou/env/venv_wrappers.py +++ b/tianshou/env/venv_wrappers.py @@ -102,7 +102,7 @@ def reset( obs = retval if isinstance(obs, tuple): - raise Exception( + raise TypeError( "Tuple observation space is not supported. ", "Please change it to array or dict space", ) diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index cdf49584f..83da60732 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -211,7 +211,7 @@ def reset( obs_list = ret_list if isinstance(obs_list[0], tuple): - raise Exception( + raise TypeError( "Tuple observation space is not supported. ", "Please change it to array or dict space", ) From 88d865af044a5a20bdfd2c02de4a1a2279e10577 Mon Sep 17 00:00:00 2001 From: Yifei Cheng Date: Sat, 21 May 2022 09:43:52 -0400 Subject: [PATCH 16/27] remove debug prints --- tianshou/env/worker/subproc.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py index ee0ce833c..3de258583 100644 --- a/tianshou/env/worker/subproc.py +++ b/tianshou/env/worker/subproc.py @@ -93,8 +93,6 @@ def _encode_obs( p.send((obs, reward, done, info)) elif cmd == "reset": retval = env.reset(**data) - print(f"type(retval): {type(retval)}") - print(retval) has_info = isinstance( retval, (tuple, list) ) and len(retval) == 2 and isinstance(retval[1], dict) From 662a68a05363db391b2a50717010f41ae10d7f1f Mon Sep 17 00:00:00 2001 From: Yifei Cheng Date: Sat, 21 May 2022 23:11:04 -0400 Subject: [PATCH 17/27] check reset_returns_info once --- tianshou/env/venv_wrappers.py | 11 ++++++----- tianshou/env/venvs.py | 11 ++++++----- tianshou/env/worker/subproc.py | 12 +++++++----- 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/tianshou/env/venv_wrappers.py b/tianshou/env/venv_wrappers.py index b52bf5aca..ad4626548 100644 --- a/tianshou/env/venv_wrappers.py +++ b/tianshou/env/venv_wrappers.py @@ -93,10 +93,11 @@ def reset( **kwargs: Any, ) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]: retval = self.venv.reset(id, **kwargs) - has_info = isinstance(retval, - (tuple, list - )) and len(retval) == 2 and isinstance(retval[1], dict) - if has_info: + if not hasattr(self, "reset_returns_info"): + self.reset_returns_info = isinstance( + retval, (tuple, list) + ) and len(retval) == 2 and isinstance(retval[1], dict) + if self.reset_returns_info: obs, info = retval else: obs = retval @@ -110,7 +111,7 @@ def reset( if self.obs_rms and self.update_obs_rms: self.obs_rms.update(obs) obs = self._norm_obs(obs) - if has_info: + if self.reset_returns_info: return obs, info else: return obs diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index 83da60732..027cd6161 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -202,10 +202,11 @@ def reset( self.workers[i].send(None, **kwargs) ret_list = [self.workers[i].recv() for i in id] - has_infos = isinstance(ret_list[0], (tuple, list)) and len( - ret_list[0] - ) == 2 and isinstance(ret_list[0][1], dict) - if has_infos: + if not hasattr(self, "reset_returns_info"): + self.reset_returns_info = isinstance(ret_list[0], (tuple, list)) and len( + ret_list[0] + ) == 2 and isinstance(ret_list[0][1], dict) + if self.reset_returns_info: obs_list = [r[0] for r in ret_list] else: obs_list = ret_list @@ -220,7 +221,7 @@ def reset( except ValueError: # different len(obs) obs = np.array(obs_list, dtype=object) - if has_infos: + if self.reset_returns_info: infos = [r[1] for r in ret_list] return obs, infos # type: ignore else: diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py index 3de258583..0cfbd4ce5 100644 --- a/tianshou/env/worker/subproc.py +++ b/tianshou/env/worker/subproc.py @@ -78,6 +78,7 @@ def _encode_obs( parent.close() env = env_fn_wrapper.data() + reset_returns_info = None try: while True: try: @@ -93,17 +94,18 @@ def _encode_obs( p.send((obs, reward, done, info)) elif cmd == "reset": retval = env.reset(**data) - has_info = isinstance( - retval, (tuple, list) - ) and len(retval) == 2 and isinstance(retval[1], dict) - if has_info: + if reset_returns_info is None: + reset_returns_info = isinstance( + retval, (tuple, list) + ) and len(retval) == 2 and isinstance(retval[1], dict) + if reset_returns_info: obs, info = retval else: obs = retval if obs_bufs is not None: _encode_obs(obs, obs_bufs) obs = None - if has_info: + if reset_returns_info: p.send((obs, info)) else: p.send(obs) From 75ecd180f7f38076ad10e36301b367b2cf63268f Mon Sep 17 00:00:00 2001 From: Yifei Cheng Date: Tue, 24 May 2022 22:35:42 -0400 Subject: [PATCH 18/27] support reset returns info in collector --- test/base/test_collector.py | 27 ++++++++++---- tianshou/data/collector.py | 73 ++++++++++++++++++++++++++----------- 2 files changed, 71 insertions(+), 29 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 9a8d74912..48986c234 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -77,7 +77,8 @@ def single_preprocess_fn(**kwargs): return Batch() -def test_collector(): +@pytest.mark.parametrize("gym_reset_return_info", [False, True]) +def test_collector(gym_reset_return_info): writer = SummaryWriter('log/collector') logger = Logger(writer) env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]] @@ -86,7 +87,13 @@ def test_collector(): dum = DummyVectorEnv(env_fns) policy = MyPolicy() env = env_fns[0]() - c0 = Collector(policy, env, ReplayBuffer(size=100), logger.preprocess_fn) + c0 = Collector( + policy, + env, + ReplayBuffer(size=100), + logger.preprocess_fn, + gym_reset_return_info=gym_reset_return_info, + ) c0.collect(n_step=3) assert len(c0.buffer) == 3 assert np.allclose(c0.buffer.obs[:4, 0], [0, 1, 0, 0]) @@ -151,7 +158,8 @@ def test_collector(): assert c3.buffer.obs.dtype == object -def test_collector_with_async(): +@pytest.mark.parametrize("gym_reset_return_info", [False, True]) +def test_collector_with_async(gym_reset_return_info): env_lens = [2, 3, 4, 5] writer = SummaryWriter('log/async_collector') logger = Logger(writer) @@ -163,8 +171,11 @@ def test_collector_with_async(): policy = MyPolicy() bufsize = 60 c1 = AsyncCollector( - policy, venv, VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4), - logger.preprocess_fn + policy, + venv, + VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4), + logger.preprocess_fn, + gym_reset_return_info=gym_reset_return_info, ) ptr = [0, 0, 0, 0] for n_episode in tqdm.trange(1, 30, desc="test async n_episode"): @@ -619,8 +630,10 @@ def test_collector_with_atari_setting(): if __name__ == '__main__': - test_collector() + test_collector(gym_reset_return_info=True) + test_collector(gym_reset_return_info=False) test_collector_with_dict_state() test_collector_with_ma() test_collector_with_atari_setting() - test_collector_with_async() + test_collector_with_async(gym_reset_return_info=True) + test_collector_with_async(gym_reset_return_info=False) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 537986cf6..20273947f 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -34,13 +34,16 @@ class Collector(object): with corresponding policy's exploration noise. If so, "policy. exploration_noise(act, batch)" will be called automatically to add the exploration noise into action. Default to False. + :param bool gym_reset_return_info: if set to True, return the info dict when + resetting the environment. The "preprocess_fn" is a function called before the data has been added to the - buffer with batch format. It will receive only "obs" and "env_id" when the - collector resets the environment, and will receive six keys "obs_next", "rew", - "done", "info", "policy" and "env_id" in a normal env step. It returns either a - dict or a :class:`~tianshou.data.Batch` with the modified keys and values. Examples - are in "test/base/test_collector.py". + buffer with batch format. It will receive only "obs", "info" (if + ``gym_reset_return_info = True``), and "env_id" when the collector resets the + environment, and will receive six keys "obs_next", "rew", "done", "info", "policy" + and "env_id" in a normal env step. It returns either a dict or a + :class:`~tianshou.data.Batch` with the modified keys and values. Examples are in + "test/base/test_collector.py". .. note:: @@ -60,6 +63,7 @@ def __init__( buffer: Optional[ReplayBuffer] = None, preprocess_fn: Optional[Callable[..., Batch]] = None, exploration_noise: bool = False, + gym_reset_return_info: bool = False, ) -> None: super().__init__() if isinstance(env, gym.Env) and not hasattr(env, "__len__"): @@ -72,6 +76,7 @@ def __init__( self._assign_buffer(buffer) self.policy = policy self.preprocess_fn = preprocess_fn + self.gym_reset_return_info = gym_reset_return_info self._action_space = self.env.action_space # avoid creating attribute outside __init__ self.reset(False) @@ -126,10 +131,20 @@ def reset_buffer(self, keep_statistics: bool = False) -> None: def reset_env(self) -> None: """Reset all of the environments.""" - obs = self.env.reset() - if self.preprocess_fn: - obs = self.preprocess_fn(obs=obs, - env_id=np.arange(self.env_num)).get("obs", obs) + if self.gym_reset_return_info: + obs, info = self.env.reset(return_info=True) + if self.preprocess_fn: + processed_data = self.preprocess_fn( + obs=obs, info=info, env_id=np.arange(self.env_num) + ) + obs = processed_data.get("obs", obs) + info = processed_data.get("info", info) + self.data.info = info + else: + obs = self.env.reset() + if self.preprocess_fn: + obs = self.preprocess_fn(obs=obs, env_id=np.arange(self.env_num + )).get("obs", obs) self.data.obs = obs def _reset_state(self, id: Union[int, List[int]]) -> None: @@ -143,6 +158,26 @@ def _reset_state(self, id: Union[int, List[int]]) -> None: elif isinstance(state, Batch): state.empty_(id) + def _reset_env_with_ids( + self, local_ids: Union[List[int], np.ndarray], global_ids: Union[List[int], + np.ndarray] + ) -> None: + if self.gym_reset_return_info: + obs_reset, info = self.env.reset(global_ids, return_info=True) + if self.preprocess_fn: + processed_data = self.preprocess_fn( + obs=obs_reset, info=info, env_id=global_ids + ) + obs_reset = processed_data.get("obs", obs_reset) + info = processed_data.get("info", info) + self.data.info[local_ids] = info + else: + obs_reset = self.env.reset(global_ids) + if self.preprocess_fn: + obs_reset = self.preprocess_fn(obs=obs_reset, env_id=global_ids + ).get("obs", obs_reset) + self.data.obs_next[local_ids] = obs_reset + def collect( self, n_step: Optional[int] = None, @@ -288,12 +323,7 @@ def collect( episode_start_indices.append(ep_idx[env_ind_local]) # now we copy obs_next to obs, but since there might be # finished episodes, we have to reset finished envs first. - obs_reset = self.env.reset(env_ind_global) - if self.preprocess_fn: - obs_reset = self.preprocess_fn( - obs=obs_reset, env_id=env_ind_global - ).get("obs", obs_reset) - self.data.obs_next[env_ind_local] = obs_reset + self._reset_env_with_ids(env_ind_local, env_ind_global) for i in env_ind_local: self._reset_state(i) @@ -364,10 +394,14 @@ def __init__( buffer: Optional[ReplayBuffer] = None, preprocess_fn: Optional[Callable[..., Batch]] = None, exploration_noise: bool = False, + gym_reset_return_info: bool = False, ) -> None: # assert env.is_async warnings.warn("Using async setting may collect extra transitions into buffer.") - super().__init__(policy, env, buffer, preprocess_fn, exploration_noise) + super().__init__( + policy, env, buffer, preprocess_fn, exploration_noise, + gym_reset_return_info + ) def reset_env(self) -> None: super().reset_env() @@ -528,12 +562,7 @@ def collect( episode_start_indices.append(ep_idx[env_ind_local]) # now we copy obs_next to obs, but since there might be # finished episodes, we have to reset finished envs first. - obs_reset = self.env.reset(env_ind_global) - if self.preprocess_fn: - obs_reset = self.preprocess_fn( - obs=obs_reset, env_id=env_ind_global - ).get("obs", obs_reset) - self.data.obs_next[env_ind_local] = obs_reset + self._reset_env_with_ids(env_ind_local, env_ind_global) for i in env_ind_local: self._reset_state(i) From c2ff71fb5ae1a5dae99dd677e3e72a274b28566a Mon Sep 17 00:00:00 2001 From: Yifei Cheng Date: Mon, 6 Jun 2022 01:56:03 -0400 Subject: [PATCH 19/27] dynamically check reset retval in collector --- test/base/test_collector.py | 125 +++++++++++++++++++++------- test/continuous/test_sac_with_il.py | 8 +- tianshou/data/collector.py | 76 +++++++++-------- tianshou/env/venv_wrappers.py | 11 ++- tianshou/env/venvs.py | 11 ++- 5 files changed, 155 insertions(+), 76 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 48986c234..8b5bf3814 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -15,6 +15,11 @@ from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.policy import BasePolicy +try: + import envpool +except ImportError: + envpool = None + if __name__ == '__main__': from env import MyTestEnv, NXEnv else: # pytest @@ -23,7 +28,7 @@ class MyPolicy(BasePolicy): - def __init__(self, dict_state=False, need_state=True): + def __init__(self, dict_state=False, need_state=True, action_shape=None): """ :param bool dict_state: if the observation of the environment is a dict :param bool need_state: if the policy needs the hidden state (for RNN) @@ -31,6 +36,7 @@ def __init__(self, dict_state=False, need_state=True): super().__init__() self.dict_state = dict_state self.need_state = need_state + self.action_shape = action_shape def forward(self, batch, state=None): if self.need_state: @@ -39,8 +45,12 @@ def forward(self, batch, state=None): else: state += 1 if self.dict_state: - return Batch(act=np.ones(len(batch.obs['index'])), state=state) - return Batch(act=np.ones(len(batch.obs)), state=state) + action_shape = self.action_shape if self.action_shape else len( + batch.obs['index'] + ) + return Batch(act=np.ones(action_shape), state=state) + action_shape = self.action_shape if self.action_shape else len(batch.obs) + return Batch(act=np.ones(action_shape), state=state) def learn(self): pass @@ -77,8 +87,8 @@ def single_preprocess_fn(**kwargs): return Batch() -@pytest.mark.parametrize("gym_reset_return_info", [False, True]) -def test_collector(gym_reset_return_info): +@pytest.mark.parametrize("gym_reset_kwargs", [None, dict(return_info=True)]) +def test_collector(gym_reset_kwargs): writer = SummaryWriter('log/collector') logger = Logger(writer) env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]] @@ -92,53 +102,93 @@ def test_collector(gym_reset_return_info): env, ReplayBuffer(size=100), logger.preprocess_fn, - gym_reset_return_info=gym_reset_return_info, ) - c0.collect(n_step=3) + c0.collect(n_step=3, gym_reset_kwargs=gym_reset_kwargs) assert len(c0.buffer) == 3 assert np.allclose(c0.buffer.obs[:4, 0], [0, 1, 0, 0]) assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1]) - c0.collect(n_episode=3) + assert np.allclose(c0.buffer.info["key"][:3], 1) + for e in c0.buffer.info["env"][:3]: + assert isinstance(e, MyTestEnv) + assert np.allclose(c0.buffer.info["env_id"][:3], 0) + assert np.allclose(c0.buffer.info["rew"][:3], [0, 1, 0]) + c0.collect(n_episode=3, gym_reset_kwargs=gym_reset_kwargs) assert len(c0.buffer) == 8 assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0]) assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2]) - c0.collect(n_step=3, random=True) + assert np.allclose(c0.buffer.info["key"][:8], 1) + for e in c0.buffer.info["env"][:8]: + assert isinstance(e, MyTestEnv) + assert np.allclose(c0.buffer.info["env_id"][:8], 0) + assert np.allclose(c0.buffer.info["rew"][:8], [0, 1, 0, 1, 0, 1, 0, 1]) + c0.collect(n_step=3, random=True, gym_reset_kwargs=gym_reset_kwargs) + c1 = Collector( policy, venv, VectorReplayBuffer(total_size=100, buffer_num=4), logger.preprocess_fn ) - c1.collect(n_step=8) + c1.collect(n_step=8, gym_reset_kwargs=gym_reset_kwargs) obs = np.zeros(100) - obs[[0, 1, 25, 26, 50, 51, 75, 76]] = [0, 1, 0, 1, 0, 1, 0, 1] - + valid_indices = [0, 1, 25, 26, 50, 51, 75, 76] + obs[valid_indices] = [0, 1, 0, 1, 0, 1, 0, 1] assert np.allclose(c1.buffer.obs[:, 0], obs) assert np.allclose(c1.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2]) - c1.collect(n_episode=4) + keys = np.zeros(100) + keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1] + assert np.allclose(c1.buffer.info["key"], keys) + for e in c1.buffer.info["env"][valid_indices]: + assert isinstance(e, MyTestEnv) + env_ids = np.zeros(100) + env_ids[valid_indices] = [0, 0, 1, 1, 2, 2, 3, 3] + assert np.allclose(c1.buffer.info["env_id"], env_ids) + rews = np.zeros(100) + rews[valid_indices] = [0, 1, 0, 0, 0, 0, 0, 0] + assert np.allclose(c1.buffer.info["rew"], rews) + c1.collect(n_episode=4, gym_reset_kwargs=gym_reset_kwargs) assert len(c1.buffer) == 16 + valid_indices = [2, 3, 27, 52, 53, 77, 78, 79] obs[[2, 3, 27, 52, 53, 77, 78, 79]] = [0, 1, 2, 2, 3, 2, 3, 4] assert np.allclose(c1.buffer.obs[:, 0], obs) assert np.allclose( c1.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5] ) - c1.collect(n_episode=4, random=True) + keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1] + assert np.allclose(c1.buffer.info["key"], keys) + for e in c1.buffer.info["env"][valid_indices]: + assert isinstance(e, MyTestEnv) + env_ids[valid_indices] = [0, 0, 1, 2, 2, 3, 3, 3] + assert np.allclose(c1.buffer.info["env_id"], env_ids) + rews[valid_indices] = [0, 1, 1, 0, 1, 0, 0, 1] + assert np.allclose(c1.buffer.info["rew"], rews) + c1.collect(n_episode=4, random=True, gym_reset_kwargs=gym_reset_kwargs) + c2 = Collector( policy, dum, VectorReplayBuffer(total_size=100, buffer_num=4), logger.preprocess_fn ) - c2.collect(n_episode=7) + c2.collect(n_episode=7, gym_reset_kwargs=gym_reset_kwargs) obs1 = obs.copy() obs1[[4, 5, 28, 29, 30]] = [0, 1, 0, 1, 2] obs2 = obs.copy() obs2[[28, 29, 30, 54, 55, 56, 57]] = [0, 1, 2, 0, 1, 2, 3] c2obs = c2.buffer.obs[:, 0] assert np.all(c2obs == obs1) or np.all(c2obs == obs2) - c2.reset_env() + c2.reset_env(gym_reset_kwargs=gym_reset_kwargs) c2.reset_buffer() - assert c2.collect(n_episode=8)['n/ep'] == 8 - obs[[4, 5, 28, 29, 30, 54, 55, 56, 57]] = [0, 1, 0, 1, 2, 0, 1, 2, 3] + assert c2.collect(n_episode=8, gym_reset_kwargs=gym_reset_kwargs)['n/ep'] == 8 + valid_indices = [4, 5, 28, 29, 30, 54, 55, 56, 57] + obs[valid_indices] = [0, 1, 0, 1, 2, 0, 1, 2, 3] assert np.all(c2.buffer.obs[:, 0] == obs) - c2.collect(n_episode=4, random=True) + keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1, 1] + assert np.allclose(c2.buffer.info["key"], keys) + for e in c2.buffer.info["env"][valid_indices]: + assert isinstance(e, MyTestEnv) + env_ids[valid_indices] = [0, 0, 1, 1, 1, 2, 2, 2, 2] + assert np.allclose(c2.buffer.info["env_id"], env_ids) + rews[valid_indices] = [0, 1, 0, 0, 1, 0, 0, 0, 1] + assert np.allclose(c2.buffer.info["rew"], rews) + c2.collect(n_episode=4, random=True, gym_reset_kwargs=gym_reset_kwargs) # test corner case with pytest.raises(TypeError): @@ -154,12 +204,12 @@ def test_collector(gym_reset_return_info): [lambda i=x: NXEnv(i, obs_type) for x in [5, 10, 15, 20]] ) c3 = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4)) - c3.collect(n_step=6) + c3.collect(n_step=6, gym_reset_kwargs=gym_reset_kwargs) assert c3.buffer.obs.dtype == object -@pytest.mark.parametrize("gym_reset_return_info", [False, True]) -def test_collector_with_async(gym_reset_return_info): +@pytest.mark.parametrize("gym_reset_kwargs", [None, dict(return_info=True)]) +def test_collector_with_async(gym_reset_kwargs): env_lens = [2, 3, 4, 5] writer = SummaryWriter('log/async_collector') logger = Logger(writer) @@ -175,11 +225,10 @@ def test_collector_with_async(gym_reset_return_info): venv, VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4), logger.preprocess_fn, - gym_reset_return_info=gym_reset_return_info, ) ptr = [0, 0, 0, 0] for n_episode in tqdm.trange(1, 30, desc="test async n_episode"): - result = c1.collect(n_episode=n_episode) + result = c1.collect(n_episode=n_episode, gym_reset_kwargs=gym_reset_kwargs) assert result["n/ep"] >= n_episode # check buffer data, obs and obs_next, env_id for i, count in enumerate(np.bincount(result["lens"], minlength=6)[2:]): @@ -194,7 +243,7 @@ def test_collector_with_async(gym_reset_return_info): assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1) # test async n_step, for now the buffer should be full of data for n_step in tqdm.trange(1, 15, desc="test async n_step"): - result = c1.collect(n_step=n_step) + result = c1.collect(n_step=n_step, gym_reset_kwargs=gym_reset_kwargs) assert result["n/st"] >= n_step for i in range(4): env_len = i + 2 @@ -629,11 +678,29 @@ def test_collector_with_atari_setting(): assert np.allclose(result2[key], result[key]) +@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") +def test_collector_envpool_gym_reset_return_info(): + envs = envpool.make_gym("Pendulum-v0", num_envs=4, gym_reset_return_info=True) + policy = MyPolicy(action_shape=(len(envs), 1)) + + c0 = Collector( + policy, + envs, + VectorReplayBuffer(len(envs) * 10, len(envs)), + exploration_noise=True + ) + c0.collect(n_step=8) + env_ids = np.zeros(len(envs) * 10) + env_ids[[0, 1, 10, 11, 20, 21, 30, 31]] = [0, 0, 1, 1, 2, 2, 3, 3] + assert np.allclose(c0.buffer.info["env_id"], env_ids) + + if __name__ == '__main__': - test_collector(gym_reset_return_info=True) - test_collector(gym_reset_return_info=False) + test_collector(gym_reset_kwargs=None) + test_collector(gym_reset_kwargs=dict(return_info=True)) test_collector_with_dict_state() test_collector_with_ma() test_collector_with_atari_setting() - test_collector_with_async(gym_reset_return_info=True) - test_collector_with_async(gym_reset_return_info=False) + test_collector_with_async(gym_reset_kwargs=None) + test_collector_with_async(gym_reset_kwargs=dict(return_info=True)) + test_collector_envpool_gym_reset_return_info() diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index b65e2d321..a3349f8e5 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -57,10 +57,14 @@ def get_args(): @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") -def test_sac_with_il(args=get_args()): +@pytest.mark.parametrize("gym_reset_return_info", [False, True]) +def test_sac_with_il(gym_reset_return_info, args=get_args()): # if you want to use python vector env, please refer to other test scripts train_envs = env = envpool.make_gym( - args.task, num_envs=args.training_num, seed=args.seed + args.task, + num_envs=args.training_num, + seed=args.seed, + gym_reset_return_info=gym_reset_return_info ) test_envs = envpool.make_gym(args.task, num_envs=args.test_num, seed=args.seed) args.state_shape = env.observation_space.shape or env.observation_space.n diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 20273947f..b33457686 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -22,7 +22,6 @@ class Collector(object): """Collector enables the policy to interact with different types of envs with \ exact number of steps or episodes. - :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param env: a ``gym.Env`` environment or an instance of the :class:`~tianshou.env.BaseVectorEnv` class. @@ -34,24 +33,16 @@ class Collector(object): with corresponding policy's exploration noise. If so, "policy. exploration_noise(act, batch)" will be called automatically to add the exploration noise into action. Default to False. - :param bool gym_reset_return_info: if set to True, return the info dict when - resetting the environment. - The "preprocess_fn" is a function called before the data has been added to the - buffer with batch format. It will receive only "obs", "info" (if - ``gym_reset_return_info = True``), and "env_id" when the collector resets the - environment, and will receive six keys "obs_next", "rew", "done", "info", "policy" - and "env_id" in a normal env step. It returns either a dict or a - :class:`~tianshou.data.Batch` with the modified keys and values. Examples are in - "test/base/test_collector.py". - + buffer with batch format. It will receive only "obs" and "env_id" when the + collector resets the environment, and will receive six keys "obs_next", "rew", + "done", "info", "policy" and "env_id" in a normal env step. It returns either a + dict or a :class:`~tianshou.data.Batch` with the modified keys and values. Examples + are in "test/base/test_collector.py". .. note:: - Please make sure the given environment has a time limitation if using n_episode collect option. - .. note:: - In past versions of Tianshou, the replay buffer that was passed to `__init__` was automatically reset. This is not done in the current implementation. """ @@ -63,7 +54,6 @@ def __init__( buffer: Optional[ReplayBuffer] = None, preprocess_fn: Optional[Callable[..., Batch]] = None, exploration_noise: bool = False, - gym_reset_return_info: bool = False, ) -> None: super().__init__() if isinstance(env, gym.Env) and not hasattr(env, "__len__"): @@ -76,7 +66,6 @@ def __init__( self._assign_buffer(buffer) self.policy = policy self.preprocess_fn = preprocess_fn - self.gym_reset_return_info = gym_reset_return_info self._action_space = self.env.action_space # avoid creating attribute outside __init__ self.reset(False) @@ -105,7 +94,11 @@ def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None: ) self.buffer = buffer - def reset(self, reset_buffer: bool = True) -> None: + def reset( + self, + reset_buffer: bool = True, + gym_reset_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: """Reset the environment, statistics, current data and possibly replay memory. :param bool reset_buffer: if true, reset the replay buffer that is attached @@ -116,7 +109,7 @@ def reset(self, reset_buffer: bool = True) -> None: self.data = Batch( obs={}, act={}, rew={}, done={}, obs_next={}, info={}, policy={} ) - self.reset_env() + self.reset_env(gym_reset_kwargs) if reset_buffer: self.reset_buffer() self.reset_stat() @@ -129,10 +122,15 @@ def reset_buffer(self, keep_statistics: bool = False) -> None: """Reset the data buffer.""" self.buffer.reset(keep_statistics=keep_statistics) - def reset_env(self) -> None: + def reset_env(self, gym_reset_kwargs: Optional[Dict[str, Any]] = None) -> None: """Reset all of the environments.""" - if self.gym_reset_return_info: - obs, info = self.env.reset(return_info=True) + gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {} + retval = self.env.reset(**gym_reset_kwargs) + returns_info = isinstance(retval, (tuple, list)) and len(retval) == 2 and ( + isinstance(retval[1], dict) or isinstance(retval[1][0], dict) + ) + if returns_info: + obs, info = retval if self.preprocess_fn: processed_data = self.preprocess_fn( obs=obs, info=info, env_id=np.arange(self.env_num) @@ -141,7 +139,7 @@ def reset_env(self) -> None: info = processed_data.get("info", info) self.data.info = info else: - obs = self.env.reset() + obs = retval if self.preprocess_fn: obs = self.preprocess_fn(obs=obs, env_id=np.arange(self.env_num )).get("obs", obs) @@ -159,11 +157,18 @@ def _reset_state(self, id: Union[int, List[int]]) -> None: state.empty_(id) def _reset_env_with_ids( - self, local_ids: Union[List[int], np.ndarray], global_ids: Union[List[int], - np.ndarray] + self, + local_ids: Union[List[int], np.ndarray], + global_ids: Union[List[int], np.ndarray], + gym_reset_kwargs: Optional[Dict[str, Any]] = None, ) -> None: - if self.gym_reset_return_info: - obs_reset, info = self.env.reset(global_ids, return_info=True) + gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {} + retval = self.env.reset(global_ids, **gym_reset_kwargs) + returns_info = isinstance(retval, (tuple, list)) and len(retval) == 2 and ( + isinstance(retval[1], dict) or isinstance(retval[1][0], dict) + ) + if returns_info: + obs_reset, info = retval if self.preprocess_fn: processed_data = self.preprocess_fn( obs=obs_reset, info=info, env_id=global_ids @@ -172,7 +177,7 @@ def _reset_env_with_ids( info = processed_data.get("info", info) self.data.info[local_ids] = info else: - obs_reset = self.env.reset(global_ids) + obs_reset = retval if self.preprocess_fn: obs_reset = self.preprocess_fn(obs=obs_reset, env_id=global_ids ).get("obs", obs_reset) @@ -185,6 +190,7 @@ def collect( random: bool = False, render: Optional[float] = None, no_grad: bool = True, + gym_reset_kwargs: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """Collect a specified number of step or episode. @@ -323,7 +329,9 @@ def collect( episode_start_indices.append(ep_idx[env_ind_local]) # now we copy obs_next to obs, but since there might be # finished episodes, we have to reset finished envs first. - self._reset_env_with_ids(env_ind_local, env_ind_global) + self._reset_env_with_ids( + env_ind_local, env_ind_global, gym_reset_kwargs + ) for i in env_ind_local: self._reset_state(i) @@ -394,17 +402,19 @@ def __init__( buffer: Optional[ReplayBuffer] = None, preprocess_fn: Optional[Callable[..., Batch]] = None, exploration_noise: bool = False, - gym_reset_return_info: bool = False, ) -> None: # assert env.is_async warnings.warn("Using async setting may collect extra transitions into buffer.") super().__init__( - policy, env, buffer, preprocess_fn, exploration_noise, - gym_reset_return_info + policy, + env, + buffer, + preprocess_fn, + exploration_noise, ) - def reset_env(self) -> None: - super().reset_env() + def reset_env(self, gym_reset_kwargs: Optional[Dict[str, Any]] = None) -> None: + super().reset_env(gym_reset_kwargs) self._ready_env_ids = np.arange(self.env_num) def collect( diff --git a/tianshou/env/venv_wrappers.py b/tianshou/env/venv_wrappers.py index ad4626548..bb5e294b6 100644 --- a/tianshou/env/venv_wrappers.py +++ b/tianshou/env/venv_wrappers.py @@ -93,11 +93,10 @@ def reset( **kwargs: Any, ) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]: retval = self.venv.reset(id, **kwargs) - if not hasattr(self, "reset_returns_info"): - self.reset_returns_info = isinstance( - retval, (tuple, list) - ) and len(retval) == 2 and isinstance(retval[1], dict) - if self.reset_returns_info: + reset_returns_info = isinstance( + retval, (tuple, list) + ) and len(retval) == 2 and isinstance(retval[1], dict) + if reset_returns_info: obs, info = retval else: obs = retval @@ -111,7 +110,7 @@ def reset( if self.obs_rms and self.update_obs_rms: self.obs_rms.update(obs) obs = self._norm_obs(obs) - if self.reset_returns_info: + if reset_returns_info: return obs, info else: return obs diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index 027cd6161..1f12d3fdf 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -202,11 +202,10 @@ def reset( self.workers[i].send(None, **kwargs) ret_list = [self.workers[i].recv() for i in id] - if not hasattr(self, "reset_returns_info"): - self.reset_returns_info = isinstance(ret_list[0], (tuple, list)) and len( - ret_list[0] - ) == 2 and isinstance(ret_list[0][1], dict) - if self.reset_returns_info: + reset_returns_info = isinstance(ret_list[0], (tuple, list)) and len( + ret_list[0] + ) == 2 and isinstance(ret_list[0][1], dict) + if reset_returns_info: obs_list = [r[0] for r in ret_list] else: obs_list = ret_list @@ -221,7 +220,7 @@ def reset( except ValueError: # different len(obs) obs = np.array(obs_list, dtype=object) - if self.reset_returns_info: + if reset_returns_info: infos = [r[1] for r in ret_list] return obs, infos # type: ignore else: From 12cf50fdcdc47c4f0fc35fb13a2bd3a158bee47d Mon Sep 17 00:00:00 2001 From: Yifei Cheng Date: Mon, 6 Jun 2022 02:04:22 -0400 Subject: [PATCH 20/27] bump gym version to 0.23.1 and fix mypy --- setup.py | 2 +- tianshou/data/collector.py | 13 ++++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 0ae409056..1f207d847 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ def get_version() -> str: def get_install_requires() -> str: return [ - "gym>=0.15.4", + "gym>=0.23.1", "tqdm", "numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793 "tensorboard>=2.5.0", diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index b33457686..2f4401102 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -103,6 +103,8 @@ def reset( :param bool reset_buffer: if true, reset the replay buffer that is attached to the collector. + :param gym_reset_kwargs: extra keyword arguments to pass into the environment's + reset function. Defaults to None (extra keyword arguments) """ # use empty Batch for "state" so that self.data supports slicing # convert empty Batch to None when passing data to policy @@ -127,7 +129,7 @@ def reset_env(self, gym_reset_kwargs: Optional[Dict[str, Any]] = None) -> None: gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {} retval = self.env.reset(**gym_reset_kwargs) returns_info = isinstance(retval, (tuple, list)) and len(retval) == 2 and ( - isinstance(retval[1], dict) or isinstance(retval[1][0], dict) + isinstance(retval[1], dict) or isinstance(retval[1][0], dict) # type: ignore ) if returns_info: obs, info = retval @@ -165,7 +167,7 @@ def _reset_env_with_ids( gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {} retval = self.env.reset(global_ids, **gym_reset_kwargs) returns_info = isinstance(retval, (tuple, list)) and len(retval) == 2 and ( - isinstance(retval[1], dict) or isinstance(retval[1][0], dict) + isinstance(retval[1], dict) or isinstance(retval[1][0], dict) # type: ignore ) if returns_info: obs_reset, info = retval @@ -206,6 +208,8 @@ def collect( Default to None (no rendering). :param bool no_grad: whether to retain gradient in policy.forward(). Default to True (no gradient retaining). + :param gym_reset_kwargs: extra keyword arguments to pass into the environment's + reset function. Defaults to None (extra keyword arguments) .. note:: @@ -424,6 +428,7 @@ def collect( random: bool = False, render: Optional[float] = None, no_grad: bool = True, + gym_reset_kwargs: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """Collect a specified number of step or episode with async env setting. @@ -439,6 +444,8 @@ def collect( Default to None (no rendering). :param bool no_grad: whether to retain gradient in policy.forward(). Default to True (no gradient retaining). + :param gym_reset_kwargs: extra keyword arguments to pass into the environment's + reset function. Defaults to None (extra keyword arguments) .. note:: @@ -572,7 +579,7 @@ def collect( episode_start_indices.append(ep_idx[env_ind_local]) # now we copy obs_next to obs, but since there might be # finished episodes, we have to reset finished envs first. - self._reset_env_with_ids(env_ind_local, env_ind_global) + self._reset_env_with_ids(env_ind_local, env_ind_global, gym_reset_kwargs) for i in env_ind_local: self._reset_state(i) From 6c60d533623ba568998e0996eb0d88eabb856f53 Mon Sep 17 00:00:00 2001 From: Yifei Cheng Date: Mon, 6 Jun 2022 02:16:49 -0400 Subject: [PATCH 21/27] fix lint check --- tianshou/data/collector.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 2f4401102..298065881 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -127,12 +127,12 @@ def reset_buffer(self, keep_statistics: bool = False) -> None: def reset_env(self, gym_reset_kwargs: Optional[Dict[str, Any]] = None) -> None: """Reset all of the environments.""" gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {} - retval = self.env.reset(**gym_reset_kwargs) - returns_info = isinstance(retval, (tuple, list)) and len(retval) == 2 and ( - isinstance(retval[1], dict) or isinstance(retval[1][0], dict) # type: ignore + rval = self.env.reset(**gym_reset_kwargs) + returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and ( + isinstance(rval[1], dict) or isinstance(rval[1][0], dict) # type: ignore ) if returns_info: - obs, info = retval + obs, info = rval if self.preprocess_fn: processed_data = self.preprocess_fn( obs=obs, info=info, env_id=np.arange(self.env_num) @@ -141,7 +141,7 @@ def reset_env(self, gym_reset_kwargs: Optional[Dict[str, Any]] = None) -> None: info = processed_data.get("info", info) self.data.info = info else: - obs = retval + obs = rval if self.preprocess_fn: obs = self.preprocess_fn(obs=obs, env_id=np.arange(self.env_num )).get("obs", obs) @@ -165,12 +165,12 @@ def _reset_env_with_ids( gym_reset_kwargs: Optional[Dict[str, Any]] = None, ) -> None: gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {} - retval = self.env.reset(global_ids, **gym_reset_kwargs) - returns_info = isinstance(retval, (tuple, list)) and len(retval) == 2 and ( - isinstance(retval[1], dict) or isinstance(retval[1][0], dict) # type: ignore + rval = self.env.reset(global_ids, **gym_reset_kwargs) + returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and ( + isinstance(rval[1], dict) or isinstance(rval[1][0], dict) # type: ignore ) if returns_info: - obs_reset, info = retval + obs_reset, info = rval if self.preprocess_fn: processed_data = self.preprocess_fn( obs=obs_reset, info=info, env_id=global_ids @@ -179,7 +179,7 @@ def _reset_env_with_ids( info = processed_data.get("info", info) self.data.info[local_ids] = info else: - obs_reset = retval + obs_reset = rval if self.preprocess_fn: obs_reset = self.preprocess_fn(obs=obs_reset, env_id=global_ids ).get("obs", obs_reset) @@ -579,7 +579,9 @@ def collect( episode_start_indices.append(ep_idx[env_ind_local]) # now we copy obs_next to obs, but since there might be # finished episodes, we have to reset finished envs first. - self._reset_env_with_ids(env_ind_local, env_ind_global, gym_reset_kwargs) + self._reset_env_with_ids( + env_ind_local, env_ind_global, gym_reset_kwargs + ) for i in env_ind_local: self._reset_state(i) From fe06182f66bf3aec35e8989b2546556589d505e8 Mon Sep 17 00:00:00 2001 From: Yifei Cheng Date: Mon, 6 Jun 2022 02:26:01 -0400 Subject: [PATCH 22/27] undo changes to test_sac_with_il --- test/base/test_collector.py | 10 +++++++--- test/continuous/test_sac_with_il.py | 22 +++++++++------------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 8b5bf3814..56c5b1155 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -107,11 +107,15 @@ def test_collector(gym_reset_kwargs): assert len(c0.buffer) == 3 assert np.allclose(c0.buffer.obs[:4, 0], [0, 1, 0, 0]) assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1]) - assert np.allclose(c0.buffer.info["key"][:3], 1) + keys = np.zeros(100) + keys[:3] = 1 + assert np.allclose(c0.buffer.info["key"], keys) for e in c0.buffer.info["env"][:3]: assert isinstance(e, MyTestEnv) - assert np.allclose(c0.buffer.info["env_id"][:3], 0) - assert np.allclose(c0.buffer.info["rew"][:3], [0, 1, 0]) + assert np.allclose(c0.buffer.info["env_id"], 0) + rews = np.zeros(100) + rews[:3] = [0, 1, 0] + assert np.allclose(c0.buffer.info["rew"], rews) c0.collect(n_episode=3, gym_reset_kwargs=gym_reset_kwargs) assert len(c0.buffer) == 8 assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0]) diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index a3349f8e5..a204e55e8 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -1,5 +1,11 @@ import argparse import os +import sys + +try: + import envpool +except ImportError: + envpool = None import numpy as np import pytest @@ -13,11 +19,6 @@ from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, ActorProb, Critic -try: - import envpool -except ImportError: - envpool = None - def get_args(): parser = argparse.ArgumentParser() @@ -56,15 +57,10 @@ def get_args(): return args -@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") -@pytest.mark.parametrize("gym_reset_return_info", [False, True]) -def test_sac_with_il(gym_reset_return_info, args=get_args()): - # if you want to use python vector env, please refer to other test scripts +@pytest.mark.skipif(sys.platform != "linux", reason="envpool only support linux now") +def test_sac_with_il(args=get_args()): train_envs = env = envpool.make_gym( - args.task, - num_envs=args.training_num, - seed=args.seed, - gym_reset_return_info=gym_reset_return_info + args.task, num_envs=args.training_num, seed=args.seed ) test_envs = envpool.make_gym(args.task, num_envs=args.test_num, seed=args.seed) args.state_shape = env.observation_space.shape or env.observation_space.n From 144e88abab344306f2d0c0709aa415fdee4663a4 Mon Sep 17 00:00:00 2001 From: Yifei Cheng Date: Mon, 6 Jun 2022 02:28:20 -0400 Subject: [PATCH 23/27] doc formatting --- tianshou/data/collector.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 298065881..ce96bcf0b 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -22,6 +22,7 @@ class Collector(object): """Collector enables the policy to interact with different types of envs with \ exact number of steps or episodes. + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param env: a ``gym.Env`` environment or an instance of the :class:`~tianshou.env.BaseVectorEnv` class. @@ -33,16 +34,21 @@ class Collector(object): with corresponding policy's exploration noise. If so, "policy. exploration_noise(act, batch)" will be called automatically to add the exploration noise into action. Default to False. + The "preprocess_fn" is a function called before the data has been added to the buffer with batch format. It will receive only "obs" and "env_id" when the collector resets the environment, and will receive six keys "obs_next", "rew", "done", "info", "policy" and "env_id" in a normal env step. It returns either a dict or a :class:`~tianshou.data.Batch` with the modified keys and values. Examples are in "test/base/test_collector.py". + .. note:: + Please make sure the given environment has a time limitation if using n_episode collect option. + .. note:: + In past versions of Tianshou, the replay buffer that was passed to `__init__` was automatically reset. This is not done in the current implementation. """ From 5d4b9a0574fc9e909b44a1ed11fdeee0eaf64f06 Mon Sep 17 00:00:00 2001 From: Yifei Cheng Date: Wed, 8 Jun 2022 08:15:17 -0400 Subject: [PATCH 24/27] undo changes to test_sac_with_il.py --- test/continuous/test_sac_with_il.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index a204e55e8..a3349f8e5 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -1,11 +1,5 @@ import argparse import os -import sys - -try: - import envpool -except ImportError: - envpool = None import numpy as np import pytest @@ -19,6 +13,11 @@ from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, ActorProb, Critic +try: + import envpool +except ImportError: + envpool = None + def get_args(): parser = argparse.ArgumentParser() @@ -57,10 +56,15 @@ def get_args(): return args -@pytest.mark.skipif(sys.platform != "linux", reason="envpool only support linux now") -def test_sac_with_il(args=get_args()): +@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") +@pytest.mark.parametrize("gym_reset_return_info", [False, True]) +def test_sac_with_il(gym_reset_return_info, args=get_args()): + # if you want to use python vector env, please refer to other test scripts train_envs = env = envpool.make_gym( - args.task, num_envs=args.training_num, seed=args.seed + args.task, + num_envs=args.training_num, + seed=args.seed, + gym_reset_return_info=gym_reset_return_info ) test_envs = envpool.make_gym(args.task, num_envs=args.test_num, seed=args.seed) args.state_shape = env.observation_space.shape or env.observation_space.n From f5eef9cd3d5b8a3aaadc81b9fa562a53b1ce092e Mon Sep 17 00:00:00 2001 From: Yifei Cheng Date: Wed, 8 Jun 2022 08:16:03 -0400 Subject: [PATCH 25/27] undo changes to test/continuous/test_sac_with_il.py --- test/continuous/test_sac_with_il.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index a3349f8e5..a204e55e8 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -1,5 +1,11 @@ import argparse import os +import sys + +try: + import envpool +except ImportError: + envpool = None import numpy as np import pytest @@ -13,11 +19,6 @@ from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, ActorProb, Critic -try: - import envpool -except ImportError: - envpool = None - def get_args(): parser = argparse.ArgumentParser() @@ -56,15 +57,10 @@ def get_args(): return args -@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") -@pytest.mark.parametrize("gym_reset_return_info", [False, True]) -def test_sac_with_il(gym_reset_return_info, args=get_args()): - # if you want to use python vector env, please refer to other test scripts +@pytest.mark.skipif(sys.platform != "linux", reason="envpool only support linux now") +def test_sac_with_il(args=get_args()): train_envs = env = envpool.make_gym( - args.task, - num_envs=args.training_num, - seed=args.seed, - gym_reset_return_info=gym_reset_return_info + args.task, num_envs=args.training_num, seed=args.seed ) test_envs = envpool.make_gym(args.task, num_envs=args.test_num, seed=args.seed) args.state_shape = env.observation_space.shape or env.observation_space.n From be7148aaf922660edf11953ea5398d9a8b7f6feb Mon Sep 17 00:00:00 2001 From: Yifei Cheng Date: Wed, 8 Jun 2022 08:17:33 -0400 Subject: [PATCH 26/27] test/continuous/test_sac_with_il.py --- test/continuous/test_sac_with_il.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index a204e55e8..b65e2d321 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -1,11 +1,5 @@ import argparse import os -import sys - -try: - import envpool -except ImportError: - envpool = None import numpy as np import pytest @@ -19,6 +13,11 @@ from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, ActorProb, Critic +try: + import envpool +except ImportError: + envpool = None + def get_args(): parser = argparse.ArgumentParser() @@ -57,8 +56,9 @@ def get_args(): return args -@pytest.mark.skipif(sys.platform != "linux", reason="envpool only support linux now") +@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") def test_sac_with_il(args=get_args()): + # if you want to use python vector env, please refer to other test scripts train_envs = env = envpool.make_gym( args.task, num_envs=args.training_num, seed=args.seed ) From edfcbcf3017d71132f235c0a5a4681811d7eeb6b Mon Sep 17 00:00:00 2001 From: Yifei Cheng Date: Tue, 21 Jun 2022 22:03:06 -0400 Subject: [PATCH 27/27] undo caching reset_return_info in subproc --- tianshou/env/worker/subproc.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py index 0cfbd4ce5..8c91b31c9 100644 --- a/tianshou/env/worker/subproc.py +++ b/tianshou/env/worker/subproc.py @@ -78,7 +78,6 @@ def _encode_obs( parent.close() env = env_fn_wrapper.data() - reset_returns_info = None try: while True: try: @@ -94,10 +93,9 @@ def _encode_obs( p.send((obs, reward, done, info)) elif cmd == "reset": retval = env.reset(**data) - if reset_returns_info is None: - reset_returns_info = isinstance( - retval, (tuple, list) - ) and len(retval) == 2 and isinstance(retval[1], dict) + reset_returns_info = isinstance( + retval, (tuple, list) + ) and len(retval) == 2 and isinstance(retval[1], dict) if reset_returns_info: obs, info = retval else: