-
-
Notifications
You must be signed in to change notification settings - Fork 104
accurate-gelu #813
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
accurate-gelu #813
Conversation
Also I'm not at sure that the cuda actually supports the erf as I can't build cuda locally. Here's a link to the cuda 32bit error function |
Finally the naming is pretty terrible. Would it make sens to make a Gelu Activation an enum enum Gelu {
Fast,
Accurate
} But then probably the |
Thank you for contributing! This looks good so far, I just have a few comments.
cargo +nightly clippy -F cuda,ci-check
|
Okay @nkoppel I updated the name. Accurate GeLU is much better. I think distinguishing that the other GeLU is faster will help. I couldn't find an explicit citation for the fact that it's faster, but it seems to only require a single exponential, while the error function requires a much higher degree polynomial and still an exponential. I also beefed up the docs. I feel like the changes to the docs for the activations don't fit in the code, but I'm guessing most people using the code will be using them as activations, rather than postfix operations so I added some info there even though it breaks the nice code block. |
src/tensor_ops/gelu/gelu.cu
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are the changes in this file just indentation? Can you revert them if so? Just for easier reviewing 😀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be, and will do. LSP automatically updated them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few little fixes to documentation and this should be good to go!
Note that I don't have write access to dfdx, I've just contributed a lot, so @coreylowman gets the final say on everything.
src/tensor_ops/fast_gelu/mod.rs
Outdated
/// See [gelu] | ||
pub fn fast_gelu(self) -> Self { | ||
self.try_fast_gelu().unwrap() | ||
} | ||
/// See [gelu] | ||
pub fn try_fast_gelu(self) -> Result<Self, D::Err> { | ||
try_unary_op(FastGeLUKernelOp, self) | ||
} | ||
|
||
#[deprecated(since = "0.12.0", note = "Use `fast_gelu` instead")] | ||
pub fn gelu(self) -> Self { | ||
self.fast_gelu() | ||
} | ||
|
||
#[deprecated(since = "0.12.0", note = "Use `try_fast_gelu` instead")] | ||
pub fn try_gelu(self) -> Result<Self, D::Err> { | ||
self.try_fast_gelu() | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Top two methods should link to fast_gelu
, and deprecated items should have a link to their non-deprecated counterparts.
src/tensor_ops/fast_gelu/mod.rs
Outdated
#[derive(Debug, Default, Copy, Clone)] | ||
pub struct FastGeLUKernelOp; | ||
|
||
/// [Fast Gaussian Linear Unit (GeLU)](https://paperswithcode.com/method/gelu). `0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x^3)))` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These docs should include a link to AccurateGeLU
src/tensor_ops/accurate_gelu/mod.rs
Outdated
/// GeLU(x) ~ 0.5 ∗ x ∗ (1.0 + tanh((sqrt(2.0/π) ∗ (x + 0.044715 ∗ x^3))) | ||
/// ``` | ||
/// | ||
/// See [gelu](crate::tensor_ops::gelu::gelu) to use this approximation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This link needs to be fixed with the new naming
#[deprecated(since = "0.12.0", note = "please use `FastGeLU` instead")] | ||
#[derive(Default, Debug, Clone, Copy)] | ||
pub struct GeLU; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs to link to it's non-deprecated counterpart.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
link is 3 lines up
@nkoppel alright should be addressed |
This looks good from me - any other updates planned here? |
I'm not planning on anything new. Just waiting on confirmation from @nkoppel that everything is addressed |
We can open another PR if there's something else to add/change! |
This PR adds an accurate gelu function (used at least in in GPT2 model from huggingface) (see Issue #804. Importantly in order to make the operators generic over
Dtype
, I introduce anErf
trait that allows us to calld_type.erf()
to get the error function of the valued_type
. I currently am getting a compile error with featurecuda
, and I have a hunch that this trait might be the issue. I'm having trouble debugging the build further as I don't have the cuda headers anywhere (working on a mac), so I would appreciate some help, so I was wondering if someone more familiar with the code might be able to point out the error.For Github: Resolves #804