+
Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 209 additions & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
)
from torchrl.envs.utils import _sort_keys, _update_during_reset, step_mdp
from torchrl.objectives.value.functional import reward2go
from torchrl.objectives.value.utils import _get_num_per_traj, _split_and_pad_sequence

try:
from torchvision.transforms.functional import center_crop
Expand Down Expand Up @@ -5600,7 +5601,6 @@ def reset_key(self, value):
def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:

_reset = _get_reset(self.reset_key, tensordict)
for in_key in self.in_keys:
buffer_name = self._buffer_name(in_key)
Expand Down Expand Up @@ -6686,7 +6686,6 @@ def _step(
raise RuntimeError("BurnInTransform can only be appended to a ReplayBuffer.")

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:

if self.burn_in == 0:
return tensordict

Expand Down Expand Up @@ -6796,3 +6795,211 @@ def _reset(
with _set_missing_tolerance(self, True):
tensordict_reset = self._call(tensordict_reset)
return tensordict_reset


class HERTransform(Transform):
"""Hindsight Experience Replay (HER) transform.

This transform is used in reinforcement learning algorithms that employ
Hindsight Experience Replay (HER). HER is a technique that allows an agent
to learn from failed experiences by replaying them with different goals.

Args:
samples (Optional[Union[int, torch.Tensor]]): The number of augmented samples
to generate for each original sample. Defaults to 4.
generation_type (str): The type of goal generation to use. Can be one of
"future", "random", or "final". Defaults to "future".
achieved_goal_key (Optional[NestedKey]): The key to access the achieved goal
in the input tensor dictionary. Defaults to "achieved_goal".
desired_goal_key (Optional[NestedKey]): The key to access the desired goal
in the input tensor dictionary. Defaults to "desired_goal".
reward_key (Optional[NestedKey]): The key to access the reward in the output
tensor dictionary. Defaults to "reward".
reward_function (Optional[callable]): The reward function to use for calculating
the rewards of augmented samples. Defaults to None, in which case the
`distance_reward_function` is used.

Attributes:
ENV_ERR (str): The error message to raise when the transform is applied to
the collector or the environment.

"""

ENV_ERR = (
"The Reward2GoTransform is only an inverse transform and can "
"only be applied to the replay buffer and not to the collector or the environment."
)

def __init__(
self,
samples: Optional[Union[int, torch.Tensor]] = 4,
generation_type: str = "future",
achieved_goal_key: Optional[NestedKey] = "achieved_goal",
desired_goal_key: Optional[NestedKey] = "desired_goal",
reward_key: Optional[NestedKey] = "reward",
reward_function: Optional[callable] = None,
):
super().__init__(
in_keys=None,
in_keys_inv=None,
out_keys_inv=None,
)
self.achieved_goal_key = achieved_goal_key
self.desired_goal_key = desired_goal_key
self.reward_key = reward_key
self.generation_type = generation_type

if reward_function is None:
self.reward_function = distance_reward_function
else:
self.reward_function = reward_function

if not isinstance(samples, torch.Tensor):
samples = torch.tensor(samples)

self.register_buffer("samples", samples)

def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
augmentation_td = self.her_augmentation(tensordict)
return torch.cat([tensordict, augmentation_td], dim=0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As explained above. It doesnt feel like a transform as we create a new tensordict and have to combine original and augmented data before adding to the replay buffer. I think ideally the "augmentations" would done directly after the collection. So as a postproc for collectors or as here in the example as an inverse_transform for the replay buffer.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not a transform? I think it's pretty neat to use a transform. Who said a transform had to change things in-place?
Our API to modify samples at writing time is to use either a transform or a different writer. If you think this can be achieved with a writer I'm on board. But I don't think there's anything wrong with the transform.

An advantage of using a writer instead os that it feels more natural (transforms can be used with envs unless specified otherwise, writers are dedicated to RBs)


def _inv_apply_transform(self, tensordict: TensorDictBase) -> torch.Tensor:
return self.her_augmentation(tensordict)

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
return tensordict

def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
raise ValueError(self.ENV_ERR)

def her_augmentation(self, sampled_td: TensorDictBase):
if len(sampled_td.shape) == 1:
sampled_td = sampled_td.unsqueeze(0)
b, t = sampled_td.shape
trajectories = _get_num_per_traj(sampled_td.get("terminated"))
splitted_td = _split_and_pad_sequence(sampled_td, trajectories)
splitted_achieved_goals = splitted_td.get(self.achieved_goal_key)

# get indices for each trajectory
idxs = self.generate_sample_idxs(trajectories)

# create new goals based idxs
new_goals = []
for i, ids in enumerate(idxs):
new_goals.append(splitted_achieved_goals[i][ids])

# calculate rewards given new desired goals and old achieved goals
vmap_rewards = torch.vmap(distance_reward_function)
Copy link

@dtsaras dtsaras Apr 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you wanna call self.reward_function instead of distance_reward_function. Also maybe the reward_function should be a TensorDictModule such it can be more easily customized for a given environment. There is torchrl.modules.VmapModule for wrapping TensorDictModules with vmap.

rewards = []
for ach, des in zip(splitted_achieved_goals, new_goals):
rewards.append(vmap_rewards(ach[: des.shape[0], :], des))

cat_rewards = torch.cat(rewards).reshape(b, t, self.samples, -1).squeeze(-1)
cat_new_goals = torch.cat(new_goals).reshape(b, t, self.samples, -1)

augmentation_td = TensorDict(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the augmentation_td should still maintain other metadata that are related to the state rather than selecting only the keys: observation, action, terminated, truncated, ...

{
"observation": sampled_td.get("observation").repeat_interleave(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we keep it a transform we probably need to specify all those tensordict keys ... Not sure what a better alternative would be. Any idea?

self.samples, dim=0
),
"action": sampled_td.get("action").repeat_interleave(
self.samples, dim=0
),
"terminated": sampled_td.get("terminated").repeat_interleave(
self.samples, dim=0
),
"truncated": sampled_td.get("truncated").repeat_interleave(
self.samples, dim=0
),
self.achieved_goal_key: sampled_td.get(
self.achieved_goal_key
).repeat_interleave(self.samples, dim=0),
},
batch_size=(b * self.samples, t),
)

augmentation_td.set(self.reward_key, cat_rewards.transpose(1, 2).flatten(0, 1))
augmentation_td.set(
self.desired_goal_key, cat_new_goals.transpose(1, 2).flatten(0, 1)
)

return augmentation_td

def generate_future_idxs(self, traj_lens):
def generate_for_single_traj_len(traj_len):
idxs = []
for i in range(traj_len - 1):
idxs.append(
torch.randint(low=i + 1, high=traj_len, size=(1, self.samples))
)
# correct for the last idx with last idx
idxs.append(torch.full((1, self.samples), fill_value=traj_len - 1))
return torch.cat(idxs)

return [generate_for_single_traj_len(traj_len) for traj_len in traj_lens]

def generate_random_idxs(self, traj_lens):
def generate_for_single_traj_len(traj_len):
idxs = []
for _ in range(traj_len):
idxs.append(torch.randint(low=0, high=traj_len, size=(1, self.samples)))
return torch.cat(idxs)

return [generate_for_single_traj_len(traj_len) for traj_len in traj_lens]

def generate_final_idx(self, traj_lens):
def generate_for_single_traj_len(traj_len):
return torch.full((traj_len, self.samples), fill_value=traj_len - 1)

return [generate_for_single_traj_len(traj_len) for traj_len in traj_lens]

def generate_sample_idxs(self, trajectories):
if self.generation_type == "future":
idxs = self.generate_future_idxs(trajectories)

elif self.generation_type == "random":
idxs = self.generate_random_idxs(trajectories)

elif self.generation_type == "final":
idxs = self.generate_final_idx(trajectories)
else:
raise ValueError("Invalid generation type")
return idxs


def distance_torch(a, b):
"""Calculate the Euclidean distance between two tensors.

Args:
a (torch.Tensor): The first tensor.
b (torch.Tensor): The second tensor.

Returns:
torch.Tensor: The Euclidean distance between the two tensors.
"""
return torch.linalg.vector_norm(a - b, dim=-1)


def distance_reward_function(
achieved_goal: torch.Tensor,
desired_goal: torch.Tensor,
threshold: float = 0.05,
reward_type: str = "sparse",
) -> torch.Tensor:
"""Calculates the distance-based reward for a given achieved goal and desired goal.

Args:
achieved_goal (torch.Tensor): The achieved goal.
desired_goal (torch.Tensor): The desired goal.
threshold (float, optional): The threshold value for determining success. Defaults to 0.05.
reward_type (str, optional): The type of reward to use. Can be "sparse" or "dense". Defaults to "sparse".

Returns:
torch.Tensor: The distance-based reward.

"""
d = distance_torch(achieved_goal, desired_goal)
if reward_type == "sparse":
return -(d > threshold).float()
else:
return -d.float()
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载