+
Skip to content

Can't call torch.compile inside of a custom op #151328

@zou3519

Description

@zou3519
import torch

lib = torch.library.Library("mylib", "FRAGMENT")
lib.define("foo(Tensor x) -> Tensor")


def inner(x):
    return x.sin().cos()

def foo_impl(x):
    return torch.compile(inner, fullgraph=True)(x)

lib.impl("foo", foo_impl, "CompositeExplicitAutograd")

@torch.compile(fullgraph=True)
def f(x):
    return torch.ops.mylib.foo.default(x)

x = torch.randn(3)
f(x)
"""
File ~/dev/misc_cpu11/pt-misc_cpu11/torch/_subclasses/meta_utils.py:894, in MetaConverter.meta_tensor(self, t, shape_env, callback_, source, symbolic_context)
    886     source = ConstantSource(
    887         f"__meta_utils_unknown_tensor{len(self.tensor_memo)}"
    888     )
    890 # This indicates you set no_dispatch() before calling into this
    891 # function.  This is an error: we may be creating fake tensors and
    892 # will perform operations on them which need fake tensor mode to
    893 # be active.  You will segfault if you are in a no_dispatch() block.
--> 894 assert not torch._C._dispatch_tls_local_exclude_set().has(
    895     torch._C.DispatchKey.Python
    896 )
    897 self.arg_cnt += 1
    899 # When we make as_strided calls, we end up generating a guard
    900 # that the new as_strided tensor is in bounds for the old storage
    901 # for the base (since as_strided calls can "bust" out of their
   (...)
    921 # as we allocate variables, and we do need to register guards for
    922 # these cases.

TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_function mylib.foo.default(*(FakeTensor(..., size=(3,)),), **{}): got AssertionError('\n\nfrom user c
ode:\n   File "<ipython-input-2-9e7ce20b02c0>", line 8, in inner\n    return x.sin().cos()\n\nSet TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especial
ly if you\'re reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"\n')

from user code:
   File "<ipython-input-2-9e7ce20b02c0>", line 17, in f
    return torch.ops.mylib.foo.default(x)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dy
namo"
"""

motivation is that we want the custom op to be backed by a torch.compile implemetation?

cc @ezyang @gchanan @kadeng @msaroufim @chauhang @penguinwu @eellison @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @amjames @bdhirsh

Metadata

Metadata

Labels

dynamo-triage-jan2025featureA request for a proper, new feature.high prioritymodule: dynamomodule: fakeTensormodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载