diff --git a/examples/offline/README.md b/examples/offline/README.md index 1ac98ec91..04c42686e 100644 --- a/examples/offline/README.md +++ b/examples/offline/README.md @@ -10,25 +10,30 @@ We provide implementation of BCQ and CQL algorithm for continuous control. ### Train -Tianshou provides an `offline_trainer` for offline reinforcement learning. You can parse d4rl datasets into a `ReplayBuffer` , and set it as the parameter `buffer` of `offline_trainer`. `offline_bcq.py` is an example of offline RL using the d4rl dataset. +Tianshou provides an `offline_trainer` for offline reinforcement learning. You can parse d4rl datasets into a `ReplayBuffer` , and set it as the parameter `buffer` of `offline_trainer`. `d4rl_bcq.py` is an example of offline RL using the d4rl dataset. -To train an agent with BCQ algorithm: +## Results -```bash -python offline_bcq.py --task halfcheetah-expert-v1 -``` +### IL (Imitation Learning, aka, Behavior Cloning) -After 1M steps: +| Environment | Dataset | IL | Parameters | +| --------------------- | --------------------- | --------------- | -------------------------------------------------------- | +| HalfCheetah-v2 | halfcheetah-expert-v2 | 11355.31 | `python3 d4rl_il.py --task HalfCheetah-v2 --expert-data-task halfcheetah-expert-v2` | +| HalfCheetah-v2 | halfcheetah-medium-v2 | 5098.16 | `python3 d4rl_il.py --task HalfCheetah-v2 --expert-data-task halfcheetah-medium-v2` | -![halfcheetah-expert-v1_reward](results/bcq/halfcheetah-expert-v1_reward.png) +### BCQ -`halfcheetah-expert-v1` is a mujoco environment. The setting of hyperparameters are similar to the off-policy algorithms in mujoco environment. +| Environment | Dataset | BCQ | Parameters | +| --------------------- | --------------------- | --------------- | -------------------------------------------------------- | +| HalfCheetah-v2 | halfcheetah-expert-v2 | 11509.95 | `python3 d4rl_bcq.py --task HalfCheetah-v2 --expert-data-task halfcheetah-expert-v2` | +| HalfCheetah-v2 | halfcheetah-medium-v2 | 5147.43 | `python3 d4rl_bcq.py --task HalfCheetah-v2 --expert-data-task halfcheetah-medium-v2` | -## Results +### CQL -| Environment | BCQ | -| --------------------- | --------------- | -| halfcheetah-expert-v1 | 10624.0 ± 181.4 | +| Environment | Dataset | CQL | Parameters | +| --------------------- | --------------------- | --------------- | -------------------------------------------------------- | +| HalfCheetah-v2 | halfcheetah-expert-v2 | 2864.37 | `python3 d4rl_cql.py --task HalfCheetah-v2 --expert-data-task halfcheetah-expert-v2` | +| HalfCheetah-v2 | halfcheetah-medium-v2 | 6505.41 | `python3 d4rl_cql.py --task HalfCheetah-v2 --expert-data-task halfcheetah-medium-v2` | ## Discrete control @@ -42,14 +47,23 @@ To running CQL algorithm on Atari, you need to do the following things: - Generate buffer with noise: `python3 atari_qrdqn.py --task {your_task} --watch --resume-path log/{your_task}/qrdqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error); - Train offline model: `python3 atari_{bcq,cql,crr}.py --task {your_task} --load-buffer-name expert.hdf5`. +### IL + +We test our IL implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step): + +| Task | Online QRDQN | Behavioral | IL | parameters | +| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ | +| PongNoFrameskip-v4 | 20.5 | 6.8 | 20.0 (epoch 5) | `python3 atari_il.py --task PongNoFrameskip-v4 --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` | +| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 121.9 (epoch 12, could be higher) | `python3 atari_il.py --task BreakoutNoFrameskip-v4 --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12` | + ### BCQ We test our BCQ implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step): | Task | Online QRDQN | Behavioral | BCQ | parameters | | ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ | -| PongNoFrameskip-v4 | 20.5 | 6.8 | 20.1 (epoch 5) | `python3 atari_bcq.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` | -| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 64.6 (epoch 12, could be higher) | `python3 atari_bcq.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12` | +| PongNoFrameskip-v4 | 20.5 | 6.8 | 20.1 (epoch 5) | `python3 atari_bcq.py --task PongNoFrameskip-v4 --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` | +| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 64.6 (epoch 12, could be higher) | `python3 atari_bcq.py --task BreakoutNoFrameskip-v4 --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12` | ### CQL @@ -57,8 +71,8 @@ We test our CQL implementation on two example tasks (different from author's ver | Task | Online QRDQN | Behavioral | CQL | parameters | | ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ | -| PongNoFrameskip-v4 | 20.5 | 6.8 | 20.4 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` | -| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 129.4 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` | +| PongNoFrameskip-v4 | 20.5 | 6.8 | 20.4 (epoch 5) | `python3 atari_cql.py --task PongNoFrameskip-v4 --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` | +| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 129.4 (epoch 12) | `python3 atari_cql.py --task BreakoutNoFrameskip-v4 --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` | We reduce the size of the offline data to 10% and 1% of the above and get: @@ -66,15 +80,15 @@ Buffer size 100000: | Task | Online QRDQN | Behavioral | CQL | parameters | | ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ | -| PongNoFrameskip-v4 | 20.5 | 5.8 | 21 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 5` | -| BreakoutNoFrameskip-v4 | 394.3 | 41.4 | 40.8 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 12 --min-q-weight 20` | +| PongNoFrameskip-v4 | 20.5 | 5.8 | 21 (epoch 5) | `python3 atari_cql.py --task PongNoFrameskip-v4 --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 5` | +| BreakoutNoFrameskip-v4 | 394.3 | 41.4 | 40.8 (epoch 12) | `python3 atari_cql.py --task BreakoutNoFrameskip-v4 --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 12 --min-q-weight 20` | Buffer size 10000: | Task | Online QRDQN | Behavioral | CQL | parameters | | ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ | -| PongNoFrameskip-v4 | 20.5 | nan | 1.8 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 5 --min-q-weight 1` | -| BreakoutNoFrameskip-v4 | 394.3 | 31.7 | 22.5 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 12 --min-q-weight 10` | +| PongNoFrameskip-v4 | 20.5 | nan | 1.8 (epoch 5) | `python3 atari_cql.py --task PongNoFrameskip-v4 --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 5 --min-q-weight 1` | +| BreakoutNoFrameskip-v4 | 394.3 | 31.7 | 22.5 (epoch 12) | `python3 atari_cql.py --task BreakoutNoFrameskip-v4 --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 12 --min-q-weight 10` | ### CRR @@ -82,7 +96,7 @@ We test our CRR implementation on two example tasks (different from author's ver | Task | Online QRDQN | Behavioral | CRR | CRR w/ CQL | parameters | | ---------------------- | ---------- | ---------- | ---------------- | ----------------- | ------------------------------------------------------------ | -| PongNoFrameskip-v4 | 20.5 | 6.8 | -21 (epoch 5) | 17.7 (epoch 5) | `python3 atari_crr.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` | -| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 23.3 (epoch 12) | 76.9 (epoch 12) | `python3 atari_crr.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` | +| PongNoFrameskip-v4 | 20.5 | 6.8 | -21 (epoch 5) | 17.7 (epoch 5) | `python3 atari_crr.py --task PongNoFrameskip-v4 --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` | +| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 23.3 (epoch 12) | 76.9 (epoch 12) | `python3 atari_crr.py --task BreakoutNoFrameskip-v4 --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` | Note that CRR itself does not work well in Atari tasks but adding CQL loss/regularizer helps. diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index 83865d18d..4ccb52df1 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python3 + import argparse import datetime import os @@ -9,12 +11,11 @@ from torch.utils.tensorboard import SummaryWriter from examples.atari.atari_network import DQN -from examples.atari.atari_wrapper import wrap_deepmind +from examples.atari.atari_wrapper import make_atari_env from tianshou.data import Collector, VectorReplayBuffer -from tianshou.env import ShmemVectorEnv from tianshou.policy import DiscreteBCQPolicy from tianshou.trainer import offline_trainer -from tianshou.utils import TensorboardLogger +from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.discrete import Actor @@ -33,12 +34,21 @@ def get_args(): parser.add_argument("--epoch", type=int, default=100) parser.add_argument("--update-per-epoch", type=int, default=10000) parser.add_argument("--batch-size", type=int, default=32) - parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512]) + parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[512]) parser.add_argument("--test-num", type=int, default=10) - parser.add_argument('--frames-stack', type=int, default=4) + parser.add_argument("--frames-stack", type=int, default=4) + parser.add_argument("--scale-obs", type=int, default=0) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.) parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument( + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + ) + parser.add_argument("--wandb-project", type=str, default="offline_atari.benchmark") parser.add_argument( "--watch", default=False, @@ -56,35 +66,24 @@ def get_args(): return args -def make_atari_env(args): - return wrap_deepmind(args.task, frame_stack=args.frames_stack) - - -def make_atari_env_watch(args): - return wrap_deepmind( +def test_discrete_bcq(args=get_args()): + # envs + env, _, test_envs = make_atari_env( args.task, + args.seed, + 1, + args.test_num, + scale=args.scale_obs, frame_stack=args.frames_stack, - episode_life=False, - clip_rewards=False ) - - -def test_discrete_bcq(args=get_args()): - # envs - env = make_atari_env(args) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - # make environments - test_envs = ShmemVectorEnv( - [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] - ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) - test_envs.seed(args.seed) # model feature_net = DQN( *args.state_shape, args.action_shape, device=args.device, features_only=True @@ -118,9 +117,9 @@ def test_discrete_bcq(args=get_args()): # buffer assert os.path.exists(args.load_buffer_name), \ "Please run atari_dqn.py first to get expert's data buffer." - if args.load_buffer_name.endswith('.pkl'): + if args.load_buffer_name.endswith(".pkl"): buffer = pickle.load(open(args.load_buffer_name, "rb")) - elif args.load_buffer_name.endswith('.hdf5'): + elif args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) else: print(f"Unknown buffer format: {args.load_buffer_name}") @@ -130,16 +129,29 @@ def test_discrete_bcq(args=get_args()): test_collector = Collector(policy, test_envs, exploration_noise=True) # log - log_path = os.path.join( - args.logdir, args.task, 'bcq', - f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}' - ) + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + args.algo_name = "bcq" + log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) + log_path = os.path.join(args.logdir, log_name) + + # logger + if args.logger == "wandb": + logger = WandbLogger( + save_interval=1, + name=log_name.replace(os.path.sep, "__"), + run_id=args.resume_id, + config=args, + project=args.wandb_project, + ) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = TensorboardLogger(writer, update_interval=args.log_interval) + if args.logger == "tensorboard": + logger = TensorboardLogger(writer) + else: # wandb + logger.load(writer) def save_fn(policy): - torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): return False diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index 22ef7b253..2515002e4 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python3 + import argparse import datetime import os @@ -9,12 +11,11 @@ from torch.utils.tensorboard import SummaryWriter from examples.atari.atari_network import QRDQN -from examples.atari.atari_wrapper import wrap_deepmind +from examples.atari.atari_wrapper import make_atari_env from tianshou.data import Collector, VectorReplayBuffer -from tianshou.env import ShmemVectorEnv from tianshou.policy import DiscreteCQLPolicy from tianshou.trainer import offline_trainer -from tianshou.utils import TensorboardLogger +from tianshou.utils import TensorboardLogger, WandbLogger def get_args(): @@ -24,19 +25,28 @@ def get_args(): parser.add_argument("--eps-test", type=float, default=0.001) parser.add_argument("--lr", type=float, default=0.0001) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument('--num-quantiles', type=int, default=200) + parser.add_argument("--num-quantiles", type=int, default=200) parser.add_argument("--n-step", type=int, default=1) parser.add_argument("--target-update-freq", type=int, default=500) parser.add_argument("--min-q-weight", type=float, default=10.) parser.add_argument("--epoch", type=int, default=100) parser.add_argument("--update-per-epoch", type=int, default=10000) parser.add_argument("--batch-size", type=int, default=32) - parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512]) + parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[512]) parser.add_argument("--test-num", type=int, default=10) - parser.add_argument('--frames-stack', type=int, default=4) + parser.add_argument("--frames-stack", type=int, default=4) + parser.add_argument("--scale-obs", type=int, default=0) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.) parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument( + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + ) + parser.add_argument("--wandb-project", type=str, default="offline_atari.benchmark") parser.add_argument( "--watch", default=False, @@ -54,35 +64,24 @@ def get_args(): return args -def make_atari_env(args): - return wrap_deepmind(args.task, frame_stack=args.frames_stack) - - -def make_atari_env_watch(args): - return wrap_deepmind( +def test_discrete_cql(args=get_args()): + # envs + env, _, test_envs = make_atari_env( args.task, + args.seed, + 1, + args.test_num, + scale=args.scale_obs, frame_stack=args.frames_stack, - episode_life=False, - clip_rewards=False ) - - -def test_discrete_cql(args=get_args()): - # envs - env = make_atari_env(args) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - # make environments - test_envs = ShmemVectorEnv( - [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] - ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) - test_envs.seed(args.seed) # model net = QRDQN(*args.state_shape, args.action_shape, args.num_quantiles, args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) @@ -103,9 +102,9 @@ def test_discrete_cql(args=get_args()): # buffer assert os.path.exists(args.load_buffer_name), \ "Please run atari_qrdqn.py first to get expert's data buffer." - if args.load_buffer_name.endswith('.pkl'): + if args.load_buffer_name.endswith(".pkl"): buffer = pickle.load(open(args.load_buffer_name, "rb")) - elif args.load_buffer_name.endswith('.hdf5'): + elif args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) else: print(f"Unknown buffer format: {args.load_buffer_name}") @@ -115,16 +114,29 @@ def test_discrete_cql(args=get_args()): test_collector = Collector(policy, test_envs, exploration_noise=True) # log - log_path = os.path.join( - args.logdir, args.task, 'cql', - f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}' - ) + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + args.algo_name = "cql" + log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) + log_path = os.path.join(args.logdir, log_name) + + # logger + if args.logger == "wandb": + logger = WandbLogger( + save_interval=1, + name=log_name.replace(os.path.sep, "__"), + run_id=args.resume_id, + config=args, + project=args.wandb_project, + ) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = TensorboardLogger(writer, update_interval=args.log_interval) + if args.logger == "tensorboard": + logger = TensorboardLogger(writer) + else: # wandb + logger.load(writer) def save_fn(policy): - torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): return False diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index 0214bf2f5..a249928e3 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python3 + import argparse import datetime import os @@ -9,12 +11,11 @@ from torch.utils.tensorboard import SummaryWriter from examples.atari.atari_network import DQN -from examples.atari.atari_wrapper import wrap_deepmind +from examples.atari.atari_wrapper import make_atari_env from tianshou.data import Collector, VectorReplayBuffer -from tianshou.env import ShmemVectorEnv from tianshou.policy import DiscreteCRRPolicy from tianshou.trainer import offline_trainer -from tianshou.utils import TensorboardLogger +from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.discrete import Actor, Critic @@ -33,12 +34,21 @@ def get_args(): parser.add_argument("--epoch", type=int, default=100) parser.add_argument("--update-per-epoch", type=int, default=10000) parser.add_argument("--batch-size", type=int, default=32) - parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512]) + parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[512]) parser.add_argument("--test-num", type=int, default=10) - parser.add_argument('--frames-stack', type=int, default=4) + parser.add_argument("--frames-stack", type=int, default=4) + parser.add_argument("--scale-obs", type=int, default=0) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.) parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument( + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + ) + parser.add_argument("--wandb-project", type=str, default="offline_atari.benchmark") parser.add_argument( "--watch", default=False, @@ -56,35 +66,24 @@ def get_args(): return args -def make_atari_env(args): - return wrap_deepmind(args.task, frame_stack=args.frames_stack) - - -def make_atari_env_watch(args): - return wrap_deepmind( +def test_discrete_crr(args=get_args()): + # envs + env, _, test_envs = make_atari_env( args.task, + args.seed, + 1, + args.test_num, + scale=args.scale_obs, frame_stack=args.frames_stack, - episode_life=False, - clip_rewards=False ) - - -def test_discrete_crr(args=get_args()): - # envs - env = make_atari_env(args) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - # make environments - test_envs = ShmemVectorEnv( - [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] - ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) - test_envs.seed(args.seed) # model feature_net = DQN( *args.state_shape, args.action_shape, device=args.device, features_only=True @@ -123,9 +122,9 @@ def test_discrete_crr(args=get_args()): # buffer assert os.path.exists(args.load_buffer_name), \ "Please run atari_qrdqn.py first to get expert's data buffer." - if args.load_buffer_name.endswith('.pkl'): + if args.load_buffer_name.endswith(".pkl"): buffer = pickle.load(open(args.load_buffer_name, "rb")) - elif args.load_buffer_name.endswith('.hdf5'): + elif args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) else: print(f"Unknown buffer format: {args.load_buffer_name}") @@ -135,16 +134,29 @@ def test_discrete_crr(args=get_args()): test_collector = Collector(policy, test_envs, exploration_noise=True) # log - log_path = os.path.join( - args.logdir, args.task, 'crr', - f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}' - ) + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + args.algo_name = "crr" + log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) + log_path = os.path.join(args.logdir, log_name) + + # logger + if args.logger == "wandb": + logger = WandbLogger( + save_interval=1, + name=log_name.replace(os.path.sep, "__"), + run_id=args.resume_id, + config=args, + project=args.wandb_project, + ) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = TensorboardLogger(writer, update_interval=args.log_interval) + if args.logger == "tensorboard": + logger = TensorboardLogger(writer) + else: # wandb + logger.load(writer) def save_fn(policy): - torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): return False diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py new file mode 100644 index 000000000..b71303d7e --- /dev/null +++ b/examples/offline/atari_il.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 + +import argparse +import datetime +import os +import pickle +import pprint + +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +from examples.atari.atari_network import DQN +from examples.atari.atari_wrapper import make_atari_env +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.policy import ImitationPolicy +from tianshou.trainer import offline_trainer +from tianshou.utils import TensorboardLogger, WandbLogger + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") + parser.add_argument("--seed", type=int, default=1626) + parser.add_argument("--lr", type=float, default=0.0001) + parser.add_argument("--epoch", type=int, default=100) + parser.add_argument("--update-per-epoch", type=int, default=10000) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--frames-stack", type=int, default=4) + parser.add_argument("--scale-obs", type=int, default=0) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.) + parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument( + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + ) + parser.add_argument("--wandb-project", type=str, default="offline_atari.benchmark") + parser.add_argument( + "--watch", + default=False, + action="store_true", + help="watch the play of pre-trained policy only" + ) + parser.add_argument("--log-interval", type=int, default=100) + parser.add_argument( + "--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5" + ) + parser.add_argument( + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" + ) + args = parser.parse_known_args()[0] + return args + + +def test_il(args=get_args()): + # envs + env, _, test_envs = make_atari_env( + args.task, + args.seed, + 1, + args.test_num, + scale=args.scale_obs, + frame_stack=args.frames_stack, + ) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # should be N_FRAMES x H x W + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + # model + net = DQN(*args.state_shape, args.action_shape, device=args.device).to(args.device) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + # define policy + policy = ImitationPolicy(net, optim, action_space=env.action_space) + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + # buffer + assert os.path.exists(args.load_buffer_name), \ + "Please run atari_qrdqn.py first to get expert's data buffer." + if args.load_buffer_name.endswith('.pkl'): + buffer = pickle.load(open(args.load_buffer_name, "rb")) + elif args.load_buffer_name.endswith('.hdf5'): + buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) + else: + print(f"Unknown buffer format: {args.load_buffer_name}") + exit(0) + + # collector + test_collector = Collector(policy, test_envs, exploration_noise=True) + + # log + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + args.algo_name = "il" + log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) + log_path = os.path.join(args.logdir, log_name) + + # logger + if args.logger == "wandb": + logger = WandbLogger( + save_interval=1, + name=log_name.replace(os.path.sep, "__"), + run_id=args.resume_id, + config=args, + project=args.wandb_project, + ) + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + if args.logger == "tensorboard": + logger = TensorboardLogger(writer) + else: # wandb + logger.load(writer) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) + + def stop_fn(mean_rewards): + return False + + # watch agent's performance + def watch(): + print("Setup test envs ...") + policy.eval() + test_envs.seed(args.seed) + print("Testing agent ...") + test_collector.reset() + result = test_collector.collect(n_episode=args.test_num, render=args.render) + pprint.pprint(result) + rew = result["rews"].mean() + print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') + + if args.watch: + watch() + exit(0) + + result = offline_trainer( + policy, + buffer, + test_collector, + args.epoch, + args.update_per_epoch, + args.test_num, + args.batch_size, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + ) + + pprint.pprint(result) + watch() + + +if __name__ == "__main__": + test_il(get_args()) diff --git a/examples/offline/offline_bcq.py b/examples/offline/d4rl_bcq.py similarity index 64% rename from examples/offline/offline_bcq.py rename to examples/offline/d4rl_bcq.py index e488489e2..38da5c104 100644 --- a/examples/offline/offline_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 + import argparse import datetime import os @@ -10,36 +11,38 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Batch, Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.data import Batch, Collector, ReplayBuffer from tianshou.env import SubprocVectorEnv from tianshou.policy import BCQPolicy from tianshou.trainer import offline_trainer -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.continuous import VAE, Critic, Perturbation def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='halfcheetah-expert-v1') - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--buffer-size', type=int, default=1000000) - parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[400, 300]) - parser.add_argument('--actor-lr', type=float, default=1e-3) - parser.add_argument('--critic-lr', type=float, default=1e-3) + parser.add_argument("--task", type=str, default="HalfCheetah-v2") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--expert-data-task", type=str, default="halfcheetah-expert-v2" + ) + parser.add_argument("--buffer-size", type=int, default=1000000) + parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256]) + parser.add_argument("--actor-lr", type=float, default=1e-3) + parser.add_argument("--critic-lr", type=float, default=1e-3) parser.add_argument("--start-timesteps", type=int, default=10000) - parser.add_argument('--epoch', type=int, default=200) - parser.add_argument('--step-per-epoch', type=int, default=5000) - parser.add_argument('--n-step', type=int, default=3) - parser.add_argument('--batch-size', type=int, default=256) - parser.add_argument('--training-num', type=int, default=10) - parser.add_argument('--test-num', type=int, default=10) - parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--render', type=float, default=1 / 35) - - parser.add_argument("--vae-hidden-sizes", type=int, nargs='*', default=[750, 750]) + parser.add_argument("--epoch", type=int, default=200) + parser.add_argument("--step-per-epoch", type=int, default=5000) + parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--batch-size", type=int, default=256) + parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=1 / 35) + + parser.add_argument("--vae-hidden-sizes", type=int, nargs="*", default=[512, 512]) # default to 2 * action_dim - parser.add_argument('--latent-dim', type=int) + parser.add_argument("--latent-dim", type=int) parser.add_argument("--gamma", default=0.99) parser.add_argument("--tau", default=0.005) # Weighting for Clipped Double Q-learning in BCQ @@ -47,14 +50,22 @@ def get_args(): # Max perturbation hyper-parameter for BCQ parser.add_argument("--phi", default=0.05) parser.add_argument( - '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" + ) + parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument( + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], ) - parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument("--wandb-project", type=str, default="offline_d4rl.benchmark") parser.add_argument( - '--watch', + "--watch", default=False, - action='store_true', - help='watch the play of pre-trained policy only', + action="store_true", + help="watch the play of pre-trained policy only", ) return parser.parse_args() @@ -74,10 +85,6 @@ def test_bcq(): args.action_dim = args.action_shape[0] print("Max_action", args.max_action) - # train_envs = gym.make(args.task) - train_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)] - ) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv( [lambda: gym.make(args.task) for _ in range(args.test_num)] @@ -85,7 +92,6 @@ def test_bcq(): # seed np.random.seed(args.seed) torch.manual_seed(args.seed) - train_envs.seed(args.seed) test_envs.seed(args.seed) # model @@ -166,38 +172,47 @@ def test_bcq(): print("Loaded agent from: ", args.resume_path) # collector - if args.training_num > 1: - buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) - else: - buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs) - train_collector.collect(n_step=args.start_timesteps, random=True) + # log - t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") - log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_bcq' - log_path = os.path.join(args.logdir, args.task, 'bcq', log_file) + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + args.algo_name = "bcq" + log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) + log_path = os.path.join(args.logdir, log_name) + + # logger + if args.logger == "wandb": + logger = WandbLogger( + save_interval=1, + name=log_name.replace(os.path.sep, "__"), + run_id=args.resume_id, + config=args, + project=args.wandb_project, + ) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer) + if args.logger == "tensorboard": + logger = TensorboardLogger(writer) + else: # wandb + logger.load(writer) def save_fn(policy): - torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def watch(): if args.resume_path is None: - args.resume_path = os.path.join(log_path, 'policy.pth') + args.resume_path = os.path.join(log_path, "policy.pth") policy.load_state_dict( - torch.load(args.resume_path, map_location=torch.device('cpu')) + torch.load(args.resume_path, map_location=torch.device("cpu")) ) policy.eval() collector = Collector(policy, env) collector.collect(n_episode=1, render=1 / 35) if not args.watch: - dataset = d4rl.qlearning_dataset(env) - dataset_size = dataset['rewards'].size + dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task)) + dataset_size = dataset["rewards"].size print("dataset_size", dataset_size) replay_buffer = ReplayBuffer(dataset_size) @@ -205,11 +220,11 @@ def watch(): for i in range(dataset_size): replay_buffer.add( Batch( - obs=dataset['observations'][i], - act=dataset['actions'][i], - rew=dataset['rewards'][i], - done=dataset['terminals'][i], - obs_next=dataset['next_observations'][i], + obs=dataset["observations"][i], + act=dataset["actions"][i], + rew=dataset["rewards"][i], + done=dataset["terminals"][i], + obs_next=dataset["next_observations"][i], ) ) print("dataset loaded") @@ -234,8 +249,8 @@ def watch(): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f"Final reward: {result['rews'].mean()}, length: {result['lens'].mean()}") -if __name__ == '__main__': +if __name__ == "__main__": test_bcq() diff --git a/examples/offline/offline_cql.py b/examples/offline/d4rl_cql.py similarity index 63% rename from examples/offline/offline_cql.py rename to examples/offline/d4rl_cql.py index f494200f2..952737aab 100644 --- a/examples/offline/offline_cql.py +++ b/examples/offline/d4rl_cql.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 + import argparse import datetime import os @@ -10,32 +11,35 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Batch, Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.data import Batch, Collector, ReplayBuffer from tianshou.env import SubprocVectorEnv from tianshou.policy import CQLPolicy from tianshou.trainer import offline_trainer -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='halfcheetah-medium-v1') - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--buffer-size', type=int, default=1000000) - parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[256, 256]) - parser.add_argument('--actor-lr', type=float, default=1e-4) - parser.add_argument('--critic-lr', type=float, default=3e-4) - parser.add_argument('--alpha', type=float, default=0.2) - parser.add_argument('--auto-alpha', default=True, action='store_true') - parser.add_argument('--alpha-lr', type=float, default=1e-4) - parser.add_argument('--cql-alpha-lr', type=float, default=3e-4) + parser.add_argument("--task", type=str, default="HalfCheetah-v2") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--expert-data-task", type=str, default="halfcheetah-expert-v2" + ) + parser.add_argument("--buffer-size", type=int, default=1000000) + parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256]) + parser.add_argument("--actor-lr", type=float, default=1e-4) + parser.add_argument("--critic-lr", type=float, default=3e-4) + parser.add_argument("--alpha", type=float, default=0.2) + parser.add_argument("--auto-alpha", default=True, action="store_true") + parser.add_argument("--alpha-lr", type=float, default=1e-4) + parser.add_argument("--cql-alpha-lr", type=float, default=3e-4) parser.add_argument("--start-timesteps", type=int, default=10000) - parser.add_argument('--epoch', type=int, default=200) - parser.add_argument('--step-per-epoch', type=int, default=5000) - parser.add_argument('--n-step', type=int, default=3) - parser.add_argument('--batch-size', type=int, default=256) + parser.add_argument("--epoch", type=int, default=200) + parser.add_argument("--step-per-epoch", type=int, default=5000) + parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--batch-size", type=int, default=256) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--temperature", type=float, default=1.0) @@ -45,19 +49,26 @@ def get_args(): parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--eval-freq", type=int, default=1) - parser.add_argument('--training-num', type=int, default=10) - parser.add_argument('--test-num', type=int, default=10) - parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--render', type=float, default=1 / 35) + parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=1 / 35) + parser.add_argument( + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" + ) + parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume-id", type=str, default=None) parser.add_argument( - '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], ) - parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument("--wandb-project", type=str, default="offline_d4rl.benchmark") parser.add_argument( - '--watch', + "--watch", default=False, - action='store_true', - help='watch the play of pre-trained policy only', + action="store_true", + help="watch the play of pre-trained policy only", ) return parser.parse_args() @@ -77,10 +88,6 @@ def test_cql(): args.action_dim = args.action_shape[0] print("Max_action", args.max_action) - # train_envs = gym.make(args.task) - train_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)] - ) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv( [lambda: gym.make(args.task) for _ in range(args.test_num)] @@ -88,7 +95,6 @@ def test_cql(): # seed np.random.seed(args.seed) torch.manual_seed(args.seed) - train_envs.seed(args.seed) test_envs.seed(args.seed) # model @@ -161,38 +167,47 @@ def test_cql(): print("Loaded agent from: ", args.resume_path) # collector - if args.training_num > 1: - buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) - else: - buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs) - train_collector.collect(n_step=args.start_timesteps, random=True) + # log - t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") - log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_cql' - log_path = os.path.join(args.logdir, args.task, 'cql', log_file) + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + args.algo_name = "cql" + log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) + log_path = os.path.join(args.logdir, log_name) + + # logger + if args.logger == "wandb": + logger = WandbLogger( + save_interval=1, + name=log_name.replace(os.path.sep, "__"), + run_id=args.resume_id, + config=args, + project=args.wandb_project, + ) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer) + if args.logger == "tensorboard": + logger = TensorboardLogger(writer) + else: # wandb + logger.load(writer) def save_fn(policy): - torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def watch(): if args.resume_path is None: - args.resume_path = os.path.join(log_path, 'policy.pth') + args.resume_path = os.path.join(log_path, "policy.pth") policy.load_state_dict( - torch.load(args.resume_path, map_location=torch.device('cpu')) + torch.load(args.resume_path, map_location=torch.device("cpu")) ) policy.eval() collector = Collector(policy, env) collector.collect(n_episode=1, render=1 / 35) if not args.watch: - dataset = d4rl.qlearning_dataset(env) - dataset_size = dataset['rewards'].size + dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task)) + dataset_size = dataset["rewards"].size print("dataset_size", dataset_size) replay_buffer = ReplayBuffer(dataset_size) @@ -200,11 +215,11 @@ def watch(): for i in range(dataset_size): replay_buffer.add( Batch( - obs=dataset['observations'][i], - act=dataset['actions'][i], - rew=dataset['rewards'][i], - done=dataset['terminals'][i], - obs_next=dataset['next_observations'][i], + obs=dataset["observations"][i], + act=dataset["actions"][i], + rew=dataset["rewards"][i], + done=dataset["terminals"][i], + obs_next=dataset["next_observations"][i], ) ) print("dataset loaded") @@ -229,8 +244,8 @@ def watch(): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f"Final reward: {result['rews'].mean()}, length: {result['lens'].mean()}") -if __name__ == '__main__': +if __name__ == "__main__": test_cql() diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py new file mode 100644 index 000000000..208e71ef7 --- /dev/null +++ b/examples/offline/d4rl_il.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 + +import argparse +import datetime +import os +import pprint + +import d4rl +import gym +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Batch, Collector, ReplayBuffer +from tianshou.env import SubprocVectorEnv +from tianshou.policy import ImitationPolicy +from tianshou.trainer import offline_trainer +from tianshou.utils import TensorboardLogger, WandbLogger +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import Actor + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--task", type=str, default="HalfCheetah-v2") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--expert-data-task", type=str, default="halfcheetah-expert-v2" + ) + parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256]) + parser.add_argument("--lr", type=float, default=1e-4) + parser.add_argument("--epoch", type=int, default=200) + parser.add_argument("--step-per-epoch", type=int, default=5000) + parser.add_argument("--batch-size", type=int, default=256) + parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=1 / 35) + parser.add_argument("--gamma", default=0.99) + parser.add_argument( + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" + ) + parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument( + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + ) + parser.add_argument("--wandb-project", type=str, default="offline_d4rl.benchmark") + parser.add_argument( + "--watch", + default=False, + action="store_true", + help="watch the play of pre-trained policy only", + ) + return parser.parse_args() + + +def test_il(): + args = get_args() + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + args.max_action = env.action_space.high[0] # float + print("device:", args.device) + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) + + args.state_dim = args.state_shape[0] + args.action_dim = args.action_shape[0] + print("Max_action", args.max_action) + + test_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + test_envs.seed(args.seed) + + # model + net = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + ) + actor = Actor( + net, + action_shape=args.action_shape, + max_action=args.max_action, + device=args.device + ).to(args.device) + optim = torch.optim.Adam(actor.parameters(), lr=args.lr) + + policy = ImitationPolicy( + actor, + optim, + action_space=env.action_space, + action_scaling=True, + action_bound_method="clip" + ) + + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + + # collector + test_collector = Collector(policy, test_envs) + + # log + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + args.algo_name = "cql" + log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) + log_path = os.path.join(args.logdir, log_name) + + # logger + if args.logger == "wandb": + logger = WandbLogger( + save_interval=1, + name=log_name.replace(os.path.sep, "__"), + run_id=args.resume_id, + config=args, + project=args.wandb_project, + ) + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + if args.logger == "tensorboard": + logger = TensorboardLogger(writer) + else: # wandb + logger.load(writer) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) + + def watch(): + if args.resume_path is None: + args.resume_path = os.path.join(log_path, "policy.pth") + + policy.load_state_dict( + torch.load(args.resume_path, map_location=torch.device("cpu")) + ) + policy.eval() + collector = Collector(policy, env) + collector.collect(n_episode=1, render=1 / 35) + + if not args.watch: + dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task)) + dataset_size = dataset["rewards"].size + + print("dataset_size", dataset_size) + replay_buffer = ReplayBuffer(dataset_size) + + for i in range(dataset_size): + replay_buffer.add( + Batch( + obs=dataset["observations"][i], + act=dataset["actions"][i], + rew=dataset["rewards"][i], + done=dataset["terminals"][i], + obs_next=dataset["next_observations"][i], + ) + ) + print("dataset loaded") + # trainer + result = offline_trainer( + policy, + replay_buffer, + test_collector, + args.epoch, + args.step_per_epoch, + args.test_num, + args.batch_size, + save_fn=save_fn, + logger=logger, + ) + pprint.pprint(result) + else: + watch() + + # Let's watch its performance! + policy.eval() + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=args.test_num, render=args.render) + print(f"Final reward: {result['rews'].mean()}, length: {result['lens'].mean()}") + + +if __name__ == "__main__": + test_il() diff --git a/examples/offline/results/bcq/halfcheetah-expert-v1_reward.png b/examples/offline/results/bcq/halfcheetah-expert-v1_reward.png deleted file mode 100644 index 5afa6a3ad..000000000 Binary files a/examples/offline/results/bcq/halfcheetah-expert-v1_reward.png and /dev/null differ diff --git a/examples/offline/results/bcq/halfcheetah-expert-v1_reward.svg b/examples/offline/results/bcq/halfcheetah-expert-v1_reward.svg deleted file mode 100644 index 87ede75ed..000000000 --- a/examples/offline/results/bcq/halfcheetah-expert-v1_reward.svg +++ /dev/null @@ -1 +0,0 @@ -1e+32e+33e+34e+35e+36e+37e+38e+39e+31e+40100k200k300k400k500k600k700k800k900k1M1.1M \ No newline at end of file diff --git a/tianshou/env/worker/ray.py b/tianshou/env/worker/ray.py index 1bc7991aa..e094692d6 100644 --- a/tianshou/env/worker/ray.py +++ b/tianshou/env/worker/ray.py @@ -24,7 +24,9 @@ class RayEnvWorker(EnvWorker): """Ray worker used in RayVectorEnv.""" def __init__(self, env_fn: Callable[[], gym.Env]) -> None: - self.env = ray.remote(_SetAttrWrapper).options(num_cpus=0).remote(env_fn()) + self.env = ray.remote(_SetAttrWrapper).options( # type: ignore + num_cpus=0 + ).remote(env_fn()) super().__init__(env_fn) def get_env_attr(self, key: str) -> Any: @@ -54,7 +56,7 @@ def send(self, action: Optional[np.ndarray]) -> None: def recv( self ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]: - return ray.get(self.result) + return ray.get(self.result) # type: ignore def seed(self, seed: Optional[int] = None) -> List[int]: super().seed(seed)