-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Hindsight Experience Replay as a replay buffer #753
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
3d596a7
init her replaybuffer
Juno-T c0d938e
debug test her
Juno-T b4d4182
add goal env wrapper, test
Juno-T f0480a3
update HER to use dict obs, and test
Juno-T da2f8b3
debug typehint, add example
Juno-T 7fc49b3
Update README.md
Juno-T 4dcb98d
add docstring
Juno-T 274b6e6
Merge branch 'master' into master
Trinkle23897 1071dab
correct wrapper test
Juno-T 5aba126
Merge branch 'master' into master
Trinkle23897 63e65c0
update from feedback
Juno-T 9d52936
add doc
Juno-T 6ee7d08
fix indices calculation, add test
Juno-T 7802451
update doc, uncomment tests
Juno-T 808760a
reorganize
Juno-T 3a98932
debug her
Juno-T 4a88d76
refactor example
Juno-T b08fac7
format
Juno-T 1798b90
add HER section
Juno-T 72b4ab4
Merge pull request #1 from thu-ml/master
Juno-T 9082db7
make linter happy
nuance1979 20ba3ab
make mypy happy
nuance1979 aaf0f73
add to word list
nuance1979 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,6 +52,7 @@ mujoco | |
jit | ||
nstep | ||
preprocess | ||
preprocessing | ||
repo | ||
ReLU | ||
namespace | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,228 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import argparse | ||
import datetime | ||
import os | ||
import pprint | ||
|
||
import gym | ||
import numpy as np | ||
import torch | ||
import wandb | ||
from torch.utils.tensorboard import SummaryWriter | ||
|
||
from tianshou.data import ( | ||
Collector, | ||
HERReplayBuffer, | ||
HERVectorReplayBuffer, | ||
ReplayBuffer, | ||
VectorReplayBuffer, | ||
) | ||
from tianshou.env import ShmemVectorEnv, TruncatedAsTerminated | ||
from tianshou.exploration import GaussianNoise | ||
from tianshou.policy import DDPGPolicy | ||
from tianshou.trainer import offpolicy_trainer | ||
from tianshou.utils import TensorboardLogger, WandbLogger | ||
from tianshou.utils.net.common import Net, get_dict_state_decorator | ||
from tianshou.utils.net.continuous import Actor, Critic | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--task", type=str, default="FetchReach-v3") | ||
parser.add_argument("--seed", type=int, default=0) | ||
parser.add_argument("--buffer-size", type=int, default=100000) | ||
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256]) | ||
parser.add_argument("--actor-lr", type=float, default=1e-3) | ||
parser.add_argument("--critic-lr", type=float, default=3e-3) | ||
parser.add_argument("--gamma", type=float, default=0.99) | ||
parser.add_argument("--tau", type=float, default=0.005) | ||
parser.add_argument("--exploration-noise", type=float, default=0.1) | ||
parser.add_argument("--start-timesteps", type=int, default=25000) | ||
parser.add_argument("--epoch", type=int, default=10) | ||
parser.add_argument("--step-per-epoch", type=int, default=5000) | ||
parser.add_argument("--step-per-collect", type=int, default=1) | ||
parser.add_argument("--update-per-step", type=int, default=1) | ||
parser.add_argument("--n-step", type=int, default=1) | ||
parser.add_argument("--batch-size", type=int, default=512) | ||
parser.add_argument( | ||
"--replay-buffer", type=str, default="her", choices=["normal", "her"] | ||
) | ||
parser.add_argument("--her-horizon", type=int, default=50) | ||
parser.add_argument("--her-future-k", type=int, default=8) | ||
parser.add_argument("--training-num", type=int, default=1) | ||
parser.add_argument("--test-num", type=int, default=10) | ||
parser.add_argument("--logdir", type=str, default="log") | ||
parser.add_argument("--render", type=float, default=0.) | ||
parser.add_argument( | ||
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" | ||
) | ||
parser.add_argument("--resume-path", type=str, default=None) | ||
parser.add_argument("--resume-id", type=str, default=None) | ||
parser.add_argument( | ||
"--logger", | ||
type=str, | ||
default="tensorboard", | ||
choices=["tensorboard", "wandb"], | ||
) | ||
parser.add_argument("--wandb-project", type=str, default="HER-benchmark") | ||
parser.add_argument( | ||
"--watch", | ||
default=False, | ||
action="store_true", | ||
help="watch the play of pre-trained policy only", | ||
) | ||
return parser.parse_args() | ||
|
||
|
||
def make_fetch_env(task, training_num, test_num): | ||
env = TruncatedAsTerminated(gym.make(task)) | ||
train_envs = ShmemVectorEnv( | ||
[lambda: TruncatedAsTerminated(gym.make(task)) for _ in range(training_num)] | ||
) | ||
test_envs = ShmemVectorEnv( | ||
[lambda: TruncatedAsTerminated(gym.make(task)) for _ in range(test_num)] | ||
) | ||
return env, train_envs, test_envs | ||
|
||
|
||
def test_ddpg(args=get_args()): | ||
# log | ||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") | ||
args.algo_name = "ddpg" | ||
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) | ||
log_path = os.path.join(args.logdir, log_name) | ||
|
||
# logger | ||
if args.logger == "wandb": | ||
logger = WandbLogger( | ||
save_interval=1, | ||
name=log_name.replace(os.path.sep, "__"), | ||
run_id=args.resume_id, | ||
config=args, | ||
project=args.wandb_project, | ||
) | ||
logger.wandb_run.config.setdefaults(vars(args)) | ||
args = argparse.Namespace(**wandb.config) | ||
writer = SummaryWriter(log_path) | ||
writer.add_text("args", str(args)) | ||
if args.logger == "tensorboard": | ||
logger = TensorboardLogger(writer) | ||
else: # wandb | ||
logger.load(writer) | ||
|
||
env, train_envs, test_envs = make_fetch_env( | ||
args.task, args.training_num, args.test_num | ||
) | ||
args.state_shape = { | ||
'observation': env.observation_space['observation'].shape, | ||
'achieved_goal': env.observation_space['achieved_goal'].shape, | ||
'desired_goal': env.observation_space['desired_goal'].shape, | ||
} | ||
args.action_shape = env.action_space.shape or env.action_space.n | ||
args.max_action = env.action_space.high[0] | ||
args.exploration_noise = args.exploration_noise * args.max_action | ||
print("Observations shape:", args.state_shape) | ||
print("Actions shape:", args.action_shape) | ||
print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) | ||
# seed | ||
np.random.seed(args.seed) | ||
torch.manual_seed(args.seed) | ||
# model | ||
dict_state_dec, flat_state_shape = get_dict_state_decorator( | ||
state_shape=args.state_shape, | ||
keys=['observation', 'achieved_goal', 'desired_goal'] | ||
) | ||
net_a = dict_state_dec(Net)( | ||
flat_state_shape, hidden_sizes=args.hidden_sizes, device=args.device | ||
) | ||
actor = dict_state_dec(Actor)( | ||
net_a, args.action_shape, max_action=args.max_action, device=args.device | ||
).to(args.device) | ||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) | ||
net_c = dict_state_dec(Net)( | ||
flat_state_shape, | ||
action_shape=args.action_shape, | ||
hidden_sizes=args.hidden_sizes, | ||
concat=True, | ||
device=args.device, | ||
) | ||
critic = dict_state_dec(Critic)(net_c, device=args.device).to(args.device) | ||
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) | ||
policy = DDPGPolicy( | ||
actor, | ||
actor_optim, | ||
critic, | ||
critic_optim, | ||
tau=args.tau, | ||
gamma=args.gamma, | ||
exploration_noise=GaussianNoise(sigma=args.exploration_noise), | ||
estimation_step=args.n_step, | ||
action_space=env.action_space, | ||
) | ||
|
||
# load a previous policy | ||
if args.resume_path: | ||
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) | ||
print("Loaded agent from: ", args.resume_path) | ||
|
||
# collector | ||
def compute_reward_fn(ag: np.ndarray, g: np.ndarray): | ||
return env.compute_reward(ag, g, {}) | ||
|
||
if args.replay_buffer == "normal": | ||
if args.training_num > 1: | ||
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) | ||
else: | ||
buffer = ReplayBuffer(args.buffer_size) | ||
else: | ||
if args.training_num > 1: | ||
buffer = HERVectorReplayBuffer( | ||
args.buffer_size, | ||
len(train_envs), | ||
compute_reward_fn=compute_reward_fn, | ||
horizon=args.her_horizon, | ||
future_k=args.her_future_k, | ||
) | ||
else: | ||
buffer = HERReplayBuffer( | ||
args.buffer_size, | ||
compute_reward_fn=compute_reward_fn, | ||
horizon=args.her_horizon, | ||
future_k=args.her_future_k, | ||
) | ||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) | ||
test_collector = Collector(policy, test_envs) | ||
train_collector.collect(n_step=args.start_timesteps, random=True) | ||
|
||
def save_best_fn(policy): | ||
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) | ||
|
||
if not args.watch: | ||
# trainer | ||
result = offpolicy_trainer( | ||
policy, | ||
train_collector, | ||
test_collector, | ||
args.epoch, | ||
args.step_per_epoch, | ||
args.step_per_collect, | ||
args.test_num, | ||
args.batch_size, | ||
save_best_fn=save_best_fn, | ||
logger=logger, | ||
update_per_step=args.update_per_step, | ||
test_in_train=False, | ||
) | ||
pprint.pprint(result) | ||
|
||
# Let's watch its performance! | ||
policy.eval() | ||
test_envs.seed(args.seed) | ||
test_collector.reset() | ||
result = test_collector.collect(n_episode=args.test_num, render=args.render) | ||
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') | ||
|
||
|
||
if __name__ == "__main__": | ||
test_ddpg() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add in docs/index.rst
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated: 9d52936