-
-
Notifications
You must be signed in to change notification settings - Fork 104
Add TensorContainer trait to allow more argument types for TensorVisitors in #469 #472
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TensorContainer trait to allow more argument types for TensorVisitors in #469 #472
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awesome!!! Going to merge for now and change to the unconstructable enums (& do some reorganization in the branch)
path: self.path, | ||
}; | ||
Field::iter_tensors(&mut walker)?; | ||
std::mem::drop(walker); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, if you remove it, rust will consider self.path
to be mutably borrowed on the next line.
impl<W: Write + Seek, E: Dtype + NumpyDtype, D: CopySlice<E>> VisitTensors<E, D> | ||
for zip::ZipWriter<W> | ||
{ | ||
type Container = &'static (); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah lets go with your suggestion in the PR description about using the unconstructable enums instead of these, purely for readability. I know i'd forget in a couple months what this means haha.
type Container = TensorRef;
makes so much sense!
} | ||
} | ||
|
||
impl<T: TensorContainer> TensorContainer for Option<T> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could see this being used for optional fields (e.g. bias if we go the Option route), nice!
* Temp commit * Using visitors for npz * Moving missing grads test to optimizers * Format * Rename ModuleWalker to TensorVisitor * Fix missing update of ModuleWalker * Fixing old docs * Moves name out of TensorOptions * Rename TensorOptions helper methods * move visitors into nn * Fixing example * Add TensorContainer trait to allow more argument types for TensorVisitors in #469 (#472) * Add TensorContainer trait; deduplicate code in visitors/base.rs * Implement TensorContainer for tuples, Option, and Vec * run cargo fmt * Add TensorMut and TensorRef tensor contains * Renamign visitors -> tensor_collection * Reorg tensor collection * Renaming * Moving name to first arg * Fixing clippy warnings * Update src/nn/tensor_collection/visitor.rs Co-authored-by: nkoppel <nathankoppel0@gmail.com> * Adding #non_exhausting to TensorOptions --------- Co-authored-by: nkoppel <nathankoppel0@gmail.com>
Adds TensorContainer to allow implementers of VisitTensors to specify their own argument types. This is done with the
Container
associated type, which represents some collection of immutable and mutable references to the same type. References to a particular type are represented with&'static ()
and&'static mut ()
, but these could be substituted for other types including unconstructable enums, like these:Related to #469 and #460.