这是indexloc提供的服务,不要输入任何密码
Skip to content

[FEATURE]: Implement support for Python Array API Standard. #1244

@imh

Description

@imh

Description

JAX includes a numpy compatible jax.numpy module which has a bunch of nice features (automatic differentiation, jit compilation, vectorized mapping, GPU runtime, js export). They've taken great pains to make sure it's usually as simple as swapping import numpy as np for import jax.numpy as np. LIkewise (but less extensively) for the jax.scipy module.

I'd like to do some optimization for which it would be really convenient to automatically differentiate some of the great stuff you've implemented and export it to js. It should be as simple as changing import numpy as np around the library:

if HOWEVER_WE_SET_THE_CONFIG:
    import jax.numpy as np
else:
    import numpy as np

Changing the type signatures probably has more degrees of freedom we can choose, but is basically the same.

I'd be happy to implement it, but don't want to make a PR that you don't want.

I expect that the added maintenance burden would be pretty minimal.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions