这是indexloc提供的服务,不要输入任何密码
Skip to content

Enable partial stacking at Batch level #100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 27, 2020

Conversation

duburcqa
Copy link
Collaborator

@duburcqa duburcqa commented Jun 26, 2020

Allow stacking of inconsistent batches:

In [1]: from tianshou.data import Batch, to_numpy ; import numpy as np ; import torch                                                                                                                       
In [2]: Batch(np.array([{'is_success': True, 'reward': {'done': 0.0001, 'direction': -0.11480981}}, 
   ...:                 {'is_success': False, 'reward': {'done': 0.0001}}], dtype=object))                                                                                                                  
Out[2]: 
Batch(
    is_success: array([ True, False]),
    reward: Batch(
                done: array([0.0001, 0.0001]),
                direction: array([-0.11480981,  0.        ]),
            ),
)

Before that, the result was unexpected (no failure and different result depending of Batch order).

As you may notice, I also replaced 'nan' in float array by 0.0, since at the end it is usually meaningful and convenient.

It comes with a slight performance improvement: #96

output_V2

@duburcqa duburcqa force-pushed the batch_partial_stack branch 2 times, most recently from 59541f5 to 22e877e Compare June 26, 2020 13:09
@duburcqa duburcqa changed the title Enable stacking of partially matching Batch instances Enable partial stacking at Batch level Jun 26, 2020
@duburcqa duburcqa force-pushed the batch_partial_stack branch 2 times, most recently from def1889 to 70bd7e0 Compare June 26, 2020 13:25
@duburcqa duburcqa force-pushed the batch_partial_stack branch from 70bd7e0 to 13fc202 Compare June 26, 2020 13:26
@duburcqa
Copy link
Collaborator Author

duburcqa commented Jun 26, 2020

@Trinkle23897 Just for you to know, I'm an engineer and for the first time, and thanks to this PR (and all the previous ones) a real case application is finally running without raising an exception 😆 Perhaps it is time to release 0.2.4 😛

@Trinkle23897
Copy link
Collaborator

Sure and much appreciate~
How about the documentation update?

@duburcqa
Copy link
Collaborator Author

duburcqa commented Jun 26, 2020

Yes I'll do it now. Could you please tell me what/where to update them and what you have in mind exactly ?

@Trinkle23897
Copy link
Collaborator

@duburcqa
Copy link
Collaborator Author

Yes I know this, but you mean only updating this page https://tianshou.readthedocs.io/en/latest/api/tianshou.data.html ?

@Trinkle23897
Copy link
Collaborator

Trinkle23897 commented Jun 26, 2020

Exactly. I mean that you can add the new feature in https://tianshou.readthedocs.io/en/latest/api/tianshou.data.html#tianshou.data.Batch

Batch has other methods, including __getitem__(), __len__(), append(), and split():

also with buffer.

@Trinkle23897
Copy link
Collaborator

Trinkle23897 commented Jun 26, 2020

You can make the documentation show __xxx__ by modifying this line:

autodoc_default_options = {'special-members': '__call__, __getitem__, __len__'}

For example, add __idiv__ into 'special-members'.
And the readthedocs page will be updated only if the master version has been updated. So it is better to preview the documentation in your own local computer.

@duburcqa duburcqa force-pushed the batch_partial_stack branch from b8b2d02 to aa7775b Compare June 26, 2020 15:20
@duburcqa duburcqa force-pushed the batch_partial_stack branch 2 times, most recently from 28c20df to 724fb22 Compare June 26, 2020 15:35
@duburcqa duburcqa force-pushed the batch_partial_stack branch from 724fb22 to 8c60d74 Compare June 26, 2020 15:42
@duburcqa
Copy link
Collaborator Author

duburcqa commented Jun 26, 2020

Added missing documentation. It turns out that I didn't modified the behaviour of Buffer by any means, it was only internal modifications, so I have not updated it.

@Trinkle23897 Trinkle23897 merged commit a951a32 into thu-ml:master Jun 27, 2020
@duburcqa duburcqa deleted the batch_partial_stack branch June 27, 2020 09:45
@Trinkle23897
Copy link
Collaborator

Trinkle23897 commented Jun 28, 2020

Could it support slicing method such as batch[:, i]?
Because I find that the advance RNN hidden state (#19) needs this and otherwise I should recursively get the value.

@duburcqa
Copy link
Collaborator Author

duburcqa commented Jun 28, 2020

In practice it shouldn't be very hard, but conceptually Batch is missing the notion of multidimensional samples, for instance only len is available, and it corresponds to the dimension to the first axis of the stored data. The first step would be to add a shape method, similar to numpy. After that It shouldn't be much trouble. Maybe you could open an issue.

The issue is that list cannot be supported since computing the shape of a list is very expensive and grows exponentially with depth.

@duburcqa
Copy link
Collaborator Author

duburcqa commented Jun 28, 2020

Adding this property to Batch class would do the trick (note that isinstance if ... else should be used to early return [] in case of native Python scalar) :

    @property
    def shape(self) -> List[int]:
        """Return self.size."""
        if len(self.__dict__.keys()) == 0:
            return []
        else:
            data_shape = []
            for v in self.__dict__.values():
                try:
                    data_shape.append(v.shape)
                except AttributeError:
                    raise TypeError("No support for 'shape' method with "\
                                    f"type {type(v)} in class Batch.")
            return min(*data_shape) if len(data_shape) > 1 else data_shape[0]

Then, _valid_bounds(len(self), index) must be replaced by np.all(map(_valid_bounds, zip(self.shape, index)), ONLY in the case where index is a actually a list of list of indices. No need to change anything else, since the rest of the method is already compatible with multidimensional indexing (expect for list, but list is not supported by shape, so an exception will be raised before reaching this point).

NB: You can ping me on the issue/PR if you want a review.

@duburcqa
Copy link
Collaborator Author

duburcqa commented Jun 28, 2020

By the way, it may be good to overload direct assignment to make sure that every list are converted to numpy. Indeed, list are a recurrent issue for slicing, stacking and concatenation, and it introduces many isinstance if ... else ... for no real advantage.

This was linked to issues Jun 29, 2020
BFAnas pushed a commit to BFAnas/tianshou that referenced this pull request May 5, 2024
* Enable stacking of partially matching Batch instances.

* Fix list support for getitem.

* Fix Batch 'size' method.

* Update Batch documentation.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Batch & Buffer profiling Batch to Numpy compatibility
2 participants