这是indexloc提供的服务,不要输入任何密码
Skip to content

pspec error when run Maxdiffusion with new JAX version #204

@hx89

Description

@hx89

We started to have pspec error when run Maxdiffusion with new JAX version like 0.6:

File "/opt/maxdiffusion/src/maxdiffusion/train_flux.py", line 36, in train
trainer.start_training()
File "/opt/maxdiffusion/src/maxdiffusion/trainers/flux_trainer.py", line 138, in start_training
p_train_step = self.compile_train_step(pipeline, params, train_states, state_shardings, data_shardings)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/maxdiffusion/src/maxdiffusion/trainers/flux_trainer.py", line 309, in compile_train_step
p_train_step = p_train_step.lower(train_states[FLUX_STATE_KEY], dummy_batch, train_rngs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: pspec PartitionSpec(('data', 'fsdp', 'tensor'), None) contains a manual axes ('data', 'fsdp', 'tensor') of mesh which is not allowed. If you are using a with_sharding_constraint under a shard_map, only use the mesh axis in PartitionSpec which are not manual.

It's likely due to API change in new JAX and Maxdiffusion may need to change accordingly.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions