-
-
Notifications
You must be signed in to change notification settings - Fork 274
Description
Description
JAX includes a numpy compatible jax.numpy
module which has a bunch of nice features (automatic differentiation, jit compilation, vectorized mapping, GPU runtime, js export). They've taken great pains to make sure it's usually as simple as swapping import numpy as np
for import jax.numpy as np
. LIkewise (but less extensively) for the jax.scipy
module.
I'd like to do some optimization for which it would be really convenient to automatically differentiate some of the great stuff you've implemented and export it to js. It should be as simple as changing import numpy as np
around the library:
if HOWEVER_WE_SET_THE_CONFIG:
import jax.numpy as np
else:
import numpy as np
Changing the type signatures probably has more degrees of freedom we can choose, but is basically the same.
I'd be happy to implement it, but don't want to make a PR that you don't want.
I expect that the added maintenance burden would be pretty minimal.