+
Skip to content

Optimizing conv kernels a bit #605

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ matrixmultiply = { version = "0.3.2", default-features = false, optional = true
zip = { version = "0.6.2", default-features = false, optional = true }
cblas-sys = { version = "0.1.4", default-features = false, optional = true }
libc = { version = "0.2", default-features = false, optional = true }
cudarc = { version = "0.9.0", default-features = false, optional = true, features = ["driver", "cublas"] }
cudarc = { version = "0.9.2", default-features = false, optional = true, features = ["driver", "cublas"] }
num-traits = { version = "0.2.15", default-features = false }
safetensors = { version = "0.3", default-features = false, optional = true }
memmap2 = { version = "0.5", default-features = false, optional = true }
Expand Down
28 changes: 27 additions & 1 deletion src/tensor/cuda/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@ use cudarc::{
cublas::{result::CublasError, CudaBlas},
driver::{CudaDevice, CudaSlice, CudaStream, DeviceSlice, DriverError},
};
use std::{sync::Arc, vec::Vec};

use std::sync::MutexGuard;
use std::{
sync::{Arc, Mutex},
vec::Vec,
};

/// A Cuda device that enables constructing tensors on GPUs
/// & running GPU kernels.
Expand All @@ -17,6 +22,7 @@ pub struct Cuda {
pub(crate) blas: Arc<CudaBlas>,
/// A second stream for kernels to optionally execute on.
pub(crate) par_stream: Arc<CudaStream>,
pub(crate) workspace: Arc<Mutex<CudaSlice<u8>>>,
}

#[derive(Debug)]
Expand Down Expand Up @@ -67,11 +73,13 @@ impl Cuda {
let dev = CudaDevice::new(ordinal)?;
let blas = Arc::new(CudaBlas::new(dev.clone())?);
let par_stream = Arc::new(dev.fork_default_stream()?);
let workspace = Arc::new(Mutex::new(dev.alloc_zeros::<u8>(0)?));
Ok(Self {
cpu,
dev,
blas,
par_stream,
workspace,
})
}

Expand All @@ -90,6 +98,24 @@ impl Cuda {
}
}

impl Cuda {
pub(crate) unsafe fn get_workspace<E>(
&self,
len: usize,
) -> Result<MutexGuard<CudaSlice<u8>>, CudaError> {
let num_bytes_required = len * std::mem::size_of::<E>();
let mut workspace = self.workspace.as_ref().lock().unwrap();

// re-allocate a larger workspace
if workspace.num_bytes() < num_bytes_required {
// we are about to memset this to zero, so this is still okay
*workspace = unsafe { self.dev.alloc::<u8>(num_bytes_required) }?;
}

Ok(workspace)
}
}

impl std::fmt::Display for CudaError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{self:?}")
Expand Down
94 changes: 35 additions & 59 deletions src/tensor_ops/conv2d/conv2d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ __device__ void unfold_input_into_patches(
T *patches // 6d (Batch, Channels, KernelSize, KernelSize, HeightOut, WidthOut)
) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
const auto patches_numel = op.batch * op.chan_in * op.kernel * op.kernel * op.h_out * op.w_out;
if (i >= patches_numel) {
size_t item_numel = op.batch * op.chan_in * op.kernel * op.kernel * op.h_out * op.w_out;
if (i >= item_numel) {
return;
}

Expand All @@ -37,28 +37,18 @@ __device__ void unfold_input_into_patches(
const size_t c = idx % op.chan_in;
idx /= op.chan_in;
const size_t b = idx % op.batch;
idx /= op.batch;

const size_t y_plus_p = oh * op.stride + k1;
if (y_plus_p < op.padding) {
return;
}
const size_t y = y_plus_p - op.padding;
if (y >= op.h_in) {
return;
}

const size_t x_plus_p = ow * op.stride + k2;
if (x_plus_p < op.padding) {
return;
}
const size_t x = x_plus_p - op.padding;
if (x >= op.w_in) {
return;
}

const size_t i_image = b * strides[0] + c * strides[1] + y * strides[2] + x * strides[3];
patches[i] = image[i_image];
if (y >= op.h_in || x >= op.w_in) {
patches[i] = 0.0;
} else {
const size_t i_image = b * strides[0] + c * strides[1] + y * strides[2] + x * strides[3];
patches[i] = image[i_image];
}
}

template<typename T>
Expand All @@ -67,9 +57,9 @@ __device__ void unfold_output_into_patches(
const T *image_out, // 4d (Batch, ChanOut, HeightOut, WidthOut)
T *patches // 6d (Batch, ChanOut, KernelSize, KernelSize, HeightIn, WidthIn)
) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
const auto patches_numel = op.batch * op.chan_out * op.kernel * op.kernel * op.h_in * op.w_in;
if (i >= patches_numel) {
const unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
const size_t item_numel = op.chan_out * op.kernel * op.kernel * op.h_in * op.w_in;
if (i >= item_numel) {
return;
}

Expand All @@ -83,46 +73,36 @@ __device__ void unfold_output_into_patches(
const size_t k1 = idx % op.kernel;
idx /= op.kernel;
const size_t o = idx % op.chan_out;
idx /= op.chan_out;
const size_t b = idx % op.batch;
idx /= op.batch;

size_t oh = y + op.padding;
if (oh < k1) {
return;
}
oh -= k1;
if (oh % op.stride != 0) {
return;
}
oh /= op.stride;
if (oh >= op.h_out) {
return;
}

size_t ow = x + op.padding;
if (ow < k2) {
return;
}
ow -= k2;
if (ow % op.stride != 0) {
return;
}
ow /= op.stride;
if (ow >= op.w_out) {
const size_t oh_ks = y + op.padding;
const size_t oh_s = oh_ks - k1;
const size_t oh = oh_s / op.stride;
const size_t ow_ks = x + op.padding;
const size_t ow_s = ow_ks - k2;
const size_t ow = ow_s / op.stride;

if (
(oh_ks < k1 || oh_s % op.stride != 0 || oh >= op.h_out)
|| (ow_ks < k2 || ow_s % op.stride != 0 || ow >= op.w_out)
) {
for (auto b = 0; b < op.batch; b++) {
patches[b * item_numel + i] = 0.0;
}
return;
}

size_t image_i = b * (op.chan_out * op.h_out * op.w_out) + o * (op.h_out * op.w_out) + oh * (op.w_out) + ow;
patches[i] = image_out[image_i];
for (auto b = 0; b < op.batch; b++) {
size_t image_i = b * (op.chan_out * op.h_out * op.w_out) + o * (op.h_out * op.w_out) + oh * (op.w_out) + ow;
patches[b * item_numel + i] = image_out[image_i];
}
}

template<typename T>
__device__ void transpose_and_broadcast_filters(
__device__ void transpose_filters(
const Conv2DOp op,
const T *filters, // 4d (ChanOut, ChanIn, KernelSize, KernelSize)
const size_t *strides, // 4d filters strides
T *filters_tr // 5d (Batch, ChanIn, ChanOut, KernelSize, KernelSize)
T *filters_tr // 4d (ChanIn, ChanOut, KernelSize, KernelSize)
) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
auto numel = op.chan_in * op.chan_out * op.kernel * op.kernel;
Expand All @@ -138,15 +118,11 @@ __device__ void transpose_and_broadcast_filters(
const size_t c = idx % op.chan_in;
idx /= op.chan_in;
const size_t o = idx % op.chan_out;
idx /= op.chan_out;

auto i_tr = c * (op.chan_out * op.kernel * op.kernel) + o * (op.kernel * op.kernel) + k1 * (op.kernel) + k2;
auto i_no = o * strides[0] + c * strides[1] + k1 * strides[2] + k2 * strides[3];

const T f = filters[i_no];
for (auto b = 0; b < op.batch; b++) {
filters_tr[b * numel + i_tr] = f;
}
filters_tr[i_tr] = filters[i_no];
}

template<typename T>
Expand Down Expand Up @@ -205,7 +181,7 @@ extern "C" __global__ void TR_FILTERS( \
const size_t *strides, \
TYPENAME *filters_tr \
) { \
transpose_and_broadcast_filters(op, filters, strides, filters_tr); \
transpose_filters(op, filters, strides, filters_tr); \
} \
extern "C" __global__ void SUM_TR_FILTERS( \
const Conv2DOp op, \
Expand All @@ -220,13 +196,13 @@ CONV_OP(
float,
unfold_input_into_patches_f32,
unfold_output_into_patches_f32,
transpose_and_broadcast_filters_f32,
transpose_filters_f32,
sum_transposed_filters_f32
);
CONV_OP(
double,
unfold_input_into_patches_f64,
unfold_output_into_patches_f64,
transpose_and_broadcast_filters_f64,
transpose_filters_f64,
sum_transposed_filters_f64
);
6 changes: 5 additions & 1 deletion src/tensor_ops/conv2d/cpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::shapes::{Dtype, Shape};
use crate::tensor::{cpu::*, Tensor};
use crate::tensor::{cpu::*, Tensor, ZerosTensor};
use crate::tensor_ops::matmul::cpu_kernel::MatMulImpl;

use super::{Conv2DKernel, Conv2DOp};
Expand Down Expand Up @@ -163,6 +163,10 @@ impl<E: Dtype> Conv2DKernel<E> for Cpu
where
Self: MatMulImpl<E>,
{
fn alloc<S: Shape>(&self, s: S) -> Result<Tensor<S, E, Self>, Self::Err> {
self.try_zeros_like(&s)
}

fn forward<L: Shape, R: Shape, O: Shape>(
&self,
op: Conv2DOp,
Expand Down
Loading
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载