这是indexloc提供的服务,不要输入任何密码
Skip to content

Which dataclass members become HLO parameters? #30251

Answered by jakevdp
staticimport asked this question in Q&A
Discussion options

You must be logged in to vote

I don't know how to get this information in the compiler_ir representation, but if you print .compile().as_text() then you get an HLO representation that includes the Python provenance of each variable:

lowered = jax.jit(fn).lower(foo).compile()
print(lowered.as_text())
HloModule jit_fn, is_scheduled=true, entry_computation_layout={(s32[10]{0})->s32[]}, allow_spmd_sharding_propagation_to_parameters={true}, allow_spmd_sharding_propagation_to_output={true}

%region_0.3 (Arg_0.4: s32[], Arg_1.5: s32[]) -> s32[] {
  %Arg_0.4 = s32[] parameter(0), metadata={op_name="jit(fn)/jit(main)/reduce_sum"}
  %Arg_1.5 = s32[] parameter(1), metadata={op_name="jit(fn)/jit(main)/reduce_sum"}
  ROOT %add.6 =…

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@staticimport
Comment options

@jakevdp
Comment options

@staticimport
Comment options

Answer selected by staticimport
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants