+
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions src/arrays.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,20 @@ pub type Axes3<const I: isize, const J: isize, const K: isize> = (Axis<I>, Axis<
pub type Axes4<const I: isize, const J: isize, const K: isize, const L: isize> =
(Axis<I>, Axis<J>, Axis<K>, Axis<L>);

/// Five axes known at compile time.
pub type Axes5<const I: isize, const J: isize, const K: isize, const L: isize, const M: isize> =
(Axis<I>, Axis<J>, Axis<K>, Axis<L>, Axis<M>);

/// Six axes known at compile time.
pub type Axes6<
const I: isize,
const J: isize,
const K: isize,
const L: isize,
const M: isize,
const N: isize,
> = (Axis<I>, Axis<J>, Axis<K>, Axis<L>, Axis<M>, Axis<N>);

/// Represents all available axes on a tensor.
pub struct AllAxes;

Expand Down Expand Up @@ -93,6 +107,17 @@ impl_has_axis!([[[[f32; P]; O]; N]; M], 0, M, {M, N, O, P});
impl_has_axis!([[[[f32; P]; O]; N]; M], 1, N, {M, N, O, P});
impl_has_axis!([[[[f32; P]; O]; N]; M], 2, O, {M, N, O, P});
impl_has_axis!([[[[f32; P]; O]; N]; M], 3, P, {M, N, O, P});
impl_has_axis!([[[[[f32; Q]; P]; O]; N]; M], 0, M, {M, N, O, P, Q});
impl_has_axis!([[[[[f32; Q]; P]; O]; N]; M], 1, N, {M, N, O, P, Q});
impl_has_axis!([[[[[f32; Q]; P]; O]; N]; M], 2, O, {M, N, O, P, Q});
impl_has_axis!([[[[[f32; Q]; P]; O]; N]; M], 3, P, {M, N, O, P, Q});
impl_has_axis!([[[[[f32; Q]; P]; O]; N]; M], 4, Q, {M, N, O, P, Q});
impl_has_axis!([[[[[[f32; R]; Q]; P]; O]; N]; M], 0, M, {M, N, O, P, Q, R});
impl_has_axis!([[[[[[f32; R]; Q]; P]; O]; N]; M], 1, N, {M, N, O, P, Q, R});
impl_has_axis!([[[[[[f32; R]; Q]; P]; O]; N]; M], 2, O, {M, N, O, P, Q, R});
impl_has_axis!([[[[[[f32; R]; Q]; P]; O]; N]; M], 3, P, {M, N, O, P, Q, R});
impl_has_axis!([[[[[[f32; R]; Q]; P]; O]; N]; M], 4, Q, {M, N, O, P, Q, R});
impl_has_axis!([[[[[[f32; R]; Q]; P]; O]; N]; M], 5, R, {M, N, O, P, Q, R});

impl<T: CountElements> HasAxes<AllAxes> for T {
const SIZE: usize = T::NUM_ELEMENTS;
Expand Down Expand Up @@ -125,6 +150,43 @@ where
* <T as HasAxes<Axis<L>>>::SIZE;
}

impl<T, const I: isize, const J: isize, const K: isize, const L: isize, const M: isize>
HasAxes<Axes5<I, J, K, L, M>> for T
where
T: HasAxes<Axis<I>> + HasAxes<Axis<J>> + HasAxes<Axis<K>> + HasAxes<Axis<L>> + HasAxes<Axis<M>>,
{
const SIZE: usize = <T as HasAxes<Axis<I>>>::SIZE
* <T as HasAxes<Axis<J>>>::SIZE
* <T as HasAxes<Axis<K>>>::SIZE
* <T as HasAxes<Axis<L>>>::SIZE
* <T as HasAxes<Axis<M>>>::SIZE;
}

impl<
T,
const I: isize,
const J: isize,
const K: isize,
const L: isize,
const M: isize,
const N: isize,
> HasAxes<Axes6<I, J, K, L, M, N>> for T
where
T: HasAxes<Axis<I>>
+ HasAxes<Axis<J>>
+ HasAxes<Axis<K>>
+ HasAxes<Axis<L>>
+ HasAxes<Axis<M>>
+ HasAxes<Axis<N>>,
{
const SIZE: usize = <T as HasAxes<Axis<I>>>::SIZE
* <T as HasAxes<Axis<J>>>::SIZE
* <T as HasAxes<Axis<K>>>::SIZE
* <T as HasAxes<Axis<L>>>::SIZE
* <T as HasAxes<Axis<M>>>::SIZE
* <T as HasAxes<Axis<N>>>::SIZE;
}

/// Holds an axis that represents the last (or right most) axis.
pub trait HasLastAxis {
type LastAxis;
Expand Down Expand Up @@ -153,6 +215,24 @@ impl<const M: usize, const N: usize, const O: usize, const P: usize> HasLastAxis
type LastAxis = Axis<3>;
const SIZE: usize = P;
}
impl<const M: usize, const N: usize, const O: usize, const P: usize, const Q: usize> HasLastAxis
for [[[[[f32; Q]; P]; O]; N]; M]
{
type LastAxis = Axis<4>;
const SIZE: usize = Q;
}
impl<
const M: usize,
const N: usize,
const O: usize,
const P: usize,
const Q: usize,
const R: usize,
> HasLastAxis for [[[[[[f32; R]; Q]; P]; O]; N]; M]
{
type LastAxis = Axis<5>;
const SIZE: usize = R;
}

/// Something that has compile time known zero values.
pub trait ZeroElements {
Expand Down
14 changes: 14 additions & 0 deletions src/devices/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,20 @@ impl<const M: usize, const N: usize, const O: usize, const P: usize> Device<[[[[
for Cpu
{
}
impl<const M: usize, const N: usize, const O: usize, const P: usize, const Q: usize>
Device<[[[[[f32; Q]; P]; O]; N]; M]> for Cpu
{
}
impl<
const M: usize,
const N: usize,
const O: usize,
const P: usize,
const Q: usize,
const R: usize,
> Device<[[[[[[f32; R]; Q]; P]; O]; N]; M]> for Cpu
{
}

/// A [crate::arrays::HasArrayType] that has a [Device] for its [crate::arrays::HasArrayType::Array]
pub trait HasDevice: crate::arrays::HasArrayType {
Expand Down
2 changes: 2 additions & 0 deletions src/tensor/impl_default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ tensor_impl!(Tensor1D, [M]);
tensor_impl!(Tensor2D, [M, N]);
tensor_impl!(Tensor3D, [M, N, O]);
tensor_impl!(Tensor4D, [M, N, O, P]);
tensor_impl!(Tensor5D, [M, N, O, P, Q]);
tensor_impl!(Tensor6D, [M, N, O, P, Q, R]);
6 changes: 6 additions & 0 deletions src/tensor/impl_has_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,9 @@ tensor_impl!(Tensor1D, [M], [f32; M]);
tensor_impl!(Tensor2D, [M, N], [[f32; N]; M]);
tensor_impl!(Tensor3D, [M, N, O], [[[f32; O]; N]; M]);
tensor_impl!(Tensor4D, [M, N, O, P], [[[[f32; P]; O]; N]; M]);
tensor_impl!(Tensor5D, [M, N, O, P, Q], [[[[[f32; Q]; P]; O]; N]; M]);
tensor_impl!(
Tensor6D,
[M, N, O, P, Q, R],
[[[[[[f32; R]; Q]; P]; O]; N]; M]
);
2 changes: 2 additions & 0 deletions src/tensor/impl_has_device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ tensor_impl!(Tensor1D, [M]);
tensor_impl!(Tensor2D, [M, N]);
tensor_impl!(Tensor3D, [M, N, O]);
tensor_impl!(Tensor4D, [M, N, O, P]);
tensor_impl!(Tensor5D, [M, N, O, P, Q]);
tensor_impl!(Tensor6D, [M, N, O, P, Q, R]);
2 changes: 2 additions & 0 deletions src/tensor/impl_has_unique_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@ tensor_impl!(Tensor1D, [M]);
tensor_impl!(Tensor2D, [M, N]);
tensor_impl!(Tensor3D, [M, N, O]);
tensor_impl!(Tensor4D, [M, N, O, P]);
tensor_impl!(Tensor5D, [M, N, O, P, Q]);
tensor_impl!(Tensor6D, [M, N, O, P, Q, R]);
2 changes: 2 additions & 0 deletions src/tensor/impl_put_tape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,5 @@ tensor_impl!(Tensor1D, [M]);
tensor_impl!(Tensor2D, [M, N]);
tensor_impl!(Tensor3D, [M, N, O]);
tensor_impl!(Tensor4D, [M, N, O, P]);
tensor_impl!(Tensor5D, [M, N, O, P, Q]);
tensor_impl!(Tensor6D, [M, N, O, P, Q, R]);
4 changes: 3 additions & 1 deletion src/tensor/impl_randomize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ macro_rules! tensor_impl {
($typename:ident, [$($Vs:tt),*]) => {
impl<$(const $Vs: usize, )* H> Randomize<f32> for $typename<$($Vs, )* H> {
/// Fills `self.mut_data()` with data from the distribution `D`
fn randomize<R: Rng, D: Distribution<f32>>(&mut self, rng: &mut R, dist: &D) {
fn randomize<RNG: Rng, D: Distribution<f32>>(&mut self, rng: &mut RNG, dist: &D) {
<Self as HasDevice>::Device::fill(self.mut_data(), &mut |v| *v = dist.sample(rng));
}
}
Expand All @@ -23,6 +23,8 @@ tensor_impl!(Tensor1D, [M]);
tensor_impl!(Tensor2D, [M, N]);
tensor_impl!(Tensor3D, [M, N, O]);
tensor_impl!(Tensor4D, [M, N, O, P]);
tensor_impl!(Tensor5D, [M, N, O, P, Q]);
tensor_impl!(Tensor6D, [M, N, O, P, Q, R]);

#[cfg(test)]
mod tests {
Expand Down
2 changes: 2 additions & 0 deletions src/tensor/impl_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ tensor_impl!(Tensor1D, [M]);
tensor_impl!(Tensor2D, [M, N]);
tensor_impl!(Tensor3D, [M, N, O]);
tensor_impl!(Tensor4D, [M, N, O, P]);
tensor_impl!(Tensor5D, [M, N, O, P, Q]);
tensor_impl!(Tensor6D, [M, N, O, P, Q, R]);

#[cfg(test)]
mod tests {
Expand Down
2 changes: 2 additions & 0 deletions src/tensor/impl_tensor_creator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ tensor_impl!(Tensor1D, [M]);
tensor_impl!(Tensor2D, [M, N]);
tensor_impl!(Tensor3D, [M, N, O]);
tensor_impl!(Tensor4D, [M, N, O, P]);
tensor_impl!(Tensor5D, [M, N, O, P, Q]);
tensor_impl!(Tensor6D, [M, N, O, P, Q, R]);

#[cfg(test)]
mod tests {
Expand Down
2 changes: 2 additions & 0 deletions src/tensor/impl_trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ tensor_impl!(Tensor1D, [M]);
tensor_impl!(Tensor2D, [M, N]);
tensor_impl!(Tensor3D, [M, N, O]);
tensor_impl!(Tensor4D, [M, N, O, P]);
tensor_impl!(Tensor5D, [M, N, O, P, Q]);
tensor_impl!(Tensor6D, [M, N, O, P, Q, R]);

#[cfg(test)]
mod tests {
Expand Down
4 changes: 3 additions & 1 deletion src/tensor/into_tensor.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{Tensor0D, Tensor1D, Tensor2D, Tensor3D, Tensor4D, TensorCreator};
use super::{Tensor0D, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, Tensor6D, TensorCreator};
use std::boxed::Box;

/// Creates a tensor using the data based in. The return type is based
Expand Down Expand Up @@ -46,6 +46,8 @@ impl_into_tensor!([f32; M], Tensor1D<M>, { M });
impl_into_tensor!([[f32; N]; M], Tensor2D<M, N>, {M, N});
impl_into_tensor!([[[f32; O]; N]; M], Tensor3D<M, N, O>, {M, N, O});
impl_into_tensor!([[[[f32; P]; O]; N]; M], Tensor4D<M, N, O, P>, {M, N, O, P});
impl_into_tensor!([[[[[f32; Q]; P]; O]; N]; M], Tensor5D<M, N, O, P, Q>, {M, N, O, P, Q});
impl_into_tensor!([[[[[[f32; R]; Q]; P]; O]; N]; M], Tensor6D<M, N, O, P, Q, R>, {M, N, O, P, Q, R});

#[cfg(test)]
mod tests {
Expand Down
33 changes: 33 additions & 0 deletions src/tensor/structs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,36 @@ pub struct Tensor4D<const M: usize, const N: usize, const O: usize, const P: usi
pub(crate) data: std::sync::Arc<[[[[f32; P]; O]; N]; M]>,
pub(crate) tape: Tape,
}

/// A 5d [super::Tensor] with shape (M, N, O, P, Q). Backed by data `[[[[[f32; Q]; P]; O]; N]; M]`.
#[derive(Debug)]
#[allow(clippy::type_complexity)]
pub struct Tensor5D<
const M: usize,
const N: usize,
const O: usize,
const P: usize,
const Q: usize,
Tape = NoneTape,
> {
pub(crate) id: UniqueId,
pub(crate) data: std::sync::Arc<[[[[[f32; Q]; P]; O]; N]; M]>,
pub(crate) tape: Tape,
}

/// A 6d [super::Tensor] with shape (M, N, O, P, Q, R). Backed by data `[[[[[[f32; R]; Q]; P]; O]; N]; M]`.
#[derive(Debug)]
#[allow(clippy::type_complexity)]
pub struct Tensor6D<
const M: usize,
const N: usize,
const O: usize,
const P: usize,
const Q: usize,
const R: usize,
Tape = NoneTape,
> {
pub(crate) id: UniqueId,
pub(crate) data: std::sync::Arc<[[[[[[f32; R]; Q]; P]; O]; N]; M]>,
pub(crate) tape: Tape,
}
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载