Which dataclass members become HLO parameters? #30251
-
Below is a toy example that has a dataclass containing 2 members, but my function only uses one of them. As part of compiling down to HLO, the unused member is eliminated. How can I programmatically tell which parameters the HLO module wants? Motivation: I'm working on using XLA AOT for C++ inference, so in C++ I need to know exactly what inputs to pass in.
Output below is completely identical whether fn( ) above sums foo.bar or foo baz:
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
I don't know how to get this information in the lowered = jax.jit(fn).lower(foo).compile()
print(lowered.as_text())
(Notice the |
Beta Was this translation helpful? Give feedback.
I don't know how to get this information in the
compiler_ir
representation, but if you print.compile().as_text()
then you get an HLO representation that includes the Python provenance of each variable: