+
Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/tensor_ops/gelu/gelu.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#include "unary_op_macros.cuh"
#define _USE_MATH_DEFINES
Copy link
Owner

@coreylowman coreylowman Jan 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this do exactly? Also have you tried compiling all this on a non GNU and MSVC device (I can verify on my dev machine later)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the MSVC documentation

The math constants aren't defined in Standard C/C++. To use them, you must first define _USE_MATH_DEFINES, and then include <cmath> or <math.h>.

The #define is used here to essentially set a flag so that the math constants are defined in the math.h header file.
If we take a look at the (relevant documentation) we can see how it works:

A #define without a token-string removes occurrences of identifier from the source file. The identifier remains defined and can be tested by using the #if defined and #ifdef directives.

(Emphasis mine)


I have not tested with other toolchains than stable-x86_64-pc-windows-gnu and stable-x86_64-pc-windows-msvc.

#include <math.h>

struct GeLUKernelOp {};

Expand All @@ -10,16 +12,16 @@ LONG_UNARY_OP(gelu_forward, gelu_backward, GeLUKernelOp,

float alpha = x + fastCoeff * x_cube;

float y = 0.5 * x * (1.0 + tanh(M_2_SQRTPI * M_SQRT1_2 * alpha));
float y = 0.5 * x * (1.0 + tanhf(M_2_SQRTPI * M_SQRT1_2 * alpha));
out[i] = y;
},
{
float kBeta = M_2_SQRTPI * M_SQRT2 * 0.5;
constexpr float kBeta = M_2_SQRTPI * M_SQRT2 * 0.5;
constexpr float fastCoeff = 0.044715;
float x_sq = x * x;
float x_cube = x_sq * x;
float inner = kBeta * (x + fastCoeff * x_cube);
float tanh_inner = tanh(inner);
float tanh_inner = tanhf(inner);

float left = 0.5 * x;
float right = 1.0 + tanh_inner;
Expand Down
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载