这是indexloc提供的服务,不要输入任何密码
Skip to content

JAX v0.7.0

Latest
Compare
Choose a tag to compare
@MichaelHudgins MichaelHudgins released this 22 Jul 20:33
  • New features:

    • Added jax.P which is an alias for jax.sharding.PartitionSpec.
    • Added jax.tree.reduce_associative.
  • Breaking changes:

    • JAX is migrating from GSPMD to Shardy by default. See the
      migration guide
      for more information.
    • JAX autodiff is switching to using direct linearization by default (instead of
      implementing linearization via JVP and partial eval).
      See migration guide
      for more information.
    • jax.stages.OutInfo has been replaced with jax.ShapeDtypeStruct.
    • jax.jit now requires fun to be passed by position, and additional
      arguments to be passed by keyword. Doing otherwise will result in an error
      starting in v0.7.x. This raised a DeprecationWarning in v0.6.x.
    • The minimum Python version is now 3.11. 3.11 will remain the minimum
      supported version until July 2026.
    • Layout API renames:
      • Layout, .layout, .input_layouts and .output_layouts have been
        renamed to Format, .format, .input_formats and .output_formats
      • DeviceLocalLayout, .device_local_layout have been renamed to Layout
        and .layout
    • jax.experimental.shard module has been deleted and all the APIs have been
      moved to the jax.sharding endpoint. So use jax.sharding.reshard,
      jax.sharding.auto_axes and jax.sharding.explicit_axes instead of their
      experimental endpoints.
    • lax.infeed and lax.outfeed were removed, after being deprecated in
      JAX 0.6. The transfer_to_infeed and transfer_from_outfeed methods were
      also removed the Device objects.
    • The jax.extend.core.primitives.pjit_p primitive has been renamed to
      jit_p, and its name attribute has changed from "pjit" to "jit".
      This affects the string representations of jaxprs. The same primitive is no
      longer exported from the jax.experimental.pjit module.
    • The (undocumented) function jax.extend.backend.add_clear_backends_callback
      has been removed. Users should use jax.extend.backend.register_backend_cache
      instead.
  • Deprecations:

    • {obj}jax.dlpack.SUPPORTED_DTYPES is deprecated; please use the new
      jax.dlpack.is_supported_dtype function.
    • jax.scipy.special.sph_harm has been deprecated following a similar
      deprecation in SciPy; use jax.scipy.special.sph_harm_y instead.
    • From {mod}jax.interpreters.xla, the previously deprecated symbols
      abstractify and pytype_aval_mappings have been removed.
    • jax.interpreters.xla.canonicalize_dtype is deprecated. For
      canonicalizing dtypes, prefer jax.dtypes.canonicalize_dtype.
      For checking whether an object is a valid jax input, prefer
      jax.core.valid_jaxtype.
    • From {mod}jax.core, the previously deprecated symbols AxisName,
      ConcretizationTypeError, axis_frame, call_p, closed_call_p,
      get_type, trace_state_clean, typematch, and typecheck have been
      removed.
    • From {mod}jax.lib.xla_client, the previously deprecated symbols
      DeviceAssignment, get_topology_for_devices, and mlir_api_version
      have been removed.
    • jax.extend.ffi was removed after being deprecated in v0.5.0.
      Use {mod}jax.ffi instead.
    • jax.lib.xla_bridge.get_compile_options is deprecated, and replaced by
      jax.extend.backend.get_compile_options.