+
Skip to content

Removes ModuleBuilder, Adds BuildModule & BuildOnDevice #405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jan 26, 2023

Conversation

coreylowman
Copy link
Owner

@coreylowman coreylowman commented Jan 26, 2023

Resolves #388

See discussions in issue and in #394.

This removes ModuleBuilder, notably you can't call dev.build_module() anymore. Instead users should use either BuildModule if they are using only the Cpu device, or if they want to specify device in the typing. Or the more general BuildOnDevice, which should be called like Model::build_on_device(&dev).

A con of this is that this does complicate creating modules a bit since there are two distinct ways to create them. The main suggested pathway should be:

  1. Specifying your model as a type alias (e.g. type Model = ...)
  2. Using BuildOnDevice to build it (e.g. let mut model = Model::build_on_device(&dev))

This makes it trivial to switch devices in the future.

Error message for BuildOnDevice with the wrong device

error[E0308]: mismatched types
  --> examples/tmp.rs:6:43
   |
6  |     let m: Model = Model::build_on_device(&dev);
   |                    ---------------------- ^^^^ expected struct `dfdx::tensor::Cpu`, found struct `dfdx::tensor::Cuda`
   |                    |
   |                    arguments to this function are incorrect
   |
   = note: expected reference `&dfdx::tensor::Cpu`
              found reference `&dfdx::tensor::Cuda`
note: associated function defined here
  --> /home/ubuntu/dfdx/src/nn/module.rs:49:8
   |
49 |     fn build_on_device(device: &D) -> Self::Output {
   |        ^^^^^^^^^^^^^^^

Error message for BuildModule with the wrong device

error[E0277]: the trait bound `dfdx::nn::Linear<5, 2>: dfdx::nn::BuildModule<dfdx::tensor::Cuda, f32>` is not satisfied
 --> examples/tmp.rs:5:39
  |
5 |     let m: Model = BuildModule::build(&dev);
  |                    ------------------ ^^^^ the trait `dfdx::nn::BuildModule<dfdx::tensor::Cuda, f32>` is not implemented for `dfdx::nn::Linear<5, 2>`
  |                    |
  |                    required by a bound introduced by this call
  |
  = help: the trait `dfdx::nn::BuildModule<D, f32>` is implemented for `dfdx::nn::Linear<I, O, D>`

For more information about this error, try `rustc --explain E0277`.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

nn api hard to use with different devices
2 participants
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载