+
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/tensor/cpu/allocate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ impl<E: Unit> TensorFromVec<E> for Cpu {
}
}

impl<S: Shape, E: Unit> AsVec for StridedArray<S, E> {
impl<S: Shape, E: Unit> AsVec<E> for StridedArray<S, E> {
fn as_vec(&self) -> Vec<E> {
let mut out = Vec::with_capacity(self.shape.num_elements());
let mut iter = self.iter();
Expand Down
23 changes: 20 additions & 3 deletions src/tensor/cuda/allocate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,22 +110,38 @@ where

impl<E: Unit> CopySlice<E> for Cuda {
fn copy_from<S: Shape, T>(dst: &mut Tensor<S, E, Self, T>, src: &[E]) {
assert_eq!(
dst.storage.data.len(),
src.len(),
"Slices must have same number of elements as *physical* storage of tensors."
);
dst.device
.dev
.sync_copy_into(src, Arc::make_mut(&mut dst.storage.data))
.unwrap();
}
fn copy_into<S: Shape, T>(src: &Tensor<S, E, Self, T>, dst: &mut [E]) {
assert_eq!(
src.storage.data.len(),
dst.len(),
"Slices must have same number of elements as *physical* storage of tensors."
);
src.device
.dev
.sync_copy_from(src.storage.data.as_ref(), dst)
.unwrap();
}
}

impl<S: Shape, E: Unit> AsVec for CudaArray<S, E> {
impl<S: Shape, E: Unit> AsVec<E> for CudaArray<S, E> {
fn as_vec(&self) -> Vec<E> {
self.data.clone_async().unwrap().try_into().unwrap()
let buf = self.data.clone_async().unwrap().try_into().unwrap();
let a = StridedArray {
data: Arc::new(buf),
shape: self.shape,
strides: self.strides,
};
a.as_vec()
}
}

Expand Down Expand Up @@ -157,8 +173,9 @@ where
{
type Array = <StridedArray<S, E> as AsArray>::Array;
fn array(&self) -> Self::Array {
let buf = self.data.clone_async().unwrap().try_into().unwrap();
let a = StridedArray {
data: Arc::new(self.as_vec()),
data: Arc::new(buf),
shape: self.shape,
strides: self.strides,
};
Expand Down
22 changes: 10 additions & 12 deletions src/tensor/storage_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ pub trait HasErr: Sized {
type Err: std::fmt::Debug + std::fmt::Display;
}

/// Convert tensors to [std::vec::Vec]
pub trait AsVec<E> {
fn as_vec(&self) -> std::vec::Vec<E>;
}

/// Something that can store nd arrays for a given [Shape] and [Dtype]
pub trait DeviceStorage: 'static + Default + Clone + HasErr {
/// Generic storage type
Expand All @@ -19,7 +24,8 @@ pub trait DeviceStorage: 'static + Default + Clone + HasErr {
+ Clone
+ Send
+ Sync
+ HasShape<Shape = S>;
+ HasShape<Shape = S>
+ AsVec<E>;

/// Generates a random u64 number
fn random_u64(&self) -> u64;
Expand Down Expand Up @@ -61,7 +67,7 @@ pub trait CopySlice<E: Unit>: DeviceStorage {
}

impl<S: Shape, E: Unit, D: CopySlice<E>, T> Tensor<S, E, D, T> {
/// Copy data from a slice - **panics** if there are not enough elements in the slice.
/// Copy *physical* data from a slice - **panics** if there are not enough elements in the slice.
///
/// ```rust
/// # use dfdx::prelude::*;
Expand All @@ -75,7 +81,7 @@ impl<S: Shape, E: Unit, D: CopySlice<E>, T> Tensor<S, E, D, T> {
D::copy_from(self, src);
}

/// Copy data into a slice - **panics** if there are not enough elements in the tensor.
/// Copy *physical* data into a slice - **panics** if there are not enough elements in the tensor.
///
/// ```rust
/// # use dfdx::prelude::*;
Expand Down Expand Up @@ -269,15 +275,7 @@ where
}
}

/// Convert tensors to [std::vec::Vec]
pub trait AsVec: HasUnitType {
fn as_vec(&self) -> std::vec::Vec<Self::Unit>;
}

impl<S: Shape, E: Unit, D: DeviceStorage, T> AsVec for Tensor<S, E, D, T>
where
D::Storage<S, E>: HasUnitType<Unit = E> + AsVec,
{
impl<S: Shape, E: Unit, D: DeviceStorage, T> AsVec<E> for Tensor<S, E, D, T> {
fn as_vec(&self) -> std::vec::Vec<E> {
self.storage.as_vec()
}
Expand Down
23 changes: 8 additions & 15 deletions src/tensor/tensor_impls.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use rand::distributions::Distribution;

use super::storage_traits::{CopySlice, DeviceStorage, HasErr, TensorFromVec};
use super::storage_traits::{AsVec, DeviceStorage, HasErr, TensorFromVec};
use super::{Cpu, OneFillStorage, SampleTensor, ZeroFillStorage};
use crate::prelude::TensorFrom;
use crate::{
gradients::{NoneTape, OwnedTape, Tape},
shapes::*,
Expand Down Expand Up @@ -66,7 +65,7 @@ impl<S: Shape, E: Unit, D: DeviceStorage, T> HasErr for Tensor<S, E, D, T> {
type Err = D::Err;
}

impl<S: Shape, E: Dtype, D: DeviceStorage> Tensor<S, E, D, NoneTape> {
impl<S: Shape, E: Unit, D: DeviceStorage> Tensor<S, E, D, NoneTape> {
/// Clone and put a [OwnedTape] into the tensor
pub fn trace(&self) -> Tensor<S, E, D, OwnedTape<D>> {
self.clone().traced()
Expand All @@ -77,7 +76,7 @@ impl<S: Shape, E: Dtype, D: DeviceStorage> Tensor<S, E, D, NoneTape> {
}
}

impl<S: Shape, E: Dtype, D: DeviceStorage, T: Tape<D>> Tensor<S, E, D, T> {
impl<S: Shape, E: Unit, D: DeviceStorage, T> Tensor<S, E, D, T> {
/// Clone and insert a new tape of type `New` into the tensor
pub fn retaped<New: Tape<D>>(&self) -> Tensor<S, E, D, New> {
Tensor {
Expand All @@ -101,7 +100,7 @@ pub trait PutTape<T> {
fn put_tape(self, tape: T) -> Self::Output;
}

impl<S: Shape, E: Dtype, D: DeviceStorage, T> PutTape<T> for Tensor<S, E, D> {
impl<S: Shape, E: Unit, D: DeviceStorage, T> PutTape<T> for Tensor<S, E, D> {
type Output = Tensor<S, E, D, T>;
fn put_tape(self, tape: T) -> Self::Output {
Tensor {
Expand Down Expand Up @@ -235,20 +234,14 @@ pub type OnCuda<M> = OnDevice<M, crate::prelude::Cuda>;
/// Equivalent to `OnDevice<M, Cpu>`
pub type OnCpu<M> = OnDevice<M, Cpu>;

impl<
S: Shape,
E: Dtype + Unit,
T,
D1: DeviceStorage + CopySlice<E>,
D2: DeviceStorage + TensorFromVec<E>,
> ToDevice<D2> for Tensor<S, E, D1, T>
impl<S: Shape, E: Dtype + Unit, T, D1: DeviceStorage, D2: TensorFromVec<E>> ToDevice<D2>
for Tensor<S, E, D1, T>
{
type Output = Tensor<S, E, D2, NoneTape>;

fn to_device(&self, device: &D2) -> Self::Output {
let mut buf = std::vec![E::default(); self.shape().num_elements()];
self.copy_into(&mut buf);
device.tensor((buf, *self.shape()))
let buf = self.as_vec();
device.tensor_from_vec(buf, *self.shape())
}
}

Expand Down
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载