+
Skip to content

Conversation

nkoppel
Copy link
Contributor

@nkoppel nkoppel commented Feb 21, 2023

Adds TensorContainer to allow implementers of VisitTensors to specify their own argument types. This is done with the Container associated type, which represents some collection of immutable and mutable references to the same type. References to a particular type are represented with &'static () and &'static mut (), but these could be substituted for other types including unconstructable enums, like these:

enum Ref {}
enum RefMut {}

Related to #469 and #460.

Copy link
Owner

@coreylowman coreylowman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome!!! Going to merge for now and change to the unconstructable enums (& do some reorganization in the branch)

path: self.path,
};
Field::iter_tensors(&mut walker)?;
std::mem::drop(walker);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if you remove it, rust will consider self.path to be mutably borrowed on the next line.

impl<W: Write + Seek, E: Dtype + NumpyDtype, D: CopySlice<E>> VisitTensors<E, D>
for zip::ZipWriter<W>
{
type Container = &'static ();
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah lets go with your suggestion in the PR description about using the unconstructable enums instead of these, purely for readability. I know i'd forget in a couple months what this means haha.

type Container = TensorRef;

makes so much sense!

}
}

impl<T: TensorContainer> TensorContainer for Option<T> {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could see this being used for optional fields (e.g. bias if we go the Option route), nice!

@coreylowman coreylowman merged commit dd7977e into coreylowman:tensor-collection Feb 22, 2023
coreylowman added a commit that referenced this pull request Feb 22, 2023
* Temp commit

* Using visitors for npz

* Moving missing grads test to optimizers

* Format

* Rename ModuleWalker to TensorVisitor

* Fix missing update of ModuleWalker

* Fixing old docs

* Moves name out of TensorOptions

* Rename TensorOptions helper methods

* move visitors into nn

* Fixing example

* Add TensorContainer trait to allow more argument types for TensorVisitors in #469 (#472)

* Add TensorContainer trait; deduplicate code in visitors/base.rs

* Implement TensorContainer for tuples, Option, and Vec

* run cargo fmt

* Add TensorMut and TensorRef tensor contains

* Renamign visitors -> tensor_collection

* Reorg tensor collection

* Renaming

* Moving name to first arg

* Fixing clippy warnings

* Update src/nn/tensor_collection/visitor.rs

Co-authored-by: nkoppel <nathankoppel0@gmail.com>

* Adding #non_exhausting to TensorOptions

---------

Co-authored-by: nkoppel <nathankoppel0@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载