+
Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 116 additions & 31 deletions src/devices/conv.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use super::Cpu;
use super::{AllocateZeros, Cpu};
#[cfg(feature = "cblas")]
use cblas_sys::{
cblas_sgemm as sgemm, CblasNoTrans as NoTr, CblasRowMajor as RowMajor, CblasTrans as Tr,
};

/// **Requires nightly** 2d convolution with stride and padding specified at trait level.
///
Expand Down Expand Up @@ -35,7 +39,10 @@ pub trait DeviceConv2D<const S: usize, const P: usize> {
);
}

impl<const S: usize, const P: usize> DeviceConv2D<S, P> for Cpu {
impl<const S: usize, const P: usize> DeviceConv2D<S, P> for Cpu
where
Self: AllocateZeros,
{
fn conv_forward<
const C: usize,
const O: usize,
Expand All @@ -48,28 +55,46 @@ impl<const S: usize, const P: usize> DeviceConv2D<S, P> for Cpu {
bias: &[f32; O],
out: &mut [[[f32; (W + 2 * P - K) / S + 1]; (H + 2 * P - K) / S + 1]; O],
) {
let mut patches: Box<
[[[[[f32; (W + 2 * P - K) / S + 1]; (H + 2 * P - K) / S + 1]; K]; K]; C],
> = Self::zeros();

for c in 0..C {
for oc in 0..O {
for oh in 0..((H + 2 * P - K) / S + 1) {
for ow in 0..((W + 2 * P - K) / S + 1) {
let o = &mut out[oc][oh][ow];
let mut tmp = 0.0;
for k1 in 0..K {
let y = (oh * S + k1).checked_sub(P);
for k2 in 0..K {
let x = (ow * S + k2).checked_sub(P);
if let Some((y, x)) = y.zip(x) {
if y < H && x < W {
tmp += weight[oc][c][k1][k2] * img[c][y][x];
}
}
for k1 in 0..K {
for k2 in 0..K {
for oh in 0..(H + 2 * P - K) / S + 1 {
for ow in 0..(W + 2 * P - K) / S + 1 {
let y = (oh * S + k1).wrapping_sub(P);
let x = (ow * S + k2).wrapping_sub(P);
if y < H && x < W {
patches[c][k1][k2][oh][ow] = img[c][y][x];
}
}
*o += tmp;
}
}
}
}

// (O, C * K * K) * (C * K * K, OH * OW) = (O, OH * OW)
let m = O;
let k = C * K * K;
let n = ((H + 2 * P - K) / S + 1) * ((W + 2 * P - K) / S + 1);
let a = weight.as_ptr() as *const f32;
let b = patches.as_ptr() as *const f32;
let c = out.as_mut_ptr() as *mut f32;
#[cfg(not(feature = "cblas"))]
unsafe {
matrixmultiply::sgemm(
m, k, n, 1.0, a, k as isize, 1, b, n as isize, 1, 1.0, c, n as isize, 1,
)
}

#[cfg(feature = "cblas")]
unsafe {
let (m, n, k) = (m as libc::c_int, n as libc::c_int, k as libc::c_int);
sgemm(RowMajor, NoTr, NoTr, m, n, k, 1.0, a, k, b, n, 1.0, c, n)
}
Comment on lines +85 to +96
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noting that this is copied from the matmul implementation - using Cpu::mm requires casting the inputs to 2d arrays, which in turn requires generic_const_expr bounds added to the trait, which I wanted to avoid.


for oc in 0..O {
for oh in 0..((H + 2 * P - K) / S + 1) {
for ow in 0..((W + 2 * P - K) / S + 1) {
Expand Down Expand Up @@ -101,27 +126,87 @@ impl<const S: usize, const P: usize> DeviceConv2D<S, P> for Cpu {
}
}

let mut w_tr: Box<[[[[f32; K]; K]; O]; C]> = Self::zeros();
for c in 0..C {
for oc in 0..O {
for oh in 0..((H + 2 * P - K) / S + 1) {
for ow in 0..((W + 2 * P - K) / S + 1) {
let o_g = &out_g[oc][oh][ow];
for k1 in 0..K {
let y = (oh * S + k1).checked_sub(P);
for k2 in 0..K {
let x = (ow * S + k2).checked_sub(P);
if let Some((y, x)) = y.zip(x) {
if y < H && x < W {
weight_g[oc][c][k1][k2] += img[c][y][x] * o_g;
img_g[c][y][x] += weight[oc][c][k1][k2] * o_g;
}
}
for o in 0..O {
w_tr[c][o].clone_from(&weight[o][c]);
}
}

let mut patches: Box<[[[[[f32; W]; H]; K]; K]; O]> = Self::zeros();
for o in 0..O {
for oh in 0..(H + 2 * P - K) / S + 1 {
for ow in 0..(W + 2 * P - K) / S + 1 {
let g = out_g[o][oh][ow];
for k1 in 0..K {
for k2 in 0..K {
let y = (oh * S + k1).wrapping_sub(P);
let x = (ow * S + k2).wrapping_sub(P);
if y < H && x < W {
patches[o][k1][k2][y][x] = g;
}
}
}
}
}
}

{
// img_g += weight^T * patches
// (C, H * W) += (C, O * K * K) * (O * K * K, H * W)

let m = C;
let k = O * K * K;
let n = H * W;
let a = w_tr.as_ptr() as *const f32;
let b = patches.as_ptr() as *const f32;
let c = img_g.as_mut_ptr() as *mut f32;
#[cfg(not(feature = "cblas"))]
unsafe {
matrixmultiply::sgemm(
m, k, n, 1.0, a, k as isize, 1, b, n as isize, 1, 1.0, c, n as isize, 1,
)
}

#[cfg(feature = "cblas")]
unsafe {
let (m, n, k) = (m as libc::c_int, n as libc::c_int, k as libc::c_int);
sgemm(RowMajor, NoTr, NoTr, m, n, k, 1.0, a, k, b, n, 1.0, c, n)
}
}

{
// weight_g^T += img * patches^T
// (C, O * K * K) += (C, H * W) * (H * W, O * K * K)

let m = C;
let k = H * W;
let n = O * K * K;
let a = img.as_ptr() as *const f32;
let b = patches.as_ptr() as *const f32;
let c = w_tr.as_mut_ptr() as *mut f32;
#[cfg(not(feature = "cblas"))]
unsafe {
matrixmultiply::sgemm(
m, k, n, 1.0, a, k as isize, 1, b, 1, k as isize, 0.0, c, n as isize, 1,
)
}
#[cfg(feature = "cblas")]
unsafe {
let (m, n, k) = (m as libc::c_int, n as libc::c_int, k as libc::c_int);
sgemm(RowMajor, NoTr, Tr, m, n, k, 1.0, a, k, b, k, 0.0, c, n)
}

for o in 0..O {
for c in 0..C {
for k1 in 0..K {
for k2 in 0..K {
weight_g[o][c][k1][k2] += w_tr[c][o][k1][k2];
}
}
}
}
}
}
}

Expand Down
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载