+
Skip to content

Removes ModuleBuilder, Adds BuildModule & BuildOnDevice #405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jan 26, 2023
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ type Mlp = (

fn main() {
let dev: Cpu = Default::default();
let mlp: Mlp = dev.build_module();
let mlp = Mlp::build_on_device(&dev);
let x: Tensor<Rank1<10>> = dev.zeros();
let y /*: Tensor<Rank1<2>>*/ = mlp.forward(x);
println!("{:?}", y);
Expand Down Expand Up @@ -122,11 +122,13 @@ for sequentially executing modules.

```rust
// no idea why you would do this, but you could!
let model: (ReLU, Sigmoid, Tanh) = dev.build_module();
type Model = (ReLU, Sigmoid, Tanh);
let model = Model::build_on_device(&dev);
```

```rust
let model: (Linear<10, 5>, Tanh) = dev.build_module();
type Model = (Linear<10, 5>, Tanh)
let model = Model::build_on_device(&dev);
```

How implementing Module for a 2-tuple looks:
Expand Down
7 changes: 4 additions & 3 deletions examples/03-nn.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Intro to dfdx::nn

use dfdx::{
nn::{Linear, Module, ModuleBuilder, ModuleMut, ReLU, ResetParams},
nn::{BuildOnDevice, Linear, Module, ModuleMut, ReLU, ResetParams},
shapes::{Const, Rank1, Rank2},
tensor::{AsArray, Cpu, SampleTensor, Tensor, ZerosTensor},
};
Expand All @@ -11,7 +11,7 @@ fn main() {

// nn exposes many different neural network types, like the Linear layer!
// you can use Build::build to construct an initialized model
let mut m: Linear<4, 2> = dev.build_module();
let mut m = Linear::<4, 2>::build_on_device(&dev);

// Build::reset_params also allows you to re-randomize the weights
m.reset_params();
Expand All @@ -34,7 +34,8 @@ fn main() {
let _: Tensor<(usize, Const<2>)> = m.forward(dev.zeros_like(&(batch_size, Const)));

// you can also combine multiple modules with tuples
let mlp: (Linear<4, 2>, ReLU, Linear<2, 1>) = dev.build_module();
type Mlp = (Linear<4, 2>, ReLU, Linear<2, 1>);
let mlp = Mlp::build_on_device(&dev);

// and of course forward passes the input through each module sequentially:
let x = dev.sample_normal::<Rank1<4>>();
Expand Down
4 changes: 2 additions & 2 deletions examples/05-optim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use dfdx::{
losses::mse_loss,
nn::{Linear, ModuleBuilder, ModuleMut, ReLU, Tanh},
nn::{BuildOnDevice, Linear, ModuleMut, ReLU, Tanh},
optim::{Momentum, Optimizer, Sgd, SgdConfig},
shapes::Rank2,
tensor::{AsArray, Cpu, SampleTensor},
Expand All @@ -29,7 +29,7 @@ fn main() {
});

// let's initialize our model and some dummy data
let mut mlp: Mlp = dev.build_module();
let mut mlp = Mlp::build_on_device(&dev);
let x = dev.sample_normal::<Rank2<3, 5>>();
let y = dev.sample_normal::<Rank2<3, 2>>();

Expand Down
12 changes: 6 additions & 6 deletions examples/06-mnist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ impl MnistDataset {

// our network structure
type Mlp = (
(Linear<784, 512, Dev>, ReLU),
(Linear<512, 128, Dev>, ReLU),
(Linear<128, 32, Dev>, ReLU),
Linear<32, 10, Dev>,
(Linear<784, 512>, ReLU),
(Linear<512, 128>, ReLU),
(Linear<128, 32>, ReLU),
Linear<32, 10>,
);

// training batch size
Expand All @@ -98,8 +98,8 @@ fn main() {
let mut rng = StdRng::seed_from_u64(0);

// initialize model and optimizer
let mut model: Mlp = dev.build_module();
let mut opt: Adam<Mlp> = Default::default();
let mut model = Mlp::build_on_device(&dev);
let mut opt = Adam::default();

// initialize dataset
let dataset = MnistDataset::train(&mnist_path);
Expand Down
15 changes: 5 additions & 10 deletions examples/07-custom-module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use dfdx::{
gradients::Tape,
nn::{self, Module, ModuleBuilder},
nn::{self, Module},
optim::{GradientUpdate, ParamUpdater, UnusedTensors},
shapes::{Rank1, Rank2},
tensor::{Cpu, HasErr, SampleTensor, Tensor},
Expand All @@ -18,21 +18,16 @@ struct Mlp<const IN: usize, const INNER: usize, const OUT: usize> {
}

// BuildModule lets you randomize a model's parameters
impl<const IN: usize, const INNER: usize, const OUT: usize> nn::ResetParams<Cpu, f32>
impl<const IN: usize, const INNER: usize, const OUT: usize> nn::BuildModule<Cpu, f32>
for Mlp<IN, INNER, OUT>
{
fn try_build(device: &Cpu) -> Result<Self, <Cpu as HasErr>::Err> {
Ok(Self {
l1: nn::ResetParams::try_build(device)?,
l2: nn::ResetParams::try_build(device)?,
l1: nn::BuildModule::try_build(device)?,
l2: nn::BuildModule::try_build(device)?,
relu: nn::ReLU,
})
}
fn try_reset_params(&mut self) -> Result<(), <Cpu as HasErr>::Err> {
self.l1.try_reset_params()?;
self.l2.try_reset_params()?;
Ok(())
}
}

// GradientUpdate lets you update a model's parameters using gradients
Expand Down Expand Up @@ -84,7 +79,7 @@ fn main() {
let dev: Cpu = Default::default();

// Construct model
let model: Mlp<10, 512, 20> = dev.build_module();
let model: Mlp<10, 512, 20> = nn::BuildModule::build(&dev);

// Forward pass with a single sample
let item: Tensor<Rank1<10>> = dev.sample_normal();
Expand Down
5 changes: 3 additions & 2 deletions examples/11-multi-headed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//! outputs using `SplitInto`.

use dfdx::{
nn::{Linear, Module, ModuleBuilder, SplitInto},
nn::{BuildOnDevice, Linear, Module, SplitInto},
shapes::Rank1,
tensor::{Cpu, Tensor, TensorFromArray},
};
Expand All @@ -13,7 +13,8 @@ fn main() {
// SplitInto accepts a tuple of modules. Each one of the items in the
// tuple must accept the same type of input.
// Note that here, both of the linears have the same size input (1)
let m: SplitInto<(Linear<1, 3>, Linear<1, 5>)> = dev.build_module();
type Model = SplitInto<(Linear<1, 3>, Linear<1, 5>)>;
let m = Model::build_on_device(&dev);

// when we forward data through, we get a tuple back!
let _: (Tensor<Rank1<3>>, Tensor<Rank1<5>>) = m.forward(dev.tensor([1.0]));
Expand Down
2 changes: 1 addition & 1 deletion examples/nightly-conv-net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ fn main() {
);

let dev: Cpu = Default::default();
let m: Model = dev.build_module();
let m = Model::build_on_device(&dev);

// single image forward
let x: Tensor<Rank3<3, 28, 28>> = dev.sample_normal();
Expand Down
2 changes: 1 addition & 1 deletion examples/nightly-resnet18.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ fn main() {

let dev: Cpu = Default::default();
let x = dev.sample_normal::<Rank3<3, 224, 224>>();
let m: Resnet18<1000> = dev.build_module();
let m = Resnet18::<1000>::build_on_device(&dev);
for _ in 0.. {
let start = Instant::now();
let _ = m.forward(x.clone());
Expand Down
3 changes: 2 additions & 1 deletion examples/nightly-transformer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ fn main() {
use dfdx::prelude::*;

let dev: Cpu = Default::default();
let t: Transformer<16, 4, 3, 3, 8> = dev.build_module();
type Model = Transformer<16, 4, 3, 3, 8>;
let t = Model::build_on_device(&dev);

let src: Tensor<Rank3<4, 12, 16>> = dev.sample_normal();
let tgt: Tensor<Rank3<4, 6, 16>> = dev.sample_normal();
Expand Down
4 changes: 2 additions & 2 deletions examples/rl-dqn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ fn main() {
let next_state = dev.sample_normal::<Rank2<BATCH, STATE>>();

// initiliaze model
let mut q_net: QNetwork = dev.build_module();
let target_q_net: QNetwork = q_net.clone();
let mut q_net = QNetwork::build_on_device(&dev);
let target_q_net = q_net.clone();

let mut sgd = Sgd::new(SgdConfig {
lr: 1e-1,
Expand Down
4 changes: 2 additions & 2 deletions examples/rl-ppo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ fn main() {
let advantage = dev.sample_normal::<Rank1<BATCH>>();

// initiliaze model - all weights are 0s
let mut pi_net: PolicyNetwork = dev.build_module();
let target_pi_net: PolicyNetwork = pi_net.clone();
let mut pi_net = PolicyNetwork::build_on_device(&dev);
let target_pi_net = pi_net.clone();

let mut sgd = Sgd::new(SgdConfig {
lr: 1e-1,
Expand Down
13 changes: 7 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,19 @@
//! );
//! ```
//!
//! 3. Instantiate models with [crate::nn::ResetParams] and [crate::nn::ModuleBuilder]
//! 3. Instantiate models with [crate::nn::BuildOnDevice]
//! ```rust
//! # use dfdx::prelude::*;
//! let dev: Cpu = Default::default();
//! let mlp: (Linear<5, 2>, ReLU) = dev.build_module();
//! type Model = (Linear<5, 2>, ReLU);
//! let mlp = Model::build_on_device(&dev);
//! ```
//!
//! 4. Pass data through networks with [crate::nn::Module]
//! ```rust
//! # use dfdx::prelude::*;
//! # let dev: Cpu = Default::default();
//! let mlp: Linear<5, 2> = dev.build_module();
//! # let mlp: Linear<5, 2> = BuildModule::build(&dev);
//! let x: Tensor<Rank1<5>> = dev.zeros();
//! let y = mlp.forward(x); // compiler infers that `y` must be `Tensor<Rank1<2>>`
//! ```
Expand All @@ -50,7 +51,7 @@
//! ```rust
//! # use dfdx::prelude::*;
//! # let dev: Cpu = Default::default();
//! # let model: Linear<10, 5> = dev.build_module();
//! # let model: Linear<10, 5> = BuildModule::build(&dev);
//! # let y_true: Tensor<Rank1<5>> = dev.sample_normal().softmax();
//! // tensors default to not having a tape
//! let x: Tensor<Rank1<10>, f32, Cpu, NoneTape> = dev.zeros();
Expand All @@ -67,7 +68,7 @@
//! ```rust
//! # use dfdx::{prelude::*, gradients::Gradients};
//! # let dev: Cpu = Default::default();
//! # let model: Linear<10, 5> = dev.build_module();
//! # let model: Linear<10, 5> = BuildModule::build(&dev);
//! # let y_true = dev.sample_normal::<Rank1<5>>().softmax();
//! # let y = model.forward(dev.zeros::<Rank1<10>>().trace());
//! // compute cross entropy loss
Expand All @@ -80,7 +81,7 @@
//! ```rust
//! # use dfdx::{prelude::*, gradients::Gradients, optim::*};
//! # let dev: Cpu = Default::default();
//! # let mut model: Linear<10, 5> = dev.build_module();
//! # let mut model: Linear<10, 5> = BuildModule::build(&dev);
//! # let y_true = dev.sample_normal::<Rank1<5>>().softmax();
//! # let y = model.forward(dev.zeros::<Rank1<10>>().trace());
//! # let loss = cross_entropy_with_logits_loss(y, y_true);
Expand Down
14 changes: 13 additions & 1 deletion src/nn/activations.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{gradients::Tape, shapes::*, tensor::*, tensor_ops::*};

use super::module::{Module, NonMutableModule, ZeroSizedModule};
use super::module::{BuildModule, Module, NonMutableModule, ZeroSizedModule};

macro_rules! activation_impls {
($struct_name:ident, $func_name:ident, #[$docstring:meta]) => {
Expand All @@ -11,6 +11,12 @@ macro_rules! activation_impls {
impl ZeroSizedModule for $struct_name {}
impl NonMutableModule for $struct_name {}

impl<D: Device<E>, E: Dtype> BuildModule<D, E> for $struct_name {
fn try_build(_: &D) -> Result<Self, <D>::Err> {
Ok(Default::default())
}
}

impl<S: Shape, E: Dtype, D: Device<E>, T: Tape<D>> Module<Tensor<S, E, D, T>>
for $struct_name
{
Expand Down Expand Up @@ -41,6 +47,12 @@ pub struct Softmax;
impl ZeroSizedModule for Softmax {}
impl NonMutableModule for Softmax {}

impl<D: Device<E>, E: Dtype> BuildModule<D, E> for Softmax {
fn try_build(_: &D) -> Result<Self, <D>::Err> {
Ok(Default::default())
}
}

impl<Ax: Axes, S: Shape<LastAxis = Ax> + ReduceShape<Ax>, E: Dtype, D: Device<E>, T: Tape<D>>
Module<Tensor<S, E, D, T>> for Softmax
{
Expand Down
Loading
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载