+
Skip to content

[Breaking] Redesign broadcasts/reductions to enable multi axis reductions #190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Sep 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/arrays.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,19 @@ impl<T: CountElements, const M: usize> CountElements for [T; M] {
}
}

/// A single axis known at compile time
pub struct Axis<const I: isize>;

/// Two axes known at compile time.
pub type Axes2<const I: isize, const J: isize> = (Axis<I>, Axis<J>);

/// Three axes known at compile time.
pub type Axes3<const I: isize, const J: isize, const K: isize> = (Axis<I>, Axis<J>, Axis<K>);

/// Four axes known at compile time.
pub type Axes4<const I: isize, const J: isize, const K: isize, const L: isize> =
(Axis<I>, Axis<J>, Axis<K>, Axis<L>);

/// An NdArray that has an `I`th axis
pub trait HasAxis<const I: isize> {
/// The size of the axis. E.g. an nd array of shape (M, N, O):
Expand Down
363 changes: 0 additions & 363 deletions src/devices/broadcast.rs

This file was deleted.

124 changes: 124 additions & 0 deletions src/devices/broadcast_reduce/accumulator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
use super::indexing::{IndexMut, IndexRef};

/// Accumulates sequence of values into a single value. Used
/// for reductions & broadcasts.
pub trait Accumulator<T> {
/// The initial value to set the accumulator to.
const INIT: T;

fn accum(accum: &mut T, item: &T);
}

pub(crate) struct MaxAccum;
impl Accumulator<f32> for MaxAccum {
const INIT: f32 = f32::NEG_INFINITY;
fn accum(accum: &mut f32, item: &f32) {
*accum = accum.max(*item);
}
}

pub(crate) struct MinAccum;
impl Accumulator<f32> for MinAccum {
const INIT: f32 = f32::INFINITY;
fn accum(accum: &mut f32, item: &f32) {
*accum = accum.min(*item);
}
}

pub(crate) struct AddAccum;
impl Accumulator<f32> for AddAccum {
const INIT: f32 = 0.0;
fn accum(accum: &mut f32, item: &f32) {
*accum += item;
}
}

pub(crate) struct SubAccum;
impl Accumulator<f32> for SubAccum {
const INIT: f32 = 0.0;
fn accum(accum: &mut f32, item: &f32) {
*accum -= item;
}
}

pub(crate) struct MulAccum;
impl Accumulator<f32> for MulAccum {
const INIT: f32 = 1.0;
fn accum(accum: &mut f32, item: &f32) {
*accum *= item;
}
}

pub(crate) struct CopyAccum;
impl Accumulator<f32> for CopyAccum {
const INIT: f32 = 0.0;
fn accum(accum: &mut f32, item: &f32) {
*accum = *item;
}
}

pub(crate) struct EqAccum;
impl Accumulator<f32> for EqAccum {
const INIT: f32 = 0.0;
fn accum(accum: &mut f32, item: &f32) {
*accum = if accum == item { 1.0 } else { 0.0 };
}
}

pub(super) fn accum1d<A, L, R, const M: usize>(l: &mut L, r: &R)
where
L: IndexMut<Index = usize>,
R: IndexRef<Index = usize, Element = L::Element>,
A: Accumulator<L::Element>,
{
for m in 0..M {
A::accum(l.index_mut(m), r.index_ref(m));
}
}

pub(super) fn accum2d<A, L, R, const M: usize, const N: usize>(l: &mut L, r: &R)
where
L: IndexMut<Index = [usize; 2]>,
R: IndexRef<Index = [usize; 2], Element = L::Element>,
A: Accumulator<L::Element>,
{
for m in 0..M {
for n in 0..N {
A::accum(l.index_mut([m, n]), r.index_ref([m, n]));
}
}
}

pub(super) fn accum3d<A, L, R, const M: usize, const N: usize, const O: usize>(l: &mut L, r: &R)
where
L: IndexMut<Index = [usize; 3]>,
R: IndexRef<Index = [usize; 3], Element = L::Element>,
A: Accumulator<L::Element>,
{
for m in 0..M {
for n in 0..N {
for o in 0..O {
A::accum(l.index_mut([m, n, o]), r.index_ref([m, n, o]));
}
}
}
}

pub(super) fn accum4d<A, L, R, const M: usize, const N: usize, const O: usize, const P: usize>(
l: &mut L,
r: &R,
) where
L: IndexMut<Index = [usize; 4]>,
R: IndexRef<Index = [usize; 4], Element = L::Element>,
A: Accumulator<L::Element>,
{
for m in 0..M {
for n in 0..N {
for o in 0..O {
for p in 0..P {
A::accum(l.index_mut([m, n, o, p]), r.index_ref([m, n, o, p]));
}
}
}
}
}
167 changes: 167 additions & 0 deletions src/devices/broadcast_reduce/indexing.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
use crate::arrays::{Axes2, Axes3, Axes4, Axis};
use std::marker::PhantomData;

/// Broadcasts `&'a T` along `Axes` to enable indexing as a higher dimensional array.
pub(super) struct BroadcastRef<'a, T, Axes>(pub &'a T, PhantomData<*const Axes>);

impl<'a, T, Axes> BroadcastRef<'a, T, Axes> {
pub fn new(t: &'a T) -> Self {
Self(t, PhantomData)
}
}

/// Broadcasts `&'a mut T` along `Axes` to enable indexing as a higher dimensional array.
pub(super) struct BroadcastMut<'a, T, Axes>(pub &'a mut T, PhantomData<*const Axes>);

impl<'a, T, Axes> BroadcastMut<'a, T, Axes> {
pub fn new(t: &'a mut T) -> Self {
Self(t, PhantomData)
}
}

/// Index to get a `&Self::Element`.
pub(super) trait IndexRef {
type Index;
type Element;
fn index_ref(&self, i: Self::Index) -> &Self::Element;
}

/// Index to get a `&mut Self::Element`.
pub(super) trait IndexMut {
type Index;
type Element;
fn index_mut(&mut self, i: Self::Index) -> &mut Self::Element;
}

impl<const M: usize> IndexRef for [f32; M] {
type Index = usize;
type Element = f32;
fn index_ref(&self, i: Self::Index) -> &Self::Element {
&self[i]
}
}

impl<const M: usize> IndexMut for [f32; M] {
type Index = usize;
type Element = f32;
fn index_mut(&mut self, i: Self::Index) -> &mut Self::Element {
&mut self[i]
}
}

impl<const M: usize, const N: usize> IndexRef for [[f32; N]; M] {
type Index = [usize; 2];
type Element = f32;
fn index_ref(&self, i: Self::Index) -> &Self::Element {
&self[i[0]][i[1]]
}
}

impl<const M: usize, const N: usize> IndexMut for [[f32; N]; M] {
type Index = [usize; 2];
type Element = f32;
fn index_mut(&mut self, i: Self::Index) -> &mut Self::Element {
&mut self[i[0]][i[1]]
}
}

impl<const M: usize, const N: usize, const O: usize> IndexRef for [[[f32; O]; N]; M] {
type Index = [usize; 3];
type Element = f32;
fn index_ref(&self, i: Self::Index) -> &Self::Element {
&self[i[0]][i[1]][i[2]]
}
}

impl<const M: usize, const N: usize, const O: usize> IndexMut for [[[f32; O]; N]; M] {
type Index = [usize; 3];
type Element = f32;
fn index_mut(&mut self, i: Self::Index) -> &mut Self::Element {
&mut self[i[0]][i[1]][i[2]]
}
}

impl<const M: usize, const N: usize, const O: usize, const P: usize> IndexRef
for [[[[f32; P]; O]; N]; M]
{
type Index = [usize; 4];
type Element = f32;
fn index_ref(&self, i: Self::Index) -> &Self::Element {
&self[i[0]][i[1]][i[2]][i[3]]
}
}

impl<const M: usize, const N: usize, const O: usize, const P: usize> IndexMut
for [[[[f32; P]; O]; N]; M]
{
type Index = [usize; 4];
type Element = f32;
fn index_mut(&mut self, i: Self::Index) -> &mut Self::Element {
&mut self[i[0]][i[1]][i[2]][i[3]]
}
}

macro_rules! impl_bcast {
($ArrTy:ty, [$($Idx:expr),*], $AxisTy:ty, $IdxTy:ty, {$($CVars:tt),*}) => {
impl<'a, $(const $CVars: usize, )*> IndexRef for BroadcastRef<'a, $ArrTy, $AxisTy> {
type Index = $IdxTy;
type Element = f32;
#[allow(unused_variables)]
fn index_ref(&self, i: Self::Index) -> &Self::Element {
&self.0 $([i[$Idx]])*
}
}
impl<'a, $(const $CVars: usize, )*> IndexMut for BroadcastMut<'a, $ArrTy, $AxisTy> {
type Index = $IdxTy;
type Element = f32;
#[allow(unused_variables)]
fn index_mut(&mut self, i: Self::Index) -> &mut Self::Element {
&mut self.0 $([i[$Idx]])*
}
}
};
}

// 0d -> nd
impl_bcast!(f32, [], Axis<-1>, usize, {});
impl_bcast!(f32, [], Axis<0>, usize, {});
impl_bcast!(f32, [], Axes2<0, 1>, [usize; 2], {});
impl_bcast!(f32, [], Axes3<0, 1, 2>, [usize; 3], {});
impl_bcast!(f32, [], Axes4<0, 1, 2, 3>, [usize; 4], {});

// 1d -> 2d
impl_bcast!([f32; M], [0], Axis<-1>, [usize; 2], { M });
impl_bcast!([f32; M], [0], Axis<1>, [usize; 2], { M });
impl_bcast!([f32; M], [1], Axis<0>, [usize; 2], { M });

// 1d -> 3d
impl_bcast!([f32; M], [2], Axes2<0, 1>, [usize; 3], { M });
impl_bcast!([f32; M], [1], Axes2<0, 2>, [usize; 3], { M });
impl_bcast!([f32; M], [0], Axes2<1, 2>, [usize; 3], { M });

// 1d -> 4d
impl_bcast!([f32; M], [3], Axes3<0, 1, 2>, [usize; 4], { M });
impl_bcast!([f32; M], [2], Axes3<0, 1, 3>, [usize; 4], { M });
impl_bcast!([f32; M], [1], Axes3<0, 2, 3>, [usize; 4], { M });
impl_bcast!([f32; M], [0], Axes3<1, 2, 3>, [usize; 4], { M });

// 2d -> 3d
impl_bcast!([[f32; N]; M], [0, 1], Axis<-1>, [usize; 3], {M, N});
impl_bcast!([[f32; N]; M], [0, 1], Axis<2>, [usize; 3], {M, N});
impl_bcast!([[f32; N]; M], [0, 2], Axis<1>, [usize; 3], {M, N});
impl_bcast!([[f32; N]; M], [1, 2], Axis<0>, [usize; 3], {M, N});

// 2d -> 4d
impl_bcast!([[f32; N]; M], [2, 3], Axes2<0, 1>, [usize; 4], {M, N});
impl_bcast!([[f32; N]; M], [1, 3], Axes2<0, 2>, [usize; 4], {M, N});
impl_bcast!([[f32; N]; M], [1, 2], Axes2<0, 3>, [usize; 4], {M, N});
impl_bcast!([[f32; N]; M], [0, 3], Axes2<1, 2>, [usize; 4], {M, N});
impl_bcast!([[f32; N]; M], [0, 2], Axes2<1, 3>, [usize; 4], {M, N});
impl_bcast!([[f32; N]; M], [0, 1], Axes2<2, 3>, [usize; 4], {M, N});

// 3d -> 4d
impl_bcast!([[[f32; O]; N]; M], [0, 1, 2], Axis<-1>, [usize; 4], {M, N, O});
impl_bcast!([[[f32; O]; N]; M], [0, 1, 2], Axis<3>, [usize; 4], {M, N, O});
impl_bcast!([[[f32; O]; N]; M], [0, 1, 3], Axis<2>, [usize; 4], {M, N, O});
impl_bcast!([[[f32; O]; N]; M], [0, 2, 3], Axis<1>, [usize; 4], {M, N, O});
impl_bcast!([[[f32; O]; N]; M], [1, 2, 3], Axis<0>, [usize; 4], {M, N, O});
Loading
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载