-
Notifications
You must be signed in to change notification settings - Fork 1.2k
implement TD3+BC for offline RL #660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,227 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import argparse | ||
import datetime | ||
import os | ||
import pprint | ||
|
||
import gym | ||
import numpy as np | ||
import torch | ||
from torch.utils.tensorboard import SummaryWriter | ||
|
||
from examples.offline.utils import load_buffer_d4rl, normalize_all_obs_in_replay_buffer | ||
from tianshou.data import Collector | ||
from tianshou.env import SubprocVectorEnv, VectorEnvNormObs | ||
from tianshou.exploration import GaussianNoise | ||
from tianshou.policy import TD3BCPolicy | ||
from tianshou.trainer import offline_trainer | ||
from tianshou.utils import TensorboardLogger, WandbLogger | ||
from tianshou.utils.net.common import Net | ||
from tianshou.utils.net.continuous import Actor, Critic | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--task", type=str, default="HalfCheetah-v2") | ||
parser.add_argument("--seed", type=int, default=0) | ||
parser.add_argument( | ||
"--expert-data-task", type=str, default="halfcheetah-expert-v2" | ||
) | ||
parser.add_argument("--buffer-size", type=int, default=1000000) | ||
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256]) | ||
parser.add_argument("--actor-lr", type=float, default=3e-4) | ||
parser.add_argument("--critic-lr", type=float, default=3e-4) | ||
parser.add_argument("--epoch", type=int, default=200) | ||
parser.add_argument("--step-per-epoch", type=int, default=5000) | ||
parser.add_argument("--n-step", type=int, default=3) | ||
parser.add_argument("--batch-size", type=int, default=256) | ||
|
||
parser.add_argument("--alpha", type=float, default=2.5) | ||
parser.add_argument("--exploration-noise", type=float, default=0.1) | ||
parser.add_argument("--policy-noise", type=float, default=0.2) | ||
parser.add_argument("--noise-clip", type=float, default=0.5) | ||
parser.add_argument("--update-actor-freq", type=int, default=2) | ||
parser.add_argument("--tau", type=float, default=0.005) | ||
parser.add_argument("--gamma", type=float, default=0.99) | ||
parser.add_argument("--norm-obs", type=int, default=1) | ||
|
||
parser.add_argument("--eval-freq", type=int, default=1) | ||
parser.add_argument("--test-num", type=int, default=10) | ||
parser.add_argument("--logdir", type=str, default="log") | ||
parser.add_argument("--render", type=float, default=1 / 35) | ||
parser.add_argument( | ||
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" | ||
) | ||
parser.add_argument("--resume-path", type=str, default=None) | ||
parser.add_argument("--resume-id", type=str, default=None) | ||
parser.add_argument( | ||
"--logger", | ||
type=str, | ||
default="tensorboard", | ||
choices=["tensorboard", "wandb"], | ||
) | ||
parser.add_argument("--wandb-project", type=str, default="offline_d4rl.benchmark") | ||
parser.add_argument( | ||
"--watch", | ||
default=False, | ||
action="store_true", | ||
help="watch the play of pre-trained policy only", | ||
) | ||
return parser.parse_args() | ||
|
||
|
||
def test_td3_bc(): | ||
args = get_args() | ||
env = gym.make(args.task) | ||
args.state_shape = env.observation_space.shape or env.observation_space.n | ||
args.action_shape = env.action_space.shape or env.action_space.n | ||
args.max_action = env.action_space.high[0] # float | ||
print("device:", args.device) | ||
print("Observations shape:", args.state_shape) | ||
print("Actions shape:", args.action_shape) | ||
print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) | ||
|
||
args.state_dim = args.state_shape[0] | ||
args.action_dim = args.action_shape[0] | ||
print("Max_action", args.max_action) | ||
|
||
test_envs = SubprocVectorEnv( | ||
[lambda: gym.make(args.task) for _ in range(args.test_num)] | ||
) | ||
if args.norm_obs: | ||
test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False) | ||
|
||
# seed | ||
np.random.seed(args.seed) | ||
torch.manual_seed(args.seed) | ||
test_envs.seed(args.seed) | ||
|
||
# model | ||
# actor network | ||
net_a = Net( | ||
args.state_shape, | ||
hidden_sizes=args.hidden_sizes, | ||
device=args.device, | ||
) | ||
actor = Actor( | ||
net_a, | ||
action_shape=args.action_shape, | ||
max_action=args.max_action, | ||
device=args.device, | ||
).to(args.device) | ||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) | ||
|
||
# critic network | ||
net_c1 = Net( | ||
args.state_shape, | ||
args.action_shape, | ||
hidden_sizes=args.hidden_sizes, | ||
concat=True, | ||
device=args.device, | ||
) | ||
net_c2 = Net( | ||
args.state_shape, | ||
args.action_shape, | ||
hidden_sizes=args.hidden_sizes, | ||
concat=True, | ||
device=args.device, | ||
) | ||
critic1 = Critic(net_c1, device=args.device).to(args.device) | ||
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) | ||
critic2 = Critic(net_c2, device=args.device).to(args.device) | ||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) | ||
|
||
policy = TD3BCPolicy( | ||
actor, | ||
actor_optim, | ||
critic1, | ||
critic1_optim, | ||
critic2, | ||
critic2_optim, | ||
tau=args.tau, | ||
gamma=args.gamma, | ||
exploration_noise=GaussianNoise(sigma=args.exploration_noise), | ||
policy_noise=args.policy_noise, | ||
update_actor_freq=args.update_actor_freq, | ||
noise_clip=args.noise_clip, | ||
alpha=args.alpha, | ||
estimation_step=args.n_step, | ||
action_space=env.action_space, | ||
) | ||
|
||
# load a previous policy | ||
if args.resume_path: | ||
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) | ||
print("Loaded agent from: ", args.resume_path) | ||
|
||
# collector | ||
test_collector = Collector(policy, test_envs) | ||
|
||
# log | ||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") | ||
args.algo_name = "td3_bc" | ||
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) | ||
log_path = os.path.join(args.logdir, log_name) | ||
|
||
# logger | ||
if args.logger == "wandb": | ||
logger = WandbLogger( | ||
save_interval=1, | ||
name=log_name.replace(os.path.sep, "__"), | ||
run_id=args.resume_id, | ||
config=args, | ||
project=args.wandb_project, | ||
) | ||
writer = SummaryWriter(log_path) | ||
writer.add_text("args", str(args)) | ||
if args.logger == "tensorboard": | ||
logger = TensorboardLogger(writer) | ||
else: # wandb | ||
logger.load(writer) | ||
|
||
def save_best_fn(policy): | ||
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) | ||
|
||
def watch(): | ||
if args.resume_path is None: | ||
args.resume_path = os.path.join(log_path, "policy.pth") | ||
|
||
policy.load_state_dict( | ||
torch.load(args.resume_path, map_location=torch.device("cpu")) | ||
) | ||
policy.eval() | ||
collector = Collector(policy, env) | ||
collector.collect(n_episode=1, render=1 / 35) | ||
|
||
if not args.watch: | ||
replay_buffer = load_buffer_d4rl(args.expert_data_task) | ||
if args.norm_obs: | ||
replay_buffer, obs_rms = normalize_all_obs_in_replay_buffer(replay_buffer) | ||
test_envs.set_obs_rms(obs_rms) | ||
# trainer | ||
result = offline_trainer( | ||
policy, | ||
replay_buffer, | ||
test_collector, | ||
args.epoch, | ||
args.step_per_epoch, | ||
args.test_num, | ||
args.batch_size, | ||
save_best_fn=save_best_fn, | ||
logger=logger, | ||
) | ||
pprint.pprint(result) | ||
else: | ||
watch() | ||
|
||
# Let's watch its performance! | ||
policy.eval() | ||
test_envs.seed(args.seed) | ||
test_collector.reset() | ||
result = test_collector.collect(n_episode=args.test_num, render=args.render) | ||
print(f"Final reward: {result['rews'].mean()}, length: {result['lens'].mean()}") | ||
|
||
|
||
if __name__ == "__main__": | ||
test_td3_bc() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.