+
Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ pub mod unique_id;
pub mod prelude {
pub use crate::gradients::{NoneTape, OwnedTape};
pub use crate::losses::*;
pub use crate::nn::{builders::*, *};
pub use crate::nn::{builders::*, visit_tensors::*, *};
pub use crate::optim::prelude::*;
pub use crate::shapes::*;
pub use crate::tensor::*;
Expand Down
27 changes: 13 additions & 14 deletions src/nn/add_into.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use crate::{optim::*, shapes::Dtype, tensor_ops::Device};
use crate::{shapes::Dtype, tensor_ops::Device};

use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ResetParams, ToDevice};
use super::{
BuildModule, BuildOnDevice, DeviceStorage, Module, ModuleMut, TensorFunction, TensorVisitor,
ToDevice, VisitTensors,
};

/// Add inputs together into a single tensor. `T` should be a tuple
//// where every element of the tuple has the same output type
Expand All @@ -23,12 +26,13 @@ use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ResetParams, ToDevice
#[derive(Debug, Default, Clone)]
pub struct AddInto<T>(pub T);

impl<T: GradientUpdate<D, E>, D: Device<E>, E: Dtype> GradientUpdate<D, E> for AddInto<T> {
fn update<U>(&mut self, updater: &mut U, unused: &mut UnusedTensors) -> Result<(), <D>::Err>
where
U: ParamUpdater<D, E>,
{
self.0.update(updater, unused)
impl<T: VisitTensors<E, D> + std::fmt::Debug, E: Dtype, D: DeviceStorage> VisitTensors<E, D>
for AddInto<T>
{
fn visit_groups<const N: usize, const M: usize, F: TensorFunction<N, M, E, D>>(
mut visitor: TensorVisitor<N, M, Self, F>,
) -> Result<(), F::Err> {
visitor.visit_field(|s| &s.0, |s| &mut s.0, "0.")
}
}

Expand All @@ -42,12 +46,6 @@ impl<T: BuildModule<D, E>, D: Device<E>, E: Dtype> BuildModule<D, E> for AddInto
}
}

impl<T: ResetParams<D, E>, D: Device<E>, E: Dtype> ResetParams<D, E> for AddInto<T> {
fn try_reset_params(&mut self) -> Result<(), <D>::Err> {
self.0.try_reset_params()
}
}

impl<T: ToDevice<D>, D> ToDevice<D> for AddInto<T> {
type Output = AddInto<T::Output>;
fn to_device(&self, device: &D) -> Self::Output {
Expand Down Expand Up @@ -103,6 +101,7 @@ mod tests {
use crate::{
gradients::OwnedTape,
nn::{builders::*, tests::SimpleUpdater, DeviceBuildExt},
optim::GradientUpdate,
shapes::*,
tensor::*,
tests::{TestDevice, TestDtype},
Expand Down
57 changes: 34 additions & 23 deletions src/nn/batchnorm2d.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use crate::{gradients::*, optim::*, shapes::*, tensor::*, tensor_ops::*};
use crate::{gradients::*, shapes::*, tensor::*, tensor_ops::*};

use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ResetParams, ToDevice};
use super::{
BuildModule, BuildOnDevice, Module, ModuleMut, TensorFunction, TensorFunctionOption,
TensorVisitor, ToDevice, VisitTensors,
};

pub mod builder {
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
Expand Down Expand Up @@ -69,6 +72,35 @@ pub struct BatchNorm2D<const C: usize, E: Dtype, D: DeviceStorage> {
pub momentum: E,
}

impl<const C: usize, E: Dtype, D: DeviceStorage> VisitTensors<E, D> for BatchNorm2D<C, E, D> {
fn visit_groups<const N: usize, const M: usize, F: TensorFunction<N, M, E, D>>(
mut visitor: TensorVisitor<N, M, Self, F>,
) -> Result<(), F::Err> {
use TensorFunctionOption::*;

visitor.visit_field_with_options(
|s| &s.scale,
|s| &mut s.scale,
"scale",
&[ResetParamsOnes],
)?;
visitor.visit_field(|s| &s.bias, |s| &mut s.bias, "bias")?;

visitor.visit_field_with_options(
|s| &s.running_mean,
|s| &mut s.running_mean,
"running_mean",
&[DisableGradientUpdate],
)?;
visitor.visit_field_with_options(
|s| &s.running_var,
|s| &mut s.running_var,
"running_var",
&[DisableGradientUpdate, ResetParamsOnes],
)
}
}

impl<const C: usize, E: Dtype, D: Device<E>> BatchNorm2D<C, E, D> {
/// generic forward for inference
fn infer_fwd<S: Shape, Ax: Axes>(&self, x: Tensor<S, E, D>) -> Tensor<S, E, D>
Expand Down Expand Up @@ -183,16 +215,6 @@ impl<const C: usize, E: Dtype, D: Device<E>> BuildModule<D, E> for BatchNorm2D<C
}
}

impl<const C: usize, E: Dtype, D: Device<E>> ResetParams<D, E> for BatchNorm2D<C, E, D> {
fn try_reset_params(&mut self) -> Result<(), D::Err> {
self.scale.try_fill_with_ones()?;
self.bias.try_fill_with_zeros()?;
self.running_mean.try_fill_with_zeros()?;
self.running_var.try_fill_with_ones()?;
Ok(())
}
}

impl<const C: usize, E: Dtype, D1: Device<E>, D2: Device<E>> ToDevice<D2>
for BatchNorm2D<C, E, D1>
{
Expand All @@ -209,17 +231,6 @@ impl<const C: usize, E: Dtype, D1: Device<E>, D2: Device<E>> ToDevice<D2>
}
}

impl<const C: usize, E: Dtype, D: Device<E>> GradientUpdate<D, E> for BatchNorm2D<C, E, D> {
fn update<U>(&mut self, updater: &mut U, unused: &mut UnusedTensors) -> Result<(), <D>::Err>
where
U: ParamUpdater<D, E>,
{
self.scale.update(updater, unused)?;
self.bias.update(updater, unused)?;
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::builder::BatchNorm2D;
Expand Down
55 changes: 24 additions & 31 deletions src/nn/conv.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
use num_traits::Float;
use rand_distr::uniform::SampleUniform;

use crate::{gradients::Tape, optim::*, shapes::*, tensor::*, tensor_ops::*};
use crate::{gradients::Tape, shapes::*, tensor::*, tensor_ops::*};

use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ResetParams, ToDevice};
use super::{
BuildModule, BuildOnDevice, Module, ModuleMut, TensorFunction, TensorFunctionOption,
TensorVisitor, ToDevice, VisitTensors,
};

pub mod builder {
#[derive(Debug)]
Expand Down Expand Up @@ -53,19 +56,25 @@ pub struct Conv2D<
pub bias: Tensor<Rank1<OUT_CHAN>, E, D>,
}

impl<const I: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D>
GradientUpdate<D, E> for Conv2D<I, O, K, S, P, E, D>
where
E: Dtype,
D: Device<E>,
impl<
const I: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
E: Dtype,
D: DeviceStorage,
> VisitTensors<E, D> for Conv2D<I, O, K, S, P, E, D>
{
fn update<U>(&mut self, updater: &mut U, unused: &mut UnusedTensors) -> Result<(), <D>::Err>
where
U: ParamUpdater<D, E>,
{
self.weight.update(updater, unused)?;
self.bias.update(updater, unused)?;
Ok(())
fn visit_groups<const N: usize, const M: usize, F: TensorFunction<N, M, E, D>>(
mut visitor: TensorVisitor<N, M, Self, F>,
) -> Result<(), F::Err> {
let k = (I * K * K) as f64;
let bound = 1. / k.sqrt();
let options = [TensorFunctionOption::ResetParamsUniform(-bound, bound)];

visitor.visit_field_with_options(|s| &s.weight, |s| &mut s.weight, "weight", &options)?;
visitor.visit_field_with_options(|s| &s.bias, |s| &mut s.bias, "bias", &options)
}
}

Expand All @@ -85,23 +94,6 @@ where
}
}

impl<const I: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D>
ResetParams<D, E> for Conv2D<I, O, K, S, P, E, D>
where
E: Dtype + Float + SampleUniform,
D: Device<E>,
{
fn try_reset_params(&mut self) -> Result<(), <D>::Err> {
let k = E::from_usize(I * K * K).unwrap();
let bound = E::ONE / k.sqrt();
self.weight
.try_fill_with_distr(rand_distr::Uniform::new(-bound, bound))?;
self.bias
.try_fill_with_distr(rand_distr::Uniform::new(-bound, bound))?;
Ok(())
}
}

impl<const I: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D1, D2>
ToDevice<D2> for Conv2D<I, O, K, S, P, E, D1>
where
Expand Down Expand Up @@ -175,6 +167,7 @@ impl<'a, B: Dim, const C: usize, H: Dim, W: Dim, E: Dtype, D: Device<E>, T: Tape
mod tests {
use crate::{
nn::DeviceBuildExt,
optim::{Optimizer, Sgd},
tensor::{AsArray, SampleTensor, ZerosTensor},
tests::*,
};
Expand Down
44 changes: 19 additions & 25 deletions src/nn/embedding.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
use num_traits::Float;
use rand_distr::uniform::SampleUniform;

use crate::{gradients::Tape, optim::*, shapes::*, tensor::*, tensor_ops::*};
use crate::{gradients::Tape, shapes::*, tensor::*, tensor_ops::*};

use super::module::{BuildModule, BuildOnDevice, Module, ModuleMut, ResetParams, ToDevice};
use super::{
module::{BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice},
TensorFunction, TensorFunctionOption, TensorVisitor, VisitTensors,
};

pub mod builder {
#[derive(Debug)]
Expand Down Expand Up @@ -52,6 +55,19 @@ pub struct Embedding<const VOCAB: usize, const DIM: usize, E: Dtype, D: DeviceSt
pub weight: Tensor<Rank2<VOCAB, DIM>, E, D>,
}

impl<const V: usize, const I: usize, E: Dtype, D: DeviceStorage> VisitTensors<E, D>
for Embedding<V, I, E, D>
{
fn visit_groups<const N: usize, const M: usize, F: TensorFunction<N, M, E, D>>(
mut visitor: TensorVisitor<N, M, Self, F>,
) -> Result<(), F::Err> {
let bound = 1. / (V as f64).sqrt();
let options = [TensorFunctionOption::ResetParamsUniform(-bound, bound)];

visitor.visit_field_with_options(|s| &s.weight, |s| &mut s.weight, "weight", &options)
}
}

impl<const V: usize, const M: usize, const S: usize, E: Dtype, D: Device<E>, T: Tape<D>>
Module<Tensor<Rank1<S>, usize, D, T>> for Embedding<V, M, E, D>
{
Expand Down Expand Up @@ -90,18 +106,6 @@ where
}
}

impl<const VOCAB: usize, const DIM: usize, E: Dtype, D: Device<E>> GradientUpdate<D, E>
for Embedding<VOCAB, DIM, E, D>
{
fn update<U>(&mut self, updater: &mut U, unused: &mut UnusedTensors) -> Result<(), D::Err>
where
U: ParamUpdater<D, E>,
{
self.weight.update(updater, unused)?;
Ok(())
}
}

impl<const V: usize, const M: usize, E: Dtype + Float + SampleUniform, D: Device<E>>
BuildModule<D, E> for Embedding<V, M, E, D>
{
Expand All @@ -113,17 +117,6 @@ impl<const V: usize, const M: usize, E: Dtype + Float + SampleUniform, D: Device
}
}

impl<const VOCAB: usize, const DIM: usize, E: Dtype + Float + SampleUniform, D: Device<E>>
ResetParams<D, E> for Embedding<VOCAB, DIM, E, D>
{
fn try_reset_params(&mut self) -> Result<(), D::Err> {
let bound = E::ONE / E::from_usize(VOCAB).unwrap().sqrt();
let distr = rand_distr::Uniform::new(-bound, bound);
self.weight.try_fill_with_distr(distr)?;
Ok(())
}
}

impl<const VOCAB: usize, const DIM: usize, E: Dtype, D1: Device<E>, D2: Device<E>> ToDevice<D2>
for Embedding<VOCAB, DIM, E, D1>
{
Expand All @@ -140,6 +133,7 @@ mod tests {
use super::*;
use crate::{
nn::{tests::SimpleUpdater, DeviceBuildExt},
optim::GradientUpdate,
tests::*,
unique_id::HasUniqueId,
};
Expand Down
37 changes: 16 additions & 21 deletions src/nn/generalized_residual.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use crate::{optim::*, shapes::*, tensor::*, tensor_ops::*};
use crate::{shapes::*, tensor::*, tensor_ops::*};

use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ResetParams, ToDevice};
use super::{
BuildModule, BuildOnDevice, Module, ModuleMut, TensorFunction, TensorVisitor, ToDevice,
VisitTensors,
};

/// A residual connection `R` around `F`: `F(x) + R(x)`,
/// as introduced in [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385).
Expand All @@ -25,16 +28,18 @@ pub struct GeneralizedResidual<F, R> {
pub r: R,
}

impl<D: Device<E>, E: Dtype, F: GradientUpdate<D, E>, R: GradientUpdate<D, E>> GradientUpdate<D, E>
for GeneralizedResidual<F, R>
impl<
F: VisitTensors<E, D> + std::fmt::Debug,
R: VisitTensors<E, D> + std::fmt::Debug,
E: Dtype,
D: DeviceStorage,
> VisitTensors<E, D> for GeneralizedResidual<F, R>
{
fn update<U>(&mut self, updater: &mut U, unused: &mut UnusedTensors) -> Result<(), <D>::Err>
where
U: ParamUpdater<D, E>,
{
self.f.update(updater, unused)?;
self.r.update(updater, unused)?;
Ok(())
fn visit_groups<const N: usize, const M: usize, Func: TensorFunction<N, M, E, D>>(
mut visitor: TensorVisitor<N, M, Self, Func>,
) -> Result<(), Func::Err> {
visitor.visit_field(|s| &s.f, |s| &mut s.f, "f.")?;
visitor.visit_field(|s| &s.r, |s| &mut s.r, "r.")
}
}

Expand All @@ -55,16 +60,6 @@ impl<D: Device<E>, E: Dtype, F: BuildModule<D, E>, R: BuildModule<D, E>> BuildMo
}
}

impl<D: Device<E>, E: Dtype, F: ResetParams<D, E>, R: ResetParams<D, E>> ResetParams<D, E>
for GeneralizedResidual<F, R>
{
fn try_reset_params(&mut self) -> Result<(), <D>::Err> {
self.f.try_reset_params()?;
self.r.try_reset_params()?;
Ok(())
}
}

impl<D, F: ToDevice<D>, R: ToDevice<D>> ToDevice<D> for GeneralizedResidual<F, R> {
type Output = GeneralizedResidual<F::Output, R::Output>;
fn to_device(&self, device: &D) -> Self::Output {
Expand Down
Loading
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载