-
-
Notifications
You must be signed in to change notification settings - Fork 104
[WIP] Add VisitTensors traits and simplify nn internals. #460
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's your sense for if this will actually help internals? I see that we've removed quite a bit of code juts with this draft, and probably can remove more of it.
If we could figure out how to use this to implement EMA, I think that'd be a key advantage, but I haven't been able to figure out how to do "zipped" visiting
I think that this pr will grant us a lot of flexibility in implementing new modules and new functionality for modules. It will also make it much easier for users who implement custom modules to get access to a lot of features that they otherwise would've had to write a lot of boilerplate to use. Also, I don't think that implementating VisitTensorGroups is very complicated, because it pretty much amounts to specifying how to access each field, and defining each field's name. TensorVisitor is also pretty simple to implement, as it amounts to getting the tensors you need from |
src/nn/add_into.rs
Outdated
fn visit_groups<F: TensorVisitor<N, M, E, D>>( | ||
mut self_refs: ModuleGroup<N, M, Self>, | ||
func: &mut F, | ||
) -> Result<(), F::Err> { | ||
self_refs.map(|s| &s.0, |s| &mut s.0, "0.").visit(func) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thoughts on including the visitor in the ModuleGroup object, and then call it inside .map?
If we can move the E/D generics to the call method, this shouldn't be an issue.
Something like:
fn visit_groups<F: TensorVisitor<N, M, E, D>>( | |
mut self_refs: ModuleGroup<N, M, Self>, | |
func: &mut F, | |
) -> Result<(), F::Err> { | |
self_refs.map(|s| &s.0, |s| &mut s.0, "0.").visit(func) | |
} | |
fn visit_groups<F: TensorVisitor<N, M>>( | |
&mut self_refs: ModuleGroup<N, M, Self, F>, | |
) -> Result<(), F::Err> { | |
self_refs.map(|s| &s.0, |s| &mut s.0, "0.") | |
} |
We could even go a step further and make ModuleGroup itself a TensorVisitor (wrapped around another TensorVisitor):
fn visit_groups<F: TensorVisitor<N, M, E, D>>( | |
mut self_refs: ModuleGroup<N, M, Self>, | |
func: &mut F, | |
) -> Result<(), F::Err> { | |
self_refs.map(|s| &s.0, |s| &mut s.0, "0.").visit(func) | |
} | |
fn visit_groups<V: TensorVisitor<N, M>>( | |
&mut visitor: V, | |
) -> Result<(), F::Err> { | |
visitor.map(|s| &s.0, |s| &mut s.0, "0.") | |
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've implemented your first suggestion and have renamed some things to make the new functionality of ModuleGroups make more sense. I haven't implemented your second suggestion because I don't really think I understand what you're getting at. What would the call implementation of ModuleGroup look like? How would we treat TensorVisitors not wrapped in a ModuleGroup?
src/nn/visit_tensors.rs
Outdated
} | ||
} | ||
|
||
pub trait VisitTensors<E: Dtype, D: DeviceStorage>: VisitTensorGroups<1, 0, E, D> + Debug { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are 1, 0
and 0, 1
for VisitTensorsRef and VisitoTensorsMut the only two cases we would ever support?
I think we might be able to get rid of the VisitTensorGroups trait if we can separately impl some trait (not sure which one) for &T and &mut T
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will eventually support 1, 1
so we can add and multiply modules for EMA.
@nkoppel i figured out how to do it without the arrays, will open another PR soon |
Closing since the other PR was merged, thanks for the great work on this @nkoppel, the get_refs/get_muts is super clean/clever! 🚀 |
Adds framework to run functions on all sets of corresponding tensors in groups of immutable and mutable references to modules of a single type, as discussed in #435. Closes #435, and will make #425 easier to implement.
Tasks:
nn