-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Description
The current implementation of __len__
and __getitem__
is kind of weird because it is infinitely recursive: one can do Batch()[0]
and it does not throw an exception but rather return the same object Batch()
. This is an issue because it breaks the compatibility with standard numpy methods such as np.mean
, np.max
... which should work otherwise. It would be easy to change this behavior, but many policy algorithm relies on it. Calling such a numpy method simply hangs infinitely, for example using np.asarray(Batch())
, and so does [o for o in Batch(a=[1.0])]
which is even worst.
I opened a PR implementing the desired behavior, but I need help to update accordingly the policy algorithms.
Here is what you get using the proposed implementation.
Input:
np.mean(Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]), b=np.array([5.0, 6.0])))
Output:
Batch(
a: array([2., 3.]),
b: 5.5,
)