+
Skip to content

Adding nn builder structs, dtype generics, and remove device defaults. #433

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Feb 7, 2023
9 changes: 5 additions & 4 deletions examples/03-nn.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Intro to dfdx::nn

use dfdx::{
nn::{BuildOnDevice, Linear, Module, ModuleMut, ReLU, ResetParams},
nn::{builders::*, modules, BuildOnDevice, Module, ModuleMut, ResetParams},
shapes::{Const, Rank1, Rank2},
tensor::{AsArray, Cpu, SampleTensor, Tensor, ZerosTensor},
};
Expand All @@ -10,10 +10,11 @@ fn main() {
let dev: Cpu = Default::default();

// nn exposes many different neural network types, like the Linear layer!
// you can use Build::build to construct an initialized model
let mut m = Linear::<4, 2>::build_on_device(&dev);
// you can use BuildModule::build to construct an initialized model
type Model = Linear<4, 2>;
let mut m: modules::Linear<4, 2, f32, Cpu> = Model::build_on_device(&dev);

// Build::reset_params also allows you to re-randomize the weights
// ResetParams::reset_params also allows you to re-randomize the weights
m.reset_params();

// Modules act on tensors using either:
Expand Down
2 changes: 1 addition & 1 deletion examples/05-optim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use dfdx::{
losses::mse_loss,
nn::{BuildOnDevice, Linear, ModuleMut, ReLU, Tanh},
nn::{builders::*, BuildOnDevice, ModuleMut},
optim::{Momentum, Optimizer, Sgd, SgdConfig},
shapes::Rank2,
tensor::{AsArray, Cpu, SampleTensor, Tensor},
Expand Down
39 changes: 12 additions & 27 deletions examples/07-custom-module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

use dfdx::{
gradients::Tape,
nn::{self, Module},
optim::{GradientUpdate, ParamUpdater, UnusedTensors},
nn::{
self,
modules::{Linear, ReLU},
BuildModule, Module,
},
shapes::{Rank1, Rank2},
tensor::{Cpu, HasErr, SampleTensor, Tensor},
};
Expand All @@ -12,9 +15,9 @@ use dfdx::{
/// This case is trivial and should be done with a tuple of linears and relus,
/// but it demonstrates how to build models with custom behavior
struct Mlp<const IN: usize, const INNER: usize, const OUT: usize> {
l1: nn::Linear<IN, INNER>,
l2: nn::Linear<INNER, OUT>,
relu: nn::ReLU,
l1: Linear<IN, INNER, f32, Cpu>,
l2: Linear<INNER, OUT, f32, Cpu>,
relu: ReLU,
}

// BuildModule lets you randomize a model's parameters
Expand All @@ -23,31 +26,13 @@ impl<const IN: usize, const INNER: usize, const OUT: usize> nn::BuildModule<Cpu,
{
fn try_build(device: &Cpu) -> Result<Self, <Cpu as HasErr>::Err> {
Ok(Self {
l1: nn::BuildModule::try_build(device)?,
l2: nn::BuildModule::try_build(device)?,
relu: nn::ReLU,
l1: BuildModule::try_build(device)?,
l2: BuildModule::try_build(device)?,
relu: ReLU,
})
}
}

// GradientUpdate lets you update a model's parameters using gradients
impl<const IN: usize, const INNER: usize, const OUT: usize> GradientUpdate<Cpu, f32>
for Mlp<IN, INNER, OUT>
{
fn update<U>(
&mut self,
updater: &mut U,
unused: &mut UnusedTensors,
) -> Result<(), <Cpu as HasErr>::Err>
where
U: ParamUpdater<Cpu, f32>,
{
self.l1.update(updater, unused)?;
self.l2.update(updater, unused)?;
Ok(())
}
}

// impl Module for single item
impl<const IN: usize, const INNER: usize, const OUT: usize> nn::Module<Tensor<Rank1<IN>, f32, Cpu>>
for Mlp<IN, INNER, OUT>
Expand Down Expand Up @@ -79,7 +64,7 @@ fn main() {
let dev: Cpu = Default::default();

// Construct model
let model: Mlp<10, 512, 20> = nn::BuildModule::build(&dev);
let model = Mlp::<10, 512, 20>::build(&dev);

// Forward pass with a single sample
let item: Tensor<Rank1<10>, f32, _> = dev.sample_normal();
Expand Down
3 changes: 2 additions & 1 deletion examples/11-multi-headed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
//! outputs using `SplitInto`.

use dfdx::{
nn::{BuildOnDevice, Linear, Module, SplitInto},
nn::builders::{Linear, SplitInto},
nn::{BuildOnDevice, Module},
shapes::Rank1,
tensor::{Cpu, Tensor, TensorFromArray},
};
Expand Down
2 changes: 1 addition & 1 deletion src/gradients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::unique_id::{HasUniqueId, UniqueId};
/// 4. Access mutable references to arrays
///
/// This structure is similar to a HashMap, where all the methods require a key
/// implementing [UniqueId], [HasShape], and [HasDtype].
/// implementing [UniqueId], [AllocGrad].
///
/// Under the hood, it actually is a HashMap, and stores values as Box<dyn Any>. The
/// important part of key's implementing [HasShape], and [HasDtype] is that the associated type
Expand Down
10 changes: 5 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
//! ```rust
//! # use dfdx::prelude::*;
//! # let dev: Cpu = Default::default();
//! # let mlp: Linear<5, 2> = BuildModule::build(&dev);
//! # let mlp = <Linear<5, 2>>::build_on_device(&dev);
//! let x: Tensor<Rank1<5>, f32, _> = dev.zeros();
//! let y = mlp.forward(x); // compiler infers that `y` must be `Tensor<Rank1<2>>`
//! ```
Expand All @@ -51,7 +51,7 @@
//! ```rust
//! # use dfdx::prelude::*;
//! # let dev: Cpu = Default::default();
//! # let model: Linear<10, 5> = BuildModule::build(&dev);
//! # let model = <Linear<10, 5>>::build_on_device(&dev);
//! # let y_true: Tensor<Rank1<5>, f32, _> = dev.sample_normal().softmax();
//! // tensors default to not having a tape
//! let x: Tensor<Rank1<10>, f32, Cpu, NoneTape> = dev.zeros();
Expand All @@ -68,7 +68,7 @@
//! ```rust
//! # use dfdx::{prelude::*, gradients::Gradients};
//! # let dev: Cpu = Default::default();
//! # let model: Linear<10, 5> = BuildModule::build(&dev);
//! # let model = <Linear<10, 5>>::build_on_device(&dev);
//! # let y_true = dev.sample_normal::<Rank1<5>>().softmax();
//! # let y = model.forward(dev.zeros::<Rank1<10>>().trace());
//! // compute cross entropy loss
Expand All @@ -81,7 +81,7 @@
//! ```rust
//! # use dfdx::{prelude::*, gradients::Gradients, optim::*};
//! # let dev: Cpu = Default::default();
//! # let mut model: Linear<10, 5> = BuildModule::build(&dev);
//! # let mut model = <Linear<10, 5>>::build_on_device(&dev);
//! # let y_true = dev.sample_normal::<Rank1<5>>().softmax();
//! # let y = model.forward(dev.zeros::<Rank1<10>>().trace());
//! # let loss = cross_entropy_with_logits_loss(y, y_true);
Expand Down Expand Up @@ -119,7 +119,7 @@ pub mod unique_id;
pub mod prelude {
pub use crate::gradients::{NoneTape, OwnedTape};
pub use crate::losses::*;
pub use crate::nn::*;
pub use crate::nn::{builders::*, *};
pub use crate::optim::prelude::*;
pub use crate::shapes::*;
pub use crate::tensor::*;
Expand Down
35 changes: 21 additions & 14 deletions src/nn/add_into.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{optim::*, shapes::Dtype, tensor_ops::Device};

use super::{BuildModule, Module, ModuleMut, ResetParams, ToDevice};
use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ResetParams, ToDevice};

/// Add inputs together into a single tensor. `T` should be a tuple
//// where every element of the tuple has the same output type
Expand Down Expand Up @@ -32,6 +32,10 @@ impl<T: GradientUpdate<D, E>, D: Device<E>, E: Dtype> GradientUpdate<D, E> for A
}
}

impl<T: BuildOnDevice<D, E>, D: Device<E>, E: Dtype> BuildOnDevice<D, E> for AddInto<T> {
type Built = AddInto<T::Built>;
}

impl<T: BuildModule<D, E>, D: Device<E>, E: Dtype> BuildModule<D, E> for AddInto<T> {
fn try_build(device: &D) -> Result<Self, <D>::Err> {
Ok(Self(BuildModule::try_build(device)?))
Expand Down Expand Up @@ -98,7 +102,7 @@ mod tests {
use super::*;
use crate::{
gradients::OwnedTape,
nn::{tests::SimpleUpdater, BuildOnDevice, Linear, ReLU},
nn::{builders::*, tests::SimpleUpdater},
shapes::*,
tensor::*,
tests::TestDevice,
Expand All @@ -112,7 +116,8 @@ mod tests {
#[test]
fn test_add_into_2() {
let dev: TestDevice = Default::default();
let m: AddInto<(Linear<2, 5, _>, Linear<3, 5, _>)> = BuildModule::build(&dev);
type Model = AddInto<(Linear<2, 5>, Linear<3, 5>)>;
let m = Model::build_on_device(&dev);
let _: Tensor<Rank1<5>, _, _, OwnedTape<_>> = m.forward((
dev.zeros::<Rank1<2>>().traced(),
dev.zeros::<Rank1<3>>().traced(),
Expand All @@ -126,8 +131,8 @@ mod tests {
#[test]
fn test_add_into_3() {
let dev: TestDevice = Default::default();
let m: AddInto<(Linear<2, 5, _>, Linear<3, 5, _>, Linear<4, 5, _>)> =
BuildModule::build(&dev);
type Model = AddInto<(Linear<2, 5>, Linear<3, 5>, Linear<4, 5>)>;
let m = Model::build_on_device(&dev);
let _: Tensor<Rank1<5>, _, _, OwnedTape<_>> = m.forward((
dev.zeros::<Rank1<2>>().traced(),
dev.zeros::<Rank1<3>>().traced(),
Expand Down Expand Up @@ -189,14 +194,15 @@ mod tests {
#[test]
fn test_add_into_6() {
let dev: TestDevice = Default::default();
let m: AddInto<(
Linear<2, 5, _>,
Linear<3, 5, _>,
Linear<4, 5, _>,
Linear<5, 5, _>,
Linear<6, 5, _>,
Linear<7, 5, _>,
)> = BuildModule::build(&dev);
type Model = AddInto<(
Linear<2, 5>,
Linear<3, 5>,
Linear<4, 5>,
Linear<5, 5>,
Linear<6, 5>,
Linear<7, 5>,
)>;
let m = Model::build_on_device(&dev);
let _: Tensor<Rank1<5>, _, _, OwnedTape<_>> = m.forward((
dev.zeros::<Rank1<2>>().traced(),
dev.zeros::<Rank1<3>>().traced(),
Expand All @@ -218,7 +224,8 @@ mod tests {
#[test]
fn test_missing_gradients() {
let dev: TestDevice = Default::default();
let mut model: AddInto<(Linear<5, 3, _>, Linear<5, 3, _>)> = BuildModule::build(&dev);
type Model = AddInto<(Linear<5, 3>, Linear<5, 3>)>;
let mut model = Model::build_on_device(&dev);
let mut g: SimpleUpdater = Default::default();

// no gradients present
Expand Down
Loading
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载