mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
@@ -20,10 +20,10 @@
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void decode_absorb_cache_kernel(
|
||||
const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512
|
||||
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
|
||||
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// nope_size]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
|
||||
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// nope_size]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -62,26 +62,25 @@ __global__ void decode_absorb_cache_kernel(
|
||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
|
||||
if (bias < nope_hidden_size) { // pe
|
||||
if (bias < nope_hidden_size) { // pe
|
||||
const uint32_t inner_bias = bias;
|
||||
const uint32_t hi = inner_bias / nope_size;
|
||||
const uint32_t h_bias = inner_bias % nope_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
start_token_idx * nope_hidden_size + inner_bias;
|
||||
const uint32_t tgt_idx =
|
||||
block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size + block_offset * all_size + h_bias;
|
||||
const uint32_t ori_idx = start_token_idx * nope_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
} else {
|
||||
const uint32_t inner_bias = bias - nope_hidden_size;
|
||||
const uint32_t hi = inner_bias / pe_size;
|
||||
const uint32_t h_bias = inner_bias % pe_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + nope_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
start_token_idx * pe_hidden_size + inner_bias;
|
||||
const uint32_t tgt_idx =
|
||||
block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size + block_offset * all_size + nope_size +
|
||||
h_bias;
|
||||
const uint32_t ori_idx = start_token_idx * pe_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
}
|
||||
@@ -91,10 +90,10 @@ __global__ void decode_absorb_cache_kernel(
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void speculate_decode_absorb_cache_kernel(
|
||||
const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512
|
||||
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
|
||||
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// nope_size]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
|
||||
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// nope_size]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
@@ -125,8 +124,7 @@ __global__ void speculate_decode_absorb_cache_kernel(
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
const int write_seq_id =
|
||||
seq_lens[ori_bi] + token_id - start_token_idx;
|
||||
const int write_seq_id = seq_lens[ori_bi] + token_id - start_token_idx;
|
||||
if (write_seq_id == 0) continue;
|
||||
|
||||
const int* block_table_now = nullptr;
|
||||
@@ -145,26 +143,25 @@ __global__ void speculate_decode_absorb_cache_kernel(
|
||||
token_id,
|
||||
cu_seqlens_q[ori_bi]);
|
||||
}
|
||||
if (bias < nope_hidden_size) { // pe
|
||||
if (bias < nope_hidden_size) { // pe
|
||||
const uint32_t inner_bias = bias;
|
||||
const uint32_t hi = inner_bias / nope_size;
|
||||
const uint32_t h_bias = inner_bias % nope_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
token_id * nope_hidden_size + inner_bias;
|
||||
const uint32_t tgt_idx =
|
||||
block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size + block_offset * all_size + h_bias;
|
||||
const uint32_t ori_idx = token_id * nope_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
} else {
|
||||
const uint32_t inner_bias = bias - nope_hidden_size;
|
||||
const uint32_t hi = inner_bias / pe_size;
|
||||
const uint32_t h_bias = inner_bias % pe_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + nope_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
token_id * pe_hidden_size + inner_bias;
|
||||
const uint32_t tgt_idx =
|
||||
block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size + block_offset * all_size + nope_size +
|
||||
h_bias;
|
||||
const uint32_t ori_idx = token_id * pe_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
}
|
||||
@@ -174,10 +171,10 @@ __global__ void speculate_decode_absorb_cache_kernel(
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void prefill_absorb_cache_kernel(
|
||||
const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512
|
||||
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
|
||||
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// nope_size]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
|
||||
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// nope_size]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
@@ -206,33 +203,33 @@ __global__ void prefill_absorb_cache_kernel(
|
||||
const uint32_t bias = linear_index % hidden_size;
|
||||
const uint32_t ori_bi = batch_id_per_token[token_idx];
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const uint32_t ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
const uint32_t ori_seq_id =
|
||||
(token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int* block_table_now = nullptr;
|
||||
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const uint32_t block_idx = block_table_now[ori_seq_id / block_size];
|
||||
const uint32_t block_offset = ori_seq_id % block_size;
|
||||
|
||||
if (bias < nope_hidden_size) { // pe
|
||||
if (bias < nope_hidden_size) { // pe
|
||||
const uint32_t inner_bias = bias;
|
||||
const uint32_t hi = inner_bias / nope_size;
|
||||
const uint32_t h_bias = inner_bias % nope_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
token_idx * nope_hidden_size + inner_bias;
|
||||
const uint32_t tgt_idx =
|
||||
block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size + block_offset * all_size + h_bias;
|
||||
const uint32_t ori_idx = token_idx * nope_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
} else {
|
||||
const uint32_t inner_bias = bias - nope_hidden_size;
|
||||
const uint32_t hi = inner_bias / pe_size;
|
||||
const uint32_t h_bias = inner_bias % pe_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + nope_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
token_idx * pe_hidden_size + inner_bias;
|
||||
const uint32_t tgt_idx =
|
||||
block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size + block_offset * all_size + nope_size +
|
||||
h_bias;
|
||||
const uint32_t ori_idx = token_idx * pe_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user