+
Skip to content

Conversation

kurtamohler
Copy link
Contributor

Description

Avoid reshaping inputs to DreamerActorLoss.

Motivation and Context

Follow-up to #2494

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

Copy link

pytorch-bot bot commented Oct 15, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2496

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures, 6 Unrelated Failures

As of commit bca6b79 with merge base d894358 (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 15, 2024
loss_td, fake_data = loss_module(tensordict)
# NOTE: Input is reshaped because GRUCell (which is part of the
# RSSMPrior module in `mb_env`) expects input to be either 1D or 2D
loss_td, fake_data = loss_module(tensordict.reshape(-1))
Copy link
Contributor Author

@kurtamohler kurtamohler Oct 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if there is a better way to fix this test. I suppose it could be possible to just reshape the direct input to the GRUCell?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we need to reshape we should reshape - but another option here would be to use vmap
like:

if tensordict.ndim > 1:
    loss_td, fake_data = vmap(loss_module, (0,))(tensordict)

(gru works with vmap as long as you are using the python only version in torchrl.modules)

Copy link
Contributor Author

@kurtamohler kurtamohler Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I try using VmapModule, I get this error:

  File "/home/endoplasm/develop/torchrl-1/test/test_cost.py", line 10338, in test_dreamer_actor
    loss_td, fake_data = VmapModule(loss_module, (0,))(tensordict)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/endoplasm/develop/torchrl-1/torchrl/modules/tensordict_module/common.py", line 454, in __init__
    self.in_keys = module.in_keys
                   ^^^^^^^^^^^^^^
  File "/home/endoplasm/develop/torchrl-1/torchrl/objectives/common.py", line 441, in __getattr__
    return super().__getattr__(item)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/endoplasm/miniconda/envs/torchrl-1/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1931, in __getattr__
    raise AttributeError(
AttributeError: 'DreamerActorLoss' object has no attribute 'in_keys'

I'll probably just leave the reshape for now, but I would like to understand this.

Indeed if I try to access loss_module.in_keys directly, I also get the above error. But I can access the in_keys of the actor model and world model within the loss module:

print(loss_module.actor_model.in_keys)
print(loss_module.model_based_env.world_model.in_keys)
['state', 'belief']
['state', 'belief', 'action']

So I'm wondering what would be the right way to make VmapModule and DreamerActorLoss compatible? Would we want to add an in_keys attribute to DreamerActorLoss that returns a combined list of the keys in the actor model and world model?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah DreamerActorLoss should have in_keys!
All losses should. Dreamer hasn't received much love lately as you can see.
Let's take care of that in a separate PR then

@vmoens
Copy link
Collaborator

vmoens commented Oct 16, 2024

The Dreamer implementation (in examples workflow) is failing

@kurtamohler kurtamohler force-pushed the dreamer-avoid-reshape-0 branch from 6009810 to bca6b79 Compare October 16, 2024 23:08
@kurtamohler
Copy link
Contributor Author

The Dreamer implementation (in examples workflow) is failing

Should be fixed now

Copy link
Collaborator

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks!

loss_td, fake_data = loss_module(tensordict)
# NOTE: Input is reshaped because GRUCell (which is part of the
# RSSMPrior module in `mb_env`) expects input to be either 1D or 2D
loss_td, fake_data = loss_module(tensordict.reshape(-1))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah DreamerActorLoss should have in_keys!
All losses should. Dreamer hasn't received much love lately as you can see.
Let's take care of that in a separate PR then

@vmoens vmoens merged commit a27514c into pytorch:main Oct 18, 2024
71 of 80 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Objectives

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants

点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载