-
-
Notifications
You must be signed in to change notification settings - Fork 105
[Breaking] Redesign broadcasts/reductions to enable multi axis reductions #190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
3e9a098
Rework broadcast/reduce
coreylowman 639eec6
Renaming accums
coreylowman 8d7f05d
Fixing unit tests
coreylowman ef7cf4a
Reworking to use Axis
coreylowman eea2207
Adding -1 to 1d reductions
coreylowman 655337e
moving axes to arrays
coreylowman 0d12bc7
Reworking tensor_ops/broadcast
coreylowman 90816e0
Building
coreylowman 9a43472
Tests passing
coreylowman dc918b2
Removing old broadcast/reduce
coreylowman 0c777e5
Removing commented code
coreylowman 24e310e
Fixing docstrings
coreylowman 374ada9
Update export of DeviceReduce
coreylowman dc161f3
Renaming broadcast to impl_broadcast_reduce
coreylowman df8ba4e
Updating docstrings of broadcast/reduce
coreylowman a7dd473
Add docstrings
coreylowman a8f3b86
Adding multi axis reduce tests
coreylowman a710d27
Adding docstring to accumulator
coreylowman 9202135
Renaming IndexRef and IndexMut
coreylowman 2698814
Simplifing index macros
coreylowman 1ac7f32
Renaming macros
coreylowman 9b07562
Try to fix flaky broadcast test
coreylowman File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
use super::indexing::{IndexMut, IndexRef}; | ||
|
||
/// Accumulates sequence of values into a single value. Used | ||
/// for reductions & broadcasts. | ||
pub trait Accumulator<T> { | ||
coreylowman marked this conversation as resolved.
Show resolved
Hide resolved
coreylowman marked this conversation as resolved.
Show resolved
Hide resolved
|
||
/// The initial value to set the accumulator to. | ||
const INIT: T; | ||
|
||
fn accum(accum: &mut T, item: &T); | ||
} | ||
|
||
pub(crate) struct MaxAccum; | ||
impl Accumulator<f32> for MaxAccum { | ||
const INIT: f32 = f32::NEG_INFINITY; | ||
fn accum(accum: &mut f32, item: &f32) { | ||
*accum = accum.max(*item); | ||
} | ||
} | ||
|
||
pub(crate) struct MinAccum; | ||
impl Accumulator<f32> for MinAccum { | ||
const INIT: f32 = f32::INFINITY; | ||
fn accum(accum: &mut f32, item: &f32) { | ||
*accum = accum.min(*item); | ||
} | ||
} | ||
|
||
pub(crate) struct AddAccum; | ||
impl Accumulator<f32> for AddAccum { | ||
const INIT: f32 = 0.0; | ||
fn accum(accum: &mut f32, item: &f32) { | ||
*accum += item; | ||
} | ||
} | ||
|
||
pub(crate) struct SubAccum; | ||
impl Accumulator<f32> for SubAccum { | ||
const INIT: f32 = 0.0; | ||
fn accum(accum: &mut f32, item: &f32) { | ||
*accum -= item; | ||
} | ||
} | ||
|
||
pub(crate) struct MulAccum; | ||
impl Accumulator<f32> for MulAccum { | ||
const INIT: f32 = 1.0; | ||
fn accum(accum: &mut f32, item: &f32) { | ||
*accum *= item; | ||
} | ||
} | ||
|
||
pub(crate) struct CopyAccum; | ||
impl Accumulator<f32> for CopyAccum { | ||
const INIT: f32 = 0.0; | ||
fn accum(accum: &mut f32, item: &f32) { | ||
*accum = *item; | ||
} | ||
} | ||
|
||
pub(crate) struct EqAccum; | ||
impl Accumulator<f32> for EqAccum { | ||
const INIT: f32 = 0.0; | ||
fn accum(accum: &mut f32, item: &f32) { | ||
*accum = if accum == item { 1.0 } else { 0.0 }; | ||
} | ||
} | ||
|
||
pub(super) fn accum1d<A, L, R, const M: usize>(l: &mut L, r: &R) | ||
where | ||
L: IndexMut<Index = usize>, | ||
R: IndexRef<Index = usize, Element = L::Element>, | ||
A: Accumulator<L::Element>, | ||
{ | ||
for m in 0..M { | ||
A::accum(l.index_mut(m), r.index_ref(m)); | ||
} | ||
} | ||
|
||
pub(super) fn accum2d<A, L, R, const M: usize, const N: usize>(l: &mut L, r: &R) | ||
where | ||
L: IndexMut<Index = [usize; 2]>, | ||
R: IndexRef<Index = [usize; 2], Element = L::Element>, | ||
A: Accumulator<L::Element>, | ||
{ | ||
for m in 0..M { | ||
for n in 0..N { | ||
A::accum(l.index_mut([m, n]), r.index_ref([m, n])); | ||
} | ||
} | ||
} | ||
|
||
pub(super) fn accum3d<A, L, R, const M: usize, const N: usize, const O: usize>(l: &mut L, r: &R) | ||
where | ||
L: IndexMut<Index = [usize; 3]>, | ||
R: IndexRef<Index = [usize; 3], Element = L::Element>, | ||
A: Accumulator<L::Element>, | ||
{ | ||
for m in 0..M { | ||
for n in 0..N { | ||
for o in 0..O { | ||
A::accum(l.index_mut([m, n, o]), r.index_ref([m, n, o])); | ||
} | ||
} | ||
} | ||
} | ||
|
||
pub(super) fn accum4d<A, L, R, const M: usize, const N: usize, const O: usize, const P: usize>( | ||
l: &mut L, | ||
r: &R, | ||
) where | ||
L: IndexMut<Index = [usize; 4]>, | ||
R: IndexRef<Index = [usize; 4], Element = L::Element>, | ||
A: Accumulator<L::Element>, | ||
{ | ||
for m in 0..M { | ||
for n in 0..N { | ||
for o in 0..O { | ||
for p in 0..P { | ||
A::accum(l.index_mut([m, n, o, p]), r.index_ref([m, n, o, p])); | ||
} | ||
} | ||
} | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
use crate::arrays::{Axes2, Axes3, Axes4, Axis}; | ||
use std::marker::PhantomData; | ||
|
||
/// Broadcasts `&'a T` along `Axes` to enable indexing as a higher dimensional array. | ||
pub(super) struct BroadcastRef<'a, T, Axes>(pub &'a T, PhantomData<*const Axes>); | ||
coreylowman marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
impl<'a, T, Axes> BroadcastRef<'a, T, Axes> { | ||
pub fn new(t: &'a T) -> Self { | ||
Self(t, PhantomData) | ||
} | ||
} | ||
|
||
/// Broadcasts `&'a mut T` along `Axes` to enable indexing as a higher dimensional array. | ||
pub(super) struct BroadcastMut<'a, T, Axes>(pub &'a mut T, PhantomData<*const Axes>); | ||
|
||
impl<'a, T, Axes> BroadcastMut<'a, T, Axes> { | ||
pub fn new(t: &'a mut T) -> Self { | ||
Self(t, PhantomData) | ||
} | ||
} | ||
|
||
/// Index to get a `&Self::Element`. | ||
pub(super) trait IndexRef { | ||
type Index; | ||
type Element; | ||
fn index_ref(&self, i: Self::Index) -> &Self::Element; | ||
} | ||
|
||
/// Index to get a `&mut Self::Element`. | ||
pub(super) trait IndexMut { | ||
type Index; | ||
type Element; | ||
fn index_mut(&mut self, i: Self::Index) -> &mut Self::Element; | ||
} | ||
|
||
impl<const M: usize> IndexRef for [f32; M] { | ||
type Index = usize; | ||
type Element = f32; | ||
fn index_ref(&self, i: Self::Index) -> &Self::Element { | ||
&self[i] | ||
} | ||
} | ||
|
||
impl<const M: usize> IndexMut for [f32; M] { | ||
type Index = usize; | ||
type Element = f32; | ||
fn index_mut(&mut self, i: Self::Index) -> &mut Self::Element { | ||
&mut self[i] | ||
} | ||
} | ||
|
||
impl<const M: usize, const N: usize> IndexRef for [[f32; N]; M] { | ||
type Index = [usize; 2]; | ||
type Element = f32; | ||
fn index_ref(&self, i: Self::Index) -> &Self::Element { | ||
&self[i[0]][i[1]] | ||
} | ||
} | ||
|
||
impl<const M: usize, const N: usize> IndexMut for [[f32; N]; M] { | ||
type Index = [usize; 2]; | ||
type Element = f32; | ||
fn index_mut(&mut self, i: Self::Index) -> &mut Self::Element { | ||
&mut self[i[0]][i[1]] | ||
} | ||
} | ||
|
||
impl<const M: usize, const N: usize, const O: usize> IndexRef for [[[f32; O]; N]; M] { | ||
type Index = [usize; 3]; | ||
type Element = f32; | ||
fn index_ref(&self, i: Self::Index) -> &Self::Element { | ||
&self[i[0]][i[1]][i[2]] | ||
} | ||
} | ||
|
||
impl<const M: usize, const N: usize, const O: usize> IndexMut for [[[f32; O]; N]; M] { | ||
type Index = [usize; 3]; | ||
type Element = f32; | ||
fn index_mut(&mut self, i: Self::Index) -> &mut Self::Element { | ||
&mut self[i[0]][i[1]][i[2]] | ||
} | ||
} | ||
|
||
impl<const M: usize, const N: usize, const O: usize, const P: usize> IndexRef | ||
for [[[[f32; P]; O]; N]; M] | ||
{ | ||
type Index = [usize; 4]; | ||
type Element = f32; | ||
fn index_ref(&self, i: Self::Index) -> &Self::Element { | ||
&self[i[0]][i[1]][i[2]][i[3]] | ||
} | ||
} | ||
|
||
impl<const M: usize, const N: usize, const O: usize, const P: usize> IndexMut | ||
for [[[[f32; P]; O]; N]; M] | ||
{ | ||
type Index = [usize; 4]; | ||
type Element = f32; | ||
fn index_mut(&mut self, i: Self::Index) -> &mut Self::Element { | ||
&mut self[i[0]][i[1]][i[2]][i[3]] | ||
} | ||
} | ||
|
||
macro_rules! impl_bcast { | ||
($ArrTy:ty, [$($Idx:expr),*], $AxisTy:ty, $IdxTy:ty, {$($CVars:tt),*}) => { | ||
impl<'a, $(const $CVars: usize, )*> IndexRef for BroadcastRef<'a, $ArrTy, $AxisTy> { | ||
type Index = $IdxTy; | ||
type Element = f32; | ||
#[allow(unused_variables)] | ||
fn index_ref(&self, i: Self::Index) -> &Self::Element { | ||
&self.0 $([i[$Idx]])* | ||
} | ||
} | ||
impl<'a, $(const $CVars: usize, )*> IndexMut for BroadcastMut<'a, $ArrTy, $AxisTy> { | ||
type Index = $IdxTy; | ||
type Element = f32; | ||
#[allow(unused_variables)] | ||
fn index_mut(&mut self, i: Self::Index) -> &mut Self::Element { | ||
&mut self.0 $([i[$Idx]])* | ||
} | ||
} | ||
}; | ||
} | ||
|
||
// 0d -> nd | ||
impl_bcast!(f32, [], Axis<-1>, usize, {}); | ||
impl_bcast!(f32, [], Axis<0>, usize, {}); | ||
impl_bcast!(f32, [], Axes2<0, 1>, [usize; 2], {}); | ||
impl_bcast!(f32, [], Axes3<0, 1, 2>, [usize; 3], {}); | ||
impl_bcast!(f32, [], Axes4<0, 1, 2, 3>, [usize; 4], {}); | ||
|
||
// 1d -> 2d | ||
impl_bcast!([f32; M], [0], Axis<-1>, [usize; 2], { M }); | ||
impl_bcast!([f32; M], [0], Axis<1>, [usize; 2], { M }); | ||
impl_bcast!([f32; M], [1], Axis<0>, [usize; 2], { M }); | ||
|
||
// 1d -> 3d | ||
impl_bcast!([f32; M], [2], Axes2<0, 1>, [usize; 3], { M }); | ||
impl_bcast!([f32; M], [1], Axes2<0, 2>, [usize; 3], { M }); | ||
impl_bcast!([f32; M], [0], Axes2<1, 2>, [usize; 3], { M }); | ||
|
||
// 1d -> 4d | ||
impl_bcast!([f32; M], [3], Axes3<0, 1, 2>, [usize; 4], { M }); | ||
impl_bcast!([f32; M], [2], Axes3<0, 1, 3>, [usize; 4], { M }); | ||
impl_bcast!([f32; M], [1], Axes3<0, 2, 3>, [usize; 4], { M }); | ||
impl_bcast!([f32; M], [0], Axes3<1, 2, 3>, [usize; 4], { M }); | ||
|
||
// 2d -> 3d | ||
impl_bcast!([[f32; N]; M], [0, 1], Axis<-1>, [usize; 3], {M, N}); | ||
impl_bcast!([[f32; N]; M], [0, 1], Axis<2>, [usize; 3], {M, N}); | ||
impl_bcast!([[f32; N]; M], [0, 2], Axis<1>, [usize; 3], {M, N}); | ||
impl_bcast!([[f32; N]; M], [1, 2], Axis<0>, [usize; 3], {M, N}); | ||
|
||
// 2d -> 4d | ||
impl_bcast!([[f32; N]; M], [2, 3], Axes2<0, 1>, [usize; 4], {M, N}); | ||
impl_bcast!([[f32; N]; M], [1, 3], Axes2<0, 2>, [usize; 4], {M, N}); | ||
impl_bcast!([[f32; N]; M], [1, 2], Axes2<0, 3>, [usize; 4], {M, N}); | ||
impl_bcast!([[f32; N]; M], [0, 3], Axes2<1, 2>, [usize; 4], {M, N}); | ||
impl_bcast!([[f32; N]; M], [0, 2], Axes2<1, 3>, [usize; 4], {M, N}); | ||
impl_bcast!([[f32; N]; M], [0, 1], Axes2<2, 3>, [usize; 4], {M, N}); | ||
|
||
// 3d -> 4d | ||
impl_bcast!([[[f32; O]; N]; M], [0, 1, 2], Axis<-1>, [usize; 4], {M, N, O}); | ||
impl_bcast!([[[f32; O]; N]; M], [0, 1, 2], Axis<3>, [usize; 4], {M, N, O}); | ||
impl_bcast!([[[f32; O]; N]; M], [0, 1, 3], Axis<2>, [usize; 4], {M, N, O}); | ||
impl_bcast!([[[f32; O]; N]; M], [0, 2, 3], Axis<1>, [usize; 4], {M, N, O}); | ||
impl_bcast!([[[f32; O]; N]; M], [1, 2, 3], Axis<0>, [usize; 4], {M, N, O}); |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.