+
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 122 additions & 79 deletions src/devices/select.rs
Original file line number Diff line number Diff line change
@@ -1,107 +1,140 @@
//! Implementations of selecting either 1 or Z elements from an axis of an nd array.
//!
//! # Implementation Details
//! There are three cases to handle:
//! There are four cases to handle:
//!
//! ## Selecting 1 element from the 0th axis
//! ## Selecting 1 element from the 0th axis [select_modes::Index]
//!
//! Just index into input using the single index and assign to output.
//!
//! ## Selecting Z elements from the 0th axis
//! ## Selecting Z elements from the 0th axis [select_modes::Index]
//!
//! Just index into input for each index and assing to `output[z]`
//!
//! ## Selecting either 1 or Z elements from a non-zero axis
//! ## Selecting either 1 or Z elements from a non-zero axis [select_modes::Recurse]
//!
//! Then all three arrays with have the same dimension as the 0th axis.
//! Do a for loop over the 0th axis and recurse!
//!
//! ## Broadcasted select [select_modes::Broadcast]
//!
//! In this case only the indices & output are indexed. The input is broadcasted by
//! not indexing into it.

use super::{Cpu, ForEachElement};
use crate::arrays::CountElements;

/// Select values from `T` using `Indices` and producing `R` along a single `AXIS`.
pub trait SelectAlongAxis<T: CountElements, Indices, R: CountElements, const AXIS: isize> {
/// Used to disambiguate trait implementations. Callees
/// must specify what kind of selection is occurring.
pub(crate) mod select_modes {
use std::marker::PhantomData;

/// Select the current axis.
pub struct Index;

/// Recurse the current axis.
pub struct Recurse<M>(PhantomData<*const M>);

/// Broadcast the current axis of input and recurse the indices.
pub struct Broadcast<M>(PhantomData<*const M>);
}

use select_modes::{Broadcast, Index, Recurse};

pub(crate) type SelectAx0 = select_modes::Index;
pub(crate) type SelectAx1 = select_modes::Recurse<SelectAx0>;
pub(crate) type SelectAx2 = select_modes::Recurse<SelectAx1>;
pub(crate) type SelectAx3 = select_modes::Recurse<SelectAx2>;
pub(crate) type BSelectAx1 = select_modes::Broadcast<SelectAx0>;

/// Select values from `T` using indices `I`. `Mode` is used to disambiguate the impl.
pub trait DeviceSelect<T, I, Mode> {
type Result;

/// Equivalent to psuedocode `out = inp[indices]`
fn select_axis(inp: &T, indices: &Indices, out: &mut R);
fn select_axis(inp: &T, indices: &I, out: &mut Self::Result);

/// `inp[indices] += out`
fn select_add(inp: &mut T, indices: &Indices, out: &R);
fn select_add(inp: &mut T, indices: &I, out: &Self::Result);
}

macro_rules! select_01 {
($Axis:expr, $SrcTy:tt, $DstTy:tt, {$($Dims:tt),*}) => {
impl<$(const $Dims: usize),*> SelectAlongAxis<$SrcTy, usize, $DstTy, $Axis> for Cpu {
fn select_axis(inp: &$SrcTy, indices: &usize, out: &mut $DstTy) {
// Select 1 element from 0th axis.
impl<T, const M: usize> DeviceSelect<[T; M], usize, Index> for Cpu
where
Self: ForEachElement<T>,
T: Copy + CountElements,
T::Dtype: for<'a> std::ops::AddAssign<&'a T::Dtype>,
{
type Result = T;

fn select_axis(inp: &[T; M], indices: &usize, out: &mut Self::Result) {
*out = inp[*indices];
}
fn select_add(inp: &mut $SrcTy, indices: &usize, out: &$DstTy) {

fn select_add(inp: &mut [T; M], indices: &usize, out: &Self::Result) {
Self::foreach_mr(&mut inp[*indices], out, &mut |a, b| *a += b);
}
}
};
}

macro_rules! select_0z {
($Axis:expr, $SrcTy:tt, $DstTy:tt, {$($Dims:tt),*}) => {
impl<$(const $Dims: usize),*> SelectAlongAxis<$SrcTy, [usize; Z], $DstTy, $Axis> for Cpu {
fn select_axis(inp: &$SrcTy, indices: &[usize; Z], out: &mut $DstTy) {
// Select Z elements from 0th axis.
impl<T, const M: usize, const Z: usize> DeviceSelect<[T; M], [usize; Z], Index> for Cpu
where
Self: ForEachElement<T>,
T: Copy + CountElements,
T::Dtype: for<'a> std::ops::AddAssign<&'a T::Dtype>,
{
type Result = [T; Z];

fn select_axis(inp: &[T; M], indices: &[usize; Z], out: &mut Self::Result) {
for z in 0..Z {
out[z] = inp[indices[z]];
}
}
fn select_add(inp: &mut $SrcTy, indices: &[usize; Z], out: &$DstTy) {
fn select_add(inp: &mut [T; M], indices: &[usize; Z], out: &Self::Result) {
for z in 0..Z {
Self::foreach_mr(&mut inp[indices[z]], &out[z], &mut |a, b| *a += b);
}
}
}
};
}

macro_rules! select_nz {
($Axis:expr, $SrcTy:tt, $IndTy:tt, $DstTy:tt, {$($Dims:tt),*}) => {
impl<$(const $Dims: usize),*> SelectAlongAxis<$SrcTy, $IndTy, $DstTy, $Axis> for Cpu {
fn select_axis(inp: &$SrcTy, indices: &$IndTy, out: &mut $DstTy) {
// Select elements from non-zero axis
impl<T, I, const M: usize, SubMode> DeviceSelect<[T; M], [I; M], Recurse<SubMode>> for Cpu
where
Self: DeviceSelect<T, I, SubMode>,
{
type Result = [<Self as DeviceSelect<T, I, SubMode>>::Result; M];

fn select_axis(inp: &[T; M], indices: &[I; M], out: &mut Self::Result) {
for m in 0..M {
Self::select_axis(&inp[m], &indices[m], &mut out[m]);
}
}
fn select_add(inp: &mut $SrcTy, indices: &$IndTy, out: &$DstTy) {

fn select_add(inp: &mut [T; M], indices: &[I; M], out: &Self::Result) {
for m in 0..M {
Self::select_add(&mut inp[m], &indices[m], &out[m]);
}
}
}
};
}

// 1d
select_01!(-1, [f32; M], f32, { M });
select_0z!(-1, [f32; M], [f32; Z], {M, Z});

// 2d
select_01!(0, [[f32; N]; M], [f32; N], {M, N});
select_0z!(0, [[f32; N]; M], [[f32; N]; Z], {M, N, Z});
select_nz!(-1, [[f32; N]; M], [usize; M], [f32; M], {M, N});
select_nz!(-1, [[f32; N]; M], [[usize; Z]; M], [[f32; Z]; M], {M, N, Z});

// 3d
select_01!(0, [[[f32; O]; N]; M], [[f32; O]; N], {M, N, O});
select_0z!(0, [[[f32; O]; N]; M], [[[f32; O]; N]; Z], {M, N, O, Z});
select_nz!(1, [[[f32; O]; N]; M], [usize; M], [[f32; O]; M], {M, N, O});
select_nz!(1, [[[f32; O]; N]; M], [[usize; Z]; M], [[[f32; O]; Z]; M], {M, N, O, Z});
select_nz!(-1, [[[f32; O]; N]; M], [[usize; N]; M], [[f32; N]; M], {M, N, O});
select_nz!(-1, [[[f32; O]; N]; M], [[[usize; Z]; N]; M], [[[f32; Z]; N]; M], {M, N, O, Z});

// 4d
select_01!(0, [[[[f32; P]; O]; N]; M], [[[f32; P]; O]; N], {M, N, O, P});
select_0z!(0, [[[[f32; P]; O]; N]; M], [[[[f32; P]; O]; N]; Z], {M, N, O, P, Z});
select_nz!(1, [[[[f32; P]; O]; N]; M], [usize; M], [[[f32; P]; O]; M], {M, N, O, P});
select_nz!(1, [[[[f32; P]; O]; N]; M], [[usize; Z]; M], [[[[f32; P]; O]; Z]; M], {M, N, O, P, Z});
select_nz!(2, [[[[f32; P]; O]; N]; M], [[usize; N]; M], [[[f32; P]; N]; M], {M, N, O, P});
select_nz!(2, [[[[f32; P]; O]; N]; M], [[[usize; Z]; N]; M], [[[[f32; P]; Z]; N]; M], {M, N, O, P, Z});
select_nz!(-1, [[[[f32; P]; O]; N]; M], [[[usize; O]; N]; M], [[[f32; O]; N]; M], {M, N, O, P});
select_nz!(-1, [[[[f32; P]; O]; N]; M], [[[[usize; Z]; O]; N]; M], [[[[f32; Z]; O]; N]; M], {M, N, O, P, Z});
// Broadcast select elements from non-zero axis.
impl<T, I, const M: usize, SubMode> DeviceSelect<T, [I; M], Broadcast<SubMode>> for Cpu
where
Self: DeviceSelect<T, I, SubMode>,
{
type Result = [<Self as DeviceSelect<T, I, SubMode>>::Result; M];

fn select_axis(inp: &T, indices: &[I; M], out: &mut Self::Result) {
for m in 0..M {
Self::select_axis(inp, &indices[m], &mut out[m]);
}
}
fn select_add(inp: &mut T, indices: &[I; M], out: &Self::Result) {
for m in 0..M {
Self::select_add(inp, &indices[m], &out[m]);
}
}
}

#[cfg(test)]
mod tests {
Expand All @@ -110,17 +143,17 @@ mod tests {

#[test]
fn test_select_1d_0() {
let a = [1.0, 2.0, 3.0];
let mut b = ZeroElements::ZEROS;
let a: [f32; 3] = [1.0, 2.0, 3.0];
let mut b: f32 = ZeroElements::ZEROS;
Cpu::select_axis(&a, &1, &mut b);
assert_eq!(b, 2.0);
}

#[test]
fn test_select_1d_0z() {
let a = [1.0, 2.0, 3.0];
let mut b = ZeroElements::ZEROS;
Cpu::select_axis(&a, &[0, 1, 2, 2, 1, 0], &mut b);
let a: [f32; 3] = [1.0f32, 2.0, 3.0];
let mut b: [f32; 6] = ZeroElements::ZEROS;
<Cpu as DeviceSelect<_, _, Index>>::select_axis(&a, &[0, 1, 2, 2, 1, 0], &mut b);
assert_eq!(b, [1.0, 2.0, 3.0, 3.0, 2.0, 1.0]);
}

Expand All @@ -129,41 +162,51 @@ mod tests {
#[test]
fn test_select_2d_0() {
let a = A_2D;
let mut b = ZeroElements::ZEROS;
let mut b: [f32; 3] = ZeroElements::ZEROS;
Cpu::select_axis(&a, &0, &mut b);
assert_eq!(b, [1.0, 2.0, 3.0]);
}

#[test]
fn test_select_2d_0z() {
let a = A_2D;
let mut b = ZeroElements::ZEROS;
Cpu::select_axis(&a, &[0, 0, 1], &mut b);
let mut b: [[f32; 3]; 3] = ZeroElements::ZEROS;
<Cpu as DeviceSelect<_, _, Index>>::select_axis(&a, &[0, 0, 1], &mut b);
assert_eq!(b, [a[0], a[0], a[1]]);
}

#[test]
fn test_select_2d_1() {
let a = A_2D;
let mut b = ZeroElements::ZEROS;
<Cpu as SelectAlongAxis<_, _, _, -1>>::select_axis(&a, &[0, 1], &mut b);
let mut b: [f32; 2] = ZeroElements::ZEROS;
<Cpu as DeviceSelect<_, _, Recurse<Index>>>::select_axis(&a, &[0, 1], &mut b);
assert_eq!(b, [1.0, 5.0]);
}

#[test]
fn test_select_2d_1z() {
let a = A_2D;
let mut b = ZeroElements::ZEROS;
Cpu::select_axis(&a, &[[0, 2], [1, 1]], &mut b);
let mut b: [[f32; 2]; 2] = ZeroElements::ZEROS;
<Cpu as DeviceSelect<_, _, Recurse<Index>>>::select_axis(&a, &[[0, 2], [1, 1]], &mut b);
assert_eq!(b, [[1.0, 3.0], [5.0, 5.0]]);
}

#[test]
fn test_select_broadcast_2d() {
let a = [[1.0], [2.0]];
let i: [[usize; 3]; 4] = [[0, 1, 0], [1, 1, 1], [0, 0, 0], [1, 0, 1]];
let mut b: [[[f32; 1]; 3]; 4] = ZeroElements::ZEROS;
<Cpu as DeviceSelect<_, _, Broadcast<Index>>>::select_axis(&a, &i, &mut b);
#[rustfmt::skip]
assert_eq!(b, [[[1.], [2.], [1.]], [[2.], [2.], [2.]], [[1.], [1.], [1.]], [[2.], [1.], [2.]]]);
}

#[test]
fn test_select_add_2d() {
let mut a = [[0.0; 3]; 2];
let b = [[1.0, 3.0], [5.0, 5.0]];
let i = [[0, 2], [1, 1]];
Cpu::select_add(&mut a, &i, &b);
<Cpu as DeviceSelect<_, _, Recurse<Index>>>::select_add(&mut a, &i, &b);
assert_eq!(a, [[1.0, 0.0, 3.0], [0.0, 10.0, 0.0]]);
}

Expand All @@ -177,40 +220,40 @@ mod tests {
#[test]
fn test_select_3d_0() {
let a = A_3D;
let mut b = ZeroElements::ZEROS;
let mut b: [[f32; 3]; 2] = ZeroElements::ZEROS;
Cpu::select_axis(&a, &0, &mut b);
assert_eq!(b, A_3D[0]);
}

#[test]
fn test_select_3d_0z() {
let a = A_3D;
let mut b = ZeroElements::ZEROS;
Cpu::select_axis(&a, &[0, 0, 1, 2, 3, 3], &mut b);
let mut b: [[[f32; 3]; 2]; 6] = ZeroElements::ZEROS;
<Cpu as DeviceSelect<_, _, Index>>::select_axis(&a, &[0, 0, 1, 2, 3, 3], &mut b);
assert_eq!(b, [A_3D[0], A_3D[0], A_3D[1], A_3D[2], A_3D[3], A_3D[3]]);
}

#[test]
fn test_select_3d_1() {
let a = A_3D;
let mut b = ZeroElements::ZEROS;
<Cpu as SelectAlongAxis<_, _, _, 1>>::select_axis(&a, &[0, 0, 1, 1], &mut b);
let mut b: [[f32; 3]; 4] = ZeroElements::ZEROS;
<Cpu as DeviceSelect<_, _, Recurse<Index>>>::select_axis(&a, &[0, 0, 1, 1], &mut b);
assert_eq!(b, [A_3D[0][0], A_3D[1][0], A_3D[2][1], A_3D[3][1]]);
}

#[test]
fn test_select_3d_1z() {
let a = A_3D;
let mut b = ZeroElements::ZEROS;
<Cpu as SelectAlongAxis<_, _, _, 1>>::select_axis(&a, &[[0], [0], [1], [1]], &mut b);
let mut b: [[[f32; 3]; 1]; 4] = ZeroElements::ZEROS;
<Cpu as DeviceSelect<_, _, Recurse<Index>>>::select_axis(&a, &[[0], [0], [1], [1]], &mut b);
assert_eq!(b, [[A_3D[0][0]], [A_3D[1][0]], [A_3D[2][1]], [A_3D[3][1]]]);
}

#[test]
fn test_select_3d_2() {
let a = A_3D;
let mut b = ZeroElements::ZEROS;
<Cpu as SelectAlongAxis<_, _, _, -1>>::select_axis(
let mut b: [[f32; 2]; 4] = ZeroElements::ZEROS;
<Cpu as DeviceSelect<_, _, Recurse<Recurse<Index>>>>::select_axis(
&a,
&[[1, 0], [0, 1], [0, 0], [1, 1]],
&mut b,
Expand All @@ -230,7 +273,7 @@ mod tests {
fn test_select_3d_2z() {
let a = A_3D;
let mut b: [[[f32; 1]; 2]; 4] = ZeroElements::ZEROS;
<Cpu as SelectAlongAxis<_, _, _, -1>>::select_axis(
<Cpu as DeviceSelect<_, _, Recurse<Recurse<Index>>>>::select_axis(
&a,
&[[[1], [0]], [[0], [1]], [[0], [0]], [[1], [1]]],
&mut b,
Expand Down
Loading
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载