+
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/tensor/ghost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,31 @@ impl<S: Shape, E: Unit, D: DeviceStorage, T> Tensor<S, E, D, T> {
}
}

impl<S: Shape, E: Unit, D: DeviceStorage> Clone for GhostTensor<S, E, D> {
fn clone(&self) -> Self {
Self {
id: self.id,
len: self.len,
shape: self.shape,
strides: self.strides,
dev: self.dev.clone(),
marker: self.marker,
}
}
}

impl<S: Shape, E: Unit, D: DeviceStorage> super::storage_traits::HasErr for GhostTensor<S, E, D> {
type Err = D::Err;
}

impl<S: Shape, E: Unit, D: DeviceStorage> HasShape for GhostTensor<S, E, D> {
type WithShape<New: Shape> = GhostTensor<New, E, D>;
type Shape = S;
fn shape(&self) -> &Self::Shape {
&self.shape
}
}

impl<S: Shape, E: Unit, D: DeviceStorage> super::storage_traits::AllocGrad
for GhostTensor<S, E, D>
{
Expand Down
45 changes: 21 additions & 24 deletions src/tensor/gradients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@
use std::collections::{BTreeMap, BTreeSet};
use std::{boxed::Box, vec::Vec};

use super::ghost::GhostTensor;
use super::{
storage_traits::{AllocGrad, DeviceStorage},
unique_id, Tensor, UniqueId,
};
use super::tensorlike::Tensorlike;
use super::{storage_traits::DeviceStorage, unique_id, Tensor, UniqueId};
use crate::shapes::{Shape, Unit};

/// A generic container for keeping gradients of tensors keyed by the
Expand Down Expand Up @@ -59,9 +56,9 @@ impl<E: Unit, D: DeviceStorage> Gradients<E, D> {
/// Inserts a gradient for `t`
pub(crate) fn try_alloc_for<S: Shape>(
&mut self,
t: &GhostTensor<S, E, D>,
t: &impl Tensorlike<S, E, D>,
) -> Result<(), D::Err> {
if let std::collections::btree_map::Entry::Vacant(e) = self.gradient_by_id.entry(t.id) {
if let std::collections::btree_map::Entry::Vacant(e) = self.gradient_by_id.entry(t.id()) {
e.insert(t.try_alloc_grad()?);
}
Ok(())
Expand Down Expand Up @@ -94,15 +91,15 @@ impl<E: Unit, D: DeviceStorage> Gradients<E, D> {
/// Returns a mutable reference to the data associated with `t`.
///
/// **Panics** if data associated with `t` is not found. This indicates an unrecoverable bug.
pub(crate) fn get_mut<S: Shape>(&mut self, t: &GhostTensor<S, E, D>) -> &mut D::Vec<E> {
self.gradient_by_id.get_mut(&t.id).unwrap()
pub(crate) fn get_mut<S: Shape>(&mut self, t: &impl Tensorlike<S, E, D>) -> &mut D::Vec<E> {
self.gradient_by_id.get_mut(&t.id()).unwrap()
}

/// Returns a mutable reference to the data associated with `t`.
///
/// **Panics** if data associated with `t` is not found. This indicates an unrecoverable bug.
pub(crate) fn get_ref<S: Shape>(&mut self, t: &GhostTensor<S, E, D>) -> &D::Vec<E> {
self.gradient_by_id.get(&t.id).unwrap()
pub(crate) fn get_ref<S: Shape>(&mut self, t: &impl Tensorlike<S, E, D>) -> &D::Vec<E> {
self.gradient_by_id.get(&t.id()).unwrap()
}

/// Clones the gradient and transforms it into a tensor.
Expand All @@ -128,10 +125,10 @@ impl<E: Unit, D: DeviceStorage> Gradients<E, D> {
/// **Panics** if `l` and `r` have the same id.
pub(crate) fn mut_and_ref<L: Shape, R: Shape>(
&mut self,
l: &GhostTensor<L, E, D>,
r: &GhostTensor<R, E, D>,
l: &impl Tensorlike<L, E, D>,
r: &impl Tensorlike<R, E, D>,
) -> (&mut D::Vec<E>, &D::Vec<E>) {
assert_ne!(l.id, r.id);
assert_ne!(l.id(), r.id());
let l_ptr = self.get_mut(l) as *mut _;
let r_ptr = self.get_ref(r) as *const _;
let l_ref = unsafe { &mut *l_ptr };
Expand All @@ -142,13 +139,13 @@ impl<E: Unit, D: DeviceStorage> Gradients<E, D> {
/// Borrows a triplet of gradients `(&mut L1, &mut L2, &R)`.
pub(crate) fn muts_and_ref<L1: Shape, L2: Shape, R: Shape>(
&mut self,
l1: &GhostTensor<L1, E, D>,
l2: &GhostTensor<L2, E, D>,
r: &GhostTensor<R, E, D>,
l1: &impl Tensorlike<L1, E, D>,
l2: &impl Tensorlike<L2, E, D>,
r: &impl Tensorlike<R, E, D>,
) -> (&mut D::Vec<E>, &mut D::Vec<E>, &D::Vec<E>) {
assert_ne!(l1.id, l2.id);
assert_ne!(l1.id, r.id);
assert_ne!(l2.id, r.id);
assert_ne!(l1.id(), l2.id());
assert_ne!(l1.id(), r.id());
assert_ne!(l2.id(), r.id());
let l1_ptr = self.get_mut(l1) as *mut _;
let l2_ptr = self.get_mut(l2) as *mut _;
let r_ptr = self.get_ref(r) as *const _;
Expand All @@ -161,13 +158,13 @@ impl<E: Unit, D: DeviceStorage> Gradients<E, D> {
#[inline]
pub(crate) fn many_and_ref<L: Shape, R: Shape>(
&mut self,
ls: &Vec<GhostTensor<L, E, D>>,
r: &GhostTensor<R, E, D>,
ls: &Vec<impl Tensorlike<L, E, D>>,
r: &impl Tensorlike<R, E, D>,
) -> (Vec<&mut D::Vec<E>>, &D::Vec<E>) {
for i in 0..ls.len() {
assert_ne!(ls[i].id, r.id);
assert_ne!(ls[i].id(), r.id());
for j in (i + 1)..ls.len() {
assert_ne!(ls[i].id, ls[j].id);
assert_ne!(ls[i].id(), ls[j].id());
}
}
let l_refs: Vec<&mut D::Vec<E>> = ls
Expand Down
2 changes: 2 additions & 0 deletions src/tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,15 @@ mod masks;
pub(crate) mod numpy;
#[cfg(feature = "safetensors")]
pub mod safetensors;
mod tensorlike;
mod unique_id;

pub(crate) mod storage_traits;
mod tensor_impls;

pub(crate) use ghost::GhostTensor;
pub(crate) use storage_traits::{OneFillStorage, ZeroFillStorage};
pub(crate) use tensorlike::Tensorlike;

pub use cpu::{Cpu, CpuError};
#[cfg(not(feature = "cuda"))]
Expand Down
65 changes: 65 additions & 0 deletions src/tensor/tensorlike.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
use crate::{
prelude::{HasErr, HasShape},
shapes::{Shape, Unit},
tensor::DeviceStorage,
};

use super::{storage_traits::AllocGrad, GhostTensor, Tensor, UniqueId};

/// Contains everything that comprises a tensor, except possibly for the actual data. This really
/// exists to unify handling of [Tensor] and [GhostTensor].
///
/// *If it looks like a tensor and barks like a tensor, then pet it like a tensor.*
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤣

pub trait Tensorlike<S: Shape, E: Unit, D: DeviceStorage>:
AllocGrad<Gradient = D::Vec<E>> + HasErr<Err = D::Err> + HasShape<Shape = S>
{
fn id(&self) -> UniqueId;
fn len(&self) -> usize;
fn strides(&self) -> S::Concrete;
fn dev(&self) -> &D;
fn data(&self) -> Option<&D::Vec<E>>;
}

impl<S: Shape, E: Unit, D: DeviceStorage, T> Tensorlike<S, E, D> for Tensor<S, E, D, T> {
fn id(&self) -> UniqueId {
self.id
}

fn len(&self) -> usize {
self.device.len(&self.data)
}

fn strides(&self) -> S::Concrete {
self.strides
}

fn dev(&self) -> &D {
&self.device
}

fn data(&self) -> Option<&<D as DeviceStorage>::Vec<E>> {
Some(self.data.as_ref())
}
}

impl<S: Shape, E: Unit, D: DeviceStorage> Tensorlike<S, E, D> for GhostTensor<S, E, D> {
fn id(&self) -> UniqueId {
self.id
}

fn len(&self) -> usize {
self.len
}

fn strides(&self) -> S::Concrete {
self.strides
}

fn dev(&self) -> &D {
&self.dev
}

fn data(&self) -> Option<&<D as DeviceStorage>::Vec<E>> {
None
}
}
7 changes: 4 additions & 3 deletions src/tensor_ops/conv2d/cpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::prelude::Tensorlike;
use crate::shapes::{Dtype, Shape};
use crate::tensor::{cpu::*, GhostTensor, Tensor, ZerosTensor};
use crate::tensor::{cpu::*, Tensor, ZerosTensor};
use crate::tensor_ops::matmul::cpu_kernel::MatMulImpl;

use super::{Conv2DKernel, Conv2DOp};
Expand Down Expand Up @@ -209,7 +210,7 @@ where
grad_lhs: &mut Self::Vec<E>,
rhs: &Tensor<R, E, Self>,
grad_rhs: &mut Self::Vec<E>,
out: &GhostTensor<O, E, Self>,
out: &impl Tensorlike<O, E, Self>,
grad_out: &Self::Vec<E>,
) -> Result<(), Self::Err> {
let f_tr_shape = op.filters_tr_shape();
Expand All @@ -232,7 +233,7 @@ where

let [lstride, ostride] = match L::NUM_DIMS {
3 => [0; 2],
4 => [lhs.strides[0], out.strides[0]],
4 => [lhs.strides[0], out.strides()[0]],
_ => unreachable!(),
};
let lhs = lhs.data.as_ref();
Expand Down
4 changes: 2 additions & 2 deletions src/tensor_ops/conv2d/cuda_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use cudarc::driver::{DeviceRepr, LaunchAsync, ValidAsZeroBits};

use crate::{
shapes::*,
tensor::{launch_cfg, unique_id, Cuda, GhostTensor, Tensor},
tensor::{launch_cfg, unique_id, Cuda, Tensor, Tensorlike},
};

use std::sync::Arc;
Expand Down Expand Up @@ -112,7 +112,7 @@ where
grad_lhs: &mut Self::Vec<E>,
rhs: &Tensor<R, E, Self>,
grad_rhs: &mut Self::Vec<E>,
_: &GhostTensor<O, E, Self>,
_: &impl Tensorlike<O, E, Self>,
grad_out: &Self::Vec<E>,
) -> Result<(), Self::Err> {
let patches_item_numel = op.chan_out * op.kernel * op.kernel * op.h_in * op.w_in;
Expand Down
8 changes: 4 additions & 4 deletions src/tensor_ops/conv2d/cudnn_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use cudarc::driver::DeviceSlice;

use crate::{
shapes::*,
tensor::{unique_id, Cuda, GhostTensor, Tensor},
tensor::{unique_id, Cuda, Tensor, Tensorlike},
};

use std::sync::Arc;
Expand Down Expand Up @@ -96,7 +96,7 @@ where
grad_lhs: &mut Self::Vec<E>,
rhs: &Tensor<R, E, Self>,
grad_rhs: &mut Self::Vec<E>,
out: &GhostTensor<O, E, Self>,
out: &impl Tensorlike<O, E, Self>,
grad_out: &Self::Vec<E>,
) -> Result<(), Self::Err> {
let conv = self.cudnn.create_conv2d::<E>(
Expand All @@ -114,8 +114,8 @@ where
make_4d::<R>(rhs.shape.concrete(), 1).map(|x| x as i32),
)?;
let out = self.cudnn.create_4d_tensor_ex::<E>(
make_4d::<O>(out.shape.concrete(), 1).map(|x| x as i32),
make_4d::<O>(out.strides, 0).map(|x| x as i32),
make_4d::<O>(out.shape().concrete(), 1).map(|x| x as i32),
make_4d::<O>(out.strides(), 0).map(|x| x as i32),
)?;

{
Expand Down
2 changes: 1 addition & 1 deletion src/tensor_ops/conv2d/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ pub(super) trait Conv2DKernel<E: Dtype>: DeviceStorage {
grad_lhs: &mut Self::Vec<E>,
rhs: &Tensor<R, E, Self>,
grad_rhs: &mut Self::Vec<E>,
out: &GhostTensor<O, E, Self>,
out: &impl Tensorlike<O, E, Self>,
grad_out: &Self::Vec<E>,
) -> Result<(), Self::Err>;
}
Expand Down
7 changes: 4 additions & 3 deletions src/tensor_ops/convtrans2d/cpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::prelude::Tensorlike;
use crate::shapes::{Dtype, Shape};
use crate::tensor::{cpu::*, GhostTensor, Tensor};
use crate::tensor::{cpu::*, Tensor};
use crate::tensor_ops::matmul::cpu_kernel::MatMulImpl;

use std::sync::Arc;
Expand Down Expand Up @@ -200,7 +201,7 @@ where
grad_lhs: &mut Self::Vec<E>,
rhs: &Tensor<R, E, Self>,
grad_rhs: &mut Self::Vec<E>,
out: &GhostTensor<O, E, Self>,
out: &impl Tensorlike<O, E, Self>,
grad_out: &Self::Vec<E>,
) -> Result<(), Self::Err> {
let f_tr_shape = op.filters_tr_shape();
Expand All @@ -223,7 +224,7 @@ where

let [lstride, ostride] = match L::NUM_DIMS {
3 => [0; 2],
4 => [lhs.strides[0], out.strides[0]],
4 => [lhs.strides[0], out.strides()[0]],
_ => unreachable!(),
};
let lhs = lhs.data.as_ref();
Expand Down
4 changes: 2 additions & 2 deletions src/tensor_ops/convtrans2d/cuda_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use cudarc::driver::{DeviceRepr, LaunchAsync, ValidAsZeroBits};

use crate::{
shapes::*,
tensor::{launch_cfg, Cuda, GhostTensor, Tensor},
tensor::{launch_cfg, Cuda, Tensor, Tensorlike},
};

use std::sync::Arc;
Expand Down Expand Up @@ -99,7 +99,7 @@ where
grad_lhs: &mut Self::Vec<E>,
rhs: &Tensor<R, E, Self>,
grad_rhs: &mut Self::Vec<E>,
_: &GhostTensor<O, E, Self>,
_: &impl Tensorlike<O, E, Self>,
grad_out: &Self::Vec<E>,
) -> Result<(), Self::Err> {
let patches_numel = op.batch * op.chan_out * op.kernel * op.kernel * op.h_in * op.w_in;
Expand Down
2 changes: 1 addition & 1 deletion src/tensor_ops/convtrans2d/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ pub(super) trait ConvTrans2DKernel<E: Dtype>: DeviceStorage {
grad_lhs: &mut Self::Vec<E>,
rhs: &Tensor<R, E, Self>,
grad_rhs: &mut Self::Vec<E>,
out: &GhostTensor<O, E, Self>,
out: &impl Tensorlike<O, E, Self>,
grad_out: &Self::Vec<E>,
) -> Result<(), Self::Err>;
}
Expand Down
4 changes: 2 additions & 2 deletions src/tensor_ops/pow/cuda_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ where
fn backward<S: Shape>(
&self,
op: super::PowiKernelOp,
inp: Result<&Tensor<S, E, Self>, &GhostTensor<S, E, Self>>,
inp: &impl Tensorlike<S, E, Self>,
grad_inp: &mut Self::Vec<E>,
out: Result<&Tensor<S, E, Self>, &GhostTensor<S, E, Self>>,
out: &impl Tensorlike<S, E, Self>,
grad_out: &Self::Vec<E>,
) -> Result<(), Self::Err> {
self.backward(
Expand Down
10 changes: 5 additions & 5 deletions src/tensor_ops/sum_to/cpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
shapes::{Axes, Dtype, HasAxes, ReduceShapeTo, Shape},
tensor::{Cpu, GhostTensor, Tensor, ZerosTensor},
tensor::{Cpu, Tensor, Tensorlike, ZerosTensor},
tensor_ops::utilities::reduction_utils::index_for_reductions,
};

Expand Down Expand Up @@ -39,7 +39,7 @@ impl<E: Dtype> super::SumKernel<E> for Cpu {
fn backward<Src: Shape, Dst: Shape, Ax: Axes>(
&self,
_dst: Dst,
inp: &GhostTensor<Src, E, Self>,
inp: &impl Tensorlike<Src, E, Self>,
grad_inp: &mut Self::Vec<E>,
grad_out: &Self::Vec<E>,
) -> Result<(), Self::Err>
Expand All @@ -49,13 +49,13 @@ impl<E: Dtype> super::SumKernel<E> for Cpu {
if Dst::NUM_DIMS == 0 {
debug_assert_eq!(grad_out.len(), 1);
let v = grad_out[0];
let scale = E::from_usize(inp.shape.num_elements() / inp.len).unwrap();
let scale = E::from_usize(inp.shape().num_elements() / inp.len()).unwrap();
for i in grad_inp.iter_mut() {
*i += v * scale;
}
} else {
let num_elems_reduced = <Src as HasAxes<Ax>>::size(&inp.shape);
let mut idx = index_for_reductions::<Src, Ax>(inp.shape, inp.strides);
let num_elems_reduced = <Src as HasAxes<Ax>>::size(inp.shape());
let mut idx = index_for_reductions::<Src, Ax>(*inp.shape(), inp.strides());
for &o in grad_out.iter() {
for _ in 0..num_elems_reduced {
grad_inp[idx.next().unwrap()] += o;
Expand Down
Loading
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载