-
Notifications
You must be signed in to change notification settings - Fork 39
Description
We started to have pspec error when run Maxdiffusion with new JAX version like 0.6:
File "/opt/maxdiffusion/src/maxdiffusion/train_flux.py", line 36, in train
trainer.start_training()
File "/opt/maxdiffusion/src/maxdiffusion/trainers/flux_trainer.py", line 138, in start_training
p_train_step = self.compile_train_step(pipeline, params, train_states, state_shardings, data_shardings)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/maxdiffusion/src/maxdiffusion/trainers/flux_trainer.py", line 309, in compile_train_step
p_train_step = p_train_step.lower(train_states[FLUX_STATE_KEY], dummy_batch, train_rngs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: pspec PartitionSpec(('data', 'fsdp', 'tensor'), None) contains a manual axes ('data', 'fsdp', 'tensor') of mesh which is not allowed. If you are using a with_sharding_constraint under a shard_map, only use the mesh axis in PartitionSpec which are not manual.
It's likely due to API change in new JAX and Maxdiffusion may need to change accordingly.