-
New features:
- Added
jax.P
which is an alias forjax.sharding.PartitionSpec
. - Added
jax.tree.reduce_associative
.
- Added
-
Breaking changes:
- JAX is migrating from GSPMD to Shardy by default. See the
migration guide
for more information. - JAX autodiff is switching to using direct linearization by default (instead of
implementing linearization via JVP and partial eval).
See migration guide
for more information. jax.stages.OutInfo
has been replaced withjax.ShapeDtypeStruct
.jax.jit
now requiresfun
to be passed by position, and additional
arguments to be passed by keyword. Doing otherwise will result in an error
starting in v0.7.x. This raised a DeprecationWarning in v0.6.x.- The minimum Python version is now 3.11. 3.11 will remain the minimum
supported version until July 2026. - Layout API renames:
Layout
,.layout
,.input_layouts
and.output_layouts
have been
renamed toFormat
,.format
,.input_formats
and.output_formats
DeviceLocalLayout
,.device_local_layout
have been renamed toLayout
and.layout
jax.experimental.shard
module has been deleted and all the APIs have been
moved to thejax.sharding
endpoint. So usejax.sharding.reshard
,
jax.sharding.auto_axes
andjax.sharding.explicit_axes
instead of their
experimental endpoints.lax.infeed
andlax.outfeed
were removed, after being deprecated in
JAX 0.6. Thetransfer_to_infeed
andtransfer_from_outfeed
methods were
also removed theDevice
objects.- The
jax.extend.core.primitives.pjit_p
primitive has been renamed to
jit_p
, and itsname
attribute has changed from"pjit"
to"jit"
.
This affects the string representations of jaxprs. The same primitive is no
longer exported from thejax.experimental.pjit
module. - The (undocumented) function
jax.extend.backend.add_clear_backends_callback
has been removed. Users should usejax.extend.backend.register_backend_cache
instead.
- JAX is migrating from GSPMD to Shardy by default. See the
-
Deprecations:
- {obj}
jax.dlpack.SUPPORTED_DTYPES
is deprecated; please use the new
jax.dlpack.is_supported_dtype
function. jax.scipy.special.sph_harm
has been deprecated following a similar
deprecation in SciPy; usejax.scipy.special.sph_harm_y
instead.- From {mod}
jax.interpreters.xla
, the previously deprecated symbols
abstractify
andpytype_aval_mappings
have been removed. jax.interpreters.xla.canonicalize_dtype
is deprecated. For
canonicalizing dtypes, preferjax.dtypes.canonicalize_dtype
.
For checking whether an object is a valid jax input, prefer
jax.core.valid_jaxtype
.- From {mod}
jax.core
, the previously deprecated symbolsAxisName
,
ConcretizationTypeError
,axis_frame
,call_p
,closed_call_p
,
get_type
,trace_state_clean
,typematch
, andtypecheck
have been
removed. - From {mod}
jax.lib.xla_client
, the previously deprecated symbols
DeviceAssignment
,get_topology_for_devices
, andmlir_api_version
have been removed. jax.extend.ffi
was removed after being deprecated in v0.5.0.
Use {mod}jax.ffi
instead.jax.lib.xla_bridge.get_compile_options
is deprecated, and replaced by
jax.extend.backend.get_compile_options
.
- {obj}