-
-
Notifications
You must be signed in to change notification settings - Fork 104
Replacing conv2d implementation with matmuls #237
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Did you run the benchmarks for the previous 'naive' implementation? |
Just added - they are really slow lol. Didn't scale up the number of channels enough in my original benchmarks of convs, but it was a really simple algorithm to debug and test. All the unit tests made it really easy to verify that this new implementation worked! 😁 |
#[cfg(not(feature = "cblas"))] | ||
unsafe { | ||
matrixmultiply::sgemm( | ||
m, k, n, 1.0, a, k as isize, 1, b, n as isize, 1, 1.0, c, n as isize, 1, | ||
) | ||
} | ||
|
||
#[cfg(feature = "cblas")] | ||
unsafe { | ||
let (m, n, k) = (m as libc::c_int, n as libc::c_int, k as libc::c_int); | ||
sgemm(RowMajor, NoTr, NoTr, m, n, k, 1.0, a, k, b, n, 1.0, c, n) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noting that this is copied from the matmul implementation - using Cpu::mm requires casting the inputs to 2d arrays, which in turn requires generic_const_expr bounds added to the trait, which I wanted to avoid.
Damn, I'd hoped for some LLVM intrinsics magic that kept the naive way somewhat competitive! Looks like Fortran will be with us for a lot longer 😝 |
Hah yeah I was hoping that too. I'm pretty sure it was auto vectorizing the forward at least, but I think all the matmul stuff is suuuuper cache optimized. |
This addresses some slowness in conv2d implementation by leveraging matrix multiplication algorithms.
Here are some numbers from my local benchmarks with
Conv2D<128, 256, 4>
onTensor4D<64, 128, 28, 28>
:There is probably still more work that can be done as far as reducing allocations/copying when creating the patches buffers used below.