-
Notifications
You must be signed in to change notification settings - Fork 389
[WIP] Hindsight Experience Replay Transform #1819
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/1819
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 19 Unrelated FailuresAs of commit 90eef75 with merge base 57139bd ( NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
||
augmentation_td = TensorDict( | ||
{ | ||
"observation": sampled_td.get("observation").repeat_interleave( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we keep it a transform we probably need to specify all those tensordict keys ... Not sure what a better alternative would be. Any idea?
|
||
def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: | ||
augmentation_td = self.her_augmentation(tensordict) | ||
return torch.cat([tensordict, augmentation_td], dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As explained above. It doesnt feel like a transform
as we create a new tensordict and have to combine original and augmented data before adding to the replay buffer. I think ideally the "augmentations" would done directly after the collection. So as a postproc for collectors or as here in the example as an inverse_transform for the replay buffer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not a transform? I think it's pretty neat to use a transform. Who said a transform had to change things in-place?
Our API to modify samples at writing time is to use either a transform or a different writer. If you think this can be achieved with a writer I'm on board. But I don't think there's anything wrong with the transform.
An advantage of using a writer instead os that it feels more natural (transforms can be used with envs unless specified otherwise, writers are dedicated to RBs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ahmed-touati suggested we use a sampler for this rather than a transform. I'm not strongly opinionated on the matter, mostly because I need more context on what we're trying to achieve here.
Can you elaborate a bit more on what this transform does, maybe with a bunch of examples?
So HER is mainly used in goal-conditioned RL with sparse reward signals where the agent has to reach/achieve a goal state and only gets a reward (+1) when the goal state is achieved, otherwise no reward. The observation consists of three elements: the observation the agent sees, the state the agent had (could be x,y,z position), and the goal state the agent should reach (x,y,z). A typical task could be a robot that has to reach a goal position. The observation will include the agent position but its mostly added as additional information also helps here for understanding. Now as we have a sparse reward function most of the trajectories will have no learning signal for the agent as it might not be possible for the agent to reach the goal position randomly or by pure luck. So lets say you have a real transition (obs, action, reward, done, next obs, achieved_position, goal_position) for this tuple you now want to sample a new goal_position and then calculate the reward based on this new goal_position and the real achieved_position. So you then add the real transition (obs, action, reward, done, next obs, achieved_position, goal_position) but also the HER augmented transition (obs, action, new_reward, done, new next obs, achieved_position, new_goal_position). The sampling can happen in different ways but is not important for now. However, I think important will be that we need the reward function, Im not sure if we can pass it to the writer/sampler for the buffer, that's why my first thought was a transform. Most of the time the reward function might just be Euclidean distance but maybe for other tasks the user needs to provide a more sophisticated reward function. |
Why not? I would guess that even if it's a complex nn.Module you can still do pretty much everything with a well tailored function (at least nothing less than with a transform). |
Thanks for the context btw! |
Revisiting this I think it would make much more sense to do it with a writer. We want to augment current incoming data with new sampled goal states and store them all together in the buffer. I think this would be generally a good way to add other data augmentation strategies with writer instead of transforms. Having a closer look right now on the writer classes and will update the code here |
But this would not allow us to stack multiple augmentations on top of each other... so maybe not that ideal for augmentations |
You could still transform your data before passing it to the writer, but not after |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While hindsight experience replay is pretty useful, I think it falls under the category of specialized algorithm rather than a building block @vmoens
new_goals.append(splitted_achieved_goals[i][ids]) | ||
|
||
# calculate rewards given new desired goals and old achieved goals | ||
vmap_rewards = torch.vmap(distance_reward_function) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you wanna call self.reward_function
instead of distance_reward_function
. Also maybe the reward_function
should be a TensorDictModule such it can be more easily customized for a given environment. There is torchrl.modules.VmapModule
for wrapping TensorDictModules with vmap.
cat_rewards = torch.cat(rewards).reshape(b, t, self.samples, -1).squeeze(-1) | ||
cat_new_goals = torch.cat(new_goals).reshape(b, t, self.samples, -1) | ||
|
||
augmentation_td = TensorDict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the augmentation_td
should still maintain other metadata that are related to the state rather than selecting only the keys: observation, action, terminated, truncated, ...
Not sure about that one. |
Are you planning to continue working on this PR? If not, I'd be happy to help out and look into finishing it myself. Additionally, @vmoens, do you now have a clearer sense of whether this should be implemented as a transform, writer, or sampler? It would be great to make sure we're aligned on the best approach before moving forward. Looking forward to your thoughts! |
I have an already complete version of it in a gist but I have not made a PR for it yet. Maybe you can have a look and provide feedback. I had talked with Alexandre on discord long ago over the implementation and he liked it. Any feedback? @vmoens https://gist.github.com/dtsaras/f321aed253a64e4849ce95bd232d1635 |
I really like the modular approach you've taken! That said, I have a few questions about some details of the implementation. From what I understand, the reward function would be implemented as a Could you share your thoughts on the reasoning behind including the |
You are correct, the HERRewardTransform would be responsible to assign the rewards for all the intermediate states. While it's not necessary that it has to be its own special transform, it has to be "something" that reassigns the rewards to intermediate states of a trajectory. The reason I choose for it to be a Transform rather than a callable is to utilize the torchRL API. For example, the user does not need to reimplement the discounted reward function as Reward2GoTransform exists and multiple transforms can also be nicely composed into one. The HERSubgoalSampler I have implemented only contains Maybe I should create a PR and work on it together where you can make changes. P.S. You can join the Discord server as well and maybe more can provide some feedback there. |
I'm sorry, but I'm having a bit of trouble following your reasoning here. Could you elaborate on why it can't be a
Ah, I see – sorry for the confusion! What I meant is that, as long as we implement all the sampling methods outlined in the paper, it seems unlikely that we'd need another one. This makes me think we could move the subgoal sampling logic directly into the I understand that random sampling didn't yield good results, but I thought it might be valuable to include it for the sake of completeness, since it was mentioned in the original paper.
That sounds like a great idea.
Thanks for letting me know! I didn't realize there was a Discord server; I've just joined. |
Description
Adds Hindsight Experience Replay (HER) Transform
Motivation and Context
The first draft for the HER transform. However, I am not sure if it should be a
Transform
or if we create an extraAugmentation
class as we are not transforming a single element in the tensordict but augmenting existing collection data. Could be interesting for future "data augmentation strategies", which I think we do not have until now.Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!