+
Skip to content
99 changes: 80 additions & 19 deletions src/nn/convtrans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,29 @@ pub mod builder {
const KERNEL_SIZE: usize,
const STRIDE: usize = 1,
const PADDING: usize = 0,
const DILATION: usize = 1,
const GROUPS: usize = 1,
>;
}

impl<const I: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D>
BuildOnDevice<D, E> for builder::ConvTrans2D<I, O, K, S, P>
impl<
const I: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
const L: usize,
const G: usize,
E,
D,
> BuildOnDevice<D, E> for builder::ConvTrans2D<I, O, K, S, P, L, G>
where
E: Dtype,
D: Device<E>,
ConvTrans2D<I, O, K, S, P, E, D>: BuildModule<D, E>,
Const<{ O / G }>: Sized,
ConvTrans2D<I, O, K, S, P, L, G, E, D>: BuildModule<D, E>,
{
type Built = ConvTrans2D<I, O, K, S, P, E, D>;
type Built = ConvTrans2D<I, O, K, S, P, L, G, E, D>;
fn try_build_on_device(device: &D) -> Result<Self::Built, <D>::Err> {
Self::Built::try_build(device)
}
Expand All @@ -45,26 +57,43 @@ where
/// - `KERNEL_SIZE`: The size of the kernel applied to both width and height of the images.
/// - `STRIDE`: How far to move the kernel each step. Defaults to `1`
/// - `PADDING`: How much zero padding to add around the images. Defaults to `0`.
/// - `DILATION`: Controls the spacing between kernel points. Defaults to `1`.
/// - `GROUPS`: Controls the connections between inputs and outputs.
/// `IN_CHAN` and `OUT_CHAN` must both be divisible by `GROUPS`. For example,
#[derive(Debug, Clone)]
pub struct ConvTrans2D<
const IN_CHAN: usize,
const OUT_CHAN: usize,
const KERNEL_SIZE: usize,
const STRIDE: usize,
const PADDING: usize,
const DILATION: usize,
const GROUPS: usize,
E: Dtype,
D: Storage<E>,
> {
pub weight: Tensor<Rank4<OUT_CHAN, IN_CHAN, KERNEL_SIZE, KERNEL_SIZE>, E, D>,
> where
Const<{ OUT_CHAN / GROUPS }>: Sized,
{
pub weight: Tensor<Rank4<IN_CHAN, { OUT_CHAN / GROUPS }, KERNEL_SIZE, KERNEL_SIZE>, E, D>,
}

impl<const I: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D>
TensorCollection<E, D> for ConvTrans2D<I, O, K, S, P, E, D>
impl<
const I: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
const L: usize,
const G: usize,
E,
D,
> TensorCollection<E, D> for ConvTrans2D<I, O, K, S, P, L, G, E, D>
where
E: Dtype + Float + SampleUniform,
D: Device<E>,
Const<{ O / G }>: Sized,
{
type To<E2: Dtype, D2: Device<E2>> = ConvTrans2D<I, O, K, S, P, E2, D2>;
type To<E2: Dtype, D2: Device<E2>> = ConvTrans2D<I, O, K, S, P, L, G, E2, D2>;

fn iter_tensors<V: ModuleVisitor<Self, E, D>>(
visitor: &mut V,
Expand All @@ -85,26 +114,58 @@ where
}

#[cfg(feature = "nightly")]
impl<const C: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D, Img>
Module<Img> for ConvTrans2D<C, O, K, S, P, E, D>
impl<
const C: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
const L: usize,
const G: usize,
E,
D,
Img,
> Module<Img> for ConvTrans2D<C, O, K, S, P, L, G, E, D>
where
E: Dtype,
D: Device<E>,
Img: TryConvTrans2DTo<Tensor<Rank4<O, C, K, K>, E, D>, S, P> + HasErr<Err = D::Err>,
Const<{ O / G }>: Sized,
(Img, Tensor<Rank4<C, { O / G }, K, K>, E, D>):
TryConvTrans2D<Const<S>, Const<P>, Const<L>, Const<G>>,
{
type Output = Img::Output;
type Error = D::Err;
type Output = <(Img, Tensor<Rank4<C, { O / G }, K, K>, E, D>) as TryConvTrans2D<
Const<S>,
Const<P>,
Const<L>,
Const<G>,
>>::Convolved;
type Error = <(Img, Tensor<Rank4<C, { O / G }, K, K>, E, D>) as TryConvTrans2D<
Const<S>,
Const<P>,
Const<L>,
Const<G>,
>>::Error;

fn try_forward(&self, x: Img) -> Result<Self::Output, D::Err> {
x.try_convtrans2d_to(self.weight.clone())
fn try_forward(&self, x: Img) -> Result<Self::Output, Self::Error> {
(x, self.weight.clone()).try_convtrans2d(Const, Const, Const, Const)
}
}

impl<const I: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D>
NonMutableModule for ConvTrans2D<I, O, K, S, P, E, D>
impl<
const I: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
const L: usize,
const G: usize,
E,
D,
> NonMutableModule for ConvTrans2D<I, O, K, S, P, L, G, E, D>
where
E: Dtype,
D: Storage<E>,
Const<{ O / G }>: Sized,
{
}

Expand Down Expand Up @@ -187,7 +248,7 @@ mod tests {

assert_ne!(
g.get(&m.weight).array(),
[[[[TestDtype::zero(); 3]; 3]; 2]; 4]
[[[[TestDtype::zero(); 3]; 3]; 4]; 2]
);

opt.update(&mut m, &g).expect("unused params");
Expand Down
1 change: 1 addition & 0 deletions src/nn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ mod batchnorm2d;
mod bias2d;
#[cfg(feature = "nightly")]
mod conv;
#[cfg(feature = "nightly")]
mod convtrans;
mod dropout;
mod ema;
Expand Down
46 changes: 24 additions & 22 deletions src/tensor_ops/conv2d/conv2d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,29 @@ __device__ void unfold_input_into_patches(
const size_t *strides, // 4d image strides
T *patches // 6d (Batch, Groups * Channels, KernelSize, KernelSize, HeightOut, WidthOut)
) {
const size_t n = op.batch * op.groups * op.chan_in * op.h_out * op.w_out;
const size_t n = op.batch * op.chan_in * op.h_out * op.w_out;
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) {
unsigned int idx = i;
const size_t ow = idx % op.w_out;
idx /= op.w_out;
const size_t oh = idx % op.h_out;
idx /= op.h_out;
const size_t c = idx % (op.chan_in * op.groups);
idx /= (op.chan_in * op.groups);
const size_t c = idx % op.chan_in;
idx /= op.chan_in;
const size_t b = idx % op.batch;

const T *image_i = image + b * strides[0] + c * strides[1];
T *patches_i = patches + oh * op.w_out + ow;
patches_i += c * (op.kernel * op.kernel * op.h_out * op.w_out);
patches_i += b * (op.groups * op.chan_in * op.kernel * op.kernel * op.h_out * op.w_out);
patches_i += b * (op.chan_in * op.kernel * op.kernel * op.h_out * op.w_out);

T zero = 0.0;

for (int k1 = 0;k1 < op.kernel;k1++) {
const size_t y = oh * op.stride + op.dilation * k1 - op.padding;
for (int k2 = 0;k2 < op.kernel;k2++) {
const size_t x = ow * op.stride + op.dilation * k2 - op.padding;
*patches_i = (y >= op.h_in || x >= op.w_in) ? zero : image[y * strides[2] + x * strides[3]];
*patches_i = (y >= op.h_in || x >= op.w_in) ? zero : image_i[y * strides[2] + x * strides[3]];
patches_i += op.h_out * op.w_out;
}
}
Expand Down Expand Up @@ -86,7 +86,7 @@ __device__ void unfold_output_into_patches(
const size_t ow = ow_s / op.stride;

const bool invalid = k1_invalid || (ow_ks < op.dilation * k2 || ow_s % op.stride != 0 || ow >= op.w_out);
*patches_i = invalid ? zero : image_out[oh * op.w_out + ow];
*patches_i = invalid ? zero : image_i[oh * op.w_out + ow];
patches_i += op.h_in * op.w_in;
}
}
Expand All @@ -96,64 +96,66 @@ __device__ void unfold_output_into_patches(
template<typename T>
__device__ void transpose_filters(
const Conv2DOp op,
const T *filters, // 4d (ChanOut, ChanIn, KernelSize, KernelSize)
const T *filters, // 4d (ChanOut, ChanIn/Groups, KernelSize, KernelSize)
const size_t *strides, // 4d filters strides
T *filters_tr // 5d (Groups, ChanIn, ChanOut/Groups, KernelSize, KernelSize)
T *filters_tr // 5d (Groups, ChanIn/Groups, ChanOut/Groups, KernelSize, KernelSize)
) {
const size_t n = op.chan_in * op.chan_out * op.kernel * op.kernel;
const size_t c_per_g = op.chan_in / op.groups;
const size_t o_per_g = op.chan_out / op.groups;
const size_t n = c_per_g * op.chan_out * op.kernel * op.kernel;

for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) {
unsigned int idx = i;
const size_t k2 = idx % op.kernel;
idx /= op.kernel;
const size_t k1 = idx % op.kernel;
idx /= op.kernel;
const size_t c = idx % op.chan_in;
idx /= op.chan_in;
const size_t cg = idx % c_per_g;
idx /= c_per_g;
const size_t o = idx % op.chan_out;
const size_t og = o % o_per_g;
const size_t g = o / o_per_g;

auto i_no = o * strides[0] + c * strides[1] + k1 * strides[2] + k2 * strides[3];
auto i_no = o * strides[0] + cg * strides[1] + k1 * strides[2] + k2 * strides[3];
T *filters_tr_i = filters_tr + k2;
filters_tr_i += k1 * op.kernel;
filters_tr_i += og * (op.kernel * op.kernel);
filters_tr_i += c * (o_per_g * op.kernel * op.kernel);
filters_tr_i += g * (op.chan_in * o_per_g * op.kernel * op.kernel);
filters_tr_i += cg * (o_per_g * op.kernel * op.kernel);
filters_tr_i += g * (c_per_g * o_per_g * op.kernel * op.kernel);
*filters_tr_i = filters[i_no];
}
}

template<typename T>
__device__ void sum_transposed_filters(
const Conv2DOp op,
const T *filters_tr, // 6d (Batch, Groups, ChanIn, ChanOut/Groups, KernelSize, KernelSize)
T *filters, // 4d (ChanOut, ChanIn, KernelSize, KernelSize)
const T *filters_tr, // 6d (Batch, Groups, ChanIn/Groups, ChanOut/Groups, KernelSize, KernelSize)
T *filters, // 4d (ChanOut, ChanIn/Groups, KernelSize, KernelSize)
const size_t *strides // 4d filter strides
) {
const size_t n = op.chan_out * op.chan_in * op.kernel * op.kernel;
const size_t o_per_g = op.chan_out / op.groups;
const size_t c_per_g = op.chan_in / op.groups;
const size_t n = op.chan_out * c_per_g * op.kernel * op.kernel;

for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) {
unsigned int idx = i;
const size_t k2 = idx % op.kernel;
idx /= op.kernel;
const size_t k1 = idx % op.kernel;
idx /= op.kernel;
const size_t c = idx % op.chan_in;
idx /= op.chan_in;
const size_t cg = idx % c_per_g;
idx /= c_per_g;
const size_t o = idx % op.chan_out;
const size_t og = o % o_per_g;
const size_t g = o / o_per_g;

auto i_no = o * strides[0] + c * strides[1] + k1 * strides[2] + k2 * strides[3];
auto i_no = o * strides[0] + cg * strides[1] + k1 * strides[2] + k2 * strides[3];

const T *filters_tr_i = filters_tr + k2;
filters_tr_i += k1 * op.kernel;
filters_tr_i += og * (op.kernel * op.kernel);
filters_tr_i += c * (o_per_g * op.kernel * op.kernel);
filters_tr_i += g * (op.chan_in * o_per_g * op.kernel * op.kernel);
filters_tr_i += cg * (o_per_g * op.kernel * op.kernel);
filters_tr_i += g * (c_per_g * o_per_g * op.kernel * op.kernel);

T tmp = 0.0;
for (int b = 0; b < op.batch; b++) {
Expand Down
38 changes: 17 additions & 21 deletions src/tensor_ops/conv2d/cpu_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ impl Cpu {
{
{
let mut i = 0;
for c in 0..(op.groups * op.chan_in) {
for c in 0..op.chan_in {
for k1 in 0..op.kernel {
for k2 in 0..op.kernel {
for oh in 0..op.h_out {
Expand All @@ -73,9 +73,11 @@ impl Cpu {
}
}

// (G, O / G, C * K * K) * (G, C * K * K, OH * OW) = (G, O / G, OH * OW)
// filters: (G, O/G, C/G*K*K)
// buf: (G, C/G*K*K, OH*OW)
// output: (G, O/G, OH*OW)
let m = op.chan_out / op.groups;
let k = op.chan_in * op.kernel * op.kernel;
let k = (op.chan_in / op.groups) * op.kernel * op.kernel;
let n = op.w_out * op.h_out;
for g in 0..op.groups {
Self::matmul(
Expand Down Expand Up @@ -128,8 +130,8 @@ impl Cpu {

{
// img_g += filters^T * unfold(grad_out)
// (G, C, H * W) += (G, C, O/G * K * K) * (G, O/G * K * K, H * W)
let m = op.chan_in;
// (G, C/G, H * W) += (G, C/G, O/G * K * K) * (G, O/G * K * K, H * W)
let m = op.chan_in / op.groups;
let k = (op.chan_out / op.groups) * op.kernel * op.kernel;
let n = op.h_in * op.w_in;
for g in 0..op.groups {
Expand All @@ -148,8 +150,8 @@ impl Cpu {

{
// weight_g^T += img * unfold(patches)^T
// (G, C, O/G * K * K) += (G, C, H * W) * (G, H * W, O/G * K * K)
let m = op.chan_in;
// (G, C/G, O/G * K * K) += (G, C/G, H * W) * (G, H * W, O/G * K * K)
let m = op.chan_in / op.groups;
let k = op.h_in * op.w_in;
let n = (op.chan_out / op.groups) * op.kernel * op.kernel;
for g in 0..op.groups {
Expand Down Expand Up @@ -184,13 +186,7 @@ where
rhs: &Tensor<R, E, Self>,
out: &mut Tensor<O, E, Self>,
) -> Result<(), Self::Err> {
let patches = (
op.groups * op.chan_in,
op.kernel,
op.kernel,
op.h_out,
op.w_out,
);
let patches = (op.chan_in, op.kernel, op.kernel, op.h_out, op.w_out);
let mut patches = self.try_alloc_zeros::<E>(patches.num_elements())?;
let [lstride, ostride] = match L::NUM_DIMS {
3 => [0; 2],
Expand Down Expand Up @@ -224,7 +220,7 @@ where
) -> Result<(), Self::Err> {
let f_tr_shape = [
op.groups,
op.chan_in,
op.chan_in / op.groups,
op.chan_out / op.groups,
op.kernel,
op.kernel,
Expand All @@ -238,9 +234,9 @@ where
// transpose filters in f1023
let buf = rhs.data.as_ref();
let mut f_idx = NdIndex::new(f_tr_shape, f_tr_shape.strides());
while let Some((i, [g, c, o, k1, k2])) = f_idx.next_with_idx() {
let idx = (g * (op.chan_out / op.groups) + o) * rhs.strides[0]
+ c * rhs.strides[1]
while let Some((i, [g, c_over_g, o_over_g, k1, k2])) = f_idx.next_with_idx() {
let idx = (g * (op.chan_out / op.groups) + o_over_g) * rhs.strides[0]
+ c_over_g * rhs.strides[1]
+ k1 * rhs.strides[2]
+ k2 * rhs.strides[3];
f1023[i] = buf[idx];
Expand Down Expand Up @@ -269,9 +265,9 @@ where
{
// untranspose filters
let mut f_idx = NdIndex::new(f_tr_shape, f_tr_shape.strides());
while let Some((i, [g, c, o, k1, k2])) = f_idx.next_with_idx() {
let idx = (g * (op.chan_out / op.groups) + o) * rhs.strides[0]
+ c * rhs.strides[1]
while let Some((i, [g, c_over_g, o_over_g, k1, k2])) = f_idx.next_with_idx() {
let idx = (g * (op.chan_out / op.groups) + o_over_g) * rhs.strides[0]
+ c_over_g * rhs.strides[1]
+ k1 * rhs.strides[2]
+ k2 * rhs.strides[3];
grad_rhs[idx] += grad_f1023[i];
Expand Down
Loading
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载