+
Skip to content

Conversation

nkoppel
Copy link
Contributor

@nkoppel nkoppel commented Feb 16, 2023

Adds framework to run functions on all sets of corresponding tensors in groups of immutable and mutable references to modules of a single type, as discussed in #435. Closes #435, and will make #425 easier to implement.

Tasks:

  • Implement VisitTensorGroups for all Modules in nn
  • Implement GradientUpdate (disable for some fields in batchnorm2d)
  • Implement ResetParams (linear and conv2d need special treatment)
  • Implement SaveToNpz
  • Implement LoadFromNpz
  • Document internals in visit_tensors.rs
  • Implement VisitTensorGroups for Mlp in 07-custom-module

@nkoppel nkoppel changed the title Add VisitTensors traits and simplify nn internals. [WIP] Add VisitTensors traits and simplify nn internals. Feb 16, 2023
@nkoppel nkoppel marked this pull request as draft February 16, 2023 17:58
Copy link
Owner

@coreylowman coreylowman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's your sense for if this will actually help internals? I see that we've removed quite a bit of code juts with this draft, and probably can remove more of it.

If we could figure out how to use this to implement EMA, I think that'd be a key advantage, but I haven't been able to figure out how to do "zipped" visiting

@nkoppel
Copy link
Contributor Author

nkoppel commented Feb 17, 2023

I think that this pr will grant us a lot of flexibility in implementing new modules and new functionality for modules. It will also make it much easier for users who implement custom modules to get access to a lot of features that they otherwise would've had to write a lot of boilerplate to use.

Also, I don't think that implementating VisitTensorGroups is very complicated, because it pretty much amounts to specifying how to access each field, and defining each field's name. TensorVisitor is also pretty simple to implement, as it amounts to getting the tensors you need from tensors and passing them to a function. For example, implementing CountParams for every nn module takes only 28 lines of fairly simple code.

Comment on lines 37 to 42
fn visit_groups<F: TensorVisitor<N, M, E, D>>(
mut self_refs: ModuleGroup<N, M, Self>,
func: &mut F,
) -> Result<(), F::Err> {
self_refs.map(|s| &s.0, |s| &mut s.0, "0.").visit(func)
}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thoughts on including the visitor in the ModuleGroup object, and then call it inside .map?

If we can move the E/D generics to the call method, this shouldn't be an issue.

Something like:

Suggested change
fn visit_groups<F: TensorVisitor<N, M, E, D>>(
mut self_refs: ModuleGroup<N, M, Self>,
func: &mut F,
) -> Result<(), F::Err> {
self_refs.map(|s| &s.0, |s| &mut s.0, "0.").visit(func)
}
fn visit_groups<F: TensorVisitor<N, M>>(
&mut self_refs: ModuleGroup<N, M, Self, F>,
) -> Result<(), F::Err> {
self_refs.map(|s| &s.0, |s| &mut s.0, "0.")
}

We could even go a step further and make ModuleGroup itself a TensorVisitor (wrapped around another TensorVisitor):

Suggested change
fn visit_groups<F: TensorVisitor<N, M, E, D>>(
mut self_refs: ModuleGroup<N, M, Self>,
func: &mut F,
) -> Result<(), F::Err> {
self_refs.map(|s| &s.0, |s| &mut s.0, "0.").visit(func)
}
fn visit_groups<V: TensorVisitor<N, M>>(
&mut visitor: V,
) -> Result<(), F::Err> {
visitor.map(|s| &s.0, |s| &mut s.0, "0.")
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've implemented your first suggestion and have renamed some things to make the new functionality of ModuleGroups make more sense. I haven't implemented your second suggestion because I don't really think I understand what you're getting at. What would the call implementation of ModuleGroup look like? How would we treat TensorVisitors not wrapped in a ModuleGroup?

}
}

pub trait VisitTensors<E: Dtype, D: DeviceStorage>: VisitTensorGroups<1, 0, E, D> + Debug {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are 1, 0 and 0, 1 for VisitTensorsRef and VisitoTensorsMut the only two cases we would ever support?

I think we might be able to get rid of the VisitTensorGroups trait if we can separately impl some trait (not sure which one) for &T and &mut T

Copy link
Contributor Author

@nkoppel nkoppel Feb 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will eventually support 1, 1 so we can add and multiply modules for EMA.

@coreylowman
Copy link
Owner

@nkoppel i figured out how to do it without the arrays, will open another PR soon

@coreylowman
Copy link
Owner

Closing since the other PR was merged, thanks for the great work on this @nkoppel, the get_refs/get_muts is super clean/clever! 🚀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Parameter Count

2 participants

点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载