+
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 42 additions & 49 deletions src/tensor_ops/conv2d/conv2d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,35 +19,31 @@ __device__ void unfold_input_into_patches(
T *patches // 6d (Batch, Channels, KernelSize, KernelSize, HeightOut, WidthOut)
) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
size_t item_numel = op.batch * op.chan_in * op.kernel * op.kernel * op.h_out * op.w_out;
if (i >= item_numel) {
if (i >= op.batch * op.chan_in * op.h_out * op.w_out) {
return;
}

// patches shape is (B, C, K, K, h_out, w_out)
unsigned int idx = i;
const size_t ow = idx % op.w_out;
idx /= op.w_out;
const size_t oh = idx % op.h_out;
idx /= op.h_out;
const size_t k2 = idx % op.kernel;
idx /= op.kernel;
const size_t k1 = idx % op.kernel;
idx /= op.kernel;
const size_t c = idx % op.chan_in;
idx /= op.chan_in;
const size_t b = idx % op.batch;

const size_t y_plus_p = oh * op.stride + k1;
const size_t y = y_plus_p - op.padding;
const size_t x_plus_p = ow * op.stride + k2;
const size_t x = x_plus_p - op.padding;

if (y >= op.h_in || x >= op.w_in) {
patches[i] = 0.0;
} else {
const size_t i_image = b * strides[0] + c * strides[1] + y * strides[2] + x * strides[3];
patches[i] = image[i_image];
image += b * strides[0] + c * strides[1];
patches += oh * op.w_out + ow;
patches += c * (op.kernel * op.kernel * op.h_out * op.w_out);
patches += b * (op.chan_in * op.kernel * op.kernel * op.h_out * op.w_out);

for (int k1 = 0;k1 < op.kernel;k1++) {
const size_t y = oh * op.stride + k1 - op.padding;
for (int k2 = 0;k2 < op.kernel;k2++) {
const size_t x = ow * op.stride + k2 - op.padding;
*patches = (y >= op.h_in || x >= op.w_in) ? 0.0 : image[y * strides[2] + x * strides[3]];
patches += op.h_out * op.w_out;
}
}
}

Expand All @@ -58,8 +54,7 @@ __device__ void unfold_output_into_patches(
T *patches // 6d (Batch, ChanOut, KernelSize, KernelSize, HeightIn, WidthIn)
) {
const unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
const size_t item_numel = op.chan_out * op.kernel * op.kernel * op.h_in * op.w_in;
if (i >= item_numel) {
if (i >= op.batch * op.chan_out * op.h_in * op.w_in) {
return;
}

Expand All @@ -68,32 +63,29 @@ __device__ void unfold_output_into_patches(
idx /= op.w_in;
const size_t y = idx % op.h_in;
idx /= op.h_in;
const size_t k2 = idx % op.kernel;
idx /= op.kernel;
const size_t k1 = idx % op.kernel;
idx /= op.kernel;
const size_t o = idx % op.chan_out;
idx /= op.chan_out;
const size_t b = idx % op.batch;

const size_t oh_ks = y + op.padding;
const size_t oh_s = oh_ks - k1;
const size_t oh = oh_s / op.stride;
const size_t ow_ks = x + op.padding;
const size_t ow_s = ow_ks - k2;
const size_t ow = ow_s / op.stride;

if (
(oh_ks < k1 || oh_s % op.stride != 0 || oh >= op.h_out)
|| (ow_ks < k2 || ow_s % op.stride != 0 || ow >= op.w_out)
) {
for (auto b = 0; b < op.batch; b++) {
patches[b * item_numel + i] = 0.0;
image_out += b * (op.chan_out * op.h_out * op.w_out) + o * (op.h_out * op.w_out);
patches += y * op.w_in + x;
patches += o * (op.kernel * op.kernel * op.h_in * op.w_in);
patches += b * (op.chan_out * op.kernel * op.kernel * op.h_in * op.w_in);

for (int k1 = 0;k1 < op.kernel;k1++) {
const size_t oh_ks = y + op.padding;
const size_t oh_s = oh_ks - k1;
const size_t oh = oh_s / op.stride;
const bool k1_invalid = (oh_ks < k1 || oh_s % op.stride != 0 || oh >= op.h_out);
for (int k2 = 0;k2 < op.kernel;k2++) {
const size_t ow_ks = x + op.padding;
const size_t ow_s = ow_ks - k2;
const size_t ow = ow_s / op.stride;

const bool invalid = k1_invalid || (ow_ks < k2 || ow_s % op.stride != 0 || ow >= op.w_out);
*patches = invalid ? 0.0 : image_out[oh * op.w_out + ow];
patches += op.h_in * op.w_in;
}
return;
}

for (auto b = 0; b < op.batch; b++) {
size_t image_i = b * (op.chan_out * op.h_out * op.w_out) + o * (op.h_out * op.w_out) + oh * (op.w_out) + ow;
patches[b * item_numel + i] = image_out[image_i];
}
}

Expand All @@ -105,8 +97,7 @@ __device__ void transpose_filters(
T *filters_tr // 4d (ChanIn, ChanOut, KernelSize, KernelSize)
) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
auto numel = op.chan_in * op.chan_out * op.kernel * op.kernel;
if (i >= numel) {
if (i >= op.chan_in * op.chan_out * op.kernel * op.kernel) {
return;
}

Expand All @@ -115,14 +106,13 @@ __device__ void transpose_filters(
idx /= op.kernel;
const size_t k1 = idx % op.kernel;
idx /= op.kernel;
const size_t c = idx % op.chan_in;
idx /= op.chan_in;
const size_t o = idx % op.chan_out;
idx /= op.chan_out;
const size_t c = idx % op.chan_in;

auto i_tr = c * (op.chan_out * op.kernel * op.kernel) + o * (op.kernel * op.kernel) + k1 * (op.kernel) + k2;
auto i_no = o * strides[0] + c * strides[1] + k1 * strides[2] + k2 * strides[3];

filters_tr[i_tr] = filters[i_no];
filters_tr[i] = filters[i_no];
}

template<typename T>
Expand Down Expand Up @@ -151,9 +141,12 @@ __device__ void sum_transposed_filters(
auto i_tr = c * (op.chan_out * op.kernel * op.kernel) + o * (op.kernel * op.kernel) + k1 * (op.kernel) + k2;
auto i_no = o * strides[0] + c * strides[1] + k1 * strides[2] + k2 * strides[3];

const T *ptr = filters_tr + i_tr;

T tmp = 0.0;
for (auto b = 0; b < op.batch; b++) {
tmp += filters_tr[b * numel + i_tr];
tmp += *ptr;
ptr += numel;
}

filters[i_no] += tmp;
Expand Down
4 changes: 2 additions & 2 deletions src/tensor_ops/conv2d/cuda_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ where

let img_strides = self.dev.htod_copy(make_4d::<L>(lhs.strides).into())?;
let unfold_fn = self.dev.get_func(Self::MOD, Self::FNS[0]).unwrap();
let cfg = launch_cfg(patches_numel as u32);
let cfg = launch_cfg((op.batch * op.chan_in * op.h_out * op.w_out) as u32);
let params = (op, lhs.data.as_ref(), &img_strides, &mut patches);
unsafe { unfold_fn.launch(cfg, params) }?;

Expand Down Expand Up @@ -131,7 +131,7 @@ where
{
// unfold grad_out into patches
let unfold_fn = self.dev.get_func(Self::MOD, Self::FNS[1]).unwrap();
let cfg = launch_cfg(patches_item_numel as u32);
let cfg = launch_cfg((op.batch * op.chan_out * op.h_in * op.w_in) as u32);
unsafe { unfold_fn.launch(cfg, (op, grad_out, &mut patches)) }?;
}

Expand Down
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载