+
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/nn/add_into.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{optim::*, shapes::Dtype, tensor_ops::Device};

use super::{Module, ModuleMut, ResetParams};
use super::{Module, ModuleMut, OnDevice, ResetParams, ToDevice};

/// Add inputs together into a single tensor. `T` should be a tuple
//// where every element of the tuple has the same output type
Expand Down Expand Up @@ -40,6 +40,14 @@ impl<T: ResetParams<D, E>, D: Device<E>, E: Dtype> ResetParams<D, E> for AddInto
}
}

impl<T: ToDevice<D>, D> ToDevice<D> for AddInto<T> {
type Output = AddInto<OnDevice<T, D>>;

fn to_device(&self, device: &D) -> Self::Output {
AddInto(self.0.to_device(device))
}
}

macro_rules! sum {
($H:tt) => { $H };
($H:tt, $($T:tt),+) => { $H + sum!($($T),+) };
Expand Down Expand Up @@ -94,6 +102,10 @@ mod tests {
unique_id::HasUniqueId,
};

type TestAddIntoCpu = AddInto<(Linear<2, 5>, Linear<3, 5>)>;
#[allow(unused)]
type TestAddInto<D> = OnDevice<TestAddIntoCpu, D>;

#[test]
fn test_add_into_2() {
let dev: TestDevice = Default::default();
Expand Down
17 changes: 16 additions & 1 deletion src/nn/batchnorm2d.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{gradients::*, optim::*, shapes::*, tensor::*, tensor_ops::*};

use super::{Module, ModuleMut, ResetParams};
use super::{Module, ModuleMut, ResetParams, ToDevice};

/// Batch normalization for images as described in
/// [Batch Normalization: Accelerating Deep Network Training
Expand Down Expand Up @@ -189,6 +189,21 @@ impl<const C: usize, D: Device<f32>> GradientUpdate<D, f32> for BatchNorm2D<C, D
}
}

impl<const C: usize, D1: Device<f32>, D2: Device<f32>> ToDevice<D2> for BatchNorm2D<C, D1> {
type Output = BatchNorm2D<C, D2>;

fn to_device(&self, device: &D2) -> Self::Output {
BatchNorm2D {
scale: self.scale.to_device(device),
bias: self.bias.to_device(device),
running_mean: self.running_mean.to_device(device),
running_var: self.running_var.to_device(device),
epsilon: self.epsilon,
momentum: self.momentum,
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
18 changes: 17 additions & 1 deletion src/nn/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::{
tensor_ops::{BroadcastTo, Device, TryConv2DTo},
};

use super::{Module, ModuleMut, ResetParams};
use super::{Module, ModuleMut, ResetParams, ToDevice};

/// **Requires Nightly** Performs 2d convolutions on 3d and 4d images.
///
Expand Down Expand Up @@ -70,6 +70,22 @@ where
}
}

impl<const I: usize, const O: usize, const K: usize, const S: usize, const P: usize, D1, D2>
ToDevice<D2> for Conv2D<I, O, K, S, P, D1>
where
D1: Device<f32>,
D2: Device<f32>,
{
type Output = Conv2D<I, O, K, S, P, D2>;

fn to_device(&self, device: &D2) -> Self::Output {
Conv2D {
weight: self.weight.to_device(device),
bias: self.bias.to_device(device),
}
}
}

impl<const C: usize, const O: usize, const K: usize, const S: usize, const P: usize, D, Img>
Module<Img> for Conv2D<C, O, K, S, P, D>
where
Expand Down
13 changes: 12 additions & 1 deletion src/nn/generalized_residual.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{optim::*, shapes::*, tensor::*, tensor_ops::*};

use super::{Module, ModuleMut, ResetParams};
use super::{Module, ModuleMut, OnDevice, ResetParams, ToDevice};

/// A residual connection `R` around `F`: `F(x) + R(x)`,
/// as introduced in [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385).
Expand Down Expand Up @@ -53,6 +53,17 @@ impl<D: Device<E>, E: Dtype, F: ResetParams<D, E>, R: ResetParams<D, E>> ResetPa
}
}

impl<D, F: ToDevice<D>, R: ToDevice<D>> ToDevice<D> for GeneralizedResidual<F, R> {
type Output = GeneralizedResidual<OnDevice<F, D>, OnDevice<R, D>>;

fn to_device(&self, device: &D) -> Self::Output {
GeneralizedResidual {
f: self.f.to_device(device),
r: self.r.to_device(device),
}
}
}

impl<T: SplitTape, F: Module<T>, R: Module<T, Output = F::Output>> Module<T>
for GeneralizedResidual<F, R>
where
Expand Down
10 changes: 9 additions & 1 deletion src/nn/impl_module_for_tuples.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{optim::*, shapes::*, tensor_ops::*};

use super::module::{Module, ModuleMut, ResetParams};
use super::module::{Module, ModuleMut, OnDevice, ResetParams, ToDevice};

macro_rules! tuple_impls {
([$($name:ident),+] [$($idx:tt),+], $last:ident, [$($rev_tail:ident),+]) => {
Expand All @@ -27,6 +27,14 @@ macro_rules! tuple_impls {
}
}

impl<$($name: ToDevice<D>,)+ D> ToDevice<D> for ($($name,)+) {
type Output = ($(OnDevice<$name, D>,)+);

fn to_device(&self, device: &D) -> Self::Output {
($(self.$idx.to_device(device)),+)
}
}

/*This macro expands like this for a 4-tuple:

impl<
Expand Down
14 changes: 13 additions & 1 deletion src/nn/layer_norm.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{gradients::Tape, optim::*, shapes::*, tensor::*, tensor_ops::*};

use super::{Module, ModuleMut, ResetParams};
use super::{Module, ModuleMut, ResetParams, ToDevice};

/// Implements layer normalization as described in [Layer Normalization](https://arxiv.org/abs/1607.06450).
///
Expand Down Expand Up @@ -54,6 +54,18 @@ impl<const M: usize, D: Device<f32>> GradientUpdate<D, f32> for LayerNorm1D<M, D
}
}

impl<const M: usize, D1: Device<f32>, D2: Device<f32>> ToDevice<D2> for LayerNorm1D<M, D1> {
type Output = LayerNorm1D<M, D2>;

fn to_device(&self, device: &D2) -> Self::Output {
LayerNorm1D {
gamma: self.gamma.to_device(device),
beta: self.beta.to_device(device),
epsilon: self.epsilon,
}
}
}

impl<const M: usize, D: Device<f32>, T: Tape<D>> Module<Tensor<Rank1<M>, f32, D, T>>
for LayerNorm1D<M, D>
{
Expand Down
28 changes: 27 additions & 1 deletion src/nn/linear.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{gradients::Tape, optim::*, shapes::*, tensor::*, tensor_ops::*};

use super::module::{Module, ModuleMut, ResetParams};
use super::module::{Module, ModuleMut, ResetParams, ToDevice};

/// A linear transformation of the form `weight * x + bias`, where `weight` is a matrix, `x` is a vector or matrix,
/// and `bias` is a vector.
Expand Down Expand Up @@ -118,6 +118,20 @@ impl<'a, B: Dim, S: Dim, const M: usize, D: Device<f32>, T: Tape<D>>
}
}

#[rustfmt::skip]
impl<const I: usize, const O: usize, D1: Device<f32>, D2: Device<f32>>
ToDevice<D2> for Linear<I, O, D1>
{
type Output = Linear<I, O, D2>;

fn to_device(&self, device: &D2) -> Self::Output {
Linear {
weight: self.weight.to_device(device),
bias: self.bias.to_device(device),
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -133,6 +147,18 @@ mod tests {
];
const B: [f32; 2] = [0.3765365, -0.290717];

#[cfg(feature = "cuda")]
#[test]
fn test_linear_ondevice() {
use super::super::module::OnDevice;

let cpu: Cpu = Default::default();
let cuda: Cuda = Default::default();
let _: Linear<1, 1, _> = cpu.build_module();
let _: OnDevice<Linear<1, 1>, Cuda> = cuda.build_module();
let _: OnDevice<(Linear<1, 2>, Linear<2, 1>), Cuda> = cuda.build_module();
}

#[test]
fn test_linear_initialize() {
let dev: TestDevice = Default::default();
Expand Down
14 changes: 13 additions & 1 deletion src/nn/module.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use crate::{optim::GradientUpdate, shapes::Dtype, tensor_ops::Device};

#[cfg(feature = "cuda")]
pub use crate::tensor::OnCuda;
pub use crate::tensor::{OnCpu, OnDevice, ToDevice};

/// Immutable forward of `Input` that produces [Module::Output].
/// See [ModuleMut] for mutable forward.
pub trait Module<Input> {
Expand Down Expand Up @@ -59,7 +63,7 @@ impl<D: Device<E>, E: Dtype> ModuleBuilder<E> for D {}
/// blanket impls for [ResetParams], [GradientUpdate], and [ModuleMut]
pub trait ZeroSizedModule: Default {}

impl<T: ZeroSizedModule, D: Device<E>, E: Dtype> ResetParams<D, E> for T {
impl<T: ZeroSizedModule + Clone, D: Device<E>, E: Dtype> ResetParams<D, E> for T {
fn try_build(_: &D) -> Result<Self, <D>::Err> {
Ok(Default::default())
}
Expand Down Expand Up @@ -90,3 +94,11 @@ where
self.forward(input)
}
}

impl<T: ZeroSizedModule + Clone, D> ToDevice<D> for T {
type Output = T;

fn to_device(&self, _device: &D) -> Self {
self.clone()
}
}
16 changes: 15 additions & 1 deletion src/nn/repeated.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{optim::*, shapes::Dtype, tensor_ops::Device};

use super::{Module, ModuleMut, ResetParams};
use super::{Module, ModuleMut, OnDevice, ResetParams, ToDevice};

/// Repeats `T` `N` times. This requires that `T`'s input is the same as it's output.
///
Expand Down Expand Up @@ -60,6 +60,20 @@ impl<D: Device<E>, E: Dtype, T: GradientUpdate<D, E>, const N: usize> GradientUp
}
}

impl<T: ToDevice<D>, const N: usize, D> ToDevice<D> for Repeated<T, N> {
type Output = Repeated<OnDevice<T, D>, N>;

fn to_device(&self, device: &D) -> Self::Output {
Repeated {
modules: self
.modules
.iter()
.map(|module| module.to_device(device))
.collect(),
}
}
}

impl<Input, T: Module<Input, Output = Input>, const N: usize> Module<Input> for Repeated<T, N> {
type Output = T::Output;
fn forward(&self, mut x: Input) -> Self::Output {
Expand Down
10 changes: 9 additions & 1 deletion src/nn/residual.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{optim::*, shapes::*, tensor::SplitTape, tensor_ops::Device};

use super::{Module, ModuleMut, ResetParams};
use super::{Module, ModuleMut, OnDevice, ResetParams, ToDevice};

/// A residual connection around `F`: `F(x) + x`,
/// as introduced in [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385).
Expand Down Expand Up @@ -38,6 +38,14 @@ impl<D: Device<E>, E: Dtype, F: ResetParams<D, E>> ResetParams<D, E> for Residua
}
}

impl<F: ToDevice<D>, D> ToDevice<D> for Residual<F> {
type Output = Residual<OnDevice<F, D>>;

fn to_device(&self, device: &D) -> Self::Output {
Residual(self.0.to_device(device))
}
}

impl<T: SplitTape + std::ops::Add<T, Output = T>, F: Module<T, Output = T>> Module<T>
for Residual<F>
{
Expand Down
10 changes: 9 additions & 1 deletion src/nn/split_into.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{optim::*, shapes::Dtype, tensor::*, tensor_ops::Device};

use super::{Module, ModuleMut, ResetParams};
use super::{Module, ModuleMut, OnDevice, ResetParams, ToDevice};

/// Splits input into multiple heads. `T` should be a tuple,
/// where every element of the tuple accepts the same input type.
Expand Down Expand Up @@ -39,6 +39,14 @@ impl<T: ResetParams<D, E>, D: Device<E>, E: Dtype> ResetParams<D, E> for SplitIn
}
}

impl<T: ToDevice<D>, D> ToDevice<D> for SplitInto<T> {
type Output = SplitInto<OnDevice<T, D>>;

fn to_device(&self, device: &D) -> Self::Output {
SplitInto(self.0.to_device(device))
}
}

macro_rules! tuple_impls {
([$($heads:ident),+] $tail:ident) => {
impl<
Expand Down
35 changes: 34 additions & 1 deletion src/nn/transformer/decoder.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{
nn::{LayerNorm1D, Linear, Module, ModuleMut, ReLU, Repeated, ResetParams, Residual},
nn::{LayerNorm1D, Linear, Module, ModuleMut, ReLU, Repeated, ResetParams, Residual, ToDevice},
optim::{GradientUpdate, ParamUpdater, UnusedTensors},
tensor::{Cpu, PutTape, SplitTape},
tensor_ops::Device,
Expand Down Expand Up @@ -47,6 +47,22 @@ impl<const M: usize, const H: usize, const F: usize, const L: usize, D: Device<f
}
}

impl<
const M: usize,
const H: usize,
const F: usize,
const L: usize,
D1: Device<f32>,
D2: Device<f32>,
> ToDevice<D2> for TransformerDecoder<M, H, F, L, D1>
{
type Output = TransformerDecoder<M, H, F, L, D2>;

fn to_device(&self, device: &D2) -> Self::Output {
TransformerDecoder(self.0.to_device(device))
}
}

impl<const M: usize, const H: usize, const F: usize, const L: usize, D, Tgt, Mem: Clone>
Module<(Tgt, Mem)> for TransformerDecoder<M, H, F, L, D>
where
Expand Down Expand Up @@ -147,6 +163,23 @@ impl<const M: usize, const H: usize, const F: usize, D: Device<f32>> GradientUpd
}
}

impl<const M: usize, const H: usize, const F: usize, D1: Device<f32>, D2: Device<f32>> ToDevice<D2>
for TransformerDecoderBlock<M, H, F, D1>
{
type Output = TransformerDecoderBlock<M, H, F, D2>;

fn to_device(&self, device: &D2) -> Self::Output {
TransformerDecoderBlock {
self_attn: self.self_attn.to_device(device),
norm1: self.norm1.to_device(device),
mh_attn: self.mh_attn.to_device(device),
norm2: self.norm2.to_device(device),
ff: self.ff.to_device(device),
norm3: self.norm3.to_device(device),
}
}
}

impl<const M: usize, const H: usize, const F: usize, D: Device<f32>, Tgt, Mem> Module<(Tgt, Mem)>
for TransformerDecoderBlock<M, H, F, D>
where
Expand Down
Loading
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载