+
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,8 @@ harness = false

[[bench]]
name = "sum"
harness = false

[[bench]]
name = "softmax"
harness = false
36 changes: 36 additions & 0 deletions benches/softmax.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use std::time::Instant;

use dfdx::prelude::*;

#[cfg(feature = "cuda")]
type Dev = Cuda;

#[cfg(not(feature = "cuda"))]
type Dev = Cpu;

type Dtype = f32;
type InputShape = Rank4<32, 64, 128, 256>;
type Ax = Axis<3>;

fn main() {
println!("Benchmarking `softmax` {}", std::any::type_name::<Ax>());
println!("Device {}", std::any::type_name::<Dev>());
println!("Dtype {}", std::any::type_name::<Dtype>());
println!("Input shape {}", std::any::type_name::<InputShape>());
println!();

let dev: Dev = Default::default();

loop {
let img: Tensor<InputShape, Dtype, _> = dev.sample_normal();

let start = Instant::now();
let y = img.traced().softmax::<Ax>();
let fwd_dur = start.elapsed();

let start = Instant::now();
let _ = y.sum().backward();
let bwd_dur = start.elapsed();
println!("fwd={:?} bwd={:?}", fwd_dur, bwd_dur);
}
}
2 changes: 1 addition & 1 deletion src/shapes/broadcasts.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::*;

/// Marker for shapes that can be reduced to [Shape] `S` along [Axes] `Ax`.
pub trait ReduceShapeTo<S, Ax>: Sized {}
pub trait ReduceShapeTo<S, Ax>: HasAxes<Ax> + Sized {}

/// Marker for shapes that can be broadcasted to [Shape] `S` along [Axes] `Ax`.
pub trait BroadcastShapeTo<S, Ax>: Sized {}
Expand Down
39 changes: 8 additions & 31 deletions src/tensor/cpu/iterate.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use super::device::StridedArray;
use crate::shapes::{BroadcastStridesTo, Shape};
use std::sync::Arc;
use crate::shapes::Shape;
use std::vec::Vec;

#[derive(Debug, Eq, PartialEq)]
pub(crate) struct NdIndex<S: Shape> {
indices: S::Concrete,
shape: S::Concrete,
strides: S::Concrete,
next: Option<usize>,
contiguous: Option<usize>,
pub(crate) indices: S::Concrete,
pub(crate) shape: S::Concrete,
pub(crate) strides: S::Concrete,
pub(crate) next: Option<usize>,
pub(crate) contiguous: Option<usize>,
}

impl<S: Shape> NdIndex<S> {
Expand Down Expand Up @@ -145,30 +145,6 @@ impl<S: Shape, E: Clone> StridedArray<S, E> {
}
}

impl<S: Shape, E: Clone> StridedArray<S, E> {
#[inline]
pub(crate) fn iter_as<Axes, Dst: Shape>(&self, dst: &Dst) -> StridedRefIter<Dst, E>
where
S: BroadcastStridesTo<Dst, Axes>,
{
StridedRefIter {
data: self.data.as_ref(),
index: NdIndex::new(*dst, self.shape.broadcast_strides(self.strides)),
}
}

#[inline]
pub(crate) fn iter_mut_as<Axes, Dst: Shape>(&mut self, dst: &Dst) -> StridedMutIter<Dst, E>
where
S: BroadcastStridesTo<Dst, Axes>,
{
StridedMutIter {
data: Arc::make_mut(&mut self.data),
index: NdIndex::new(*dst, self.shape.broadcast_strides(self.strides)),
}
}
}

pub(crate) trait LendingIterator {
type Item<'a>
where
Expand Down Expand Up @@ -215,6 +191,7 @@ impl<'q, S: Shape, E> LendingIterator for StridedMutIndexIter<'q, S, E> {
#[cfg(test)]
mod tests {
use crate::shapes::{Rank0, Rank1, Rank2, Rank3};
use std::sync::Arc;

use super::*;

Expand Down
70 changes: 45 additions & 25 deletions src/tensor_ops/max_to/cpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,49 +1,69 @@
use crate::{
shapes::{Axes, Dtype, ReduceShapeTo, Shape},
tensor::cpu::{Cpu, LendingIterator, StridedArray},
shapes::{Axes, Dtype, HasAxes, ReduceShapeTo, Shape},
tensor::cpu::{Cpu, StridedArray},
tensor_ops::utilities::reduction_utils::index_for_reductions,
};

use num_traits::Float;

impl<F: Dtype + Float> super::MaxReduceKernel<F> for Cpu {
impl<E: Dtype + Float> super::MaxReduceKernel<E> for Cpu {
fn forward<Src: Shape, Dst: Shape, Ax: Axes>(
&self,
dst: Dst,
inp: &Self::Storage<Src, F>,
) -> Result<Self::Storage<Dst, F>, Self::Err>
inp: &Self::Storage<Src, E>,
) -> Result<Self::Storage<Dst, E>, Self::Err>
where
Src: ReduceShapeTo<Dst, Ax>,
{
let mut out: StridedArray<Dst, F> = StridedArray::try_new_with(dst, F::neg_infinity())?;
let mut out_iter = out.iter_mut_as(&inp.shape);
let mut inp_iter = inp.iter();
while let Some((out_i, inp_i)) = out_iter.next().zip(inp_iter.next()) {
*out_i = F::max(*out_i, *inp_i);
let mut out: StridedArray<Dst, E> = StridedArray::new(dst)?;
if Dst::NUM_DIMS == 0 {
debug_assert_eq!(out.data.len(), 1);
let mut tmp: E = E::neg_infinity();
for i in inp.buf_iter() {
tmp = i.max(tmp);
}
std::sync::Arc::get_mut(&mut out.data).unwrap()[0] = tmp;
} else {
let num_elems_reduced = <Src as HasAxes<Ax>>::size(&inp.shape);
let inp_buf = inp.data.as_ref();
let mut idx = index_for_reductions::<Src, Ax>(inp.shape, inp.strides);
for o in out.buf_iter_mut() {
let mut tmp: E = E::neg_infinity();
for _ in 0..num_elems_reduced {
tmp = tmp.max(inp_buf[idx.next().unwrap()]);
}
*o = tmp;
}
}
Ok(out)
}

fn backward<Src: Shape, Dst: Shape, Ax: Axes>(
&self,
inp: &Self::Storage<Src, F>,
grad_inp: &mut Self::Storage<Src, F>,
out: &Self::Storage<Dst, F>,
grad_out: &Self::Storage<Dst, F>,
inp: &Self::Storage<Src, E>,
grad_inp: &mut Self::Storage<Src, E>,
out: &Self::Storage<Dst, E>,
grad_out: &Self::Storage<Dst, E>,
) -> Result<(), Self::Err>
where
Src: ReduceShapeTo<Dst, Ax>,
{
let mut inp_iter = inp.iter();
let mut grad_inp_iter = grad_inp.iter_mut();
let mut out_iter = out.iter_as(&inp.shape);
let mut grad_out_iter = grad_out.iter_as(&inp.shape);
for _ in 0..inp.shape.num_elements() {
let d = if out_iter.next().unwrap() == inp_iter.next().unwrap() {
F::one()
} else {
F::zero()
};
*grad_inp_iter.next().unwrap() += *grad_out_iter.next().unwrap() * d;
let num_elems_reduced = <Src as HasAxes<Ax>>::size(&grad_inp.shape);

let grad_inp_buf = std::sync::Arc::make_mut(&mut grad_inp.data);
let inp_buf = inp.data.as_ref();
let mut inp_idx = index_for_reductions::<Src, Ax>(grad_inp.shape, grad_inp.strides);

for (&o, &go) in out.buf_iter().zip(grad_out.buf_iter()) {
for _ in 0..num_elems_reduced {
let inp_i = inp_idx.next().unwrap();
let d = if o == inp_buf[inp_i] {
E::one()
} else {
E::zero()
};
grad_inp_buf[inp_i] += go * d;
}
}
Ok(())
}
Expand Down
70 changes: 45 additions & 25 deletions src/tensor_ops/min_to/cpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,49 +1,69 @@
use crate::{
shapes::{Axes, Dtype, ReduceShapeTo, Shape},
tensor::cpu::{Cpu, LendingIterator, StridedArray},
shapes::{Axes, Dtype, HasAxes, ReduceShapeTo, Shape},
tensor::cpu::{Cpu, StridedArray},
tensor_ops::utilities::reduction_utils::index_for_reductions,
};

use num_traits::Float;

impl<F: Float + Dtype> super::MinReduceKernel<F> for Cpu {
impl<E: Dtype + Float> super::MinReduceKernel<E> for Cpu {
fn forward<Src: Shape, Dst: Shape, Ax: Axes>(
&self,
dst: Dst,
inp: &Self::Storage<Src, F>,
) -> Result<Self::Storage<Dst, F>, Self::Err>
inp: &Self::Storage<Src, E>,
) -> Result<Self::Storage<Dst, E>, Self::Err>
where
Src: ReduceShapeTo<Dst, Ax>,
{
let mut out: StridedArray<Dst, F> = StridedArray::try_new_with(dst, F::infinity())?;
let mut out_iter = out.iter_mut_as(&inp.shape);
let mut inp_iter = inp.iter();
while let Some((out_i, inp_i)) = out_iter.next().zip(inp_iter.next()) {
*out_i = F::min(*out_i, *inp_i);
let mut out: StridedArray<Dst, E> = StridedArray::new(dst)?;
if Dst::NUM_DIMS == 0 {
debug_assert_eq!(out.data.len(), 1);
let mut tmp: E = E::infinity();
for i in inp.buf_iter() {
tmp = i.min(tmp);
}
std::sync::Arc::get_mut(&mut out.data).unwrap()[0] = tmp;
} else {
let num_elems_reduced = <Src as HasAxes<Ax>>::size(&inp.shape);
let inp_buf = inp.data.as_ref();
let mut idx = index_for_reductions::<Src, Ax>(inp.shape, inp.strides);
for o in out.buf_iter_mut() {
let mut tmp: E = E::infinity();
for _ in 0..num_elems_reduced {
tmp = tmp.min(inp_buf[idx.next().unwrap()]);
}
*o = tmp;
}
}
Ok(out)
}

fn backward<Src: Shape, Dst: Shape, Ax: Axes>(
&self,
inp: &Self::Storage<Src, F>,
grad_inp: &mut Self::Storage<Src, F>,
out: &Self::Storage<Dst, F>,
grad_out: &Self::Storage<Dst, F>,
inp: &Self::Storage<Src, E>,
grad_inp: &mut Self::Storage<Src, E>,
out: &Self::Storage<Dst, E>,
grad_out: &Self::Storage<Dst, E>,
) -> Result<(), Self::Err>
where
Src: ReduceShapeTo<Dst, Ax>,
{
let mut inp_iter = inp.iter();
let mut grad_inp_itr = grad_inp.iter_mut();
let mut out_iter = out.iter_as(&inp.shape);
let mut grad_out_iter = grad_out.iter_as(&inp.shape);
for _ in 0..inp.shape.num_elements() {
let d = if out_iter.next().unwrap() == inp_iter.next().unwrap() {
F::one()
} else {
F::zero()
};
*grad_inp_itr.next().unwrap() += *grad_out_iter.next().unwrap() * d;
let num_elems_reduced = <Src as HasAxes<Ax>>::size(&grad_inp.shape);

let grad_inp_buf = std::sync::Arc::make_mut(&mut grad_inp.data);
let inp_buf = inp.data.as_ref();
let mut inp_idx = index_for_reductions::<Src, Ax>(grad_inp.shape, grad_inp.strides);

for (&o, &go) in out.buf_iter().zip(grad_out.buf_iter()) {
for _ in 0..num_elems_reduced {
let inp_i = inp_idx.next().unwrap();
let d = if o == inp_buf[inp_i] {
E::one()
} else {
E::zero()
};
grad_inp_buf[inp_i] += go * d;
}
}
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion src/tensor_ops/sub/cpu_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ impl<F: num_traits::Float> BinaryDerivative<F> for super::BinarySubKernelOp {
}
#[inline(always)]
fn dfdy(&self, _: &F, _: &F) -> F {
F::one().neg()
-F::one()
}
}
48 changes: 38 additions & 10 deletions src/tensor_ops/sum_to/cpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::{
shapes::{Axes, Dtype, ReduceShapeTo, Shape},
tensor::cpu::{Cpu, LendingIterator, StridedArray},
shapes::{Axes, Dtype, HasAxes, ReduceShapeTo, Shape},
tensor::cpu::{Cpu, StridedArray},
tensor_ops::utilities::reduction_utils::index_for_reductions,
};

impl<E: Dtype> super::SumKernel<E> for Cpu {
Expand All @@ -13,10 +14,25 @@ impl<E: Dtype> super::SumKernel<E> for Cpu {
Src: ReduceShapeTo<Dst, Ax>,
{
let mut out: StridedArray<Dst, E> = StridedArray::new(dst)?;
let mut out_iter = out.iter_mut_as(&inp.shape);
let mut inp_iter = inp.iter();
while let Some((o, i)) = out_iter.next().zip(inp_iter.next()) {
o.add_assign(*i);
if Dst::NUM_DIMS == 0 {
debug_assert_eq!(out.data.len(), 1);
let scale = E::from_usize(inp.shape.num_elements() / inp.data.len()).unwrap();
let mut tmp: E = Default::default();
for v in inp.buf_iter() {
tmp += *v;
}
std::sync::Arc::get_mut(&mut out.data).unwrap()[0] = tmp * scale;
} else {
let num_elems_reduced = <Src as HasAxes<Ax>>::size(&inp.shape);
let inp_buf = inp.data.as_ref();
let mut idx = index_for_reductions::<Src, Ax>(inp.shape, inp.strides);
for o in out.buf_iter_mut() {
let mut tmp: E = Default::default();
for _ in 0..num_elems_reduced {
tmp += inp_buf[idx.next().unwrap()];
}
*o = tmp;
}
}
Ok(out)
}
Expand All @@ -28,10 +44,22 @@ impl<E: Dtype> super::SumKernel<E> for Cpu {
where
Src: ReduceShapeTo<Dst, Ax>,
{
let mut out_iter = grad_out.iter_as(&grad_inp.shape);
let mut inp_iter = grad_inp.iter_mut();
while let Some((i, o)) = inp_iter.next().zip(out_iter.next()) {
i.add_assign(*o);
if Dst::NUM_DIMS == 0 {
debug_assert_eq!(grad_out.data.len(), 1);
let v = grad_out.data[0];
let scale = E::from_usize(grad_inp.shape.num_elements() / grad_inp.data.len()).unwrap();
for i in grad_inp.buf_iter_mut() {
*i += v * scale;
}
} else {
let num_elems_reduced = <Src as HasAxes<Ax>>::size(&grad_inp.shape);
let inp_buf = std::sync::Arc::make_mut(&mut grad_inp.data);
let mut idx = index_for_reductions::<Src, Ax>(grad_inp.shape, grad_inp.strides);
for &o in grad_out.buf_iter() {
for _ in 0..num_elems_reduced {
inp_buf[idx.next().unwrap()] += o;
}
}
}
Ok(())
}
Expand Down
11 changes: 11 additions & 0 deletions src/tensor_ops/sum_to/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,15 @@ mod tests {
let c = b.sum::<Rank2<4, 3>, _>();
assert_eq!(c.array(), [[2.0, 4.0, 6.0]; 4]);
}

#[test]
fn test_sum_reduce_to_0d_from_broadcasted() {
let dev: TestDevice = Default::default();
let a: Tensor<Rank1<3>, TestDtype, _> = dev.ones();
let b = a.trace().broadcast::<Rank3<4, 3, 2>, _>();
let c = b.sum();
assert_eq!(c.array(), 24.0);
let g = c.backward();
assert_eq!(g.get(&a).array(), [8.0; 3]);
}
}
Loading
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载