-
Notifications
You must be signed in to change notification settings - Fork 412
[WIP] Hindsight Experience Replay Transform #1819
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,6 +46,7 @@ | |
) | ||
from torchrl.envs.utils import _sort_keys, _update_during_reset, step_mdp | ||
from torchrl.objectives.value.functional import reward2go | ||
from torchrl.objectives.value.utils import _get_num_per_traj, _split_and_pad_sequence | ||
|
||
try: | ||
from torchvision.transforms.functional import center_crop | ||
|
@@ -5600,7 +5601,6 @@ def reset_key(self, value): | |
def _reset( | ||
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase | ||
) -> TensorDictBase: | ||
|
||
_reset = _get_reset(self.reset_key, tensordict) | ||
for in_key in self.in_keys: | ||
buffer_name = self._buffer_name(in_key) | ||
|
@@ -6686,7 +6686,6 @@ def _step( | |
raise RuntimeError("BurnInTransform can only be appended to a ReplayBuffer.") | ||
|
||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase: | ||
|
||
if self.burn_in == 0: | ||
return tensordict | ||
|
||
|
@@ -6796,3 +6795,211 @@ def _reset( | |
with _set_missing_tolerance(self, True): | ||
tensordict_reset = self._call(tensordict_reset) | ||
return tensordict_reset | ||
|
||
|
||
class HERTransform(Transform): | ||
"""Hindsight Experience Replay (HER) transform. | ||
|
||
This transform is used in reinforcement learning algorithms that employ | ||
Hindsight Experience Replay (HER). HER is a technique that allows an agent | ||
to learn from failed experiences by replaying them with different goals. | ||
|
||
Args: | ||
samples (Optional[Union[int, torch.Tensor]]): The number of augmented samples | ||
to generate for each original sample. Defaults to 4. | ||
generation_type (str): The type of goal generation to use. Can be one of | ||
"future", "random", or "final". Defaults to "future". | ||
achieved_goal_key (Optional[NestedKey]): The key to access the achieved goal | ||
in the input tensor dictionary. Defaults to "achieved_goal". | ||
desired_goal_key (Optional[NestedKey]): The key to access the desired goal | ||
in the input tensor dictionary. Defaults to "desired_goal". | ||
reward_key (Optional[NestedKey]): The key to access the reward in the output | ||
tensor dictionary. Defaults to "reward". | ||
reward_function (Optional[callable]): The reward function to use for calculating | ||
the rewards of augmented samples. Defaults to None, in which case the | ||
`distance_reward_function` is used. | ||
|
||
Attributes: | ||
ENV_ERR (str): The error message to raise when the transform is applied to | ||
the collector or the environment. | ||
|
||
""" | ||
|
||
ENV_ERR = ( | ||
"The Reward2GoTransform is only an inverse transform and can " | ||
"only be applied to the replay buffer and not to the collector or the environment." | ||
) | ||
|
||
def __init__( | ||
self, | ||
samples: Optional[Union[int, torch.Tensor]] = 4, | ||
generation_type: str = "future", | ||
achieved_goal_key: Optional[NestedKey] = "achieved_goal", | ||
desired_goal_key: Optional[NestedKey] = "desired_goal", | ||
reward_key: Optional[NestedKey] = "reward", | ||
reward_function: Optional[callable] = None, | ||
): | ||
super().__init__( | ||
in_keys=None, | ||
in_keys_inv=None, | ||
out_keys_inv=None, | ||
) | ||
self.achieved_goal_key = achieved_goal_key | ||
self.desired_goal_key = desired_goal_key | ||
self.reward_key = reward_key | ||
self.generation_type = generation_type | ||
|
||
if reward_function is None: | ||
self.reward_function = distance_reward_function | ||
else: | ||
self.reward_function = reward_function | ||
|
||
if not isinstance(samples, torch.Tensor): | ||
samples = torch.tensor(samples) | ||
|
||
self.register_buffer("samples", samples) | ||
|
||
def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: | ||
augmentation_td = self.her_augmentation(tensordict) | ||
return torch.cat([tensordict, augmentation_td], dim=0) | ||
|
||
def _inv_apply_transform(self, tensordict: TensorDictBase) -> torch.Tensor: | ||
return self.her_augmentation(tensordict) | ||
|
||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase: | ||
return tensordict | ||
|
||
def _call(self, tensordict: TensorDictBase) -> TensorDictBase: | ||
raise ValueError(self.ENV_ERR) | ||
|
||
def her_augmentation(self, sampled_td: TensorDictBase): | ||
if len(sampled_td.shape) == 1: | ||
sampled_td = sampled_td.unsqueeze(0) | ||
b, t = sampled_td.shape | ||
trajectories = _get_num_per_traj(sampled_td.get("terminated")) | ||
splitted_td = _split_and_pad_sequence(sampled_td, trajectories) | ||
splitted_achieved_goals = splitted_td.get(self.achieved_goal_key) | ||
|
||
# get indices for each trajectory | ||
idxs = self.generate_sample_idxs(trajectories) | ||
|
||
# create new goals based idxs | ||
new_goals = [] | ||
for i, ids in enumerate(idxs): | ||
new_goals.append(splitted_achieved_goals[i][ids]) | ||
|
||
# calculate rewards given new desired goals and old achieved goals | ||
vmap_rewards = torch.vmap(distance_reward_function) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you wanna call |
||
rewards = [] | ||
for ach, des in zip(splitted_achieved_goals, new_goals): | ||
rewards.append(vmap_rewards(ach[: des.shape[0], :], des)) | ||
|
||
cat_rewards = torch.cat(rewards).reshape(b, t, self.samples, -1).squeeze(-1) | ||
cat_new_goals = torch.cat(new_goals).reshape(b, t, self.samples, -1) | ||
|
||
augmentation_td = TensorDict( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the |
||
{ | ||
"observation": sampled_td.get("observation").repeat_interleave( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we keep it a transform we probably need to specify all those tensordict keys ... Not sure what a better alternative would be. Any idea? |
||
self.samples, dim=0 | ||
), | ||
"action": sampled_td.get("action").repeat_interleave( | ||
self.samples, dim=0 | ||
), | ||
"terminated": sampled_td.get("terminated").repeat_interleave( | ||
self.samples, dim=0 | ||
), | ||
"truncated": sampled_td.get("truncated").repeat_interleave( | ||
self.samples, dim=0 | ||
), | ||
self.achieved_goal_key: sampled_td.get( | ||
self.achieved_goal_key | ||
).repeat_interleave(self.samples, dim=0), | ||
}, | ||
batch_size=(b * self.samples, t), | ||
) | ||
|
||
augmentation_td.set(self.reward_key, cat_rewards.transpose(1, 2).flatten(0, 1)) | ||
augmentation_td.set( | ||
self.desired_goal_key, cat_new_goals.transpose(1, 2).flatten(0, 1) | ||
) | ||
|
||
return augmentation_td | ||
|
||
def generate_future_idxs(self, traj_lens): | ||
def generate_for_single_traj_len(traj_len): | ||
idxs = [] | ||
for i in range(traj_len - 1): | ||
idxs.append( | ||
torch.randint(low=i + 1, high=traj_len, size=(1, self.samples)) | ||
) | ||
# correct for the last idx with last idx | ||
idxs.append(torch.full((1, self.samples), fill_value=traj_len - 1)) | ||
return torch.cat(idxs) | ||
|
||
return [generate_for_single_traj_len(traj_len) for traj_len in traj_lens] | ||
|
||
def generate_random_idxs(self, traj_lens): | ||
def generate_for_single_traj_len(traj_len): | ||
idxs = [] | ||
for _ in range(traj_len): | ||
idxs.append(torch.randint(low=0, high=traj_len, size=(1, self.samples))) | ||
return torch.cat(idxs) | ||
|
||
return [generate_for_single_traj_len(traj_len) for traj_len in traj_lens] | ||
|
||
def generate_final_idx(self, traj_lens): | ||
def generate_for_single_traj_len(traj_len): | ||
return torch.full((traj_len, self.samples), fill_value=traj_len - 1) | ||
|
||
return [generate_for_single_traj_len(traj_len) for traj_len in traj_lens] | ||
|
||
def generate_sample_idxs(self, trajectories): | ||
if self.generation_type == "future": | ||
idxs = self.generate_future_idxs(trajectories) | ||
|
||
elif self.generation_type == "random": | ||
idxs = self.generate_random_idxs(trajectories) | ||
|
||
elif self.generation_type == "final": | ||
idxs = self.generate_final_idx(trajectories) | ||
else: | ||
raise ValueError("Invalid generation type") | ||
return idxs | ||
|
||
|
||
def distance_torch(a, b): | ||
"""Calculate the Euclidean distance between two tensors. | ||
|
||
Args: | ||
a (torch.Tensor): The first tensor. | ||
b (torch.Tensor): The second tensor. | ||
|
||
Returns: | ||
torch.Tensor: The Euclidean distance between the two tensors. | ||
""" | ||
return torch.linalg.vector_norm(a - b, dim=-1) | ||
|
||
|
||
def distance_reward_function( | ||
achieved_goal: torch.Tensor, | ||
desired_goal: torch.Tensor, | ||
threshold: float = 0.05, | ||
reward_type: str = "sparse", | ||
) -> torch.Tensor: | ||
"""Calculates the distance-based reward for a given achieved goal and desired goal. | ||
|
||
Args: | ||
achieved_goal (torch.Tensor): The achieved goal. | ||
desired_goal (torch.Tensor): The desired goal. | ||
threshold (float, optional): The threshold value for determining success. Defaults to 0.05. | ||
reward_type (str, optional): The type of reward to use. Can be "sparse" or "dense". Defaults to "sparse". | ||
|
||
Returns: | ||
torch.Tensor: The distance-based reward. | ||
|
||
""" | ||
d = distance_torch(achieved_goal, desired_goal) | ||
if reward_type == "sparse": | ||
return -(d > threshold).float() | ||
else: | ||
return -d.float() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As explained above. It doesnt feel like a
transform
as we create a new tensordict and have to combine original and augmented data before adding to the replay buffer. I think ideally the "augmentations" would done directly after the collection. So as a postproc for collectors or as here in the example as an inverse_transform for the replay buffer.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not a transform? I think it's pretty neat to use a transform. Who said a transform had to change things in-place?
Our API to modify samples at writing time is to use either a transform or a different writer. If you think this can be achieved with a writer I'm on board. But I don't think there's anything wrong with the transform.
An advantage of using a writer instead os that it feels more natural (transforms can be used with envs unless specified otherwise, writers are dedicated to RBs)