[Speculative Decoding] Refactor Eagle MTP hidden states copy (#6812)

* reformat eagle_get_hidden_states & eagle_get_self_hidden_states

* readibility

* fix xpu bug

* fix coverage failure

* change luanch params & parallelize position_map compute

* Fix MTP-related bugs in FastDeploy centralized inference

* fix

* refactor mtp hidden_states process

* fix

* add unittest & optimize kernel

* remove useless code

* fix
This commit is contained in:
huicongyao
2026-03-26 13:54:31 +08:00
committed by GitHub
parent 4fd877ed43
commit 25d64efdc4
12 changed files with 1309 additions and 383 deletions
+15 -1
View File
@@ -1016,7 +1016,17 @@ std::vector<paddle::Tensor> EagleGetSelfHiddenStates(
const paddle::Tensor& input,
const paddle::Tensor& last_seq_lens_this_time,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& step_idx);
const paddle::Tensor& seq_lens_encoder);
std::vector<paddle::Tensor> EagleGatherHiddenStates(
const paddle::Tensor& input,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token_output,
const paddle::Tensor& cu_seqlens_q_output,
const paddle::Tensor& real_output_token_num);
void MTPStepPaddle(
const paddle::Tensor& base_model_stop_flags,
@@ -1820,6 +1830,10 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
&EagleGetSelfHiddenStates,
"eagle_get_self_hidden_states function");
m.def("eagle_gather_hidden_states",
&EagleGatherHiddenStates,
"eagle_gather_hidden_states function");
m.def("mtp_step_paddle", &MTPStepPaddle, "mtp_step_paddle function");
m.def("speculate_step_paddle",
@@ -42,16 +42,21 @@ __global__ void draft_model_update_kernel(const int64_t* inter_next_tokens,
int64_t stop_flag_now_int = 0;
int tid = threadIdx.x;
if (tid < bsz) {
auto* draft_token_now = draft_tokens + tid * max_draft_token;
auto* pre_ids_now = pre_ids + tid * pre_id_length;
auto* base_model_draft_tokens_now =
if (tid < bsz && seq_lens_this_time[tid] > 0) {
int seq_len_this_time = seq_lens_this_time[tid];
int seq_len_encoder = seq_lens_encoder[tid];
int seq_len_decoder = seq_lens_decoder[tid];
int next_tokens_start_id = 0;
for (int i = 0; i < tid; i++) {
next_tokens_start_id += seq_lens_this_time[i] > 0 ? 1 : 0;
}
int64_t* draft_token_now = draft_tokens + tid * max_draft_token;
int64_t* pre_ids_now = pre_ids + tid * pre_id_length;
int64_t* base_model_draft_tokens_now =
base_model_draft_tokens + tid * max_base_model_draft_token;
const int next_tokens_start_id = cu_seqlens_q_output[tid];
auto* next_tokens_start = inter_next_tokens + next_tokens_start_id;
auto seq_len_this_time = seq_lens_this_time[tid];
auto seq_len_encoder = seq_lens_encoder[tid];
auto seq_len_decoder = seq_lens_decoder[tid];
const int64_t* next_tokens_start = inter_next_tokens + next_tokens_start_id;
// 1. update step_idx && seq_lens_dec
if (!stop_flags[tid]) {
@@ -72,8 +77,8 @@ __global__ void draft_model_update_kernel(const int64_t* inter_next_tokens,
base_model_draft_tokens_now[substep + 1] = -1;
} else {
seq_lens_decoder[tid] += seq_len_this_time;
token_this_time = next_tokens_start[seq_len_this_time - 1];
draft_token_now[0] = next_tokens_start[seq_len_this_time - 1];
token_this_time = next_tokens_start[0];
draft_token_now[0] = token_this_time;
base_model_draft_tokens_now[substep + 1] = token_this_time;
step_idx[tid] += seq_len_this_time;
pre_ids_now[step_idx[tid]] = token_this_time;
@@ -0,0 +1,273 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cooperative_groups.h>
#include "paddle/extension.h"
#include "helper.h"
namespace cg = cooperative_groups;
// Fused kernel: block 0 computes position_map and output_token_num in parallel
// (one thread per batch element), then all blocks synchronize via
// cooperative_groups grid sync, and finally all threads perform the hidden
// states gather in parallel.
template <typename T, int VecSize>
__global__ void EagleGatherHiddenStatesKernel(
T* output_data,
int* position_map,
int* output_token_num,
const T* input,
const int* cu_seqlens_q,
const int* seq_lens_this_time,
const int* seq_lens_decoder,
const int* seq_lens_encoder,
const int* batch_id_per_token_output,
const int* cu_seqlens_q_output,
const int dim_embed,
const int64_t input_token_num,
const int real_bsz) {
cg::grid_group grid = cg::this_grid();
// Dynamic shared memory layout: [in_count|out_count|in_offsets|out_offsets]
extern __shared__ int smem[];
int* in_count = smem;
int* out_count = smem + real_bsz;
int* in_offsets = smem + 2 * real_bsz;
int* out_offsets = smem + 3 * real_bsz;
// Phase 1: compute position_map (parallelized across threads in block 0)
if (blockIdx.x == 0) {
// Phase 1a: each thread computes counts for its batch elements
for (int t = threadIdx.x; t < real_bsz; t += blockDim.x) {
int cur_seq_len = seq_lens_this_time[t];
// has seq in curent batch
if (cur_seq_len > 0) {
in_count[t] = cur_seq_len;
out_count[t] = 1;
} else {
in_count[t] = 0;
out_count[t] = 0;
}
}
__syncthreads();
// Phase 1b: prefix sum (thread 0 computes exclusive prefix sums)
if (threadIdx.x == 0) {
int in_acc = 0, out_acc = 0;
for (int i = 0; i < real_bsz; i++) {
in_offsets[i] = in_acc;
out_offsets[i] = out_acc;
in_acc += in_count[i];
out_acc += out_count[i];
}
output_token_num[0] = out_acc;
}
__syncthreads();
// Phase 1c: each thread fills position_map for its batch elements
for (int t = threadIdx.x; t < real_bsz; t += blockDim.x) {
int in_off = in_offsets[t];
int out_off = out_offsets[t];
if (seq_lens_this_time[t] > 0) {
// For gather: map input token to output position
// Use last token of each sequence
int last_token_idx = in_off + in_count[t] - 1;
position_map[last_token_idx] = out_off;
}
}
}
// Phase 2: grid-wide sync to ensure position_map is ready
grid.sync();
// Phase 3: gather hidden states in parallel
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec;
int elem_cnt = input_token_num * dim_embed;
int global_thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (int elem_idx = global_thread_idx * VecSize; elem_idx < elem_cnt;
elem_idx += blockDim.x * gridDim.x * VecSize) {
int ori_token_idx = elem_idx / dim_embed;
int token_idx = position_map[ori_token_idx];
if (token_idx >= 0) {
int offset = elem_idx % dim_embed;
Load<T, VecSize>(input + ori_token_idx * dim_embed + offset, &src_vec);
Store<T, VecSize>(src_vec, output_data + token_idx * dim_embed + offset);
}
}
}
template <paddle::DataType D>
std::vector<paddle::Tensor> DispatchDtype(
const paddle::Tensor& input,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token_output,
const paddle::Tensor& cu_seqlens_q_output,
const paddle::Tensor& real_output_token_num) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto input_token_num = input.shape()[0];
auto dim_embed = input.shape()[1];
int real_bsz = seq_lens_this_time.shape()[0];
auto position_map = paddle::empty(
{input_token_num}, seq_lens_this_time.dtype(), input.place());
cudaMemsetAsync(position_map.data<int>(),
0xFF,
input_token_num * sizeof(int),
input.stream());
// TODO(yaohuicong): not need this params in future
auto output_token_num =
paddle::empty({1}, seq_lens_this_time.dtype(), input.place());
// Pre-allocate output with max possible size (real_bsz)
auto out = paddle::zeros({real_bsz, dim_embed}, input.dtype(), input.place());
constexpr int VecSize = 4;
int elem_cnt = input_token_num * dim_embed;
assert(elem_cnt % VecSize == 0);
// Grid size linearly related to real_bsz for cooperative launch efficiency
// and CUDA graph capture friendliness
constexpr int thread_per_block = 128;
constexpr int DESIRED_BLOCKS_PER_BATCH = 4;
int dynamic_smem_size = 4 * real_bsz * static_cast<int>(sizeof(int));
// Cooperative launch limit: use conservative smem upper bound for caching
static const int max_grid_size = [&]() {
int blocks_per_sm = 0;
constexpr int smem_upper_bound = 4 * 512 * sizeof(int); // 8KB
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&blocks_per_sm,
EagleGatherHiddenStatesKernel<DataType_, VecSize>,
thread_per_block,
smem_upper_bound);
int dev = 0;
cudaGetDevice(&dev);
int sms = 0;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
return blocks_per_sm * sms;
}();
int blocks_per_batch =
std::min(DESIRED_BLOCKS_PER_BATCH, max_grid_size / std::max(real_bsz, 1));
blocks_per_batch = std::max(blocks_per_batch, 1);
int grid_size = std::min(real_bsz * blocks_per_batch, max_grid_size);
grid_size = std::max(grid_size, 1);
const DataType_* input_ptr =
reinterpret_cast<const DataType_*>(input.data<data_t>());
const int* cu_seqlens_q_ptr = cu_seqlens_q.data<int>();
const int* seq_lens_this_time_ptr = seq_lens_this_time.data<int>();
const int* seq_lens_decoder_ptr = seq_lens_decoder.data<int>();
const int* seq_lens_encoder_ptr = seq_lens_encoder.data<int>();
const int* batch_id_per_token_output_ptr =
batch_id_per_token_output.data<int>();
const int* cu_seqlens_q_output_ptr = cu_seqlens_q_output.data<int>();
DataType_* output_data_ptr = reinterpret_cast<DataType_*>(out.data<data_t>());
int* position_map_ptr = position_map.data<int>();
int* output_token_num_ptr = output_token_num.data<int>();
int dim_embed_int = static_cast<int>(dim_embed);
int64_t input_token_num_int64 = input_token_num;
void* kernel_args[] = {&output_data_ptr,
&position_map_ptr,
&output_token_num_ptr,
&input_ptr,
&cu_seqlens_q_ptr,
&seq_lens_this_time_ptr,
&seq_lens_decoder_ptr,
&seq_lens_encoder_ptr,
&batch_id_per_token_output_ptr,
&cu_seqlens_q_output_ptr,
&dim_embed_int,
&input_token_num_int64,
&real_bsz};
cudaLaunchCooperativeKernel(
(void*)EagleGatherHiddenStatesKernel<DataType_, VecSize>,
dim3(grid_size),
dim3(thread_per_block),
kernel_args,
dynamic_smem_size,
input.stream());
// Return output and output_token_num
return {out, output_token_num};
}
// Wrapper function for PD_BUILD_STATIC_OP
std::vector<paddle::Tensor> EagleGatherHiddenStates(
const paddle::Tensor& input,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token_output,
const paddle::Tensor& cu_seqlens_q_output,
const paddle::Tensor& real_output_token_num) {
switch (input.dtype()) {
case paddle::DataType::BFLOAT16:
return DispatchDtype<paddle::DataType::BFLOAT16>(
input,
cu_seqlens_q,
seq_lens_this_time,
seq_lens_decoder,
seq_lens_encoder,
batch_id_per_token_output,
cu_seqlens_q_output,
real_output_token_num);
case paddle::DataType::FLOAT16:
return DispatchDtype<paddle::DataType::FLOAT16>(input,
cu_seqlens_q,
seq_lens_this_time,
seq_lens_decoder,
seq_lens_encoder,
batch_id_per_token_output,
cu_seqlens_q_output,
real_output_token_num);
case paddle::DataType::FLOAT32:
return DispatchDtype<paddle::DataType::FLOAT32>(input,
cu_seqlens_q,
seq_lens_this_time,
seq_lens_decoder,
seq_lens_encoder,
batch_id_per_token_output,
cu_seqlens_q_output,
real_output_token_num);
default:
PD_THROW("eagle_gather_hidden_states: NOT supported data type.");
}
}
PD_BUILD_STATIC_OP(eagle_gather_hidden_states)
.Inputs({"input",
"cu_seqlens_q",
"seq_lens_this_time",
"seq_lens_decoder",
"seq_lens_encoder",
"batch_id_per_token_output",
"cu_seqlens_q_output",
"real_output_token_num"})
.Outputs({"out", "output_token_num"})
.SetKernelFn(PD_KERNEL(EagleGatherHiddenStates));
@@ -12,104 +12,117 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cooperative_groups.h>
#include "paddle/extension.h"
#include "helper.h"
// #define DEBUG_EAGLE_KERNEL
namespace cg = cooperative_groups;
__global__ void ComputeOrderKernel(const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const int* accept_nums,
int* position_map,
int* output_token_num,
const int bsz,
const int actual_draft_token_num,
const int input_token_num) {
int in_offset = 0; // input_offset(long)
int out_offset = 0; // output_offset(short)
if (threadIdx.x == 0) {
for (int i = 0; i < bsz; ++i) {
int cur_base_model_seq_lens_this_time = base_model_seq_lens_this_time[i];
int cur_base_model_seq_lens_encoder = base_model_seq_lens_encoder[i];
int cur_seq_lens_this_time = seq_lens_this_time[i];
int accept_num = accept_nums[i];
int cur_seq_lens_encoder = seq_lens_encoder[i];
#ifdef DEBUG_EAGLE_KERNEL
printf(
"batch %d: cur_base_model_seq_lens_this_time%d. "
"cur_seq_lens_this_time%d, accept_num %d\n",
i,
cur_base_model_seq_lens_this_time,
cur_seq_lens_this_time,
accept_num);
#endif
// Fused kernel: block 0 computes position_map and output_token_num in parallel
// (one thread per batch element), then all blocks synchronize via
// cooperative_groups grid sync, and finally all threads perform the hidden
// states rebuild in parallel.
template <typename T, int VecSize>
__global__ void rebuildHiddenStatesKernel(
const T* input,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const int* accept_nums,
int* position_map,
int* output_token_num,
T* out,
const int bsz,
const int dim_embed,
const int input_token_num) {
cg::grid_group grid = cg::this_grid();
// Dynamic shared memory layout: [in_count|out_count|in_offsets|out_offsets]
extern __shared__ int smem[];
int* in_count = smem;
int* out_count = smem + bsz;
int* in_offsets = smem + 2 * bsz;
int* out_offsets = smem + 3 * bsz;
// Phase 1: compute position_map (parallelized across threads in block 0)
if (blockIdx.x == 0) {
// Phase 1a: each thread computes counts for its batch elements
for (int t = threadIdx.x; t < bsz; t += blockDim.x) {
int cur_base_model_seq_lens_this_time = base_model_seq_lens_this_time[t];
int cur_seq_lens_this_time = seq_lens_this_time[t];
int accept_num = accept_nums[t];
int cur_seq_lens_encoder = seq_lens_encoder[t];
// 1. eagle encoder. Base step=1
if (cur_seq_lens_encoder > 0) {
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d: cur_seq_lens_encoder > 0 \n", i);
#endif
for (int j = 0; j < cur_seq_lens_encoder; j++) {
position_map[in_offset++] = out_offset++;
}
in_count[t] = cur_seq_lens_encoder;
out_count[t] = cur_seq_lens_encoder;
// 2. Base model stop at last verify-step.
} else if (cur_base_model_seq_lens_this_time != 0 &&
cur_seq_lens_this_time == 0) {
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d: base=0. draft !=0 \n", i);
#endif
in_offset += cur_base_model_seq_lens_this_time;
// 4. stopped
in_count[t] = cur_base_model_seq_lens_this_time;
out_count[t] = 0;
// 3. stopped
} else if (cur_base_model_seq_lens_this_time == 0 &&
cur_seq_lens_this_time == 0) /* end */ {
cur_seq_lens_this_time == 0) {
in_count[t] = 0;
out_count[t] = 0;
} else {
for (int i = 0; i < accept_num; i++) {
position_map[in_offset++] = out_offset++;
}
in_offset += cur_base_model_seq_lens_this_time - accept_num;
// (liuzichang): Temporary Reserved for debug
// if (accept_num <= actual_draft_token_num) /*Accept partial
// draft tokens*/ {
// #ifdef DEBUG_EAGLE_KERNEL
// printf("batch %d: accept_num <= actual_draft_token_num \n",
// i);
// #endif
// position_map[in_offset + accept_num - 1] = out_offset++;
// in_offset += cur_base_model_seq_lens_this_time;
// } else /*Accept all draft tokens*/ {
// #ifdef DEBUG_EAGLE_KERNEL
// printf("batch %d: accept_num > actual_draft_token_num \n",
// i);
// #endif
// position_map[in_offset + accept_num - 2] = out_offset++;
// position_map[in_offset + accept_num - 1] = out_offset++;
// in_offset += cur_base_model_seq_lens_this_time;
// }
in_count[t] = cur_base_model_seq_lens_this_time;
out_count[t] = accept_num;
}
}
output_token_num[0] = out_offset;
#ifdef DEBUG_EAGLE_KERNEL
printf("position map output_token_num%d:\n", output_token_num[0]);
for (int i = 0; i < output_token_num[0]; i++) {
printf("%d ", position_map[i]);
}
printf("\n");
#endif
}
}
__syncthreads();
template <typename T, int VecSize>
__global__ void rebuildHiddenStatesKernel(const T* input,
const int* position_map,
T* out,
const int dim_embed,
const int elem_cnt) {
// Phase 1b: prefix sum (thread 0 computes exclusive prefix sums)
if (threadIdx.x == 0) {
int in_acc = 0, out_acc = 0;
for (int i = 0; i < bsz; i++) {
in_offsets[i] = in_acc;
out_offsets[i] = out_acc;
in_acc += in_count[i];
out_acc += out_count[i];
}
output_token_num[0] = out_acc;
}
__syncthreads();
// Phase 1c: each thread fills position_map for its batch elements
for (int t = threadIdx.x; t < bsz; t += blockDim.x) {
int in_off = in_offsets[t];
int out_off = out_offsets[t];
int cur_seq_lens_encoder = seq_lens_encoder[t];
int cur_base_model_seq_lens_this_time = base_model_seq_lens_this_time[t];
int cur_seq_lens_this_time = seq_lens_this_time[t];
int accept_num = accept_nums[t];
// 1. eagle encoder. Base step=1
if (cur_seq_lens_encoder > 0) {
for (int j = 0; j < cur_seq_lens_encoder; j++) {
position_map[in_off + j] = out_off + j;
}
// 2. Base model stop at last verify-step: no writes needed
// 3. stopped: no writes needed
} else if (cur_base_model_seq_lens_this_time != 0 &&
cur_seq_lens_this_time != 0) {
// 4. normal decode: copy accepted tokens
for (int j = 0; j < accept_num; j++) {
position_map[in_off + j] = out_off + j;
}
}
// Branches 2 & 3: position_map stays -1 from memset
}
}
// Phase 2: grid-wide sync to ensure position_map is ready
grid.sync();
// Phase 3: rebuild hidden states in parallel
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec;
int elem_cnt = input_token_num * dim_embed;
int global_thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (int elem_idx = global_thread_idx * VecSize; elem_idx < elem_cnt;
elem_idx += blockDim.x * gridDim.x * VecSize) {
@@ -117,8 +130,6 @@ __global__ void rebuildHiddenStatesKernel(const T* input,
int token_idx = position_map[ori_token_idx];
if (token_idx >= 0) {
int offset = elem_idx % dim_embed;
if (token_idx == 0) {
}
Load<T, VecSize>(input + ori_token_idx * dim_embed + offset, &src_vec);
Store<T, VecSize>(src_vec, out + token_idx * dim_embed + offset);
}
@@ -141,60 +152,92 @@ std::vector<paddle::Tensor> DispatchDtype(
typedef typename traits_::data_t data_t;
auto input_token_num = input.shape()[0];
// auto output_token_num = padding_offset.shape()[0];
auto dim_embed = input.shape()[1];
int bsz = seq_lens_this_time.shape()[0];
auto position_map = paddle::empty(
{input_token_num}, seq_lens_this_time.dtype(), input.place());
cudaMemsetAsync(position_map.data<int>(),
0xFF,
input_token_num * sizeof(seq_lens_this_time.dtype()),
seq_lens_this_time.stream());
input_token_num * sizeof(int),
input.stream());
auto output_token_num = paddle::empty(
{1}, seq_lens_this_time.dtype(), seq_lens_this_time.place());
ComputeOrderKernel<<<1, 1>>>(seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
base_model_seq_lens_this_time.data<int>(),
base_model_seq_lens_encoder.data<int>(),
accept_nums.data<int>(),
position_map.data<int>(),
output_token_num.data<int>(),
bsz,
actual_draft_token_num,
input_token_num);
auto output_token_num =
paddle::empty({1}, seq_lens_this_time.dtype(), input.place());
int output_token_num_cpu =
output_token_num.copy_to(paddle::CPUPlace(), false).data<int>()[0];
auto out = paddle::empty(
{output_token_num_cpu, dim_embed}, input.dtype(), input.place());
// Pre-allocate output with max possible size (input_token_num)
auto out =
paddle::empty({input_token_num, dim_embed}, input.dtype(), input.place());
constexpr int packSize = VEC_16B / (sizeof(DataType_));
int elem_cnt = input_token_num * dim_embed;
assert(elem_cnt % packSize == 0);
int pack_num = elem_cnt / packSize;
int grid_size = 1;
GetNumBlocks(pack_num, &grid_size);
// Grid size linearly related to bsz for cooperative launch efficiency
// and CUDA graph capture friendliness
constexpr int thread_per_block = 128;
constexpr int DESIRED_BLOCKS_PER_BATCH = 4;
int dynamic_smem_size = 4 * bsz * static_cast<int>(sizeof(int));
rebuildHiddenStatesKernel<DataType_, packSize>
<<<grid_size, thread_per_block>>>(
reinterpret_cast<const DataType_*>(input.data<data_t>()),
position_map.data<int>(),
reinterpret_cast<DataType_*>(out.data<data_t>()),
dim_embed,
elem_cnt);
// Cooperative launch limit: use conservative smem upper bound for caching
static const int max_grid_size = [&]() {
int blocks_per_sm = 0;
constexpr int smem_upper_bound = 4 * 512 * sizeof(int); // 8KB
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&blocks_per_sm,
rebuildHiddenStatesKernel<DataType_, packSize>,
thread_per_block,
smem_upper_bound);
int dev = 0;
cudaGetDevice(&dev);
int sms = 0;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
return blocks_per_sm * sms;
}();
return {out};
int blocks_per_batch =
std::min(DESIRED_BLOCKS_PER_BATCH, max_grid_size / std::max(bsz, 1));
blocks_per_batch = std::max(blocks_per_batch, 1);
int grid_size = std::min(bsz * blocks_per_batch, max_grid_size);
grid_size = std::max(grid_size, 1);
const DataType_* input_ptr =
reinterpret_cast<const DataType_*>(input.data<data_t>());
const int* seq_lens_this_time_ptr = seq_lens_this_time.data<int>();
const int* seq_lens_encoder_ptr = seq_lens_encoder.data<int>();
const int* base_model_seq_lens_this_time_ptr =
base_model_seq_lens_this_time.data<int>();
const int* base_model_seq_lens_encoder_ptr =
base_model_seq_lens_encoder.data<int>();
const int* accept_nums_ptr = accept_nums.data<int>();
int* position_map_ptr = position_map.data<int>();
int* output_token_num_ptr = output_token_num.data<int>();
DataType_* out_ptr = reinterpret_cast<DataType_*>(out.data<data_t>());
int dim_embed_int = static_cast<int>(dim_embed);
int input_token_num_int = static_cast<int>(input_token_num);
void* kernel_args[] = {&input_ptr,
&seq_lens_this_time_ptr,
&seq_lens_encoder_ptr,
&base_model_seq_lens_this_time_ptr,
&base_model_seq_lens_encoder_ptr,
&accept_nums_ptr,
&position_map_ptr,
&output_token_num_ptr,
&out_ptr,
&bsz,
&dim_embed_int,
&input_token_num_int};
cudaLaunchCooperativeKernel(
(void*)rebuildHiddenStatesKernel<DataType_, packSize>,
dim3(grid_size),
dim3(thread_per_block),
kernel_args,
dynamic_smem_size,
input.stream());
return {out, output_token_num};
}
std::vector<paddle::Tensor> EagleGetHiddenStates(
@@ -248,5 +291,5 @@ PD_BUILD_STATIC_OP(eagle_get_hidden_states)
"base_model_seq_lens_this_time",
"base_model_seq_lens_encoder"})
.Attrs({"actual_draft_token_num: int"})
.Outputs({"out"})
.Outputs({"out", "output_token_num"})
.SetKernelFn(PD_KERNEL(EagleGetHiddenStates));
@@ -1,4 +1,3 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,174 +12,230 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cooperative_groups.h>
#include "paddle/extension.h"
#include "helper.h"
// #define DEBUG_EAGLE_KERNEL
namespace cg = cooperative_groups;
__global__ void computeOrderKernel(const int* last_seq_lens_this_time,
const int* seq_lens_this_time,
const int64_t* step_idx,
int* src_map,
int* output_token_num,
int bsz) {
int in_offset = 0;
int out_offset = 0;
if (threadIdx.x == 0) {
for (int i = 0; i < bsz; ++i) {
int cur_seq_lens_this_time = seq_lens_this_time[i];
int cur_last_seq_lens_this_time = last_seq_lens_this_time[i];
#ifdef DEBUG_EAGLE_KERNEL
printf(
"batch %d: cur_seq_lens_this_time:%d. "
"cur_last_seq_lens_this_time:%d\n",
i,
cur_seq_lens_this_time,
cur_last_seq_lens_this_time);
#endif
// Fused kernel: block 0 computes position_map and output_token_num in parallel
// (one thread per batch element), then all blocks synchronize via
// cooperative_groups grid sync, and finally all threads perform the hidden
// states rebuild in parallel.
template <typename T, int VecSize>
__global__ void rebuildSelfHiddenStatesKernel(
const T* input,
const int* last_seq_lens_this_time,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
int* position_map,
int* output_token_num,
T* out,
const int bsz,
const int dim_embed,
const int input_token_num) {
cg::grid_group grid = cg::this_grid();
// Dynamic shared memory layout: [in_count|out_count|in_offsets|out_offsets]
extern __shared__ int smem[];
int* in_count = smem;
int* out_count = smem + bsz;
int* in_offsets = smem + 2 * bsz;
int* out_offsets = smem + 3 * bsz;
// Phase 1: compute position_map (parallelized across threads in block 0)
if (blockIdx.x == 0) {
// Phase 1a: each thread computes counts for its batch elements
for (int t = threadIdx.x; t < bsz; t += blockDim.x) {
int cur_seq_lens_this_time = seq_lens_this_time[t];
int cur_last_seq_lens_this_time = last_seq_lens_this_time[t];
// 1. encoder
if (step_idx[i] == 1 && cur_seq_lens_this_time > 0) {
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d last_step is encoder \n", i);
#endif
in_offset += 1;
src_map[out_offset++] = in_offset - 1;
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d finish. src_map[%d]=%d \n",
i,
out_offset - 1,
in_offset - 1);
#endif
if (seq_lens_encoder[t] > 0 && cur_seq_lens_this_time > 0) {
in_count[t] = 1;
out_count[t] = 1;
// 2. decoder
} else if (cur_seq_lens_this_time > 0) /* =1 */ {
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d is decoder\n", i);
#endif
in_offset += cur_last_seq_lens_this_time;
src_map[out_offset++] = in_offset - 1;
} else if (cur_seq_lens_this_time > 0) {
in_count[t] = cur_last_seq_lens_this_time;
out_count[t] = 1;
// 3. stop
} else {
// first token end
if (step_idx[i] == 1) {
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d finished in first token \n", i);
#endif
in_offset += cur_last_seq_lens_this_time > 0 ? 1 : 0;
// normal end
if (seq_lens_encoder[t] > 0) {
in_count[t] = cur_last_seq_lens_this_time > 0 ? 1 : 0;
} else {
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d finished in non-first token \n", i);
#endif
in_offset += cur_last_seq_lens_this_time;
in_count[t] = cur_last_seq_lens_this_time;
}
out_count[t] = 0;
}
}
output_token_num[0] = out_offset;
#ifdef DEBUG_EAGLE_KERNEL
printf("position map output_token_num%d:\n", output_token_num[0]);
for (int i = 0; i < output_token_num[0]; i++) {
printf("%d ", src_map[i]);
}
printf("\n");
#endif
}
}
__syncthreads();
template <typename T, int PackSize>
__global__ void rebuildSelfHiddenStatesKernel(
const T* input, int* src_map, T* output, int dim_embed, int elem_cnt) {
using LoadT = AlignedVector<T, PackSize>;
// Phase 1b: prefix sum (thread 0 computes exclusive prefix sums)
if (threadIdx.x == 0) {
int in_acc = 0, out_acc = 0;
for (int i = 0; i < bsz; i++) {
in_offsets[i] = in_acc;
out_offsets[i] = out_acc;
in_acc += in_count[i];
out_acc += out_count[i];
}
output_token_num[0] = out_acc;
}
__syncthreads();
// Phase 1c: each thread fills position_map for its batch elements
for (int t = threadIdx.x; t < bsz; t += blockDim.x) {
int in_off = in_offsets[t];
int out_off = out_offsets[t];
int cur_seq_lens_this_time = seq_lens_this_time[t];
int cur_last_seq_lens_this_time = last_seq_lens_this_time[t];
// 1. encoder
if (seq_lens_encoder[t] > 0 && cur_seq_lens_this_time > 0) {
position_map[in_off] = out_off;
// 2. decoder
} else if (cur_seq_lens_this_time > 0) {
position_map[in_off + cur_last_seq_lens_this_time - 1] = out_off;
}
// 3. stop: no writes needed (position_map stays -1 from memset)
}
}
// Phase 2: grid-wide sync to ensure position_map is ready
grid.sync();
// Phase 3: rebuild hidden states in parallel
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec;
int global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
for (int elem_id = global_thread_idx * PackSize; elem_id < elem_cnt;
elem_id += blockDim.x * gridDim.x * PackSize) {
int output_token_idx = elem_id / dim_embed;
int input_token_idx = src_map[output_token_idx];
int offset = elem_id % dim_embed;
Load<T, PackSize>(input + input_token_idx * dim_embed + offset, &src_vec);
Store<T, PackSize>(src_vec, output + output_token_idx * dim_embed + offset);
int elem_cnt = input_token_num * dim_embed;
int global_thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (int elem_idx = global_thread_idx * VecSize; elem_idx < elem_cnt;
elem_idx += blockDim.x * gridDim.x * VecSize) {
int ori_token_idx = elem_idx / dim_embed;
int token_idx = position_map[ori_token_idx];
if (token_idx >= 0) {
int offset = elem_idx % dim_embed;
Load<T, VecSize>(input + ori_token_idx * dim_embed + offset, &src_vec);
Store<T, VecSize>(src_vec, out + token_idx * dim_embed + offset);
}
}
}
template <paddle::DataType D>
std::vector<paddle::Tensor> DispatchDtype(
const paddle::Tensor input,
const paddle::Tensor last_seq_lens_this_time,
const paddle::Tensor seq_lens_this_time,
const paddle::Tensor step_idx) {
const paddle::Tensor& input,
const paddle::Tensor& last_seq_lens_this_time,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
int input_token_num = input.shape()[0];
int dim_embed = input.shape()[1];
auto input_token_num = input.shape()[0];
auto dim_embed = input.shape()[1];
int bsz = seq_lens_this_time.shape()[0];
auto src_map = paddle::full({input_token_num},
-1,
seq_lens_this_time.dtype(),
seq_lens_this_time.place());
auto output_token_num = paddle::full(
{1}, 0, seq_lens_this_time.dtype(), seq_lens_this_time.place());
computeOrderKernel<<<1, 1, 0, seq_lens_this_time.stream()>>>(
last_seq_lens_this_time.data<int>(),
seq_lens_this_time.data<int>(),
step_idx.data<int64_t>(),
src_map.data<int>(),
output_token_num.data<int>(),
bsz);
auto position_map = paddle::empty(
{input_token_num}, seq_lens_this_time.dtype(), input.place());
cudaMemsetAsync(position_map.data<int>(),
0xFF,
input_token_num * sizeof(int),
input.stream());
int output_token_num_cpu =
output_token_num.copy_to(paddle::CPUPlace(), false).data<int>()[0];
auto output_token_num =
paddle::empty({1}, seq_lens_this_time.dtype(), input.place());
auto out = paddle::full(
{output_token_num_cpu, dim_embed}, -1, input.type(), input.place());
// Pre-allocate output with max possible size (input_token_num)
auto out =
paddle::empty({input_token_num, dim_embed}, input.dtype(), input.place());
constexpr int packSize = VEC_16B / (sizeof(DataType_));
int elem_cnt = output_token_num_cpu * dim_embed;
// printf("output_token_num: %d, dim_embed: %d, cnt: %d. packSize: %d\n",
// output_token_num_cpu, dim_embed,elem_cnt, packSize);
int elem_cnt = input_token_num * dim_embed;
assert(elem_cnt % packSize == 0);
int pack_num = elem_cnt / packSize;
// Grid size linearly related to bsz for cooperative launch efficiency
// and CUDA graph capture friendliness
constexpr int thread_per_block = 128;
constexpr int DESIRED_BLOCKS_PER_BATCH = 4;
int dynamic_smem_size = 4 * bsz * static_cast<int>(sizeof(int));
int grid_size = 1;
// Cooperative launch limit: use conservative smem upper bound for caching
static const int max_grid_size = [&]() {
int blocks_per_sm = 0;
constexpr int smem_upper_bound = 4 * 512 * sizeof(int); // 8KB
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&blocks_per_sm,
rebuildSelfHiddenStatesKernel<DataType_, packSize>,
thread_per_block,
smem_upper_bound);
int dev = 0;
cudaGetDevice(&dev);
int sms = 0;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
return blocks_per_sm * sms;
}();
GetNumBlocks(pack_num, &grid_size);
int blocks_per_batch =
std::min(DESIRED_BLOCKS_PER_BATCH, max_grid_size / std::max(bsz, 1));
blocks_per_batch = std::max(blocks_per_batch, 1);
int grid_size = std::min(bsz * blocks_per_batch, max_grid_size);
grid_size = std::max(grid_size, 1);
constexpr int threadPerBlock = 128;
const DataType_* input_ptr =
reinterpret_cast<const DataType_*>(input.data<data_t>());
const int* last_seq_lens_this_time_ptr = last_seq_lens_this_time.data<int>();
const int* seq_lens_this_time_ptr = seq_lens_this_time.data<int>();
const int* seq_lens_encoder_ptr = seq_lens_encoder.data<int>();
int* position_map_ptr = position_map.data<int>();
int* output_token_num_ptr = output_token_num.data<int>();
DataType_* out_ptr = reinterpret_cast<DataType_*>(out.data<data_t>());
int dim_embed_int = static_cast<int>(dim_embed);
int input_token_num_int = static_cast<int>(input_token_num);
rebuildSelfHiddenStatesKernel<DataType_, packSize>
<<<grid_size, threadPerBlock, 0, input.stream()>>>(
reinterpret_cast<const DataType_*>(input.data<data_t>()),
src_map.data<int>(),
reinterpret_cast<DataType_*>(out.data<data_t>()),
dim_embed,
elem_cnt);
void* kernel_args[] = {&input_ptr,
&last_seq_lens_this_time_ptr,
&seq_lens_this_time_ptr,
&seq_lens_encoder_ptr,
&position_map_ptr,
&output_token_num_ptr,
&out_ptr,
&bsz,
&dim_embed_int,
&input_token_num_int};
return {out};
cudaLaunchCooperativeKernel(
(void*)rebuildSelfHiddenStatesKernel<DataType_, packSize>,
dim3(grid_size),
dim3(thread_per_block),
kernel_args,
dynamic_smem_size,
input.stream());
return {out, output_token_num};
}
std::vector<paddle::Tensor> EagleGetSelfHiddenStates(
const paddle::Tensor& input,
const paddle::Tensor& last_seq_lens_this_time,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& step_idx) {
const paddle::Tensor& seq_lens_encoder) {
switch (input.dtype()) {
case paddle::DataType::BFLOAT16:
return DispatchDtype<paddle::DataType::BFLOAT16>(
input, last_seq_lens_this_time, seq_lens_this_time, step_idx);
input, last_seq_lens_this_time, seq_lens_this_time, seq_lens_encoder);
case paddle::DataType::FLOAT16:
return DispatchDtype<paddle::DataType::FLOAT16>(
input, last_seq_lens_this_time, seq_lens_this_time, step_idx);
input, last_seq_lens_this_time, seq_lens_this_time, seq_lens_encoder);
default:
PD_THROW("Not support this data type");
}
}
PD_BUILD_STATIC_OP(eagle_get_self_hidden_states)
.Inputs(
{"input", "last_seq_lens_this_time", "seq_lens_this_time", "step_idx"})
.Outputs({"out"})
.Inputs({"input",
"last_seq_lens_this_time",
"seq_lens_this_time",
"seq_lens_encoder"})
.Outputs({"out", "output_token_num"})
.SetKernelFn(PD_KERNEL(EagleGetSelfHiddenStates));