-
-
Notifications
You must be signed in to change notification settings - Fork 104
Add TensorFrom trait to create tensors from both vectors and arrays. #414
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Would it be possible to have stack-allocated tensors if they're small enough? |
Currently, no, because StridedArrays store tensors as an |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Want to also update the examples/06-mnist to use this new trait?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking great! Just need to settle on a good name for the non-const shape method
examples/06-mnist.rs
Outdated
let mut lbl = dev.zeros(); | ||
lbl.copy_from(&lbl_data); | ||
(img, lbl) | ||
(dev.tensor_from_vec(img_data), dev.tensor_from_vec(lbl_data)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🎉
src/tensor/storage_traits.rs
Outdated
/// # let dev: Cpu = Default::default(); | ||
/// let _ = dev.dynamic_tensor_from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3]); | ||
/// ``` | ||
fn dynamic_tensor_from_vec<S: Shape>(&self, src: Vec<E>, shape: S) -> Tensor<S, E, Self> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thoughts on these names?
tensor_from_vec_with_shape
tensor_from_shaped_vec
tensor_from_vec_like
(similar tozeros_like
/ones_like
/etc, however I do think that is less clear in this case)
I think the with_shape
one is the most clear about how its different from the other calls of this trait, and how you call it. Usage of dynamic
doesn't necessarily imply you have to give a shape, but I do like that it implies it should be used with runtime shapes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I've gone with tensor_from_vec_with_shape, but I'm not super happy that the names are as long as they are. It's kind of a shame that rust doesn't seem to have the features necessary to have the features necessary for us to merge the TensorFromVec and TensorFromArray traits while still allowing users to use these traits with a generic Device<E>
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm could we do something like this?
trait TryTensor<Src, S: Shape, E: Unit>: DeviceStorage {
fn tensor(&self, src: Src) -> Tensor<S, E, Self>;
}
impl<E, S: ConstShape> TryTensor<Vec<E>, S, E> for ... {}
impl<E, S: Shape> TryTensor<(Vec<E>, S), S, E> for ... {}
impl<E, S: ConstShape, Src: RustArray<Shape=S>> TryTensor<Src, S, E> for ... {}
which would look like:
let a: Tensor<Rank1<3>, ...> = dev.tensor([1.0, 2.0, 3.0]);
let a: Tensor<Rank1<3>, ...> = dev.tensor(vec![1.0, 2, 3]);
// note: tuple
let a: Tensor<(usize, ), ...> = dev.tensor((vec![1.0, 2.0, 3.0], (3, ));
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a very interesting idea actually, but I don't know that its possible to have all of these implemented for a generic Device, so I'll have to experiment.
Per your previous comment, I've greatly simplified the api, and it is now possible to create tensors from arrays and vectors, even with a generic Device, with only the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice this looks great! Doc example is very clear 🚀
Draft of a solution for #399, which splits TensorFromArray into ConstShapeTensorFrom, DynamicTensorFrom, and the public-facing TensorFrom. Currently, this doesn't work how I'd like it to, as this cannot be used when the device is a generic
Device
because the traits need to be parameterized by a Shape, which means that it can't be added to the definition ofDevice
.