+
Skip to content

Adding Unit and HasUnitType. Reducing bounds for Dtype #313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/shapes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ pub(crate) use same_numel::HasSameNumelAs;

pub use axes::{Axes2, Axes3, Axes4, Axes5, Axes6, Axis, HasAxes};
pub use shape::{Const, Dim};
pub use shape::{Dtype, HasDtype};
pub use shape::{Dtype, HasDtype, HasUnitType, Unit};
pub use shape::{HasShape, Rank0, Rank1, Rank2, Rank3, Rank4, Rank5, Rank6, Shape};
24 changes: 16 additions & 8 deletions src/shapes/shape.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
use super::{axes::*, ReduceShapeTo};

/// Represents a unit type, but no arithmetic.
pub trait Unit:
'static + Copy + Clone + Default + std::fmt::Debug + PartialOrd + Send + Sync
{
}
impl Unit for f32 {}
impl Unit for f64 {}
impl Unit for usize {}
impl Unit for bool {}

/// Represents something that has a [Unit].
pub trait HasUnitType {
type Unit: Unit;
}

/// Represents a data type or element of an array.
pub trait Dtype:
'static
+ Copy
+ Clone
+ Default
+ std::fmt::Debug
+ PartialOrd
+ Send
+ Sync
Unit
+ std::ops::Add<Self, Output = Self>
+ std::ops::Sub<Self, Output = Self>
+ std::ops::Mul<Self, Output = Self>
Expand Down
34 changes: 16 additions & 18 deletions src/tensor/cpu/allocate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::{sync::Arc, vec::Vec};

use super::{Cpu, CpuError, LendingIterator, StridedArray};

impl<S: Shape, E: Dtype> StridedArray<S, E> {
impl<S: Shape, E: Default + Clone> StridedArray<S, E> {
#[inline]
pub(crate) fn new(shape: S) -> Result<Self, CpuError> {
Self::try_new_with(shape, Default::default())
Expand Down Expand Up @@ -48,14 +48,14 @@ impl<S: Shape, E: Dtype> StridedArray<S, E> {
}
}

impl<E: Dtype> ZerosTensor<E> for Cpu {
impl<E: Unit> ZerosTensor<E> for Cpu {
fn try_zeros_like<S: HasShape>(&self, src: &S) -> Result<Tensor<S::Shape, E, Self>, Self::Err> {
let storage = StridedArray::try_new_with(*src.shape(), Default::default())?;
Ok(self.upgrade(storage))
}
}

impl<E: Dtype> ZeroFillStorage<E> for Cpu {
impl<E: Unit> ZeroFillStorage<E> for Cpu {
fn try_fill_with_zeros<S: Shape>(
&self,
storage: &mut Self::Storage<S, E>,
Expand Down Expand Up @@ -185,7 +185,7 @@ impl RandnFillStorage<f32> for Cpu {
}
}

impl<E: Dtype> CopySlice<E> for Cpu {
impl<E: Unit> CopySlice<E> for Cpu {
fn copy_from<S: Shape, T>(dst: &mut Tensor<S, E, Self, T>, src: &[E]) {
std::sync::Arc::make_mut(&mut dst.storage.data).copy_from_slice(src);
}
Expand All @@ -194,15 +194,15 @@ impl<E: Dtype> CopySlice<E> for Cpu {
}
}

impl<E: Dtype> TensorFromArray<E, Rank0, E> for Cpu {
impl<E: Unit> TensorFromArray<E, Rank0, E> for Cpu {
fn try_tensor(&self, src: E) -> Result<Tensor<Rank0, E, Self>, Self::Err> {
let mut storage: StridedArray<_, E> = StridedArray::new(Default::default())?;
storage[[]].clone_from(&src);
Ok(self.upgrade(storage))
}
}

impl<E: Dtype, const M: usize> TensorFromArray<[E; M], Rank1<M>, E> for Cpu {
impl<E: Unit, const M: usize> TensorFromArray<[E; M], Rank1<M>, E> for Cpu {
fn try_tensor(&self, src: [E; M]) -> Result<Tensor<Rank1<M>, E, Self>, Self::Err> {
let mut storage: StridedArray<Rank1<M>, E> = StridedArray::new(Default::default())?;
let mut iter = storage.iter_mut_with_index();
Expand All @@ -213,7 +213,7 @@ impl<E: Dtype, const M: usize> TensorFromArray<[E; M], Rank1<M>, E> for Cpu {
}
}

impl<E: Dtype, const M: usize> TensorFromArray<&[E; M], Rank1<M>, E> for Cpu {
impl<E: Unit, const M: usize> TensorFromArray<&[E; M], Rank1<M>, E> for Cpu {
fn try_tensor(&self, src: &[E; M]) -> Result<Tensor<Rank1<M>, E, Self>, Self::Err> {
let mut storage: StridedArray<Rank1<M>, E> = StridedArray::new(Default::default())?;
let mut iter = storage.iter_mut_with_index();
Expand All @@ -224,9 +224,7 @@ impl<E: Dtype, const M: usize> TensorFromArray<&[E; M], Rank1<M>, E> for Cpu {
}
}

impl<E: Dtype, const M: usize, const N: usize> TensorFromArray<[[E; N]; M], Rank2<M, N>, E>
for Cpu
{
impl<E: Unit, const M: usize, const N: usize> TensorFromArray<[[E; N]; M], Rank2<M, N>, E> for Cpu {
fn try_tensor(&self, src: [[E; N]; M]) -> Result<Tensor<Rank2<M, N>, E, Self>, Self::Err> {
let mut storage: StridedArray<Rank2<M, N>, E> = StridedArray::new(Default::default())?;
let mut iter = storage.iter_mut_with_index();
Expand All @@ -237,7 +235,7 @@ impl<E: Dtype, const M: usize, const N: usize> TensorFromArray<[[E; N]; M], Rank
}
}

impl<E: Dtype, const M: usize, const N: usize, const O: usize>
impl<E: Unit, const M: usize, const N: usize, const O: usize>
TensorFromArray<[[[E; O]; N]; M], Rank3<M, N, O>, E> for Cpu
{
fn try_tensor(
Expand All @@ -253,7 +251,7 @@ impl<E: Dtype, const M: usize, const N: usize, const O: usize>
}
}

impl<E: Dtype, const M: usize, const N: usize, const O: usize, const P: usize>
impl<E: Unit, const M: usize, const N: usize, const O: usize, const P: usize>
TensorFromArray<[[[[E; P]; O]; N]; M], Rank4<M, N, O, P>, E> for Cpu
{
fn try_tensor(
Expand All @@ -270,7 +268,7 @@ impl<E: Dtype, const M: usize, const N: usize, const O: usize, const P: usize>
}
}

impl<S: Shape, E: Dtype> AsVec for StridedArray<S, E> {
impl<S: Shape, E: Unit> AsVec for StridedArray<S, E> {
fn as_vec(&self) -> Vec<E> {
let mut out = Vec::with_capacity(self.shape.num_elements());
let mut iter = self.iter();
Expand All @@ -281,7 +279,7 @@ impl<S: Shape, E: Dtype> AsVec for StridedArray<S, E> {
}
}

impl<E: Dtype> AsArray for StridedArray<Rank0, E> {
impl<E: Unit> AsArray for StridedArray<Rank0, E> {
type Array = E;
fn array(&self) -> Self::Array {
let mut out: Self::Array = Default::default();
Expand All @@ -290,7 +288,7 @@ impl<E: Dtype> AsArray for StridedArray<Rank0, E> {
}
}

impl<E: Dtype, const M: usize> AsArray for StridedArray<Rank1<M>, E> {
impl<E: Unit, const M: usize> AsArray for StridedArray<Rank1<M>, E> {
type Array = [E; M];
fn array(&self) -> Self::Array {
let mut out: Self::Array = [Default::default(); M];
Expand All @@ -302,7 +300,7 @@ impl<E: Dtype, const M: usize> AsArray for StridedArray<Rank1<M>, E> {
}
}

impl<E: Dtype, const M: usize, const N: usize> AsArray for StridedArray<Rank2<M, N>, E> {
impl<E: Unit, const M: usize, const N: usize> AsArray for StridedArray<Rank2<M, N>, E> {
type Array = [[E; N]; M];
fn array(&self) -> Self::Array {
let mut out: Self::Array = [[Default::default(); N]; M];
Expand All @@ -316,7 +314,7 @@ impl<E: Dtype, const M: usize, const N: usize> AsArray for StridedArray<Rank2<M,
}
}

impl<E: Dtype, const M: usize, const N: usize, const O: usize> AsArray
impl<E: Unit, const M: usize, const N: usize, const O: usize> AsArray
for StridedArray<Rank3<M, N, O>, E>
{
type Array = [[[E; O]; N]; M];
Expand All @@ -330,7 +328,7 @@ impl<E: Dtype, const M: usize, const N: usize, const O: usize> AsArray
}
}

impl<E: Dtype, const M: usize, const N: usize, const O: usize, const P: usize> AsArray
impl<E: Unit, const M: usize, const N: usize, const O: usize, const P: usize> AsArray
for StridedArray<Rank4<M, N, O, P>, E>
{
type Array = [[[[E; P]; O]; N]; M];
Expand Down
12 changes: 8 additions & 4 deletions src/tensor/cpu/device.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::shapes::{Dtype, HasDtype, HasShape, Shape};
use crate::shapes::{Dtype, HasDtype, HasShape, HasUnitType, Shape, Unit};
use crate::tensor::storage_traits::*;
use rand::{rngs::StdRng, Rng, SeedableRng};
use std::{
Expand Down Expand Up @@ -35,8 +35,8 @@ impl Cpu {

/// The storage for the cpu device
#[derive(Debug, Clone)]
pub struct StridedArray<S: Shape, Elem> {
pub(crate) data: Arc<Vec<Elem>>,
pub struct StridedArray<S: Shape, E> {
pub(crate) data: Arc<Vec<E>>,
pub(crate) shape: S,
pub(crate) strides: S::Concrete,
}
Expand Down Expand Up @@ -66,6 +66,10 @@ impl<S: Shape, E> HasShape for StridedArray<S, E> {
}
}

impl<S: Shape, E: Unit> HasUnitType for StridedArray<S, E> {
type Unit = E;
}

impl<S: Shape, E: Dtype> HasDtype for StridedArray<S, E> {
type Dtype = E;
}
Expand All @@ -75,7 +79,7 @@ impl HasErr for Cpu {
}

impl DeviceStorage for Cpu {
type Storage<S: Shape, E: Dtype> = StridedArray<S, E>;
type Storage<S: Shape, E: Unit> = StridedArray<S, E>;

fn try_alloc_grad<S: Shape, E: Dtype>(
&self,
Expand Down
6 changes: 3 additions & 3 deletions src/tensor/cpu/index.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::device::StridedArray;
use crate::shapes::{Dtype, Shape};
use crate::shapes::Shape;
use std::sync::Arc;

fn index_to_i<S: Shape>(shape: &S, strides: &S::Concrete, index: S::Concrete) -> usize {
Expand All @@ -16,7 +16,7 @@ fn index_to_i<S: Shape>(shape: &S, strides: &S::Concrete, index: S::Concrete) ->
.sum()
}

impl<S: Shape, E: Dtype> std::ops::Index<S::Concrete> for StridedArray<S, E> {
impl<S: Shape, E> std::ops::Index<S::Concrete> for StridedArray<S, E> {
type Output = E;
#[inline(always)]
fn index(&self, index: S::Concrete) -> &Self::Output {
Expand All @@ -25,7 +25,7 @@ impl<S: Shape, E: Dtype> std::ops::Index<S::Concrete> for StridedArray<S, E> {
}
}

impl<S: Shape, E: Dtype> std::ops::IndexMut<S::Concrete> for StridedArray<S, E> {
impl<S: Shape, E: Clone> std::ops::IndexMut<S::Concrete> for StridedArray<S, E> {
#[inline(always)]
fn index_mut(&mut self, index: S::Concrete) -> &mut Self::Output {
let i = index_to_i(&self.shape, &self.strides, index);
Expand Down
38 changes: 19 additions & 19 deletions src/tensor/cpu/iterate.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::device::StridedArray;
use crate::shapes::{BroadcastStridesTo, Dtype, Shape};
use crate::shapes::{BroadcastStridesTo, Shape};
use std::sync::Arc;
use std::vec::Vec;

Expand Down Expand Up @@ -64,66 +64,66 @@ impl<S: Shape> NdIndex<S> {
}
}

pub(crate) struct StridedRefIter<'a, S: Shape, Elem> {
data: &'a Vec<Elem>,
pub(crate) struct StridedRefIter<'a, S: Shape, E> {
data: &'a Vec<E>,
index: NdIndex<S>,
}

pub(crate) struct StridedMutIter<'a, S: Shape, Elem> {
data: &'a mut Vec<Elem>,
pub(crate) struct StridedMutIter<'a, S: Shape, E> {
data: &'a mut Vec<E>,
index: NdIndex<S>,
}

pub(crate) struct StridedRefIndexIter<'a, S: Shape, Elem> {
data: &'a Vec<Elem>,
pub(crate) struct StridedRefIndexIter<'a, S: Shape, E> {
data: &'a Vec<E>,
index: NdIndex<S>,
}

pub(crate) struct StridedMutIndexIter<'a, S: Shape, Elem> {
data: &'a mut Vec<Elem>,
pub(crate) struct StridedMutIndexIter<'a, S: Shape, E> {
data: &'a mut Vec<E>,
index: NdIndex<S>,
}

impl<S: Shape, Elem: Dtype> StridedArray<S, Elem> {
pub(crate) fn buf_iter(&self) -> std::slice::Iter<'_, Elem> {
impl<S: Shape, E: Clone> StridedArray<S, E> {
pub(crate) fn buf_iter(&self) -> std::slice::Iter<'_, E> {
self.data.iter()
}

pub(crate) fn buf_iter_mut(&mut self) -> std::slice::IterMut<'_, Elem> {
pub(crate) fn buf_iter_mut(&mut self) -> std::slice::IterMut<'_, E> {
std::sync::Arc::make_mut(&mut self.data).iter_mut()
}

pub(crate) fn iter(&self) -> StridedRefIter<S, Elem> {
pub(crate) fn iter(&self) -> StridedRefIter<S, E> {
StridedRefIter {
data: self.data.as_ref(),
index: NdIndex::new(self.shape, self.strides),
}
}

pub(crate) fn iter_mut(&mut self) -> StridedMutIter<S, Elem> {
pub(crate) fn iter_mut(&mut self) -> StridedMutIter<S, E> {
StridedMutIter {
data: std::sync::Arc::make_mut(&mut self.data),
index: NdIndex::new(self.shape, self.strides),
}
}

pub(crate) fn iter_with_index(&self) -> StridedRefIndexIter<S, Elem> {
pub(crate) fn iter_with_index(&self) -> StridedRefIndexIter<S, E> {
StridedRefIndexIter {
data: self.data.as_ref(),
index: NdIndex::new(self.shape, self.strides),
}
}

pub(crate) fn iter_mut_with_index(&mut self) -> StridedMutIndexIter<S, Elem> {
pub(crate) fn iter_mut_with_index(&mut self) -> StridedMutIndexIter<S, E> {
StridedMutIndexIter {
data: std::sync::Arc::make_mut(&mut self.data),
index: NdIndex::new(self.shape, self.strides),
}
}
}

impl<S: Shape, Elem: Dtype> StridedArray<S, Elem> {
pub(crate) fn iter_as<Axes, Dst: Shape>(&self, dst: &Dst) -> StridedRefIter<Dst, Elem>
impl<S: Shape, E: Clone> StridedArray<S, E> {
pub(crate) fn iter_as<Axes, Dst: Shape>(&self, dst: &Dst) -> StridedRefIter<Dst, E>
where
S: BroadcastStridesTo<Dst, Axes>,
{
Expand All @@ -133,7 +133,7 @@ impl<S: Shape, Elem: Dtype> StridedArray<S, Elem> {
}
}

pub(crate) fn iter_mut_as<Axes, Dst: Shape>(&mut self, dst: &Dst) -> StridedMutIter<Dst, Elem>
pub(crate) fn iter_mut_as<Axes, Dst: Shape>(&mut self, dst: &Dst) -> StridedMutIter<Dst, E>
where
S: BroadcastStridesTo<Dst, Axes>,
{
Expand Down
Loading
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载