+
Skip to content

Adding GeLU operator (used in Gpt2) #397

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/nn/activations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ macro_rules! activation_impls {
}

activation_impls!(ReLU, relu, #[doc="Unit struct that impls [Module] as calling [relu()] on `input`."]);
activation_impls!(GeLU, gelu, #[doc="Unit struct that impls [Module] as calling [gelu()] on `input`."]);
activation_impls!(Sin, sin, #[doc="Unit struct that impls [Module] as calling [sin()] on `input`."]);
activation_impls!(Cos, cos, #[doc="Unit struct that impls [Module] as calling [cos()] on `input`."]);
activation_impls!(Ln, ln, #[doc="Unit struct that impls [Module] as calling [ln()] on `input`."]);
Expand Down Expand Up @@ -64,6 +65,15 @@ mod tests {
assert_eq!(r1.array(), r2.array());
}

#[test]
fn test_nn_activations_gelu() {
let dev: TestDevice = Default::default();
let t = dev.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]);
let r1 = GeLU.forward_mut(t.clone());
let r2 = gelu(t);
assert_eq!(r1.array(), r2.array());
}

#[test]
fn test_nn_activations_sin() {
let dev: TestDevice = Default::default();
Expand Down
1 change: 1 addition & 0 deletions src/tensor_ops/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ pub trait Device<E: Dtype>:
+ UnaryKernel<super::nans_to::NansToKernelOp<E>, E>
+ UnaryKernel<super::negate::NegateKernelOp, E>
+ UnaryKernel<super::relu::ReLUKernelOp, E>
+ UnaryKernel<super::gelu::GeLUKernelOp, E>
+ UnaryKernel<super::sigmoid::SigmoidKernelOp, E>
+ UnaryKernel<super::sin::SinKernelOp, E>
+ UnaryKernel<super::sqrt::SqrtKernelOp, E>
Expand Down
33 changes: 33 additions & 0 deletions src/tensor_ops/gelu/cpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use crate::tensor_ops::cpu_kernels::UnaryDerivative;
use std::f32::consts::PI;

impl UnaryDerivative<f32> for super::GeLUKernelOp {
#[inline(always)]
fn f(&self, x: &f32) -> f32 {
let alpha = x + 0.044715 * x.powf(3.0);
0.5 * (*x) * (1.0 + f32::tanh((2.0f32 / PI).sqrt() * alpha))
}

#[inline(always)]
fn df(&self, x: &f32) -> f32 {
let sqrt2 = 2.0f32.sqrt();
let sqrt_2_pi = 2.0 * (1.0 / PI).sqrt();
let beta = sqrt2 * sqrt_2_pi * 0.5;
let kappa = 0.044715;
let x_sq = x * x;
let x_cube = x_sq * x;
let inner = beta * (x + kappa * x_cube);
let tanh_inner = f32::tanh(inner);

let left = 0.5 * x;
let right = 1.0 + tanh_inner;

let left_derivative = 0.5 * right;

let tanh_derivative = 1.0 - tanh_inner * tanh_inner;
let inner_derivative = beta * (1.0 + 3.0 * kappa * x_sq);
let right_derivative = left * tanh_derivative * inner_derivative;

left_derivative + right_derivative
}
}
10 changes: 10 additions & 0 deletions src/tensor_ops/gelu/cuda_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
use crate::tensor_ops::cuda_kernels::UnaryOpCudaKernel;

unsafe impl cudarc::driver::AsKernelParam for super::GeLUKernelOp {}

impl UnaryOpCudaKernel for super::GeLUKernelOp {
const PTX_SRC: &'static str = include_str!(concat!(env!("OUT_DIR"), "/gelu.ptx"));
const MODULE_NAME: &'static str = "gelu";
const FWD_FN_NAME: &'static str = "gelu_forward";
const BWD_FN_NAME: &'static str = "gelu_backward";
}
54 changes: 54 additions & 0 deletions src/tensor_ops/gelu/gelu.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
struct GeLUKernelOp {};

extern "C" __global__ void gelu_forward(
const GeLUKernelOp op,
const size_t numel,
const float *inp,
float *out
) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= numel) {
return;
}

constexpr float fastCoeff = 0.044715;
float x = inp[i];
float x_sq = x * x;
float x_cube = x_sq * x;

float alpha = x + fastCoeff * x_cube;

float y = 0.5 * x * (1.0 + tanh(M_2_SQRTPI * M_SQRT1_2 * alpha));
out[i] = y;
}

extern "C" __global__ void gelu_backward(
const GeLUKernelOp op,
const size_t numel,
const float *inp,
float *grad_inp,
const float *grad_out
) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= numel) {
return;
}
float kBeta = M_2_SQRTPI * M_SQRT2 * 0.5;
constexpr float fastCoeff = 0.044715;
float x = inp[i];
float x_sq = x * x;
float x_cube = x_sq * x;
float inner = kBeta * (x + fastCoeff * x_cube);
float tanh_inner = tanh(inner);

float left = 0.5 * x;
float right = 1.0 + tanh_inner;

float left_derivative = 0.5 * right;

float tanh_derivative = 1.0 - tanh_inner * tanh_inner;
float inner_derivative = kBeta * (1.0 + 3.0 * fastCoeff * x_sq);
float right_derivative = left * tanh_derivative * inner_derivative;
float dx = left_derivative + right_derivative;
grad_inp[i] += dx * grad_out[i];
}
64 changes: 64 additions & 0 deletions src/tensor_ops/gelu/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
mod cpu_kernel;

#[cfg(feature = "cuda")]
mod cuda_kernel;

use super::ops::{try_unary_op, UnaryKernel};
use crate::{gradients::Tape, shapes::*, tensor::Tensor};

#[repr(C)]
#[derive(Debug, Default, Copy, Clone)]
pub struct GeLUKernelOp;

/// [Gaussian Linear Unit (GeLU)](https://paperswithcode.com/method/gelu). `0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x^3)))`
///
/// Examples:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let t = dev.tensor([-1.0, 0.0, 1.0, 2.0]);
/// let r = t.gelu();
/// ```
pub fn gelu<S: Shape, E: Dtype, D: UnaryKernel<GeLUKernelOp, E>, T: Tape<D>>(
t: Tensor<S, E, D, T>,
) -> Tensor<S, E, D, T> {
t.gelu()
}

impl<S: Shape, E: Dtype, D: UnaryKernel<GeLUKernelOp, E>, T: Tape<D>> Tensor<S, E, D, T> {
/// See [gelu]
pub fn gelu(self) -> Self {
self.try_gelu().unwrap()
}
/// See [gelu]
pub fn try_gelu(self) -> Result<Self, D::Err> {
try_unary_op(GeLUKernelOp, self)
}
}

#[cfg(test)]
mod tests {
use crate::{
tensor::*,
tensor_ops::*,
tests::{assert_close, TestDevice},
};

#[test]
fn test_gelu() {
let dev: TestDevice = Default::default();
let x = dev.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]);
let r = x.trace().gelu();
assert_close(
&r.array(),
&[-0.04540229, -0.158808, 0.0, 0.841192, 1.9545977],
);

// NOTE: call .exp() to make sure we cover cases where .gelu() uses the result's gradient
let g = r.exp().mean().backward();
assert_close(
&g.get(&x).array(),
&[-0.016455507, -0.014156329, 0.1, 0.5023068, 1.5338063],
);
}
}
2 changes: 2 additions & 0 deletions src/tensor_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ mod cos;
mod div;
mod dropout;
mod exp;
mod gelu;
mod huber_error;
mod ln;
mod log_softmax;
Expand Down Expand Up @@ -194,6 +195,7 @@ pub use cos::cos;
pub use div::{div, TryDiv};
pub use dropout::dropout;
pub use exp::exp;
pub use gelu::gelu;
pub use huber_error::huber_error;
pub use ln::ln;
pub use log_softmax::log_softmax;
Expand Down
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载