这是indexloc提供的服务,不要输入任何密码
Skip to content

Batch to Numpy compatibility #90

@duburcqa

Description

@duburcqa

The current implementation of __len__ and __getitem__ is kind of weird because it is infinitely recursive: one can do Batch()[0] and it does not throw an exception but rather return the same object Batch(). This is an issue because it breaks the compatibility with standard numpy methods such as np.mean, np.max ... which should work otherwise. It would be easy to change this behavior, but many policy algorithm relies on it. Calling such a numpy method simply hangs infinitely, for example using np.asarray(Batch()), and so does [o for o in Batch(a=[1.0])] which is even worst.

I opened a PR implementing the desired behavior, but I need help to update accordingly the policy algorithms.


Here is what you get using the proposed implementation.

Input:

np.mean(Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]), b=np.array([5.0, 6.0])))

Output:

Batch(
    a: array([2., 3.]),
    b: 5.5,
)

Metadata

Metadata

Assignees

Labels

enhancementFeature that is not a new algorithm or an algorithm enhancement

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions