+
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions examples/11-conv-net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,31 @@
//! layers on nightly rust.

#![cfg_attr(feature = "nightly", feature(generic_const_exprs))]
use dfdx::prelude::*;
use rand::thread_rng;

#[cfg(not(feature = "nightly"))]
fn main() {
panic!("`+nightly` required to run this example.")
}
type Model = (
(Conv2D<3, 4, 3>, ReLU),
(Conv2D<4, 8, 3>, ReLU),
(Conv2D<8, 16, 3>, ReLU),
Flatten2D,
Linear<7744, 10>,
);

#[cfg(feature = "nightly")]
fn main() {
use dfdx::prelude::*;
use rand::thread_rng;

type Model = (
(Conv2D<3, 4, 3>, ReLU),
(Conv2D<4, 8, 3>, ReLU),
(Conv2D<8, 16, 3>, ReLU),
FlattenImage,
Linear<7744, 10>,
);

let mut rng = thread_rng();
let mut m: Model = Default::default();
m.reset_params(&mut rng);

// single image forward
let x: Tensor3D<3, 28, 28> = TensorCreator::randn(&mut rng);

#[cfg(feature = "nightly")]
let _: Tensor1D<10> = m.forward(x);

// batched image forward
let x: Tensor4D<32, 3, 28, 28> = TensorCreator::randn(&mut rng);

#[cfg(feature = "nightly")]
let _: Tensor2D<32, 10> = m.forward(x);
}
22 changes: 10 additions & 12 deletions src/nn/flatten.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,23 @@ use crate::{Assert, ConstTrue};
/// Specifically:
/// ```ignore
/// # use dfdx::prelude::*;
/// let _: Tensor1D<{3 * 5 * 7}> = FlattenImage.forward(Tensor3D::<3, 5, 7>::zeros());
/// let _: Tensor2D<8, {3 * 5 * 7}> = FlattenImage.forward(Tensor4D::<8, 3, 5, 7>::zeros());
/// let _: Tensor1D<{3 * 5 * 7}> = Flatten2D.forward(Tensor3D::<3, 5, 7>::zeros());
/// let _: Tensor2D<8, {3 * 5 * 7}> = Flatten2D.forward(Tensor4D::<8, 3, 5, 7>::zeros());
/// ```
#[derive(Default, Clone, Copy)]
pub struct FlattenImage;
pub struct Flatten2D;

impl ResetParams for FlattenImage {
/// Does nothing.
impl ResetParams for Flatten2D {
fn reset_params<R: rand::Rng>(&mut self, _: &mut R) {}
}

impl CanUpdateWithGradients for FlattenImage {
/// Does nothing.
impl CanUpdateWithGradients for Flatten2D {
fn update<G: GradientProvider>(&mut self, _: &mut G, _: &mut UnusedTensors) {}
}

#[cfg(feature = "nightly")]
impl<const M: usize, const N: usize, const O: usize, H: Tape> Module<Tensor3D<M, N, O, H>>
for FlattenImage
for Flatten2D
where
Assert<{ M * N * O == (M * N * O) }>: ConstTrue,
{
Expand All @@ -38,7 +36,7 @@ where

#[cfg(feature = "nightly")]
impl<const M: usize, const N: usize, const O: usize, const P: usize, H: Tape>
Module<Tensor4D<M, N, O, P, H>> for FlattenImage
Module<Tensor4D<M, N, O, P, H>> for Flatten2D
where
Assert<{ M * N * O * P == M * (N * O * P) }>: ConstTrue,
{
Expand All @@ -48,7 +46,7 @@ where
}
}

impl<T> ModuleMut<T> for FlattenImage
impl<T> ModuleMut<T> for Flatten2D
where
Self: Module<T>,
{
Expand All @@ -65,7 +63,7 @@ mod tests {

#[test]
fn test_flattens() {
let _: Tensor1D<{ 15 * 10 * 5 }> = FlattenImage.forward_mut(Tensor3D::<15, 10, 5>::zeros());
let _: Tensor2D<5, 24> = FlattenImage.forward_mut(Tensor4D::<5, 4, 3, 2>::zeros());
let _: Tensor1D<{ 15 * 10 * 5 }> = Flatten2D.forward_mut(Tensor3D::<15, 10, 5>::zeros());
let _: Tensor2D<5, 24> = Flatten2D.forward_mut(Tensor4D::<5, 4, 3, 2>::zeros());
}
}
2 changes: 1 addition & 1 deletion src/nn/npz_impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ empty_npz_impl!(Dropout);
empty_npz_impl!(AvgPoolGlobal);
empty_npz_impl!(MaxPoolGlobal);
empty_npz_impl!(MinPoolGlobal);
empty_npz_impl!(FlattenImage);
empty_npz_impl!(Flatten2D);

impl<const N: usize> SaveToNpz for DropoutOneIn<N> {}
impl<const N: usize> LoadFromNpz for DropoutOneIn<N> {}
Expand Down
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载