-
-
Notifications
You must be signed in to change notification settings - Fork 104
Implementing abs/exp/div/sum_to cuda kernels #331
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ccd791f
952e550
d518052
f051fb1
058c99f
4b774ca
75ae247
f16cb47
0c8b085
67c2f16
5b5b767
7fe98e6
4648e1b
f07d2e7
ded9384
abe62ca
309aa11
904fc46
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
struct AbsKernelOp {}; | ||
|
||
extern "C" __global__ void abs_forward( | ||
const AbsKernelOp op, | ||
const size_t numel, | ||
const float *inp, | ||
float *out | ||
) { | ||
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (i >= numel) { | ||
return; | ||
} | ||
out[i] = abs(inp[i]); | ||
} | ||
|
||
extern "C" __global__ void abs_backward( | ||
const AbsKernelOp op, | ||
const size_t numel, | ||
const float *inp, | ||
float *grad_inp, | ||
const float *grad_out | ||
) { | ||
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (i >= numel) { | ||
return; | ||
} | ||
float dx = inp[i] == 0.0 ? 0.0 : (signbit(inp[i]) ? 1.0 : -1.0); | ||
grad_inp[i] += dx * grad_out[i]; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,10 @@ | ||
use crate::{shapes::Shape, tensor::Cuda, tensor_ops::ops::UnaryKernel}; | ||
use crate::tensor_ops::cuda_kernels::UnaryOpCudaKernel; | ||
|
||
impl UnaryKernel<super::AbsKernelOp, f32> for Cuda { | ||
fn forward<S: Shape>( | ||
&self, | ||
op: super::AbsKernelOp, | ||
inp: &Self::Storage<S, f32>, | ||
) -> Result<Self::Storage<S, f32>, Self::Err> { | ||
todo!() | ||
} | ||
fn backward<S: Shape>( | ||
&self, | ||
op: super::AbsKernelOp, | ||
inp: &Self::Storage<S, f32>, | ||
grad_inp: &mut Self::Storage<S, f32>, | ||
grad_out: &Self::Storage<S, f32>, | ||
) -> Result<(), Self::Err> { | ||
todo!() | ||
} | ||
unsafe impl cudarc::device::AsKernelParam for super::AbsKernelOp {} | ||
|
||
impl UnaryOpCudaKernel for super::AbsKernelOp { | ||
const PTX_SRC: &'static str = include_str!(concat!(env!("OUT_DIR"), "/abs.ptx")); | ||
const MODULE_NAME: &'static str = "abs"; | ||
const FWD_FN_NAME: &'static str = "abs_forward"; | ||
const BWD_FN_NAME: &'static str = "abs_backward"; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
use crate::{ | ||
shapes::Shape, | ||
tensor::cuda::{Cuda, CudaArray}, | ||
tensor_ops::ops::{BinaryKernel, UnaryKernel}, | ||
}; | ||
use cudarc::device::{AsKernelParam, CudaSlice, LaunchAsync, LaunchConfig}; | ||
use std::sync::Arc; | ||
|
||
pub trait UnaryOpCudaKernel { | ||
/// Compiled by build.rs | ||
const PTX_SRC: &'static str; | ||
|
||
/// Unique name for the kernel | ||
const MODULE_NAME: &'static str; | ||
|
||
/// Name of function in the .cu file | ||
const FWD_FN_NAME: &'static str; | ||
|
||
/// Name of function in the .cu file | ||
const BWD_FN_NAME: &'static str; | ||
|
||
const ALL_FN_NAMES: [&'static str; 2] = [Self::FWD_FN_NAME, Self::BWD_FN_NAME]; | ||
} | ||
|
||
impl<K: UnaryOpCudaKernel + AsKernelParam> UnaryKernel<K, f32> for Cuda { | ||
fn forward<S: Shape>( | ||
&self, | ||
op: K, | ||
inp: &Self::Storage<S, f32>, | ||
) -> Result<Self::Storage<S, f32>, Self::Err> { | ||
if !self.dev.has_func(K::MODULE_NAME, K::FWD_FN_NAME) { | ||
self.dev | ||
.load_ptx(K::PTX_SRC.into(), K::MODULE_NAME, &K::ALL_FN_NAMES)?; | ||
} | ||
|
||
let numel = inp.data.len(); | ||
let mut storage = self.dev.alloc_zeros_async::<f32>(numel)?; | ||
|
||
let fwd_fn = self.dev.get_func(K::MODULE_NAME, K::FWD_FN_NAME).unwrap(); | ||
let cfg = LaunchConfig::for_num_elems(numel as u32); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should probably make a helper method for computing a good version of this launch config - need to take advantage of threads & blocks. |
||
let params = ( | ||
op, | ||
numel, // const size_t numel, | ||
inp.data.as_ref(), // const float *inp, | ||
&mut storage, // float *out | ||
); | ||
unsafe { fwd_fn.launch_async(cfg, params) }?; | ||
|
||
Ok(CudaArray { | ||
data: Arc::new(storage), | ||
shape: inp.shape, | ||
strides: inp.strides, | ||
}) | ||
} | ||
|
||
fn backward<S: Shape>( | ||
&self, | ||
op: K, | ||
inp: &Self::Storage<S, f32>, | ||
grad_inp: &mut Self::Storage<S, f32>, | ||
grad_out: &Self::Storage<S, f32>, | ||
) -> Result<(), Self::Err> { | ||
let bwd_fn = self.dev.get_func(K::MODULE_NAME, K::BWD_FN_NAME).unwrap(); | ||
let numel = inp.data.len(); | ||
let cfg = LaunchConfig::for_num_elems(numel as u32); | ||
let params = ( | ||
op, | ||
numel, // const size_t numel, | ||
inp.data.as_ref(), // const float *inp, | ||
Arc::make_mut(&mut grad_inp.data), // float *grad_inp, | ||
grad_out.data.as_ref(), // const float *grad_out | ||
); | ||
unsafe { bwd_fn.launch_async(cfg, params) }?; | ||
Ok(()) | ||
} | ||
} | ||
|
||
pub trait BinaryOpCudaKernel { | ||
/// Compiled by build.rs | ||
const PTX_SRC: &'static str; | ||
|
||
/// Unique name for the kernel | ||
const MODULE_NAME: &'static str; | ||
|
||
/// Name of function in the .cu file | ||
const FWD_FN_NAME: &'static str; | ||
|
||
/// Name of function in the .cu file | ||
const BWD_FN_NAME: &'static str; | ||
|
||
const ALL_FN_NAMES: [&'static str; 2] = [Self::FWD_FN_NAME, Self::BWD_FN_NAME]; | ||
} | ||
|
||
impl<K: BinaryOpCudaKernel> BinaryKernel<K, f32> for Cuda { | ||
fn forward<S: Shape>( | ||
&self, | ||
_: K, | ||
lhs: &Self::Storage<S, f32>, | ||
rhs: &Self::Storage<S, f32>, | ||
) -> Result<Self::Storage<S, f32>, Self::Err> { | ||
if !self.dev.has_func(K::MODULE_NAME, K::FWD_FN_NAME) { | ||
self.dev | ||
.load_ptx(K::PTX_SRC.into(), K::MODULE_NAME, &K::ALL_FN_NAMES)?; | ||
} | ||
|
||
let shape = lhs.shape; | ||
let strides = lhs.shape.strides(); | ||
let numel = shape.num_elements(); | ||
|
||
let mut storage = self.dev.alloc_zeros_async::<f32>(numel)?; | ||
|
||
let dims: CudaSlice<usize> = self.dev.take_async(shape.concrete().into())?; | ||
let lhs_strides: CudaSlice<usize> = self.dev.take_async(lhs.strides.into())?; | ||
let rhs_strides: CudaSlice<usize> = self.dev.take_async(rhs.strides.into())?; | ||
let out_strides: CudaSlice<usize> = self.dev.take_async(strides.into())?; | ||
|
||
let fwd_fn = self.dev.get_func(K::MODULE_NAME, K::FWD_FN_NAME).unwrap(); | ||
let cfg = LaunchConfig::for_num_elems(numel as u32); | ||
let params = ( | ||
numel, // const size_t numel, | ||
S::NUM_DIMS, // const size_t num_dims, | ||
&dims, // const size_t *dims, | ||
lhs.data.as_ref(), // const float *lhs, | ||
&lhs_strides, // const size_t *lhs_strides, | ||
rhs.data.as_ref(), // const float *rhs, | ||
&rhs_strides, // const size_t *rhs_strides, | ||
&mut storage, // float *out, | ||
&out_strides, // const size_t *out_strides | ||
); | ||
unsafe { fwd_fn.launch_async(cfg, params) }?; | ||
Ok(CudaArray { | ||
data: Arc::new(storage), | ||
shape, | ||
strides, | ||
}) | ||
} | ||
|
||
fn backward<S: Shape>( | ||
&self, | ||
_: K, | ||
lhs: &Self::Storage<S, f32>, | ||
grad_lhs: &mut Self::Storage<S, f32>, | ||
rhs: &Self::Storage<S, f32>, | ||
grad_rhs: &mut Self::Storage<S, f32>, | ||
grad_out: &Self::Storage<S, f32>, | ||
) -> Result<(), Self::Err> { | ||
let bwd_fn = self.dev.get_func(K::MODULE_NAME, K::BWD_FN_NAME).unwrap(); | ||
let numel = lhs.shape.num_elements(); | ||
|
||
let dims: CudaSlice<usize> = self.dev.take_async(lhs.shape.concrete().into())?; | ||
let lhs_strides: CudaSlice<usize> = self.dev.take_async(lhs.strides.into())?; | ||
let rhs_strides: CudaSlice<usize> = self.dev.take_async(rhs.strides.into())?; | ||
let out_strides: CudaSlice<usize> = self.dev.take_async(grad_out.strides.into())?; | ||
Comment on lines
+150
to
+153
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These same values were also allocated in the forward call - a potential improvement for the future is pre-allocating them. Though these are only used in binary ops - if a tensor is only ever used in a unary op then it doesn't need to allocate these |
||
|
||
let cfg = LaunchConfig::for_num_elems(numel as u32); | ||
let params = ( | ||
numel, // const size_t numel, | ||
S::NUM_DIMS, // const size_t num_dims, | ||
&dims, // const size_t *dims, | ||
lhs.data.as_ref(), // const float *lhs, | ||
Arc::make_mut(&mut grad_lhs.data), // float *grad_lhs, | ||
&lhs_strides, // const size_t *lhs_strides, | ||
rhs.data.as_ref(), // const float *rhs, | ||
Arc::make_mut(&mut grad_rhs.data), // float *grad_rhs, | ||
&rhs_strides, // const size_t *rhs_strides, | ||
grad_out.data.as_ref(), // const float *grad_out, | ||
&out_strides, // const size_t *out_strides | ||
); | ||
unsafe { bwd_fn.launch_async(cfg, params) }?; | ||
Ok(()) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
__device__ unsigned int get_strided_index( | ||
unsigned int idx, | ||
const size_t num_dims, | ||
const size_t *dims, | ||
const size_t *strides | ||
) { | ||
unsigned int strided_i = 0; | ||
for (unsigned int d = 0; d < num_dims; d++) { | ||
unsigned int dim_idx = num_dims - 1 - d; | ||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; | ||
idx /= dims[dim_idx]; | ||
} | ||
return strided_i; | ||
} | ||
|
||
extern "C" __global__ void binary_div_forward( | ||
const size_t numel, | ||
const size_t num_dims, | ||
const size_t *dims, | ||
const float *lhs, | ||
const size_t *lhs_strides, | ||
const float *rhs, | ||
const size_t *rhs_strides, | ||
float *out, | ||
const size_t *out_strides | ||
) { | ||
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (i >= numel) { | ||
return; | ||
} | ||
|
||
unsigned int lhs_i = get_strided_index(i, num_dims, dims, lhs_strides); | ||
unsigned int rhs_i = get_strided_index(i, num_dims, dims, rhs_strides); | ||
unsigned int out_i = get_strided_index(i, num_dims, dims, out_strides); | ||
|
||
out[out_i] = lhs[lhs_i] / rhs[rhs_i]; | ||
} | ||
|
||
extern "C" __global__ void binary_div_backward( | ||
const size_t numel, | ||
const size_t num_dims, | ||
const size_t *dims, | ||
const float *lhs, | ||
float *grad_lhs, | ||
const size_t *lhs_strides, | ||
const float *rhs, | ||
float *grad_rhs, | ||
const size_t *rhs_strides, | ||
const float *grad_out, | ||
const size_t *out_strides | ||
) { | ||
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (i >= numel) { | ||
return; | ||
} | ||
|
||
unsigned int lhs_i = get_strided_index(i, num_dims, dims, lhs_strides); | ||
unsigned int rhs_i = get_strided_index(i, num_dims, dims, rhs_strides); | ||
unsigned int out_i = get_strided_index(i, num_dims, dims, out_strides); | ||
|
||
auto x = lhs[lhs_i]; | ||
auto y = rhs[rhs_i]; | ||
auto go = grad_out[out_i]; | ||
|
||
float dfdx = 1.0 / y; | ||
grad_lhs[lhs_i] += dfdx * go; | ||
|
||
float dfdy = -x / (y * y); | ||
grad_rhs[rhs_i] += dfdy * go; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this can be done later, I'm not even sure if it's necessary. Once we have all the kernels in place we can see, but no need to complicate something that's pretty simple atm