+
Skip to content

Adding attention_reshape (inference only) kernels. #497

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions src/tensor_ops/attention_reshape/attention_reshape.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#include "cuda_utils.cuh"
extern "C" __global__ void attention_reshape_f32(
const size_t numel,
const size_t num_heads,
const size_t head_dim,
const size_t sequence_length,
const size_t past_length,
const float *qkv,
const float *past_key,
const float *past_value,
float *query,
float *key,
float *value
) {
size_t n = blockIdx.x * blockDim.x + threadIdx.x;
if (n >= numel) {
return;
}
const size_t hidden_dim = num_heads * head_dim;
const size_t total_length = sequence_length + past_length;
const size_t q_length = hidden_dim * sequence_length;
const size_t k_length = hidden_dim * total_length;
if (n < q_length){
const size_t k = n % head_dim;
const size_t j = (n / head_dim) % sequence_length;
const size_t i = n / head_dim / sequence_length;
const size_t qkv_index = j * hidden_dim * 3 + i * head_dim + k;
const size_t q_index = n;
query[q_index] = qkv[qkv_index];
} else if (n < q_length + k_length){
const size_t k_index = n - q_length;
size_t k = k_index % total_length;
const size_t j = (k_index / total_length) % head_dim;
const size_t i = k_index / head_dim / total_length;

if (k < past_length){
const size_t past_key_index = i * past_length * head_dim + j * past_length + k;
key[k_index] = past_key[past_key_index];
// key[k_index] = 0;
}else{
k -= past_length;
const size_t qkv_index = k * hidden_dim * 3 + i * head_dim + j + hidden_dim;
key[k_index] = qkv[qkv_index];
// key[k_index] = 0;
}
} else{
const size_t v_index = n - k_length - q_length;
const size_t k = v_index % head_dim;
size_t j = (v_index / head_dim) % total_length;
const size_t i = v_index / head_dim / total_length;

if (j < past_length){
const size_t past_value_index = i * past_length * head_dim + j * head_dim + k;
value[v_index] = past_value[past_value_index];
// value[v_index] = 0;
}else{
j -= past_length;
const size_t qkv_index = j * hidden_dim * 3 + i * head_dim + k + 2 * hidden_dim;
value[v_index] = qkv[qkv_index];
// value[v_index] = 0;
}
}
}
93 changes: 93 additions & 0 deletions src/tensor_ops/attention_reshape/cpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
use super::*;
use crate::{gradients::NoneTape, tensor::cpu::Cpu};
use std::vec;

impl<E: Dtype> super::AttentionReshapeKernel<E> for Cpu {
fn forward<const THREE_HIDDEN_DIM: usize, const NUM_HEADS: usize, const HEAD_DIM: usize>(
&self,
qkv: &Tensor<(usize, Const<THREE_HIDDEN_DIM>), E, Self>,
past_key: &Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, Self>,
past_value: &Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>,
) -> Result<
(
Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>,
Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, Self>,
Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>,
),
Self::Err,
> {
let sequence_length = qkv.shape().0;
let past_sequence_length = past_key.shape().2;
let total_length = sequence_length.size() + past_sequence_length.size();
let dev = qkv.device.clone();

let mut q: Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self, NoneTape> =
dev.zeros_like(&(Const, sequence_length, Const));
let mut k: Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, Self, NoneTape> =
dev.zeros_like(&(Const, Const, total_length));
let mut v: Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self, NoneTape> =
dev.zeros_like(&(Const, total_length, Const));
let mut q_vec = vec![E::default(); q.shape().num_elements()];
let mut k_vec = vec![E::default(); k.shape().num_elements()];
let mut v_vec = vec![E::default(); v.shape().num_elements()];
let mut past_key_vec = vec![E::default(); past_key.shape().num_elements()];
let mut past_value_vec = vec![E::default(); past_value.shape().num_elements()];
let mut qkv_vec = vec![E::default(); qkv.shape().num_elements()];
past_key.copy_into(past_key_vec.as_mut_slice());
past_value.copy_into(&mut past_value_vec);
qkv.copy_into(&mut qkv_vec);

let head_dim = HEAD_DIM;
let hidden_dim = THREE_HIDDEN_DIM / 3;
let num_heads = NUM_HEADS;
(0..num_heads).for_each(|i| {
(0..sequence_length.size()).for_each(|j| {
(0..head_dim).for_each(|k| {
let index = j * hidden_dim * 3 + i * head_dim + k;
let out_index = i * sequence_length.size() * head_dim + j * head_dim + k;
let value = qkv_vec[index];
q_vec[out_index] = value;
});
});
});
(0..num_heads).for_each(|i| {
(0..past_sequence_length.size() + sequence_length.size()).for_each(|j| {
(0..head_dim).for_each(|k| {
let in_index_k =
i * (past_sequence_length.size() + sequence_length.size()) * head_dim
+ k * (past_sequence_length.size() + sequence_length.size())
+ j;

let in_index_v =
i * (past_sequence_length.size() + sequence_length.size()) * head_dim
+ j * head_dim
+ k;
if j < past_sequence_length.size() {
let k_index = i * past_sequence_length.size() * head_dim
+ k * past_sequence_length.size()
+ j;
let k_value = past_key_vec[k_index];
k_vec[in_index_k] = k_value;

let v_index = i * past_sequence_length.size() * head_dim + j * head_dim + k;
let v_value = past_value_vec[v_index];
v_vec[in_index_v] = v_value;
} else {
let sj = j - past_sequence_length.size();
let k_index = sj * hidden_dim * 3 + i * head_dim + hidden_dim + k;
let k_value = qkv_vec[k_index];
k_vec[in_index_k] = k_value;

let v_index = sj * hidden_dim * 3 + i * head_dim + hidden_dim * 2 + k;
let v_value = qkv_vec[v_index];
v_vec[in_index_v] = v_value;
}
});
});
});
q.copy_from(&q_vec);
k.copy_from(&k_vec);
v.copy_from(&v_vec);
Ok((q, k, v))
}
}
82 changes: 82 additions & 0 deletions src/tensor_ops/attention_reshape/cuda_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
use super::*;
use crate::tensor::cuda::{Cuda, CudaArray};
use cudarc::driver::{LaunchAsync, LaunchConfig};
use std::sync::Arc;

const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/attention_reshape.ptx"));

impl super::AttentionReshapeKernel<f32> for Cuda {
fn forward<const THREE_HIDDEN_DIM: usize, const NUM_HEADS: usize, const HEAD_DIM: usize>(
&self,
qkv: &Tensor<(usize, Const<THREE_HIDDEN_DIM>), f32, Self>,
past_key: &Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), f32, Self>,
past_value: &Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), f32, Self>,
) -> Result<
(
Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), f32, Self>,
Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), f32, Self>,
Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), f32, Self>,
),
Self::Err,
> {
let mod_ = "attention_reshape_f32";
let fns = "attention_reshape_f32";
if !self.dev.has_func(mod_, fns) {
self.dev.load_ptx(PTX.into(), mod_, &[fns])?;
}
let f = self.dev.get_func(mod_, fns).unwrap();
let seq = qkv.shape().0;
let sequence_length = seq.size();
let past_length = past_key.shape().2;
let total_length = sequence_length + past_length;
let head_dim = HEAD_DIM;
let num_heads = NUM_HEADS;

let q_shape = (Const, seq, Const);
let mut q_storage = self.dev.alloc_zeros_async::<f32>(q_shape.num_elements())?;

let k_shape = (Const, Const, total_length);
let mut k_storage = self.dev.alloc_zeros_async::<f32>(k_shape.num_elements())?;

let v_shape = (Const, total_length, Const);
let mut v_storage = self.dev.alloc_zeros_async::<f32>(v_shape.num_elements())?;

let numel = q_shape.num_elements() + k_shape.num_elements() + v_shape.num_elements();
let cfg = LaunchConfig::for_num_elems(numel as u32);
let params = (
numel, // const size_t numel,
num_heads, // const size_t num_heads,
head_dim, // const size_t head_dim,
sequence_length, // const size_t sequence_length,
past_length, // const size_t past_length,
qkv.storage.data.as_ref(), // const float *qkv,
past_key.storage.data.as_ref(), // const float *past_key,
past_value.storage.data.as_ref(), // const float *past_value,
&mut q_storage, // float *q,
&mut k_storage, // float *k,
&mut v_storage, // float *v
);

unsafe { f.launch_async(cfg, params) }?;
let device = qkv.device.clone();
let q: Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), f32, Self> =
device.upgrade(CudaArray {
data: Arc::new(q_storage),
shape: q_shape,
strides: q_shape.strides(),
});
let k: Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), f32, Self> =
device.upgrade(CudaArray {
data: Arc::new(k_storage),
shape: k_shape,
strides: k_shape.strides(),
});
let v: Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), f32, Self> =
device.upgrade(CudaArray {
data: Arc::new(v_storage),
shape: v_shape,
strides: v_shape.strides(),
});
Ok((q, k, v))
}
}
136 changes: 136 additions & 0 deletions src/tensor_ops/attention_reshape/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
use crate::{shapes::*, tensor::*};

mod cpu_kernel;
#[cfg(feature = "cuda")]
mod cuda_kernel;

pub type Query<const NUM_HEADS: usize, const HEAD_DIM: usize, E, D> =
Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, D>;
pub type Key<const NUM_HEADS: usize, const HEAD_DIM: usize, E, D> =
Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, D>;
pub type Value<const NUM_HEADS: usize, const HEAD_DIM: usize, E, D> =
Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, D>;

type QkvTuple<const NUM_HEADS: usize, const HEAD_DIM: usize, E, D> = (
Query<NUM_HEADS, HEAD_DIM, E, D>,
Key<NUM_HEADS, HEAD_DIM, E, D>,
Value<NUM_HEADS, HEAD_DIM, E, D>,
);

/// AttentionReshape qkv + past_key + past_value into (q, k, v) used
/// in attention layer
pub trait TryAttentionReshape<E: Dtype>: DeviceStorage {
/// This is an inference only kernel:
/// Within `transformers` architecture, a core component is the `attention`
/// layer, which can be written in many forms.
///
/// This particular version expects a `qkv` tensor (gotten from one single
/// Linear layer, corresponding of stacked `query`, `key`, `value`.
/// And `past_key` and `past_value` which are the cached values within attention
/// (This speeds up inference speed).
/// For the first pass, just send zero-width tensors when the cache isn't present
/// already.
///
/// Having a single layer instead of many `cat`, `reshape`, `permute` makes this
/// operation very efficient on GPU.
fn attention_reshape<
const THREE_HIDDEN_DIM: usize,
const NUM_HEADS: usize,
const HEAD_DIM: usize,
>(
&self,
qkv: &Tensor<(usize, Const<THREE_HIDDEN_DIM>), E, Self>,
past_key: &Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, Self>,
past_value: &Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>,
) -> QkvTuple<NUM_HEADS, HEAD_DIM, E, Self> {
self.try_attention_reshape(qkv, past_key, past_value)
.unwrap()
}

/// Fallible version of [TryAttentionReshape::attention_reshape]
fn try_attention_reshape<
const THREE_HIDDEN_DIM: usize,
const NUM_HEADS: usize,
const HEAD_DIM: usize,
>(
&self,
qkv: &Tensor<(usize, Const<THREE_HIDDEN_DIM>), E, Self>,
past_key: &Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, Self>,
past_value: &Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>,
) -> Result<QkvTuple<NUM_HEADS, HEAD_DIM, E, Self>, Self::Err>;
}

pub trait AttentionReshapeKernel<E: Dtype>: DeviceStorage {
fn forward<const THREE_HIDDEN_DIM: usize, const NUM_HEADS: usize, const HEAD_DIM: usize>(
&self,
qkv: &Tensor<(usize, Const<THREE_HIDDEN_DIM>), E, Self>,
past_key: &Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, Self>,
past_value: &Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>,
) -> Result<QkvTuple<NUM_HEADS, HEAD_DIM, E, Self>, Self::Err>;
}

impl<E: Dtype, D: AttentionReshapeKernel<E>> TryAttentionReshape<E> for D {
/// Fallible version of [TryAttentionReshape::cat]
fn try_attention_reshape<
const THREE_HIDDEN_DIM: usize,
const NUM_HEADS: usize,
const HEAD_DIM: usize,
>(
&self,
qkv: &Tensor<(usize, Const<THREE_HIDDEN_DIM>), E, Self>,
past_key: &Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, Self>,
past_value: &Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>,
) -> Result<QkvTuple<NUM_HEADS, HEAD_DIM, E, Self>, Self::Err> {
let device = qkv.device.clone();
device.forward(qkv, past_key, past_value)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::tests::*;

#[test]
fn test_attention_reshape() {
let dev: TestDevice = Default::default();

const NUM_HEADS: usize = 2;
const HEAD_DIM: usize = 3;
let sequence_length = 1;
let past_length = 3;

{
let qkv: Tensor<(usize, Const<{ NUM_HEADS * HEAD_DIM * 3 }>), TestDtype, _> =
dev.zeros_like(&(sequence_length, Const)) + 1.0;
let past_key: Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), TestDtype, _> =
dev.zeros_like(&(Const, Const, past_length)) + 2.0;
let past_value: Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), TestDtype, _> =
dev.zeros_like(&(Const, past_length, Const)) + 3.0;

let (q, k, v) = dev.attention_reshape(&qkv, &past_key, &past_value);

assert_eq!(q.as_vec(), std::vec![1.0; 6]);
#[rustfmt::skip]
assert_eq!(
k.as_vec(),
std::vec![
2.0, 2.0, 2.0, 1.0,
2.0, 2.0, 2.0, 1.0,
2.0, 2.0, 2.0, 1.0,
2.0, 2.0, 2.0, 1.0,
2.0, 2.0, 2.0, 1.0,
2.0, 2.0, 2.0, 1.0
]
);
#[rustfmt::skip]
assert_eq!(
v.as_vec(),
std::vec![
3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 1.0, 1.0, 1.0,
3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 1.0, 1.0, 1.0
]
);
}
}
}
2 changes: 2 additions & 0 deletions src/tensor_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ pub use utilities::*;

mod abs;
mod add;
mod attention_reshape;
mod bce;
mod boolean;
mod broadcast_to;
Expand Down Expand Up @@ -188,6 +189,7 @@ mod var_to;

pub use abs::abs;
pub use add::{add, TryAdd};
pub use attention_reshape::TryAttentionReshape;
pub use bce::bce_with_logits;
pub use boolean::{bool_and, bool_not, bool_or, bool_xor};
pub use broadcast_to::BroadcastTo;
Expand Down
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载