Releases: jax-ml/jax
JAX v0.7.0
-
New features:
- Added
jax.P
which is an alias forjax.sharding.PartitionSpec
. - Added
jax.tree.reduce_associative
.
- Added
-
Breaking changes:
- JAX is migrating from GSPMD to Shardy by default. See the
migration guide
for more information. - JAX autodiff is switching to using direct linearization by default (instead of
implementing linearization via JVP and partial eval).
See migration guide
for more information. jax.stages.OutInfo
has been replaced withjax.ShapeDtypeStruct
.jax.jit
now requiresfun
to be passed by position, and additional
arguments to be passed by keyword. Doing otherwise will result in an error
starting in v0.7.x. This raised a DeprecationWarning in v0.6.x.- The minimum Python version is now 3.11. 3.11 will remain the minimum
supported version until July 2026. - Layout API renames:
Layout
,.layout
,.input_layouts
and.output_layouts
have been
renamed toFormat
,.format
,.input_formats
and.output_formats
DeviceLocalLayout
,.device_local_layout
have been renamed toLayout
and.layout
jax.experimental.shard
module has been deleted and all the APIs have been
moved to thejax.sharding
endpoint. So usejax.sharding.reshard
,
jax.sharding.auto_axes
andjax.sharding.explicit_axes
instead of their
experimental endpoints.lax.infeed
andlax.outfeed
were removed, after being deprecated in
JAX 0.6. Thetransfer_to_infeed
andtransfer_from_outfeed
methods were
also removed theDevice
objects.- The
jax.extend.core.primitives.pjit_p
primitive has been renamed to
jit_p
, and itsname
attribute has changed from"pjit"
to"jit"
.
This affects the string representations of jaxprs. The same primitive is no
longer exported from thejax.experimental.pjit
module. - The (undocumented) function
jax.extend.backend.add_clear_backends_callback
has been removed. Users should usejax.extend.backend.register_backend_cache
instead.
- JAX is migrating from GSPMD to Shardy by default. See the
-
Deprecations:
- {obj}
jax.dlpack.SUPPORTED_DTYPES
is deprecated; please use the new
jax.dlpack.is_supported_dtype
function. jax.scipy.special.sph_harm
has been deprecated following a similar
deprecation in SciPy; usejax.scipy.special.sph_harm_y
instead.- From {mod}
jax.interpreters.xla
, the previously deprecated symbols
abstractify
andpytype_aval_mappings
have been removed. jax.interpreters.xla.canonicalize_dtype
is deprecated. For
canonicalizing dtypes, preferjax.dtypes.canonicalize_dtype
.
For checking whether an object is a valid jax input, prefer
jax.core.valid_jaxtype
.- From {mod}
jax.core
, the previously deprecated symbolsAxisName
,
ConcretizationTypeError
,axis_frame
,call_p
,closed_call_p
,
get_type
,trace_state_clean
,typematch
, andtypecheck
have been
removed. - From {mod}
jax.lib.xla_client
, the previously deprecated symbols
DeviceAssignment
,get_topology_for_devices
, andmlir_api_version
have been removed. jax.extend.ffi
was removed after being deprecated in v0.5.0.
Use {mod}jax.ffi
instead.jax.lib.xla_bridge.get_compile_options
is deprecated, and replaced by
jax.extend.backend.get_compile_options
.
- {obj}
JAX v0.6.2
-
New features:
- Added
jax.tree.broadcast
which implements a pytree prefix broadcasting helper.
- Added
-
Changes
- The minimum NumPy version is 1.26 and the minimum SciPy version is 1.12.
JAX v0.6.1
-
New features:
- Added
jax.lax.axis_size
which returns the size of the mapped axis
given its name.
- Added
-
Changes
- Additional checking for the versions of CUDA package dependencies was
reenabled, having been accidentally disabled in a previous release. - JAX nightly packages are now published to artifact registry. To install
these packages, see the JAX installation guide. jax.sharding.PartitionSpec
no longer inherits from a tuple.jax.ShapeDtypeStruct
is immutable now. Please use.update
method to
update yourShapeDtypeStruct
instead of doing in-place updates.
- Additional checking for the versions of CUDA package dependencies was
-
Deprecations
jax.custom_derivatives.custom_jvp_call_jaxpr_p
is deprecated, and will be
removed in JAX v0.7.0.
JAX v0.6.0
-
Breaking changes
jax.numpy.array
no longer acceptsNone
. This behavior was
deprecated since November 2023 and is now removed.- Removed the
config.jax_data_dependent_tracing_fallback
config option,
which was added temporarily in v0.4.36 to allow users to opt out of the
new "stackless" tracing machinery. - Removed the
config.jax_eager_pmap
config option. - Disallow the calling of
lower
andtrace
AOT APIs on the result
ofjax.jit
if there have been subsequent wrappers applied.
Previously this worked, but silently ignored the wrappers.
The workaround is to applyjax.jit
last among the wrappers,
and similarly forjax.pmap
.
See#27873
. - The
cuda12_pip
extra forjax
has been removed; usepip install jax[cuda12]
instead.
-
Changes
- The minimum CuDNN version is v9.8.
- JAX is now built using CUDA 12.8. All versions of CUDA 12.1 or newer remain
supported. - JAX package extras are now updated to use dash instead of underscore to
align with PEP 685. For instance, if you were previously usingpip install jax[cuda12_local]
to install JAX, runpip install jax[cuda12-local]
instead. jax.jit
now requiresfun
to be passed by position, and additional
arguments to be passed by keyword. Doing otherwise will result in a
DeprecationWarning in v0.6.X, and an error in starting in v0.7.X.
-
Deprecations
jax.tree_util.build_tree
is deprecated. Usejax.tree.unflatten
instead.- Implemented host callback handlers for CPU and GPU devices using XLA's FFI
and removed existing CPU/GPU handlers using XLA's custom call. - All APIs in
jax.lib.xla_extension
are now deprecated. jax.interpreters.mlir.hlo
andjax.interpreters.mlir.func_dialect
,
which were accidental exports, have been removed. If needed, they are
available fromjax.extend.mlir
.jax.interpreters.mlir.custom_call
is deprecated. The APIs provided by
jax.ffi
should be used instead.- The deprecated use of
jax.ffi.ffi_call
with inline arguments is no
longer supported.jax.ffi.ffi_call
now unconditionally returns a
callable. - The following exports in
jax.lib.xla_client
are deprecated:
get_topology_for_devices
,heap_profile
,mlir_api_version
,Client
,
CompileOptions
,DeviceAssignment
,Frame
,HloSharding
,OpSharding
,
Traceback
. - The following internal APIs in
jax.util
are deprecated:
HashableFunction
,as_hashable_function
,cache
,safe_map
,safe_zip
,
split_dict
,split_list
,split_list_checked
,split_merge
,subvals
,
toposort
,unzip2
,wrap_name
, andwraps
. jax.dlpack.to_dlpack
has been deprecated. You can usually pass a JAX
Array
directly to thefrom_dlpack
function of another framework. If you
need the functionality ofto_dlpack
, use the__dlpack__
attribute of an
array.jax.lax.infeed
,jax.lax.infeed_p
,jax.lax.outfeed
, and
jax.lax.outfeed_p
are deprecated and will be removed in JAX v0.7.0.- Several previously-deprecated APIs have been removed, including:
- From
jax.lib.xla_client
:ArrayImpl
,FftType
,PaddingType
,
PrimitiveType
,XlaBuilder
,dtype_to_etype
,
ops
,register_custom_call_target
,shape_from_pyval
,Shape
,
XlaComputation
. - From
jax.lib.xla_extension
:ArrayImpl
,XlaRuntimeError
. - From
jax
:jax.treedef_is_leaf
,jax.tree_flatten
,jax.tree_map
,
jax.tree_leaves
,jax.tree_structure
,jax.tree_transpose
, and
jax.tree_unflatten
. Replacements can be found injax.tree
or
jax.tree_util
. - From
jax.core
:AxisSize
,ClosedJaxpr
,EvalTrace
,InDBIdx
,InputType
,
Jaxpr
,JaxprEqn
,Literal
,MapPrimitive
,OpaqueTraceState
,OutDBIdx
,
Primitive
,Token
,TRACER_LEAK_DEBUGGER_WARNING
,Var
,concrete_aval
,
dedup_referents
,escaped_tracer_error
,extend_axis_env_nd
,full_lower
,get_referent
,jaxpr_as_fun
,join_effects
,lattice_join
,
leaked_tracer_error
,maybe_find_leaked_tracers
,raise_to_shaped
,
raise_to_shaped_mappings
,reset_trace_state
,str_eqn_compact
,
substitute_vars_in_output_ty
,typecompat
, andused_axis_names_jaxpr
. Most
have no public replacement, though a few are available atjax.extend.core
. - The
vectorized
argument tojax.pure_callback
and
jax.ffi.ffi_call
. Use thevmap_method
parameter instead.
- From
JAX v0.5.3
-
New Features
- Added a
allow_negative_indices
option tojax.lax.dynamic_slice
,
jax.lax.dynamic_update_slice
and related functions. The default is
true, matching the current behavior. If set to false, JAX does not need to
emit code clamping negative indices, which improves code size. - Added a
replace
option tojax.random.categorical
to enable sampling
without replacement.
- Added a
JAX v0.5.2
Patch release of 0.5.1
- Bug fixes
- Fixes TPU metric logging and
tpu-info
, which was broken in 0.5.1
- Fixes TPU metric logging and
JAX v0.5.1
-
New Features
- Added an experimental
jax.experimental.custom_dce.custom_dce
decorator to support customizing the behavior of opaque functions under
JAX-level dead code elimination (DCE). See#25956
for more
details. - Added low-level reduction APIs in {mod}
jax.lax
:jax.lax.reduce_sum
,
jax.lax.reduce_prod
,jax.lax.reduce_max
,jax.lax.reduce_min
,
jax.lax.reduce_and
,jax.lax.reduce_or
, andjax.lax.reduce_xor
. jax.lax.linalg.qr
, andjax.scipy.linalg.qr
, now support
column-pivoting on CPU and GPU. See #20282 and
#25955 for more details.
- Added an experimental
-
Changes
JAX_CPU_COLLECTIVES_IMPLEMENTATION
andJAX_NUM_CPU_DEVICES
now work as
env vars. Before they could only be specified via jax.config or flags.JAX_CPU_COLLECTIVES_IMPLEMENTATION
now defaults to'gloo'
, meaning
multi-process CPU communication works out-of-the-box.- The
jax[tpu]
TPU extra no longer depends on thelibtpu-nightly
package.
This package may safely be removed if it is present on your machine; JAX now
useslibtpu
instead.
-
Deprecations
- The internal function
linear_util.wrap_init
and the constructor
core.Jaxpr
now must take a non-emptycore.DebugInfo
kwarg. For
a limited time, aDeprecationWarning
is printed if
jax.extend.linear_util.wrap_init
is used without debugging info.
A downstream effect of this several other internal functions need debug
info. This change does not affect public APIs.
See #26480 for more detail.
- The internal function
-
Bug fixes
- TPU runtime startup and shutdown time should be significantly improved on
TPU v5e and newer (from around 17s to around 8s). If not already set, you may
need to enable transparent hugepages in your VM image
(sudo sh -c 'echo always > /sys/kernel/mm/transparent_hugepage/enabled'
).
We hope to improve this further in future releases. - Persistent compilation cache no longer writes access time file if
JAX_COMPILATION_CACHE_MAX_SIZE
is unset or set to -1, i.e. if the LRU
eviction policy isn't enabled. This should improve performance when using
the cache with large-scale network storage.
- TPU runtime startup and shutdown time should be significantly improved on
JAX v0.5.0
As of this release, JAX now uses effort-based versioning.
Since this release makes a breaking change to PRNG key semantics that
may require users to update their code, we are bumping the "meso" version of JAX
to signify this.
-
Breaking changes
-
Enable
jax_threefry_partitionable
by default (see
the update note). -
This release drops support for Mac x86 wheels. Mac ARM of course remains
supported. For a recent discussion, see #22936.Two key factors motivated this decision:
- The Mac x86 build (only) has a number of test failures and crashes. We
would prefer to ship no release than a broken release. - Mac x86 hardware is end-of-life and cannot be easily obtained for
developers at this point. So it is difficult for us to fix this kind of
problem even if we wanted to.
We are open to readding support for Mac x86 if the community is willing
to help support that platform: in particular, we would need the JAX test
suite to pass cleanly on Mac x86 before we could ship releases again. - The Mac x86 build (only) has a number of test failures and crashes. We
-
-
Changes:
- The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
supported version until June 2025. - The minimum SciPy version is now 1.11. SciPy 1.11 will remain the minimum
supported version until June 2025. jax.numpy.einsum
now defaults tooptimize='auto'
rather than
optimize='optimal'
. This avoids exponentially-scaling trace-time in
the case of many arguments (#25214
).jax.numpy.linalg.solve
no longer supports batched 1D arguments
on the right hand side. To recover the previous behavior in these cases,
usesolve(a, b[..., None]).squeeze(-1)
.
- The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
-
New Features
jax.numpy.fft.fftn
,jax.numpy.fft.rfftn
,
jax.numpy.fft.ifftn
, andjax.numpy.fft.irfftn
now support
transforms in more than 3 dimensions, which was previously the limit. See
#25606
for more details.- Support added for user defined state in the FFI via the new
jax.ffi.register_ffi_type_id
function. - The AOT lowering
.as_text()
method now supports thedebug_info
option
to include debugging information, e.g., source location, in the output.
-
Deprecations
- From
jax.interpreters.xla
,abstractify
andpytype_aval_mappings
are now deprecated, having been replaced by symbols of the same name
injax.core
. jax.scipy.special.lpmn
andjax.scipy.special.lpmn_values
are deprecated, following their deprecation in SciPy v1.15.0. There are
no plans to replace these deprecated functions with new APIs.- The
jax.extend.ffi
submodule was moved tojax.ffi
, and the
previous import path is deprecated.
- From
-
Deletions
jax_enable_memories
flag has been deleted and the behavior of that flag
is on by default.- From
jax.lib.xla_client
, the previously-deprecatedDevice
and
XlaRuntimeError
symbols have been removed; instead usejax.Device
andjax.errors.JaxRuntimeError
respectively. - The
jax.experimental.array_api
module has been removed after being
deprecated in JAX v0.4.32. Since that release,jax.numpy
supports
the array API directly.
JAX v0.4.38
-
Changes:
jax.tree.flatten_with_path
andjax.tree.map_with_path
are added
as shortcuts of the correspondingtree_util
functions.
-
Deprecations
- a number of APIs in the internal
jax.core
namespace have been deprecated.
Most were no-ops, were little-used, or can be replaced by APIs of the same
name injax.extend.core
; see the documentation for {mod}jax.extend
for information on the compatibility guarantees of these semi-public extensions. - Several previously-deprecated APIs have been removed, including:
- from
jax.core
:check_eqn
,check_type
,check_valid_jaxtype
, and
non_negative_dim
. - from
jax.lib.xla_bridge
:xla_client
anddefault_backend
. - from
jax.lib.xla_client
:_xla
andbfloat16
. - from
jax.numpy
:round_
.
- from
- a number of APIs in the internal
-
New Features
jax.export.export
can be used for device-polymorphic export with
shardings constructed with {func}jax.sharding.AbstractMesh
.
See the jax.export documentation.- Added
jax.lax.split
. This is a primitive version of
jax.numpy.split
, added because it yields a more compact
transpose during automatic differentiation.
JAX v0.4.37
This is a patch release of jax 0.4.36. Only "jax" was released at this version.
- Bug fixes
- Fixed a bug where
jit
would error if an argument was namedf
(#25329). - Fix a bug that will throw
index out of range
error in
jax.lax.while_loop
if the user registers pytree node class with
different aux data for the flatten and flatten_with_path. - Pinned a new libtpu release (0.0.6) that fixes a compiler bug on TPU v6e.
- Fixed a bug where