+
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/shapes/replace_dim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub trait RemoveDimTo<Dst: Shape, Idx: Shape>: Shape {
let src_dims = self.concrete();
let idx_dims = idx.concrete();
for i in 0..Idx::NUM_DIMS {
assert_eq!(src_dims[i], idx_dims[i]);
assert_eq!(src_dims[i], idx_dims[i], "dimension {i} not the same");
}
}

Expand Down Expand Up @@ -49,7 +49,7 @@ pub trait ReplaceDimTo<Dst: Shape, Idx: Shape>: Shape {
let src_dims = self.concrete();
let idx_dims = idx.concrete();
for i in 0..Idx::NUM_DIMS - 1 {
assert_eq!(src_dims[i], idx_dims[i]);
assert_eq!(src_dims[i], idx_dims[i], "dimension {i} not the same");
}
} else {
// batch replace case - we actually don't need to check this case
Expand Down
1 change: 1 addition & 0 deletions src/tensor_ops/select_and_gather/gather.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ __device__ unsigned int get_gathered_index(
// indices for dimensions before, at, and after the indexed dimension
unsigned int idx_before = index / (elem_size * row_len);
unsigned int idx_mid = idx[idx_idx];
assert(idx_mid < inp_dims[ax]);
unsigned int idx_after = index % elem_size;

// recombine
Expand Down
105 changes: 75 additions & 30 deletions src/tensor_ops/select_and_gather/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,33 +187,34 @@ mod tests {
use crate::tests::*;

#[test]
#[should_panic]
fn test_remove_wrong_index_shape_2d() {
#[should_panic = "dimension 0 not the same"]
fn test_select_wrong_index_shape_2d() {
let dev: TestDevice = Default::default();
let t: Tensor<_, TestDtype, _> = dev.sample_like(&(5, 3), rand_distr::StandardNormal);
// here we are selecting from axis 1, so the 7 should actually be a 5
let _ = t.trace().select(dev.zeros_like(&(7,)));
}

#[test]
#[should_panic]
fn test_remove_wrong_index_shape_3d() {
#[should_panic = "dimension 1 not the same"]
fn test_select_wrong_index_shape_3d() {
let dev: TestDevice = Default::default();
let t: Tensor<_, TestDtype, _> = dev.sample_like(&(7, 5, 3), rand_distr::StandardNormal);
let _ = t.trace().select(dev.zeros_like(&(7, 4)));
}

#[cfg(not(feature = "test-cuda"))]
#[test]
#[should_panic]
fn test_remove_index_out_of_bounds() {
#[should_panic = "Index out of bounds: index=[7]"]
fn test_select_index_out_of_bounds() {
let dev: TestDevice = Default::default();
let t: Tensor<Rank1<5>, TestDtype, _> = dev.sample_normal();
let _ = t.trace().select(dev.tensor(7));
}

#[test]
#[should_panic]
fn test_replace_wrong_index_shape_3d1() {
#[should_panic = "dimension 0 not the same"]
fn test_gather_wrong_index_shape_3d1() {
let dev: TestDevice = Default::default();
let t: Tensor<_, TestDtype, _> = dev.sample_like(&(5, 3, 1), rand_distr::StandardNormal);
let r = t.trace().gather(dev.zeros_like(&(7,)));
Expand All @@ -222,23 +223,33 @@ mod tests {
}

#[test]
#[should_panic]
fn test_replace_wrong_index_shape_3d2() {
#[should_panic = "dimension 1 not the same"]
fn test_gather_wrong_index_shape_3d2() {
let dev: TestDevice = Default::default();
let t: Tensor<_, TestDtype, _> = dev.sample_like(&(5, 3, 1), rand_distr::StandardNormal);
let _ = t.trace().gather(dev.zeros_like(&(5, 4, 2)));
}

#[cfg(not(feature = "test-cuda"))]
#[test]
#[should_panic]
fn test_replace_index_out_of_bounds() {
#[should_panic = "Index out of bounds: index=[7]"]
fn test_gather_index_out_of_bounds() {
let dev: TestDevice = Default::default();
let t: Tensor<Rank1<5>, TestDtype, _> = dev.sample_normal();
let _ = t.trace().gather(dev.tensor([7, 6, 1, 2]));
}

#[cfg(not(feature = "test-cuda"))]
#[test]
fn test_remove_1d_backward() {
#[should_panic = "Index out of bounds: index=[5, 0]"]
fn test_gather_batch_out_of_bounds() {
let dev: TestDevice = Default::default();
let t: Tensor<Rank2<4, 5>, TestDtype, _> = dev.sample_normal();
let _ = t.trace().try_gather(dev.tensor([[5, 0, 0], [0, 0, 0]]));
}

#[test]
fn test_select_1d_backward() {
let dev: TestDevice = Default::default();
let t: Tensor<Rank1<5>, TestDtype, _> = dev.sample_normal();
let r = t.trace().select(dev.tensor(0));
Expand All @@ -249,7 +260,7 @@ mod tests {
}

#[test]
fn test_replace_1d_backward() {
fn test_gather_1d_backward() {
let dev: TestDevice = Default::default();
let t: Tensor<Rank1<5>, TestDtype, _> = dev.sample_normal();
let r = t.trace().gather(dev.tensor([0, 1, 1, 3]));
Expand All @@ -269,7 +280,7 @@ mod tests {
}

#[test]
fn test_replace_1d_less_backward() {
fn test_gather_1d_less_backward() {
let dev: TestDevice = Default::default();
let t: Tensor<Rank1<5>, TestDtype, _> = dev.sample_normal();
let t_array = t.array();
Expand All @@ -280,17 +291,7 @@ mod tests {
}

#[test]
fn test_select_last_2d() {
let dev: TestDevice = Default::default();
let t: Tensor<_, TestDtype, _> = dev.tensor([[1.0, 2.0, 3.0], [-1.0, -2.0, -3.0]]);
let r = t.trace().select(dev.tensor([1, 1]));
assert_eq!(r.array(), [2.0, -2.0]);
let g = r.mean().backward();
assert_eq!(g.get(&t).array(), [[0.0, 0.5, 0.0], [0.0, 0.5, 0.0]]);
}

#[test]
fn test_replace_1d_more_backward() {
fn test_gather_1d_more_backward() {
let dev: TestDevice = Default::default();
let t: Tensor<Rank1<5>, TestDtype, _> = dev.sample_normal();
let _t = t.array();
Expand All @@ -307,7 +308,51 @@ mod tests {
}

#[test]
fn test_remove_3d_axis_0_backward() {
fn test_select_2d_axis_0() {
let dev: TestDevice = Default::default();
let t: Tensor<_, TestDtype, _> = dev.tensor([[1.0, 2.0, 3.0], [-1.0, -2.0, -3.0]]);
let r = t.trace().select(dev.tensor(0));
assert_eq!(r.array(), [1.0, 2.0, 3.0]);
let g = r.mean().backward();
assert_eq!(g.get(&t).array(), [[1.0 / 3.0; 3], [0.0; 3]]);
}

#[test]
fn test_select_2d_axis_1() {
let dev: TestDevice = Default::default();
let t: Tensor<_, TestDtype, _> = dev.tensor([[1.0, 2.0, 3.0], [-1.0, -2.0, -3.0]]);
let r = t.trace().select(dev.tensor([1, 1]));
assert_eq!(r.array(), [2.0, -2.0]);
let g = r.mean().backward();
assert_eq!(g.get(&t).array(), [[0.0, 0.5, 0.0], [0.0, 0.5, 0.0]]);
}

#[test]
fn test_select_2d_broadcasted() {
let dev: TestDevice = Default::default();
let t: Tensor<_, TestDtype, _> = dev.tensor([1.0, 2.0, 3.0]);
let r = t
.trace()
.broadcast::<Rank2<2, 3>, _>()
.select(dev.tensor([0, 1]));
assert_eq!(r.array(), [1.0, 2.0]);
let g = r.mean().backward();
assert_eq!(g.get(&t).array(), [0.5, 0.5, 0.0]);
}

#[test]
fn test_gather_2d_broadcasted() {
let dev: TestDevice = Default::default();
let t: Tensor<_, TestDtype, _> = dev.tensor([1.0, 2.0, 3.0]);
let idx: Tensor<Rank2<2, 2>, usize, _> = dev.tensor([[0, 1], [1, 2]]);
let r: Tensor<Rank2<2, 2>, _, _, _> = t.trace().broadcast::<Rank2<2, 3>, _>().gather(idx);
assert_eq!(r.array(), [[1.0, 2.0], [2.0, 3.0]]);
let g = r.mean().backward();
assert_eq!(g.get(&t).array(), [0.25, 0.5, 0.25]);
}

#[test]
fn test_select_3d_axis_0_backward() {
let dev: TestDevice = Default::default();
let t: Tensor<Rank3<2, 3, 4>, TestDtype, _> = dev.sample_normal();
let t_array = t.array();
Expand All @@ -319,7 +364,7 @@ mod tests {
}

#[test]
fn test_remove_3d_axis_1_backward() {
fn test_select_3d_axis_1_backward() {
let dev: TestDevice = Default::default();
let t: Tensor<Rank3<2, 3, 4>, TestDtype, _> = dev.sample_normal();
let t_array = t.array();
Expand All @@ -339,7 +384,7 @@ mod tests {
}

#[test]
fn test_remove_3d_axis_2_backward() {
fn test_select_3d_axis_2_backward() {
let dev: TestDevice = Default::default();
let t: Tensor<Rank3<2, 3, 4>, TestDtype, _> = dev.sample_normal();
let t_array = t.array();
Expand Down Expand Up @@ -370,7 +415,7 @@ mod tests {
}

#[test]
fn test_select_batch_backwards() {
fn test_gather_batch_backwards() {
let dev: TestDevice = Default::default();
let t: Tensor<Rank2<4, 5>, TestDtype, _> = dev.sample_normal();
let t_array = t.array();
Expand Down
1 change: 1 addition & 0 deletions src/tensor_ops/select_and_gather/select.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ __device__ unsigned int get_selected_index(
// indices for dimensions before, at, and after the indexed dimension
unsigned int idx_before = index / elem_size;
unsigned int idx_mid = idx[get_strided_index(idx_before, idx_num_dims, idx_dims, idx_strides)];
assert(idx_mid < inp_dims[idx_num_dims]);
unsigned int idx_after = index % elem_size;

// recombine
Expand Down
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载