-
-
Notifications
You must be signed in to change notification settings - Fork 104
accurate-gelu #813
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
accurate-gelu #813
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
9e140be
accurate-gelu
jcrist1 04eb169
accurate-gelu - make it actually available
jcrist1 cb8a99c
accurate-gelu - trying to tweak cuda kernel
jcrist1 b42b5e3
accurate-gelu - rename
jcrist1 8fbf249
accurate-gelu - more documentation, fix tests
jcrist1 d201f4a
accurate-gelu - describe corresponding pytorch algos for gelus
jcrist1 4e11335
accurate-gelu - rename gelu -> fast_gelu + bwd compat and deprecation…
jcrist1 f30a37f
accurate-gelu - correct cuda struct name
jcrist1 858d176
accurate-gelu - GuLU -> GeLU X-D
jcrist1 360c381
accurate-gelu - create backwards compatibility `gelu(t)` function
jcrist1 9fd8685
accurate-gelu - mark backwards compatible gelu import as allow(deprec…
jcrist1 9f3c7bb
accurate-gelu - fix backwards-compatibility for GeLU, rename test to …
jcrist1 03b5397
accurate-gelu - added doc links
jcrist1 2c2b3d9
accurat-gelu - last link
jcrist1 d00d2df
accurate-gelu - correct method links
jcrist1 ec5f25d
Merge branch 'main' into accurate-gelu
coreylowman File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
#include "unary_op_macros.cuh" | ||
#define _USE_MATH_DEFINES | ||
#include <math.h> | ||
|
||
struct AccurateGeLUKernelOp {}; | ||
|
||
template <typename T> __device__ T accurate_gelu_fwd(T x) { | ||
T one = 1.0; | ||
T half = 0.5; | ||
T alpha = M_SQRT1_2; | ||
return half * x * (one + erfg(x * alpha)) | ||
} | ||
|
||
template <typename T> __device__ T accurate_gelu_bwd(T x) { | ||
T one = 1.0; | ||
T half = 0.5; | ||
T alpha = M_SQRT1_2; | ||
T x_sq = x * x; | ||
T norm = expg(M_2_SQRTPI * half * x_sq); | ||
|
||
T left = half * x; | ||
T right = one + erfg(alph * x); | ||
|
||
T left_derivative = half * right; | ||
|
||
T right_derivative = left * normal_dist; | ||
|
||
return left_derivative + right_derivative; | ||
} | ||
|
||
UNARY_OP(__half, accurate_gelu_fwd_f16, accurate_gelu_bwd_f16, | ||
AccurateGeLUKernelOp, | ||
accurate_gelu_fwd(x), | ||
accurate_gelu_bwd(x) | ||
) | ||
|
||
UNARY_OP(float, accurate_gelu_fwd_f32, accurate_gelu_bwd_f32, | ||
AccurateGeLUKernelOp, | ||
accurate_gelu_fwd(x), | ||
accurate_gelu_bwd(x) | ||
) | ||
|
||
UNARY_OP(double, accurate_gelu_fwd_f64, accurate_gelu_bwd_f64, | ||
AccurateGeLUKernelOp, | ||
accurate_gelu_fwd(x), | ||
accurate_gelu_bwd(x) | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
use crate::tensor_ops::cpu_kernels::UnaryDerivative; | ||
#[cfg(feature = "f16")] | ||
use half::f16; | ||
use libm::{erf, erff}; | ||
use num_traits::{Float, FloatConst}; | ||
|
||
trait Erf { | ||
fn erf(self) -> Self; | ||
} | ||
|
||
#[cfg(feature = "f16")] | ||
impl Erf for f16 { | ||
fn erf(self) -> Self { | ||
f16::from_f32(erff(f16::to_f32(self))) | ||
} | ||
} | ||
|
||
impl Erf for f64 { | ||
fn erf(self) -> Self { | ||
erf(self) | ||
} | ||
} | ||
|
||
impl Erf for f32 { | ||
fn erf(self) -> Self { | ||
erff(self) | ||
} | ||
} | ||
coreylowman marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
impl<F: Float + FloatConst + Erf> UnaryDerivative<F> for super::AccurateGeLUKernelOp { | ||
const DF_USES_FX: bool = false; | ||
const HAS_CONST_DF: bool = false; | ||
#[inline(always)] | ||
fn f(&self, &x: &F) -> F { | ||
let alpha = F::FRAC_1_SQRT_2(); | ||
F::from(0.5).unwrap() * x * (F::one() + (x * alpha).erf()) | ||
} | ||
|
||
#[inline(always)] | ||
fn df(&self, &x: &F) -> F { | ||
let half = F::from(0.5).unwrap(); | ||
let alpha = F::FRAC_1_SQRT_2(); | ||
let x_sq = x * x; | ||
let normal_dist = F::FRAC_2_SQRT_PI() * (F::from(0.5).unwrap() * x_sq.neg()).exp(); | ||
|
||
let left = half * x; | ||
let right = F::one() + (alpha * x).erf(); | ||
|
||
let left_derivative = half * right; | ||
|
||
let right_derivative = left * normal_dist; | ||
|
||
left_derivative + right_derivative | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
use super::AccurateGeLUKernelOp; | ||
use crate::tensor_ops::cuda_kernels::cuda_unary; | ||
|
||
unsafe impl cudarc::driver::DeviceRepr for super::AccurateGeLUKernelOp {} | ||
|
||
const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/accurate_gelu.ptx")); | ||
|
||
#[cfg(feature = "f16")] | ||
cuda_unary!( | ||
AccurateGeLUKernelOp, | ||
half::f16, | ||
PTX, | ||
"accurate_gelu_fwd_f16", | ||
"accurate_gelu_bwd_f16" | ||
); | ||
cuda_unary!( | ||
AccurateGeLUKernelOp, | ||
f32, | ||
PTX, | ||
"accurate_gelu_fwd_f32", | ||
"accurate_gelu_bwd_f32" | ||
); | ||
cuda_unary!( | ||
AccurateGeLUKernelOp, | ||
f64, | ||
PTX, | ||
"accurate_gelu_fwd_f64", | ||
"accurate_gelu_bwd_f64" | ||
); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
mod cpu_kernel; | ||
|
||
#[cfg(feature = "cuda")] | ||
mod cuda_kernel; | ||
|
||
use super::ops::{try_unary_op, UnaryKernel}; | ||
use crate::{shapes::*, tensor::*}; | ||
|
||
#[repr(C)] | ||
#[derive(Debug, Default, Copy, Clone)] | ||
pub struct AccurateGeLUKernelOp; | ||
|
||
/// [Accurate Gaussian Linear Unit (GeLU)](https://paperswithcode.com/method/gelu). This is defined as `x * Phi(x)` where `Phi(x)` is the cumulative | ||
/// distribution function of a standard normal distribution. This can be calculated via the Error | ||
/// Function `erf(x)` using | ||
/// ```text | ||
/// 0.5 * x * (1.0 + erf(x / 2.0.sqrt())) | ||
/// ``` | ||
/// As an accurate error function is [computationally expensive](https://en.wikipedia.org/wiki/Error_function#Numerical_approximations) it is | ||
/// possible to approximate the Gaussian Linear Unit with a hyperbolic tangent function `tanh` | ||
/// | ||
/// ```text | ||
/// GeLU(x) ~ 0.5 ∗ x ∗ (1.0 + tanh((sqrt(2.0/π) ∗ (x + 0.044715 ∗ x^3))) | ||
/// ``` | ||
/// | ||
/// See [fast_gelu](super::fast_gelu::fast_gelu) to use this approximation | ||
/// | ||
/// | ||
/// Examples: | ||
/// ```rust | ||
/// # use dfdx::prelude::*; | ||
/// # let dev: Cpu = Default::default(); | ||
/// let t = dev.tensor([-1.0, 0.0, 1.0, 2.0]); | ||
/// let r = t.accurate_gelu(); | ||
/// ``` | ||
pub fn accurate_gelu<S: Shape, E: Dtype, D: UnaryKernel<AccurateGeLUKernelOp, E>, T: Tape<E, D>>( | ||
coreylowman marked this conversation as resolved.
Show resolved
Hide resolved
|
||
t: Tensor<S, E, D, T>, | ||
) -> Tensor<S, E, D, T> { | ||
t.accurate_gelu() | ||
} | ||
|
||
impl<S: Shape, E: Dtype, D: UnaryKernel<AccurateGeLUKernelOp, E>, T: Tape<E, D>> | ||
Tensor<S, E, D, T> | ||
{ | ||
/// See [accurate_gelu] | ||
pub fn accurate_gelu(self) -> Self { | ||
self.try_accurate_gelu().unwrap() | ||
} | ||
/// See [accurate_gelu] | ||
pub fn try_accurate_gelu(self) -> Result<Self, D::Err> { | ||
try_unary_op(AccurateGeLUKernelOp, self) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use crate::{tensor::*, tensor_ops::*, tests::*}; | ||
|
||
#[test] | ||
fn test_accurate_gelu() { | ||
let dev: TestDevice = Default::default(); | ||
let x = dev | ||
.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) | ||
.to_dtype::<TestDtype>(); | ||
let r = x.leaky_trace().accurate_gelu(); | ||
|
||
assert_close_to_literal!(r, [-0.04550027, -0.15865525, 0.0, 0.84134471, 1.9544997,]); | ||
// NOTE: call .exp() to make sure we cover cases where .gelu() uses the result's gradient | ||
let g = r.exp().mean().backward(); | ||
assert_close_to_literal!( | ||
g.get(&x), | ||
[-0.024835737, -0.03132311, 0.1, 0.5490418, 1.59559] | ||
); | ||
} | ||
} |
2 changes: 1 addition & 1 deletion
2
src/tensor_ops/gelu/cpu_kernel.rs → src/tensor_ops/fast_gelu/cpu_kernel.rs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
use super::FastGeLUKernelOp; | ||
use crate::tensor_ops::cuda_kernels::cuda_unary; | ||
|
||
unsafe impl cudarc::driver::DeviceRepr for super::FastGeLUKernelOp {} | ||
|
||
const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/fast_gelu.ptx")); | ||
|
||
#[cfg(feature = "f16")] | ||
cuda_unary!( | ||
FastGeLUKernelOp, | ||
half::f16, | ||
PTX, | ||
"fast_gelu_fwd_f16", | ||
"fast_gelu_bwd_f16" | ||
); | ||
cuda_unary!( | ||
FastGeLUKernelOp, | ||
f32, | ||
PTX, | ||
"fast_gelu_fwd_f32", | ||
"fast_gelu_bwd_f32" | ||
); | ||
cuda_unary!( | ||
FastGeLUKernelOp, | ||
f64, | ||
PTX, | ||
"fast_gelu_fwd_f64", | ||
"fast_gelu_bwd_f64" | ||
); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs to link to it's non-deprecated counterpart.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
link is 3 lines up