From 2622f608f471f4c4cac7454a78beb359d4c37f6f Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Wed, 8 Sep 2021 11:05:03 -0400 Subject: [PATCH 1/3] fix logger.write error in atari script --- .gitignore | 1 + examples/atari/atari_c51.py | 3 ++- examples/atari/atari_dqn.py | 3 ++- examples/atari/atari_fqf.py | 3 ++- examples/atari/atari_iqn.py | 3 ++- examples/atari/atari_qrdqn.py | 3 ++- examples/atari/atari_rainbow.py | 6 ++++-- 7 files changed, 15 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index e9510a1df..fd72be398 100644 --- a/.gitignore +++ b/.gitignore @@ -148,3 +148,4 @@ MUJOCO_LOG.TXT *.pkl *.hdf5 wandb/ +videos/ diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 291fb7007..5a15b1e4b 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -141,7 +141,8 @@ def train_fn(epoch, env_step): else: eps = args.eps_train_final policy.set_eps(eps) - logger.write('train/eps', env_step, eps) + if env_step % 1000 == 0: + logger.write("train/env_step", env_step, {"train/eps": eps}) def test_fn(epoch, env_step): policy.set_eps(args.eps_test) diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index c9f74af8c..f62b23951 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -135,7 +135,8 @@ def train_fn(epoch, env_step): else: eps = args.eps_train_final policy.set_eps(eps) - logger.write('train/eps', env_step, eps) + if env_step % 1000 == 0: + logger.write("train/env_step", env_step, {"train/eps": eps}) def test_fn(epoch, env_step): policy.set_eps(args.eps_test) diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 4629bede2..aa94fd4a8 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -158,7 +158,8 @@ def train_fn(epoch, env_step): else: eps = args.eps_train_final policy.set_eps(eps) - logger.write('train/eps', env_step, eps) + if env_step % 1000 == 0: + logger.write("train/env_step", env_step, {"train/eps": eps}) def test_fn(epoch, env_step): policy.set_eps(args.eps_test) diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index d0e7773d0..86b943f55 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -153,7 +153,8 @@ def train_fn(epoch, env_step): else: eps = args.eps_train_final policy.set_eps(eps) - logger.write('train/eps', env_step, eps) + if env_step % 1000 == 0: + logger.write("train/env_step", env_step, {"train/eps": eps}) def test_fn(epoch, env_step): policy.set_eps(args.eps_test) diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index 23a7966eb..08857aba4 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -137,7 +137,8 @@ def train_fn(epoch, env_step): else: eps = args.eps_train_final policy.set_eps(eps) - logger.write('train/eps', env_step, eps) + if env_step % 1000 == 0: + logger.write("train/env_step", env_step, {"train/eps": eps}) def test_fn(epoch, env_step): policy.set_eps(args.eps_test) diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index b131cce5f..8719017b0 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -174,7 +174,8 @@ def train_fn(epoch, env_step): else: eps = args.eps_train_final policy.set_eps(eps) - logger.write('train/eps', env_step, eps) + if env_step % 1000 == 0: + logger.write("train/env_step", env_step, {"train/eps": eps}) if not args.no_priority: if env_step <= args.beta_anneal_step: beta = args.beta - env_step / args.beta_anneal_step * \ @@ -182,7 +183,8 @@ def train_fn(epoch, env_step): else: beta = args.beta_final buffer.set_beta(beta) - logger.write('train/beta', env_step, beta) + if env_step % 1000 == 0: + logger.write("train/env_step", env_step, {"train/beta": beta}) def test_fn(epoch, env_step): policy.set_eps(args.eps_test) From d378dd1652a4a8623cc9cbb00fb7367f05460a63 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Wed, 8 Sep 2021 11:14:49 -0400 Subject: [PATCH 2/3] update --- examples/vizdoom/vizdoom_c51.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index bb3a1f207..300689aed 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -144,7 +144,8 @@ def train_fn(epoch, env_step): else: eps = args.eps_train_final policy.set_eps(eps) - logger.write('train/eps', env_step, eps) + if env_step % 1000 == 0: + logger.write("train/env_step", env_step, {"train/eps": eps}) def test_fn(epoch, env_step): policy.set_eps(args.eps_test) From 3758a1f13e35b90e26b1a49a2660e1ab8c535aaa Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Wed, 8 Sep 2021 11:50:46 -0400 Subject: [PATCH 3/3] SubprocVectorEnv -> ShmemVectorEnv in atari --- examples/atari/atari_bcq.py | 4 ++-- examples/atari/atari_c51.py | 6 +++--- examples/atari/atari_cql.py | 4 ++-- examples/atari/atari_crr.py | 4 ++-- examples/atari/atari_dqn.py | 6 +++--- examples/atari/atari_fqf.py | 6 +++--- examples/atari/atari_iqn.py | 6 +++--- examples/atari/atari_qrdqn.py | 6 +++--- examples/atari/atari_rainbow.py | 6 +++--- examples/vizdoom/vizdoom_c51.py | 6 +++--- 10 files changed, 27 insertions(+), 27 deletions(-) diff --git a/examples/atari/atari_bcq.py b/examples/atari/atari_bcq.py index 1be441013..ec89243b4 100644 --- a/examples/atari/atari_bcq.py +++ b/examples/atari/atari_bcq.py @@ -11,7 +11,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer -from tianshou.env import SubprocVectorEnv +from tianshou.env import ShmemVectorEnv from tianshou.policy import DiscreteBCQPolicy from tianshou.trainer import offline_trainer from tianshou.utils import TensorboardLogger @@ -77,7 +77,7 @@ def test_discrete_bcq(args=get_args()): print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments - test_envs = SubprocVectorEnv( + test_envs = ShmemVectorEnv( [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] ) # seed diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 5a15b1e4b..dcd1911dc 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -9,7 +9,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer -from tianshou.env import SubprocVectorEnv +from tianshou.env import ShmemVectorEnv from tianshou.policy import C51Policy from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger @@ -75,10 +75,10 @@ def test_c51(args=get_args()): print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments - train_envs = SubprocVectorEnv( + train_envs = ShmemVectorEnv( [lambda: make_atari_env(args) for _ in range(args.training_num)] ) - test_envs = SubprocVectorEnv( + test_envs = ShmemVectorEnv( [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] ) # seed diff --git a/examples/atari/atari_cql.py b/examples/atari/atari_cql.py index db4e33a9a..685e006db 100644 --- a/examples/atari/atari_cql.py +++ b/examples/atari/atari_cql.py @@ -11,7 +11,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer -from tianshou.env import SubprocVectorEnv +from tianshou.env import ShmemVectorEnv from tianshou.policy import DiscreteCQLPolicy from tianshou.trainer import offline_trainer from tianshou.utils import TensorboardLogger @@ -76,7 +76,7 @@ def test_discrete_cql(args=get_args()): print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments - test_envs = SubprocVectorEnv( + test_envs = ShmemVectorEnv( [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] ) # seed diff --git a/examples/atari/atari_crr.py b/examples/atari/atari_crr.py index 06cde415b..8905c7e58 100644 --- a/examples/atari/atari_crr.py +++ b/examples/atari/atari_crr.py @@ -11,7 +11,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer -from tianshou.env import SubprocVectorEnv +from tianshou.env import ShmemVectorEnv from tianshou.policy import DiscreteCRRPolicy from tianshou.trainer import offline_trainer from tianshou.utils import TensorboardLogger @@ -77,7 +77,7 @@ def test_discrete_crr(args=get_args()): print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments - test_envs = SubprocVectorEnv( + test_envs = ShmemVectorEnv( [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] ) # seed diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index f62b23951..67a44d002 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -9,7 +9,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer -from tianshou.env import SubprocVectorEnv +from tianshou.env import ShmemVectorEnv from tianshou.policy import DQNPolicy from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger @@ -72,10 +72,10 @@ def test_dqn(args=get_args()): print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments - train_envs = SubprocVectorEnv( + train_envs = ShmemVectorEnv( [lambda: make_atari_env(args) for _ in range(args.training_num)] ) - test_envs = SubprocVectorEnv( + test_envs = ShmemVectorEnv( [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] ) # seed diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index aa94fd4a8..99f8957c4 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -9,7 +9,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer -from tianshou.env import SubprocVectorEnv +from tianshou.env import ShmemVectorEnv from tianshou.policy import FQFPolicy from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger @@ -78,10 +78,10 @@ def test_fqf(args=get_args()): print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments - train_envs = SubprocVectorEnv( + train_envs = ShmemVectorEnv( [lambda: make_atari_env(args) for _ in range(args.training_num)] ) - test_envs = SubprocVectorEnv( + test_envs = ShmemVectorEnv( [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] ) # seed diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index 86b943f55..532d59482 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -9,7 +9,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer -from tianshou.env import SubprocVectorEnv +from tianshou.env import ShmemVectorEnv from tianshou.policy import IQNPolicy from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger @@ -78,10 +78,10 @@ def test_iqn(args=get_args()): print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments - train_envs = SubprocVectorEnv( + train_envs = ShmemVectorEnv( [lambda: make_atari_env(args) for _ in range(args.training_num)] ) - test_envs = SubprocVectorEnv( + test_envs = ShmemVectorEnv( [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] ) # seed diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index 08857aba4..af5d78e3f 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -9,7 +9,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer -from tianshou.env import SubprocVectorEnv +from tianshou.env import ShmemVectorEnv from tianshou.policy import QRDQNPolicy from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger @@ -73,10 +73,10 @@ def test_qrdqn(args=get_args()): print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments - train_envs = SubprocVectorEnv( + train_envs = ShmemVectorEnv( [lambda: make_atari_env(args) for _ in range(args.training_num)] ) - test_envs = SubprocVectorEnv( + test_envs = ShmemVectorEnv( [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] ) # seed diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 8719017b0..4e1a78ced 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -10,7 +10,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer -from tianshou.env import SubprocVectorEnv +from tianshou.env import ShmemVectorEnv from tianshou.policy import RainbowPolicy from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger @@ -85,10 +85,10 @@ def test_rainbow(args=get_args()): print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments - train_envs = SubprocVectorEnv( + train_envs = ShmemVectorEnv( [lambda: make_atari_env(args) for _ in range(args.training_num)] ) - test_envs = SubprocVectorEnv( + test_envs = ShmemVectorEnv( [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] ) # seed diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 300689aed..53eafae20 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -9,7 +9,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer -from tianshou.env import SubprocVectorEnv +from tianshou.env import ShmemVectorEnv from tianshou.policy import C51Policy from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger @@ -72,13 +72,13 @@ def test_c51(args=get_args()): print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments - train_envs = SubprocVectorEnv( + train_envs = ShmemVectorEnv( [ lambda: Env(args.cfg_path, args.frames_stack, args.res) for _ in range(args.training_num) ] ) - test_envs = SubprocVectorEnv( + test_envs = ShmemVectorEnv( [ lambda: Env(args.cfg_path, args.frames_stack, args.res, args.save_lmp) for _ in range(min(os.cpu_count() - 1, args.test_num))