这是indexloc提供的服务,不要输入任何密码
Skip to content

Hindsight Experience Replay as a replay buffer #753

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Oct 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
- [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf)
- [Intrinsic Curiosity Module (ICM)](https://arxiv.org/pdf/1705.05363.pdf)
- [Hindsight Experience Replay (HER)](https://arxiv.org/pdf/1707.01495.pdf)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add in docs/index.rst

Copy link
Contributor Author

@Juno-T Juno-T Oct 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated: 9d52936


Here are Tianshou's other features:

Expand Down
25 changes: 25 additions & 0 deletions docs/api/tianshou.data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ PrioritizedReplayBuffer
:undoc-members:
:show-inheritance:

HERReplayBuffer
~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: tianshou.data.HERReplayBuffer
:members:
:undoc-members:
:show-inheritance:

ReplayBufferManager
~~~~~~~~~~~~~~~~~~~

Expand All @@ -46,6 +54,15 @@ PrioritizedReplayBufferManager
:undoc-members:
:show-inheritance:


HERReplayBufferManager
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: tianshou.data.HERReplayBufferManager
:members:
:undoc-members:
:show-inheritance:

VectorReplayBuffer
~~~~~~~~~~~~~~~~~~

Expand All @@ -62,6 +79,14 @@ PrioritizedVectorReplayBuffer
:undoc-members:
:show-inheritance:

HERVectorReplayBuffer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: tianshou.data.HERVectorReplayBuffer
:members:
:undoc-members:
:show-inheritance:

CachedReplayBuffer
~~~~~~~~~~~~~~~~~~

Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.ICMPolicy` `Intrinsic Curiosity Module <https://arxiv.org/pdf/1705.05363.pdf>`_
* :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay <https://arxiv.org/pdf/1511.05952.pdf>`_
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator <https://arxiv.org/pdf/1506.02438.pdf>`_
* :class:`~tianshou.data.HERReplayBuffer` `Hindsight Experience Replay <https://arxiv.org/pdf/1707.01495.pdf>`_

Here is Tianshou's other features:

Expand Down
1 change: 1 addition & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ mujoco
jit
nstep
preprocess
preprocessing
repo
ReLU
namespace
Expand Down
13 changes: 13 additions & 0 deletions examples/mujoco/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Supported algorithms are listed below:
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/), [commit id](https://github.com/thu-ml/tianshou/tree/1730a9008ad6bb67cac3b21347bed33b532b17bc)
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/6426a39796db052bafb7cabe85c764db20a722b0)
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/pdf/1502.05477.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/5057b5c89e6168220272c9c28a15b758a72efc32)
- [Hindsight Experience Replay (HER)](https://arxiv.org/abs/1707.01495)

## EnvPool

Expand Down Expand Up @@ -304,6 +305,18 @@ For pretrained agents, detailed graphs (single agent, single game) and log detai
1. All shared hyperparameters are exactly the same as TRPO, regarding how similar these two algorithms are.
2. We found different games in Mujoco may require quite different `actor-step-size`: Reacher/Swimmer are insensitive to step-size in range (0.1~1.0), while InvertedDoublePendulum / InvertedPendulum / Humanoid are quite sensitive to step size, and even 0.1 is too large. Other games may require `actor-step-size` in range (0.1~0.4), but aren't that sensitive in general.

## Others

### HER
| Environment | DDPG without HER | DDPG with HER |
| :--------------------: | :--------------: | :--------------: |
| FetchReach | -49.9±0.2. | **-17.6±21.7** |

#### Hints for HER
1. The HER technique is proposed for solving task-based environments, so it cannot be compared with non-task-based mujoco benchmarks. The environment used in this evaluation is ``FetchReach-v3`` which requires an extra [installation](https://github.com/Farama-Foundation/Gymnasium-Robotics).
2. Simple hyperparameters optimizations are done for both settings, DDPG with and without HER. However, since *DDPG without HER* failed in every experiment, the best hyperparameters for *DDPG with HER* are used in the evaluation of both settings.
3. The scores are the mean reward ± 1 standard deviation of 16 seeds. The minimum reward for ``FetchReach-v3`` is -50 which we can imply that *DDPG without HER* performs as good as a random policy. *DDPG with HER* although has a better mean reward, the standard deviation is quite high. This is because in this setting, the agent will either fail completely (-50 reward) or successfully learn the task (close to 0 reward). This means that the agent successfully learned in about 70% of the 16 seeds.

## Note

<a name="footnote1">[1]</a> Supported environments include HalfCheetah-v3, Hopper-v3, Swimmer-v3, Walker2d-v3, Ant-v3, Humanoid-v3, Reacher-v2, InvertedPendulum-v2 and InvertedDoublePendulum-v2. Pusher, Thrower, Striker and HumanoidStandup are not supported because they are not commonly seen in literatures.
Expand Down
228 changes: 228 additions & 0 deletions examples/mujoco/fetch_her_ddpg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
#!/usr/bin/env python3

import argparse
import datetime
import os
import pprint

import gym
import numpy as np
import torch
import wandb
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import (
Collector,
HERReplayBuffer,
HERVectorReplayBuffer,
ReplayBuffer,
VectorReplayBuffer,
)
from tianshou.env import ShmemVectorEnv, TruncatedAsTerminated
from tianshou.exploration import GaussianNoise
from tianshou.policy import DDPGPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger, WandbLogger
from tianshou.utils.net.common import Net, get_dict_state_decorator
from tianshou.utils.net.continuous import Actor, Critic


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="FetchReach-v3")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--buffer-size", type=int, default=100000)
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256])
parser.add_argument("--actor-lr", type=float, default=1e-3)
parser.add_argument("--critic-lr", type=float, default=3e-3)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--tau", type=float, default=0.005)
parser.add_argument("--exploration-noise", type=float, default=0.1)
parser.add_argument("--start-timesteps", type=int, default=25000)
parser.add_argument("--epoch", type=int, default=10)
parser.add_argument("--step-per-epoch", type=int, default=5000)
parser.add_argument("--step-per-collect", type=int, default=1)
parser.add_argument("--update-per-step", type=int, default=1)
parser.add_argument("--n-step", type=int, default=1)
parser.add_argument("--batch-size", type=int, default=512)
parser.add_argument(
"--replay-buffer", type=str, default="her", choices=["normal", "her"]
)
parser.add_argument("--her-horizon", type=int, default=50)
parser.add_argument("--her-future-k", type=int, default=8)
parser.add_argument("--training-num", type=int, default=1)
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument(
"--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument("--wandb-project", type=str, default="HER-benchmark")
parser.add_argument(
"--watch",
default=False,
action="store_true",
help="watch the play of pre-trained policy only",
)
return parser.parse_args()


def make_fetch_env(task, training_num, test_num):
env = TruncatedAsTerminated(gym.make(task))
train_envs = ShmemVectorEnv(
[lambda: TruncatedAsTerminated(gym.make(task)) for _ in range(training_num)]
)
test_envs = ShmemVectorEnv(
[lambda: TruncatedAsTerminated(gym.make(task)) for _ in range(test_num)]
)
return env, train_envs, test_envs


def test_ddpg(args=get_args()):
# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "ddpg"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)

# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
logger.wandb_run.config.setdefaults(vars(args))
args = argparse.Namespace(**wandb.config)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)

env, train_envs, test_envs = make_fetch_env(
args.task, args.training_num, args.test_num
)
args.state_shape = {
'observation': env.observation_space['observation'].shape,
'achieved_goal': env.observation_space['achieved_goal'].shape,
'desired_goal': env.observation_space['desired_goal'].shape,
}
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
args.exploration_noise = args.exploration_noise * args.max_action
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# model
dict_state_dec, flat_state_shape = get_dict_state_decorator(
state_shape=args.state_shape,
keys=['observation', 'achieved_goal', 'desired_goal']
)
net_a = dict_state_dec(Net)(
flat_state_shape, hidden_sizes=args.hidden_sizes, device=args.device
)
actor = dict_state_dec(Actor)(
net_a, args.action_shape, max_action=args.max_action, device=args.device
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net_c = dict_state_dec(Net)(
flat_state_shape,
action_shape=args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
critic = dict_state_dec(Critic)(net_c, device=args.device).to(args.device)
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
policy = DDPGPolicy(
actor,
actor_optim,
critic,
critic_optim,
tau=args.tau,
gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
estimation_step=args.n_step,
action_space=env.action_space,
)

# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)

# collector
def compute_reward_fn(ag: np.ndarray, g: np.ndarray):
return env.compute_reward(ag, g, {})

if args.replay_buffer == "normal":
if args.training_num > 1:
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
else:
buffer = ReplayBuffer(args.buffer_size)
else:
if args.training_num > 1:
buffer = HERVectorReplayBuffer(
args.buffer_size,
len(train_envs),
compute_reward_fn=compute_reward_fn,
horizon=args.her_horizon,
future_k=args.her_future_k,
)
else:
buffer = HERReplayBuffer(
args.buffer_size,
compute_reward_fn=compute_reward_fn,
horizon=args.her_horizon,
future_k=args.her_future_k,
)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
train_collector.collect(n_step=args.start_timesteps, random=True)

def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

if not args.watch:
# trainer
result = offpolicy_trainer(
policy,
train_collector,
test_collector,
args.epoch,
args.step_per_epoch,
args.step_per_collect,
args.test_num,
args.batch_size,
save_best_fn=save_best_fn,
logger=logger,
update_per_step=args.update_per_step,
test_in_train=False,
)
pprint.pprint(result)

# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')


if __name__ == "__main__":
test_ddpg()
45 changes: 45 additions & 0 deletions test/base/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,48 @@ def step(self, action):
for i in range(self.size):
self.graph.nodes[i]["data"] = next_graph_state[i]
return self._encode_obs(), 1.0, 0, 0, {}


class MyGoalEnv(MyTestEnv):

def __init__(self, *args, **kwargs):
assert kwargs.get("dict_state", 0) + kwargs.get("recurse_state", 0) == 0, \
"dict_state / recurse_state not supported"
super().__init__(*args, **kwargs)
obs, _ = super().reset(state=0)
obs, _, _, _, _ = super().step(1)
self._goal = obs * self.size
super_obsv = self.observation_space
self.observation_space = gym.spaces.Dict(
{
'observation': super_obsv,
'achieved_goal': super_obsv,
'desired_goal': super_obsv,
}
)

def reset(self, *args, **kwargs):
obs, info = super().reset(*args, **kwargs)
new_obs = {
'observation': obs,
'achieved_goal': obs,
'desired_goal': self._goal
}
return new_obs, info

def step(self, *args, **kwargs):
obs_next, rew, terminated, truncated, info = super().step(*args, **kwargs)
new_obs_next = {
'observation': obs_next,
'achieved_goal': obs_next,
'desired_goal': self._goal
}
return new_obs_next, rew, terminated, truncated, info

def compute_reward_fn(
self, achieved_goal: np.ndarray, desired_goal: np.ndarray, info: dict
) -> np.ndarray:
axis = -1
if self.array_state:
axis = (-3, -2, -1)
return (achieved_goal == desired_goal).all(axis=axis)
Loading