+
Skip to content

Conversation

BowenBao
Copy link
Collaborator

@BowenBao BowenBao commented Apr 20, 2023

Stack from ghstack (oldest at bottom):

Summary

Todo

  • Training vs eval in dynamo_export
    So we are effectively exporting all models in traning mode by
    default. But for the sake of this export we are only interested in eval mode.
    The question is, should we call model.eval() in dynamo_export?
    Tests with model containing batch norm fails 'functionalization' in training mode.
    We are explicitly calling model.eval() for these model for now.
  • Merge decomp and functionalize pass. Both calls into make_fx.
    Merging potentially increases performance. However it is unclear
    if it will result in different behavior.

Fixes #99662. (For the functionalization issue. Still need missing op support.)

@pytorch-bot pytorch-bot bot added release notes: onnx torch.onnx related changes that should show up in the release notes labels Apr 20, 2023
@pytorch-bot
Copy link

pytorch-bot bot commented Apr 20, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/99667

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 3fc55a6:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

BowenBao added a commit that referenced this pull request Apr 20, 2023
ghstack-source-id: b15d961
Pull Request resolved: #99667
Summary
- Previously this was required by `tracing_mode=symbolic` for `dynamic` tracing.
  That argument will be dropped by #99555.
- Later decomposition pass will do graph lowering, so this step is duplicated.
- Functionalization currently cannot work properly on aten level graph.
  So it must happen before lowering & decompositions.

[ghstack-poisoned]
Summary
- Previously this was required by `tracing_mode=symbolic` for `dynamic` tracing.
  That argument will be dropped by #99555.
- Later decomposition pass will do graph lowering, so this step is duplicated.
- Functionalization currently cannot work properly on aten level graph.
  So it must happen before lowering & decompositions.
- Introduce `ReplaceInplacePostFunctionalization` pass to replace inplace variant ops with outplace version.
  These ops are created by aten graph lowering and decomposition post functionalization. They
  won't be doing any real mutation as it is expected to have been handled by functionalization.

Workaround to unblock #99662.

[ghstack-poisoned]
BowenBao added a commit that referenced this pull request Apr 21, 2023
ghstack-source-id: aa29f45
Pull Request resolved: #99667
@BowenBao BowenBao marked this pull request as ready for review April 21, 2023 01:52
@BowenBao BowenBao requested a review from abock as a code owner April 21, 2023 01:52
Summary
- Previously this was required by `tracing_mode=symbolic` for `dynamic` tracing.
  That argument will be dropped by #99555.
- Later decomposition pass will do graph lowering, so this step is duplicated.
- Functionalization currently cannot work properly on aten level graph.
  So it must happen before lowering & decompositions.
- Introduce `ReplaceInplacePostFunctionalization` pass to replace inplace variant ops with outplace version.
  These ops are created by aten graph lowering and decomposition post functionalization. They
  won't be doing any real mutation as it is expected to have been handled by functionalization.

Workaround to unblock #99662.

[ghstack-poisoned]
BowenBao added a commit that referenced this pull request Apr 21, 2023
ghstack-source-id: e25861a
Pull Request resolved: #99667
@BowenBao BowenBao added module: onnx Related to torch.onnx topic: new features topic category ciflow/trunk Trigger trunk jobs on your pull request labels Apr 21, 2023
…] Drop 'aten_graph' arg for 'DynamoExporter'"


Summary
- Previously this was required by `tracing_mode=symbolic` for `dynamic` tracing.
  That argument will be dropped by #99555.
- Later decomposition pass will do graph lowering, so this step is duplicated.
- Functionalization currently cannot work properly on aten level graph.
  So it must happen before lowering & decompositions.
- Introduce `ReplaceInplacePostFunctionalization` pass to replace inplace variant ops with outplace version.
  These ops are created by aten graph lowering and decomposition post functionalization. They
  won't be doing any real mutation as it is expected to have been handled by functionalization.

Workaround to unblock #99662.

[ghstack-poisoned]
…h' arg for 'DynamoExporter'"


Summary
- Previously this was required by `tracing_mode=symbolic` for `dynamic` tracing.
  That argument will be dropped by #99555.
- Later decomposition pass will do graph lowering, so this step is duplicated.
- Functionalization currently cannot work properly on aten level graph.
  So it must happen before lowering & decompositions.
- Introduce `ReplaceInplacePostFunctionalization` pass to replace inplace variant ops with outplace version.
  These ops are created by aten graph lowering and decomposition post functionalization. They
  won't be doing any real mutation as it is expected to have been handled by functionalization.

Workaround to unblock #99662.

[ghstack-poisoned]
@BowenBao BowenBao changed the title [ONNX] Drop 'aten_graph' arg for 'DynamoExporter' [ONNX] Drop 'aten_graph' arg for 'DynamoExporter'; Apply workaround in 'Functionalize' pass Apr 29, 2023
…orkaround in 'Functionalize' pass"


Summary
- Previously this was required by and entangled with `tracing_mode=symbolic` for `dynamic` tracing.
  That is resolved by #99555 and its follow ups.
- Later decomposition pass will do graph lowering, so this step is duplicated.
- Updated `Functionalization` to workaround #99774 (comment)

Todo
- Training vs eval in dynamo_export
  So we are effectively exporting all models in traning mode by
  default. But for the sake of this export we are only interested in eval mode.
  The question is, should we call `model.eval()` in `dynamo_export`?
  Tests with model containing batch norm fails 'functionalization' in training mode.
  We are explicitly calling `model.eval()` for these model for now.
- Merge decomp and functionalize pass. Both calls into `make_fx`.
  Merging potentially increases performance. However it is unclear
  if it will result in different behavior.

Workaround to unblock #99662.

[ghstack-poisoned]
BowenBao added a commit that referenced this pull request Apr 29, 2023
ghstack-source-id: 78e3950
Pull Request resolved: #99667
…orkaround in 'Functionalize' pass"


Summary
- Previously this was required by and entangled with `tracing_mode=symbolic` for `dynamic` tracing.
  That is resolved by #99555 and its follow ups.
- Later decomposition pass will do graph lowering, so this step is duplicated.
- Updated `Functionalization` to workaround #99774 (comment)

Todo
- Training vs eval in dynamo_export
  So we are effectively exporting all models in traning mode by
  default. But for the sake of this export we are only interested in eval mode.
  The question is, should we call `model.eval()` in `dynamo_export`?
  Tests with model containing batch norm fails 'functionalization' in training mode.
  We are explicitly calling `model.eval()` for these model for now.
- Merge decomp and functionalize pass. Both calls into `make_fx`.
  Merging potentially increases performance. However it is unclear
  if it will result in different behavior.

Workaround to unblock #99662.

[ghstack-poisoned]
BowenBao added a commit that referenced this pull request May 1, 2023
ghstack-source-id: 594171b
Pull Request resolved: #99667
…orkaround in 'Functionalize' pass"


Summary
- Previously this was required by and entangled with `tracing_mode=symbolic` for `dynamic` tracing.
  That is resolved by #99555 and its follow ups.
- Later decomposition pass will do graph lowering, so this step is duplicated.
- Updated `Functionalization` to workaround #99774 (comment)

Todo
- Training vs eval in dynamo_export
  So we are effectively exporting all models in traning mode by
  default. But for the sake of this export we are only interested in eval mode.
  The question is, should we call `model.eval()` in `dynamo_export`?
  Tests with model containing batch norm fails 'functionalization' in training mode.
  We are explicitly calling `model.eval()` for these model for now.
- Merge decomp and functionalize pass. Both calls into `make_fx`.
  Merging potentially increases performance. However it is unclear
  if it will result in different behavior.

Workaround to unblock #99662.

[ghstack-poisoned]
BowenBao added a commit that referenced this pull request May 3, 2023
ghstack-source-id: 278c37b
Pull Request resolved: #99667
@titaiwangms
Copy link
Collaborator

Seems getting complicated. You might consider move Functionalization ahead of dropping aten_graph in the title.

@BowenBao BowenBao changed the title [ONNX] Drop 'aten_graph' arg for 'DynamoExporter'; Apply workaround in 'Functionalize' pass [ONNX] Update 'Functionalize' pass to support pre-decomp graph; Drop 'aten_graph' arg for 'DynamoExporter' May 3, 2023
@BowenBao
Copy link
Collaborator Author

BowenBao commented May 3, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/BowenBao/234/head branch June 8, 2023 14:28
for inpt, input_functional in zip(flat_inputs, flat_inputs_functional):
if isinstance(input_functional, torch.Tensor):
torch._sync(input_functional)
inpt_new = torch._from_functional_tensor(input_functional)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this inpt_new assigned?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: onnx Related to torch.onnx open source release notes: onnx torch.onnx related changes that should show up in the release notes topic: new features topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载