+
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/nn/npz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,6 @@ mod tests {
test_save_load::<Rank1<5>, TestDtype, TestDevice, (T, T)>(&dev);
}

#[cfg(feature = "nightly")]
#[test]
fn test_save_load_mha() {
let dev: TestDevice = Default::default();
Expand All @@ -309,7 +308,6 @@ mod tests {
assert_eq!(y1.array(), y2.array());
}

#[cfg(feature = "nightly")]
#[test]
fn test_save_load_transformer() {
let dev: TestDevice = Default::default();
Expand Down
11 changes: 10 additions & 1 deletion src/nn/reshape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,16 @@ use crate::{

use super::*;

/// **Requires Nightly** Flattens 3d tensors to 1d, and 4d tensors to 2d.
/// Reshapes input tensors to a shape known *at compile time*.
///
/// Example usage:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let model: Reshape<Rank2<5, 24>> = Default::default();
/// let x: Tensor<Rank4<5, 4, 3, 2>, f32, _> = dev.sample_normal();
/// let _: Tensor<Rank2<5, 24>, f32, _> = model.forward(x);
/// ```
#[derive(Default, Clone, Copy)]
pub struct Reshape<S: ConstShape>(S);

Expand Down
2 changes: 0 additions & 2 deletions src/nn/safetensors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,6 @@ mod tests {
test_save_load::<Rank1<5>, TestDtype, TestDevice, (T, T)>(&dev);
}

#[cfg(feature = "nightly")]
#[test]
fn test_save_load_mha() {
let dev: TestDevice = Default::default();
Expand All @@ -310,7 +309,6 @@ mod tests {
assert_eq!(y1.array(), y2.array());
}

#[cfg(feature = "nightly")]
#[test]
fn test_save_load_transformer() {
let dev: TestDevice = Default::default();
Expand Down
4 changes: 2 additions & 2 deletions src/tensor_ops/concat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mod cuda_kernel;
///
/// **Pytorch equivalent** `torch.concat`.
///
/// Stacking with const dims **requires nightly**:
/// Concat with const dims **requires nightly**:
/// ```ignore
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
Expand All @@ -17,7 +17,7 @@ mod cuda_kernel;
/// let _: Tensor<Rank2<6, 4>, f32, _> = a.concat(b);
/// ```
///
/// Stacking with usize dims:
/// Concat with usize dims:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
Expand Down
42 changes: 42 additions & 0 deletions src/tensor_ops/conv2d/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,51 @@ pub(super) trait Conv2DKernel<E: Dtype>: DeviceStorage {
) -> Result<(), Self::Err>;
}

/// Apply the 2d convolution to a tensor.
///
/// [Const] dims **require nightly**:
/// ```ignore
/// #![feature(generic_const_exprs)]
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let x: Tensor<Rank4<2, 3, 32, 32>, f32, _> = dev.sample_normal();
/// let w: Tensor<Rank4<6, 3, 3, 3>, f32, _> = dev.sample_normal();
/// let y = (x, w).conv2d(
/// Const::<1>, // stride
/// Const::<0>, // padding
/// Const::<1>, // dilation
/// Const::<1>, // groups
/// );
/// ```
///
/// [usize] dims can be used on stable:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let x: Tensor<_, f32, _> = dev.sample_normal_like(&(
/// 2, // batch size
/// 3, // input channels
/// 32, // height
/// 32, // width
/// ));
/// let w: Tensor<_, f32, _> = dev.sample_normal_like(&(
/// 6, // output channels
/// 3, // input channels
/// 3, // kernel size
/// 3, // kernel size
/// ));
/// let y = (x, w).conv2d(
/// 1, // stride
/// 0, // padding
/// 1, // dilation
/// 1, // groups
/// );
/// ```
pub trait TryConv2D<Stride, Padding, Dilation, Groups>: Sized {
type Convolved;
type Error: std::fmt::Debug;

/// Applies a 2D convolution to the input tensor.
fn conv2d(
self,
stride: Stride,
Expand All @@ -66,6 +107,7 @@ pub trait TryConv2D<Stride, Padding, Dilation, Groups>: Sized {
self.try_conv2d(stride, padding, dilation, groups).unwrap()
}

/// Fallibly applies a 2D convolution to the input tensor.
fn try_conv2d(
self,
stride: Stride,
Expand Down
7 changes: 3 additions & 4 deletions src/tensor_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ mod sum_to;
mod tanh;
mod to_dtype;
mod tri;
mod upscale2d;
mod var_to;

pub use abs::abs;
Expand Down Expand Up @@ -258,9 +259,11 @@ pub use sum_to::SumTo;
pub use tanh::tanh;
pub use to_dtype::to_dtype;
pub use tri::{lower_tri, upper_tri};
pub use upscale2d::{Bilinear, GenericUpscale2D, NearestNeighbor, TryUpscale2D, UpscaleMethod};
pub use var_to::VarTo;

pub(crate) use to_dtype::ToDtypeKernel;
pub(crate) use upscale2d::Upscale2DKernel;

#[cfg(feature = "nightly")]
mod conv2d;
Expand All @@ -272,10 +275,6 @@ mod convtrans2d;
#[cfg(feature = "nightly")]
pub use convtrans2d::{ConvTransAlgebra, TryConvTrans2D, TryConvTrans2DTo};

mod upscale2d;
pub(crate) use upscale2d::Upscale2DKernel;
pub use upscale2d::{Bilinear, GenericUpscale2D, NearestNeighbor, TryUpscale2D, UpscaleMethod};

#[cfg(feature = "nightly")]
mod pool2d;
#[cfg(feature = "nightly")]
Expand Down
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载