+
Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dfdx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ num-traits = { workspace = true }
safetensors = { workspace = true, optional = true }
memmap2 = { workspace = true, optional = true }
half = { version = "2.3.1", optional = true, features = ["num-traits", "rand_distr"] }
gemm = { version = "0.16.14", default-features = false, optional = true, features = ["rayon"] }
gemm = { version = "0.17.1", default-features = false, optional = true, features = ["rayon"] }
rayon = { version = "1.7.0", optional = true }
libm = { workspace = true }
wgpu = { version = "0.18.0", features = ["glsl", "spirv"], optional = true }
Expand Down
1 change: 1 addition & 0 deletions dfdx-core/src/data/collate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ impl<A, B> Collate for Vec<(A, B)> {
impl<'a, A, B> Collate for Vec<&'a (A, B)> {
type Collated = (Vec<&'a A>, Vec<&'a B>);
fn collated(self) -> Self::Collated {
#[allow(clippy::map_identity)]
self.into_iter().map(|(a, b)| (a, b)).unzip()
}
}
Expand Down
38 changes: 0 additions & 38 deletions dfdx-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,44 +128,6 @@ pub mod prelude {
pub use crate::tensor_ops::*;
}

/// Sets a CPU `sse` flag to flush denormal floating point numbers to zero. The opposite of this is [keep_denormals()].
///
/// Some resources:
/// 1. [Effects of Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/the-effects-of-using-flush-to-zero-mode?lang=en)
/// 2. [When to use Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/when-to-use-flush-to-zero-mode?lang=en)
pub fn flush_denormals_to_zero() {
#[cfg(all(target_arch = "x86", target_feature = "sse"))]
{
use std::arch::x86::{_MM_FLUSH_ZERO_ON, _MM_SET_FLUSH_ZERO_MODE};
unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON) }
}

#[cfg(all(target_arch = "x86_64", target_feature = "sse"))]
{
use std::arch::x86_64::{_MM_FLUSH_ZERO_ON, _MM_SET_FLUSH_ZERO_MODE};
unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON) }
}
}

/// Sets a CPU flag to keep denormal floating point numbers. The opposite of this is [flush_denormals_to_zero()].
///
/// Some resources:
/// 1. [Effects of Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/the-effects-of-using-flush-to-zero-mode?lang=en)
/// 2. [When to use Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/when-to-use-flush-to-zero-mode?lang=en)
pub fn keep_denormals() {
#[cfg(all(target_arch = "x86", target_feature = "sse"))]
{
use std::arch::x86::{_MM_FLUSH_ZERO_OFF, _MM_SET_FLUSH_ZERO_MODE};
unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_OFF) }
}

#[cfg(all(target_arch = "x86_64", target_feature = "sse"))]
{
use std::arch::x86_64::{_MM_FLUSH_ZERO_OFF, _MM_SET_FLUSH_ZERO_MODE};
unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_OFF) }
}
}

#[cfg(test)]
pub(crate) mod tests {
pub use num_traits::{Float, NumCast, Zero};
Expand Down
2 changes: 1 addition & 1 deletion dfdx-core/src/tensor/gradients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ impl<E, D: Storage<E>> Gradients<E, D> {
#[inline]
pub(crate) fn many_and_ref<L: Shape, R: Shape>(
&mut self,
ls: &Vec<impl Tensorlike<L, E, D>>,
ls: &[impl Tensorlike<L, E, D>],
r: &impl Tensorlike<R, E, D>,
) -> (Vec<&mut D::Vec>, &D::Vec) {
for i in 0..ls.len() {
Expand Down
8 changes: 4 additions & 4 deletions dfdx-core/src/tensor_ops/concat_tensor_along/cpu_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ impl<E: Dtype> super::ConcatAlongKernel<E> for Cpu {
let buf = std::sync::Arc::get_mut(&mut c.data).unwrap();
while i < n {
for _ in 0..a_n {
buf[i] = a.data[a_idx.next().unwrap()];
(*buf)[i] = a.data[a_idx.next().unwrap()];
i += 1;
}
for _ in 0..b_n {
buf[i] = b.data[b_idx.next().unwrap()];
(*buf)[i] = b.data[b_idx.next().unwrap()];
i += 1;
}
}
Expand Down Expand Up @@ -59,11 +59,11 @@ impl<E: Dtype> super::ConcatAlongKernel<E> for Cpu {
let n = grad_out.len();
while i < n {
for _ in 0..a_n {
grad_a[a_idx.next().unwrap()] += grad_out[i];
(*grad_a)[a_idx.next().unwrap()] += grad_out[i];
i += 1;
}
for _ in 0..b_n {
grad_b[b_idx.next().unwrap()] += grad_out[i];
(*grad_b)[b_idx.next().unwrap()] += grad_out[i];
i += 1;
}
}
Expand Down
8 changes: 4 additions & 4 deletions dfdx-core/src/tensor_ops/concat_tensor_along/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ mod webgpu_kernel;
/// # let dev: Cpu = Default::default();
/// let a: Tensor<Rank2<3, 4>, f32, _> = dev.zeros();
/// let b: Tensor<Rank2<3, 4>, f32, _> = dev.zeros();
/// let _: Tensor<Rank2<6, 4>, f32, _> = (a, b).concat_along(Axis::<0>);
/// let _: Tensor<Rank2<6, 4>, f32, _> = (a, b).concat_tensor_along(Axis::<0>);
/// ```
///
/// Along Axis 1:
Expand All @@ -28,7 +28,7 @@ mod webgpu_kernel;
/// # let dev: Cpu = Default::default();
/// let a: Tensor<Rank2<3, 4>, f32, _> = dev.zeros();
/// let b: Tensor<Rank2<3, 4>, f32, _> = dev.zeros();
/// let _: Tensor<Rank2<3, 8>, f32, _> = (a, b).concat_along(Axis::<1>);
/// let _: Tensor<Rank2<3, 8>, f32, _> = (a, b).concat_tensor_along(Axis::<1>);
/// ```
///
/// # [usize] dims
Expand All @@ -38,7 +38,7 @@ mod webgpu_kernel;
/// # let dev: Cpu = Default::default();
/// let a: Tensor<(usize, Const<3>), f32, _> = dev.zeros_like(&(2, Const));
/// let b: Tensor<(usize, Const<3>), f32, _> = dev.zeros_like(&(4, Const));
/// let _: Tensor<Rank2<6, 3>, f32, _> = (a, b).concat_along(Axis::<0>).realize();
/// let _: Tensor<Rank2<6, 3>, f32, _> = (a, b).concat_tensor_along(Axis::<0>).realize();
/// ```
///
/// Along Axis 1:
Expand All @@ -47,7 +47,7 @@ mod webgpu_kernel;
/// # let dev: Cpu = Default::default();
/// let a: Tensor<(Const<2>, usize), f32, _> = dev.zeros_like(&(Const, 2));
/// let b: Tensor<(Const<2>, usize), f32, _> = dev.zeros_like(&(Const, 4));
/// let _: Tensor<Rank2<2, 6>, f32, _> = (a, b).concat_along(Axis::<1>).realize();
/// let _: Tensor<Rank2<2, 6>, f32, _> = (a, b).concat_tensor_along(Axis::<1>).realize();
/// ```
pub trait TryConcatTensorAlong<Ax>: Sized {
type Output;
Expand Down
4 changes: 4 additions & 0 deletions dfdx-core/src/tensor_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ mod sigmoid;
mod sin;
mod slice;
mod softmax;
mod split_shape_along;
mod split_tensor_along;
mod sqrt;
mod square;
mod stack;
Expand Down Expand Up @@ -267,6 +269,8 @@ pub use sigmoid::sigmoid;
pub use sin::sin;
pub use slice::slice;
pub use softmax::softmax;
pub use split_shape_along::TrySplitShapeAlong;
pub use split_tensor_along::TrySplitTensorAlong;
pub use sqrt::sqrt;
pub use square::square;
pub use stack::{AddDim, TryStack};
Expand Down
158 changes: 158 additions & 0 deletions dfdx-core/src/tensor_ops/split_shape_along/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
use crate::{shapes::*, tensor::*};

/// Split a shape in two along a given axis.
///
/// # [Const] dims **requires nightly**
///
/// Along Axis 0:
/// ```ignore
/// # use dfdx_core::prelude::*;
/// # let dev: Cpu = Default::default();
/// let (a, b): (Rank2<3, 3>, Rank2<4, 3>) = (Const::<7>, Const::<3>).split_shape_along(Axis::<0>, Const::<3>, Const::<4>);
/// ```
///
/// Along Axis 1:
/// ```ignore
/// # use dfdx_core::prelude::*;
/// # let dev: Cpu = Default::default();
/// let (a, b): (Rank2<7, 2>, Rank2<7, 1>) = (Const::<7>, Const::<3>).split_shape_along(Axis::<1>, Const::<2>, Const::<1>);
/// ```
///
/// # [usize] dims
/// Along Axis 0:
/// ```rust
/// # use dfdx_core::prelude::*;
/// # let dev: Cpu = Default::default();
/// let (a, b) = (7, Const::<3>).split_shape_along(Axis::<0>, 3, 4);
/// assert_eq!(a, (3, Const::<3>));
/// assert_eq!(b, (4, Const::<3>));
/// ```
///
/// Along Axis 1:
/// ```rust
/// # use dfdx_core::prelude::*;
/// # let dev: Cpu = Default::default();
/// let (a, b) = (Const::<7>, 3).split_shape_along(Axis::<1>, 2, 1);
/// assert_eq!(a, (Const::<7>, 2));
/// assert_eq!(b, (Const::<7>, 1));
/// ```
pub trait TrySplitShapeAlong<Ax, A: Dim, B: Dim>: Shape {
type Output;

/// Splits self along the given axis.
fn split_shape_along(self, ax: Ax, a: A, b: B) -> Self::Output {
self.try_split_shape_along(ax, a, b).unwrap()
}
/// Fallibly splits self along the given axis.
fn try_split_shape_along(self, ax: Ax, a: A, b: B) -> Result<Self::Output, Error>;
}

macro_rules! impl_split {
($Ax:expr, $NumDims:expr, [$($Head:tt),*], [$($Tail:tt),*]) => {
impl<A: Dim, B: Dim, AB:Dim, $($Head: Dim, )* $($Tail: Dim, )*> TrySplitShapeAlong<Axis<$Ax>, A, B>
for
(
$($Head, )*
AB,
$($Tail, )*
)
where
($($Head, )* A, $($Tail, )*): Shape<Concrete = <Self as Shape>::Concrete>,
($($Head, )* B, $($Tail, )*): Shape<Concrete = <Self as Shape>::Concrete>,
{
type Output =
(
($($Head, )* A, $($Tail, )*),
($($Head, )* B, $($Tail, )*),
);

fn try_split_shape_along(self, _: Axis<$Ax>, a: A, b: B) -> Result<Self::Output, Error> {
let dims = self.concrete();
let mut lhs_dims = dims;
let mut rhs_dims = dims;
lhs_dims[$Ax] = a.size();
rhs_dims[$Ax] = b.size();
assert_eq!(dims[$Ax], lhs_dims[$Ax] + rhs_dims[$Ax]);

Ok((
<($($Head, )* A, $($Tail, )*)>::from_concrete(&lhs_dims).unwrap(),
<($($Head, )* B, $($Tail, )*)>::from_concrete(&rhs_dims).unwrap(),
))
}
}
};
}

impl_split!(0, 1, [], []);
impl_split!(0, 2, [], [D1]);
impl_split!(0, 3, [], [D1, D2]);
impl_split!(0, 4, [], [D1, D2, D3]);
impl_split!(0, 5, [], [D1, D2, D3, D4]);
impl_split!(0, 6, [], [D1, D2, D3, D4, D5]);

impl_split!(1, 2, [D0], []);
impl_split!(1, 3, [D0], [D2]);
impl_split!(1, 4, [D0], [D2, D3]);
impl_split!(1, 5, [D0], [D2, D3, D4]);
impl_split!(1, 6, [D0], [D2, D3, D4, D5]);

impl_split!(2, 3, [D0, D1], []);
impl_split!(2, 4, [D0, D1], [D3]);
impl_split!(2, 5, [D0, D1], [D3, D4]);
impl_split!(2, 6, [D0, D1], [D3, D4, D5]);

impl_split!(3, 4, [D0, D1, D2], []);
impl_split!(3, 5, [D0, D1, D2], [D4]);
impl_split!(3, 6, [D0, D1, D2], [D4, D5]);

impl_split!(4, 5, [D0, D1, D2, D3], []);
impl_split!(4, 6, [D0, D1, D2, D3], [D5]);

impl_split!(5, 6, [D0, D1, D2, D3, D4], []);

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_split_shape() {
let a: (usize, Const<5>) = (5, Const);
let b: (usize, Const<5>) = (3, Const);
assert_eq!(
(8, Const::<5>).split_shape_along(Axis::<0>, a.0, b.0),
(a, b)
);

let a: (Const<5>, Const<5>) = (Const, Const);
let b: (usize, Const<5>) = (3, Const);
assert_eq!(
(8, Const::<5>).split_shape_along(Axis::<0>, a.0, b.0),
(a, b)
);

let a: (usize, Const<5>) = (5, Const);
let b: (Const<3>, Const<5>) = (Const, Const);
assert_eq!(
(8, Const::<5>).split_shape_along(Axis::<0>, a.0, b.0),
(a, b)
);

#[cfg(feature = "nightly")]
{
let a: (Const<5>, Const<5>) = (Const, Const);
let b: (Const<3>, Const<5>) = (Const, Const);
assert_eq!(
(Const::<8>, Const::<5>).split_shape_along(Axis::<0>, a.0, b.0),
(a, b)
);
}
}

#[test]
#[should_panic = "left: 8\n right: 7"]
fn test_split_shape_fails() {
let a: (usize, Const<5>) = (4, Const);
let b: (usize, Const<5>) = (3, Const);
(8, Const::<5>).split_shape_along(Axis::<0>, a.0, b.0);
}
}
Loading
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载