mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Speculative Decoding] Refactor Eagle MTP hidden states copy (#6812)
* reformat eagle_get_hidden_states & eagle_get_self_hidden_states * readibility * fix xpu bug * fix coverage failure * change luanch params & parallelize position_map compute * Fix MTP-related bugs in FastDeploy centralized inference * fix * refactor mtp hidden_states process * fix * add unittest & optimize kernel * remove useless code * fix
This commit is contained in:
@@ -1016,7 +1016,17 @@ std::vector<paddle::Tensor> EagleGetSelfHiddenStates(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& last_seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& step_idx);
|
||||
const paddle::Tensor& seq_lens_encoder);
|
||||
|
||||
std::vector<paddle::Tensor> EagleGatherHiddenStates(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token_output,
|
||||
const paddle::Tensor& cu_seqlens_q_output,
|
||||
const paddle::Tensor& real_output_token_num);
|
||||
|
||||
void MTPStepPaddle(
|
||||
const paddle::Tensor& base_model_stop_flags,
|
||||
@@ -1820,6 +1830,10 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
&EagleGetSelfHiddenStates,
|
||||
"eagle_get_self_hidden_states function");
|
||||
|
||||
m.def("eagle_gather_hidden_states",
|
||||
&EagleGatherHiddenStates,
|
||||
"eagle_gather_hidden_states function");
|
||||
|
||||
m.def("mtp_step_paddle", &MTPStepPaddle, "mtp_step_paddle function");
|
||||
|
||||
m.def("speculate_step_paddle",
|
||||
|
||||
@@ -42,16 +42,21 @@ __global__ void draft_model_update_kernel(const int64_t* inter_next_tokens,
|
||||
int64_t stop_flag_now_int = 0;
|
||||
|
||||
int tid = threadIdx.x;
|
||||
if (tid < bsz) {
|
||||
auto* draft_token_now = draft_tokens + tid * max_draft_token;
|
||||
auto* pre_ids_now = pre_ids + tid * pre_id_length;
|
||||
auto* base_model_draft_tokens_now =
|
||||
if (tid < bsz && seq_lens_this_time[tid] > 0) {
|
||||
int seq_len_this_time = seq_lens_this_time[tid];
|
||||
int seq_len_encoder = seq_lens_encoder[tid];
|
||||
int seq_len_decoder = seq_lens_decoder[tid];
|
||||
|
||||
int next_tokens_start_id = 0;
|
||||
for (int i = 0; i < tid; i++) {
|
||||
next_tokens_start_id += seq_lens_this_time[i] > 0 ? 1 : 0;
|
||||
}
|
||||
|
||||
int64_t* draft_token_now = draft_tokens + tid * max_draft_token;
|
||||
int64_t* pre_ids_now = pre_ids + tid * pre_id_length;
|
||||
int64_t* base_model_draft_tokens_now =
|
||||
base_model_draft_tokens + tid * max_base_model_draft_token;
|
||||
const int next_tokens_start_id = cu_seqlens_q_output[tid];
|
||||
auto* next_tokens_start = inter_next_tokens + next_tokens_start_id;
|
||||
auto seq_len_this_time = seq_lens_this_time[tid];
|
||||
auto seq_len_encoder = seq_lens_encoder[tid];
|
||||
auto seq_len_decoder = seq_lens_decoder[tid];
|
||||
const int64_t* next_tokens_start = inter_next_tokens + next_tokens_start_id;
|
||||
|
||||
// 1. update step_idx && seq_lens_dec
|
||||
if (!stop_flags[tid]) {
|
||||
@@ -72,8 +77,8 @@ __global__ void draft_model_update_kernel(const int64_t* inter_next_tokens,
|
||||
base_model_draft_tokens_now[substep + 1] = -1;
|
||||
} else {
|
||||
seq_lens_decoder[tid] += seq_len_this_time;
|
||||
token_this_time = next_tokens_start[seq_len_this_time - 1];
|
||||
draft_token_now[0] = next_tokens_start[seq_len_this_time - 1];
|
||||
token_this_time = next_tokens_start[0];
|
||||
draft_token_now[0] = token_this_time;
|
||||
base_model_draft_tokens_now[substep + 1] = token_this_time;
|
||||
step_idx[tid] += seq_len_this_time;
|
||||
pre_ids_now[step_idx[tid]] = token_this_time;
|
||||
|
||||
@@ -0,0 +1,273 @@
|
||||
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
// Fused kernel: block 0 computes position_map and output_token_num in parallel
|
||||
// (one thread per batch element), then all blocks synchronize via
|
||||
// cooperative_groups grid sync, and finally all threads perform the hidden
|
||||
// states gather in parallel.
|
||||
template <typename T, int VecSize>
|
||||
__global__ void EagleGatherHiddenStatesKernel(
|
||||
T* output_data,
|
||||
int* position_map,
|
||||
int* output_token_num,
|
||||
const T* input,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens_this_time,
|
||||
const int* seq_lens_decoder,
|
||||
const int* seq_lens_encoder,
|
||||
const int* batch_id_per_token_output,
|
||||
const int* cu_seqlens_q_output,
|
||||
const int dim_embed,
|
||||
const int64_t input_token_num,
|
||||
const int real_bsz) {
|
||||
cg::grid_group grid = cg::this_grid();
|
||||
|
||||
// Dynamic shared memory layout: [in_count|out_count|in_offsets|out_offsets]
|
||||
extern __shared__ int smem[];
|
||||
int* in_count = smem;
|
||||
int* out_count = smem + real_bsz;
|
||||
int* in_offsets = smem + 2 * real_bsz;
|
||||
int* out_offsets = smem + 3 * real_bsz;
|
||||
|
||||
// Phase 1: compute position_map (parallelized across threads in block 0)
|
||||
if (blockIdx.x == 0) {
|
||||
// Phase 1a: each thread computes counts for its batch elements
|
||||
for (int t = threadIdx.x; t < real_bsz; t += blockDim.x) {
|
||||
int cur_seq_len = seq_lens_this_time[t];
|
||||
// has seq in curent batch
|
||||
if (cur_seq_len > 0) {
|
||||
in_count[t] = cur_seq_len;
|
||||
out_count[t] = 1;
|
||||
} else {
|
||||
in_count[t] = 0;
|
||||
out_count[t] = 0;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Phase 1b: prefix sum (thread 0 computes exclusive prefix sums)
|
||||
if (threadIdx.x == 0) {
|
||||
int in_acc = 0, out_acc = 0;
|
||||
for (int i = 0; i < real_bsz; i++) {
|
||||
in_offsets[i] = in_acc;
|
||||
out_offsets[i] = out_acc;
|
||||
in_acc += in_count[i];
|
||||
out_acc += out_count[i];
|
||||
}
|
||||
output_token_num[0] = out_acc;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Phase 1c: each thread fills position_map for its batch elements
|
||||
for (int t = threadIdx.x; t < real_bsz; t += blockDim.x) {
|
||||
int in_off = in_offsets[t];
|
||||
int out_off = out_offsets[t];
|
||||
if (seq_lens_this_time[t] > 0) {
|
||||
// For gather: map input token to output position
|
||||
// Use last token of each sequence
|
||||
int last_token_idx = in_off + in_count[t] - 1;
|
||||
position_map[last_token_idx] = out_off;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 2: grid-wide sync to ensure position_map is ready
|
||||
grid.sync();
|
||||
|
||||
// Phase 3: gather hidden states in parallel
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT src_vec;
|
||||
|
||||
int elem_cnt = input_token_num * dim_embed;
|
||||
int global_thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (int elem_idx = global_thread_idx * VecSize; elem_idx < elem_cnt;
|
||||
elem_idx += blockDim.x * gridDim.x * VecSize) {
|
||||
int ori_token_idx = elem_idx / dim_embed;
|
||||
int token_idx = position_map[ori_token_idx];
|
||||
if (token_idx >= 0) {
|
||||
int offset = elem_idx % dim_embed;
|
||||
Load<T, VecSize>(input + ori_token_idx * dim_embed + offset, &src_vec);
|
||||
Store<T, VecSize>(src_vec, output_data + token_idx * dim_embed + offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <paddle::DataType D>
|
||||
std::vector<paddle::Tensor> DispatchDtype(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token_output,
|
||||
const paddle::Tensor& cu_seqlens_q_output,
|
||||
const paddle::Tensor& real_output_token_num) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
auto input_token_num = input.shape()[0];
|
||||
auto dim_embed = input.shape()[1];
|
||||
int real_bsz = seq_lens_this_time.shape()[0];
|
||||
|
||||
auto position_map = paddle::empty(
|
||||
{input_token_num}, seq_lens_this_time.dtype(), input.place());
|
||||
cudaMemsetAsync(position_map.data<int>(),
|
||||
0xFF,
|
||||
input_token_num * sizeof(int),
|
||||
input.stream());
|
||||
|
||||
// TODO(yaohuicong): not need this params in future
|
||||
auto output_token_num =
|
||||
paddle::empty({1}, seq_lens_this_time.dtype(), input.place());
|
||||
|
||||
// Pre-allocate output with max possible size (real_bsz)
|
||||
auto out = paddle::zeros({real_bsz, dim_embed}, input.dtype(), input.place());
|
||||
|
||||
constexpr int VecSize = 4;
|
||||
int elem_cnt = input_token_num * dim_embed;
|
||||
assert(elem_cnt % VecSize == 0);
|
||||
|
||||
// Grid size linearly related to real_bsz for cooperative launch efficiency
|
||||
// and CUDA graph capture friendliness
|
||||
constexpr int thread_per_block = 128;
|
||||
constexpr int DESIRED_BLOCKS_PER_BATCH = 4;
|
||||
int dynamic_smem_size = 4 * real_bsz * static_cast<int>(sizeof(int));
|
||||
|
||||
// Cooperative launch limit: use conservative smem upper bound for caching
|
||||
static const int max_grid_size = [&]() {
|
||||
int blocks_per_sm = 0;
|
||||
constexpr int smem_upper_bound = 4 * 512 * sizeof(int); // 8KB
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&blocks_per_sm,
|
||||
EagleGatherHiddenStatesKernel<DataType_, VecSize>,
|
||||
thread_per_block,
|
||||
smem_upper_bound);
|
||||
int dev = 0;
|
||||
cudaGetDevice(&dev);
|
||||
int sms = 0;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
|
||||
return blocks_per_sm * sms;
|
||||
}();
|
||||
|
||||
int blocks_per_batch =
|
||||
std::min(DESIRED_BLOCKS_PER_BATCH, max_grid_size / std::max(real_bsz, 1));
|
||||
blocks_per_batch = std::max(blocks_per_batch, 1);
|
||||
int grid_size = std::min(real_bsz * blocks_per_batch, max_grid_size);
|
||||
grid_size = std::max(grid_size, 1);
|
||||
|
||||
const DataType_* input_ptr =
|
||||
reinterpret_cast<const DataType_*>(input.data<data_t>());
|
||||
const int* cu_seqlens_q_ptr = cu_seqlens_q.data<int>();
|
||||
const int* seq_lens_this_time_ptr = seq_lens_this_time.data<int>();
|
||||
const int* seq_lens_decoder_ptr = seq_lens_decoder.data<int>();
|
||||
const int* seq_lens_encoder_ptr = seq_lens_encoder.data<int>();
|
||||
const int* batch_id_per_token_output_ptr =
|
||||
batch_id_per_token_output.data<int>();
|
||||
const int* cu_seqlens_q_output_ptr = cu_seqlens_q_output.data<int>();
|
||||
DataType_* output_data_ptr = reinterpret_cast<DataType_*>(out.data<data_t>());
|
||||
int* position_map_ptr = position_map.data<int>();
|
||||
int* output_token_num_ptr = output_token_num.data<int>();
|
||||
int dim_embed_int = static_cast<int>(dim_embed);
|
||||
int64_t input_token_num_int64 = input_token_num;
|
||||
|
||||
void* kernel_args[] = {&output_data_ptr,
|
||||
&position_map_ptr,
|
||||
&output_token_num_ptr,
|
||||
&input_ptr,
|
||||
&cu_seqlens_q_ptr,
|
||||
&seq_lens_this_time_ptr,
|
||||
&seq_lens_decoder_ptr,
|
||||
&seq_lens_encoder_ptr,
|
||||
&batch_id_per_token_output_ptr,
|
||||
&cu_seqlens_q_output_ptr,
|
||||
&dim_embed_int,
|
||||
&input_token_num_int64,
|
||||
&real_bsz};
|
||||
|
||||
cudaLaunchCooperativeKernel(
|
||||
(void*)EagleGatherHiddenStatesKernel<DataType_, VecSize>,
|
||||
dim3(grid_size),
|
||||
dim3(thread_per_block),
|
||||
kernel_args,
|
||||
dynamic_smem_size,
|
||||
input.stream());
|
||||
|
||||
// Return output and output_token_num
|
||||
return {out, output_token_num};
|
||||
}
|
||||
|
||||
// Wrapper function for PD_BUILD_STATIC_OP
|
||||
std::vector<paddle::Tensor> EagleGatherHiddenStates(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token_output,
|
||||
const paddle::Tensor& cu_seqlens_q_output,
|
||||
const paddle::Tensor& real_output_token_num) {
|
||||
switch (input.dtype()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
return DispatchDtype<paddle::DataType::BFLOAT16>(
|
||||
input,
|
||||
cu_seqlens_q,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token_output,
|
||||
cu_seqlens_q_output,
|
||||
real_output_token_num);
|
||||
case paddle::DataType::FLOAT16:
|
||||
return DispatchDtype<paddle::DataType::FLOAT16>(input,
|
||||
cu_seqlens_q,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token_output,
|
||||
cu_seqlens_q_output,
|
||||
real_output_token_num);
|
||||
case paddle::DataType::FLOAT32:
|
||||
return DispatchDtype<paddle::DataType::FLOAT32>(input,
|
||||
cu_seqlens_q,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token_output,
|
||||
cu_seqlens_q_output,
|
||||
real_output_token_num);
|
||||
default:
|
||||
PD_THROW("eagle_gather_hidden_states: NOT supported data type.");
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(eagle_gather_hidden_states)
|
||||
.Inputs({"input",
|
||||
"cu_seqlens_q",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_encoder",
|
||||
"batch_id_per_token_output",
|
||||
"cu_seqlens_q_output",
|
||||
"real_output_token_num"})
|
||||
.Outputs({"out", "output_token_num"})
|
||||
.SetKernelFn(PD_KERNEL(EagleGatherHiddenStates));
|
||||
@@ -12,104 +12,117 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
// #define DEBUG_EAGLE_KERNEL
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
__global__ void ComputeOrderKernel(const int* seq_lens_this_time,
|
||||
const int* seq_lens_encoder,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* accept_nums,
|
||||
int* position_map,
|
||||
int* output_token_num,
|
||||
const int bsz,
|
||||
const int actual_draft_token_num,
|
||||
const int input_token_num) {
|
||||
int in_offset = 0; // input_offset(long)
|
||||
int out_offset = 0; // output_offset(short)
|
||||
if (threadIdx.x == 0) {
|
||||
for (int i = 0; i < bsz; ++i) {
|
||||
int cur_base_model_seq_lens_this_time = base_model_seq_lens_this_time[i];
|
||||
int cur_base_model_seq_lens_encoder = base_model_seq_lens_encoder[i];
|
||||
int cur_seq_lens_this_time = seq_lens_this_time[i];
|
||||
int accept_num = accept_nums[i];
|
||||
int cur_seq_lens_encoder = seq_lens_encoder[i];
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf(
|
||||
"batch %d: cur_base_model_seq_lens_this_time%d. "
|
||||
"cur_seq_lens_this_time%d, accept_num %d\n",
|
||||
i,
|
||||
cur_base_model_seq_lens_this_time,
|
||||
cur_seq_lens_this_time,
|
||||
accept_num);
|
||||
#endif
|
||||
// Fused kernel: block 0 computes position_map and output_token_num in parallel
|
||||
// (one thread per batch element), then all blocks synchronize via
|
||||
// cooperative_groups grid sync, and finally all threads perform the hidden
|
||||
// states rebuild in parallel.
|
||||
template <typename T, int VecSize>
|
||||
__global__ void rebuildHiddenStatesKernel(
|
||||
const T* input,
|
||||
const int* seq_lens_this_time,
|
||||
const int* seq_lens_encoder,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* accept_nums,
|
||||
int* position_map,
|
||||
int* output_token_num,
|
||||
T* out,
|
||||
const int bsz,
|
||||
const int dim_embed,
|
||||
const int input_token_num) {
|
||||
cg::grid_group grid = cg::this_grid();
|
||||
|
||||
// Dynamic shared memory layout: [in_count|out_count|in_offsets|out_offsets]
|
||||
extern __shared__ int smem[];
|
||||
int* in_count = smem;
|
||||
int* out_count = smem + bsz;
|
||||
int* in_offsets = smem + 2 * bsz;
|
||||
int* out_offsets = smem + 3 * bsz;
|
||||
|
||||
// Phase 1: compute position_map (parallelized across threads in block 0)
|
||||
if (blockIdx.x == 0) {
|
||||
// Phase 1a: each thread computes counts for its batch elements
|
||||
for (int t = threadIdx.x; t < bsz; t += blockDim.x) {
|
||||
int cur_base_model_seq_lens_this_time = base_model_seq_lens_this_time[t];
|
||||
int cur_seq_lens_this_time = seq_lens_this_time[t];
|
||||
int accept_num = accept_nums[t];
|
||||
int cur_seq_lens_encoder = seq_lens_encoder[t];
|
||||
// 1. eagle encoder. Base step=1
|
||||
if (cur_seq_lens_encoder > 0) {
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d: cur_seq_lens_encoder > 0 \n", i);
|
||||
#endif
|
||||
for (int j = 0; j < cur_seq_lens_encoder; j++) {
|
||||
position_map[in_offset++] = out_offset++;
|
||||
}
|
||||
in_count[t] = cur_seq_lens_encoder;
|
||||
out_count[t] = cur_seq_lens_encoder;
|
||||
// 2. Base model stop at last verify-step.
|
||||
} else if (cur_base_model_seq_lens_this_time != 0 &&
|
||||
cur_seq_lens_this_time == 0) {
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d: base=0. draft !=0 \n", i);
|
||||
#endif
|
||||
|
||||
in_offset += cur_base_model_seq_lens_this_time;
|
||||
// 4. stopped
|
||||
in_count[t] = cur_base_model_seq_lens_this_time;
|
||||
out_count[t] = 0;
|
||||
// 3. stopped
|
||||
} else if (cur_base_model_seq_lens_this_time == 0 &&
|
||||
cur_seq_lens_this_time == 0) /* end */ {
|
||||
cur_seq_lens_this_time == 0) {
|
||||
in_count[t] = 0;
|
||||
out_count[t] = 0;
|
||||
} else {
|
||||
for (int i = 0; i < accept_num; i++) {
|
||||
position_map[in_offset++] = out_offset++;
|
||||
}
|
||||
in_offset += cur_base_model_seq_lens_this_time - accept_num;
|
||||
// (liuzichang): Temporary Reserved for debug
|
||||
// if (accept_num <= actual_draft_token_num) /*Accept partial
|
||||
// draft tokens*/ {
|
||||
// #ifdef DEBUG_EAGLE_KERNEL
|
||||
// printf("batch %d: accept_num <= actual_draft_token_num \n",
|
||||
// i);
|
||||
// #endif
|
||||
// position_map[in_offset + accept_num - 1] = out_offset++;
|
||||
// in_offset += cur_base_model_seq_lens_this_time;
|
||||
// } else /*Accept all draft tokens*/ {
|
||||
// #ifdef DEBUG_EAGLE_KERNEL
|
||||
// printf("batch %d: accept_num > actual_draft_token_num \n",
|
||||
// i);
|
||||
// #endif
|
||||
// position_map[in_offset + accept_num - 2] = out_offset++;
|
||||
// position_map[in_offset + accept_num - 1] = out_offset++;
|
||||
// in_offset += cur_base_model_seq_lens_this_time;
|
||||
// }
|
||||
in_count[t] = cur_base_model_seq_lens_this_time;
|
||||
out_count[t] = accept_num;
|
||||
}
|
||||
}
|
||||
output_token_num[0] = out_offset;
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("position map output_token_num%d:\n", output_token_num[0]);
|
||||
for (int i = 0; i < output_token_num[0]; i++) {
|
||||
printf("%d ", position_map[i]);
|
||||
}
|
||||
printf("\n");
|
||||
#endif
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
template <typename T, int VecSize>
|
||||
__global__ void rebuildHiddenStatesKernel(const T* input,
|
||||
const int* position_map,
|
||||
T* out,
|
||||
const int dim_embed,
|
||||
const int elem_cnt) {
|
||||
// Phase 1b: prefix sum (thread 0 computes exclusive prefix sums)
|
||||
if (threadIdx.x == 0) {
|
||||
int in_acc = 0, out_acc = 0;
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
in_offsets[i] = in_acc;
|
||||
out_offsets[i] = out_acc;
|
||||
in_acc += in_count[i];
|
||||
out_acc += out_count[i];
|
||||
}
|
||||
output_token_num[0] = out_acc;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Phase 1c: each thread fills position_map for its batch elements
|
||||
for (int t = threadIdx.x; t < bsz; t += blockDim.x) {
|
||||
int in_off = in_offsets[t];
|
||||
int out_off = out_offsets[t];
|
||||
int cur_seq_lens_encoder = seq_lens_encoder[t];
|
||||
int cur_base_model_seq_lens_this_time = base_model_seq_lens_this_time[t];
|
||||
int cur_seq_lens_this_time = seq_lens_this_time[t];
|
||||
int accept_num = accept_nums[t];
|
||||
// 1. eagle encoder. Base step=1
|
||||
if (cur_seq_lens_encoder > 0) {
|
||||
for (int j = 0; j < cur_seq_lens_encoder; j++) {
|
||||
position_map[in_off + j] = out_off + j;
|
||||
}
|
||||
// 2. Base model stop at last verify-step: no writes needed
|
||||
// 3. stopped: no writes needed
|
||||
} else if (cur_base_model_seq_lens_this_time != 0 &&
|
||||
cur_seq_lens_this_time != 0) {
|
||||
// 4. normal decode: copy accepted tokens
|
||||
for (int j = 0; j < accept_num; j++) {
|
||||
position_map[in_off + j] = out_off + j;
|
||||
}
|
||||
}
|
||||
// Branches 2 & 3: position_map stays -1 from memset
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 2: grid-wide sync to ensure position_map is ready
|
||||
grid.sync();
|
||||
|
||||
// Phase 3: rebuild hidden states in parallel
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT src_vec;
|
||||
|
||||
int elem_cnt = input_token_num * dim_embed;
|
||||
int global_thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (int elem_idx = global_thread_idx * VecSize; elem_idx < elem_cnt;
|
||||
elem_idx += blockDim.x * gridDim.x * VecSize) {
|
||||
@@ -117,8 +130,6 @@ __global__ void rebuildHiddenStatesKernel(const T* input,
|
||||
int token_idx = position_map[ori_token_idx];
|
||||
if (token_idx >= 0) {
|
||||
int offset = elem_idx % dim_embed;
|
||||
if (token_idx == 0) {
|
||||
}
|
||||
Load<T, VecSize>(input + ori_token_idx * dim_embed + offset, &src_vec);
|
||||
Store<T, VecSize>(src_vec, out + token_idx * dim_embed + offset);
|
||||
}
|
||||
@@ -141,60 +152,92 @@ std::vector<paddle::Tensor> DispatchDtype(
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
auto input_token_num = input.shape()[0];
|
||||
|
||||
// auto output_token_num = padding_offset.shape()[0];
|
||||
auto dim_embed = input.shape()[1];
|
||||
|
||||
int bsz = seq_lens_this_time.shape()[0];
|
||||
|
||||
auto position_map = paddle::empty(
|
||||
{input_token_num}, seq_lens_this_time.dtype(), input.place());
|
||||
cudaMemsetAsync(position_map.data<int>(),
|
||||
0xFF,
|
||||
input_token_num * sizeof(seq_lens_this_time.dtype()),
|
||||
seq_lens_this_time.stream());
|
||||
input_token_num * sizeof(int),
|
||||
input.stream());
|
||||
|
||||
auto output_token_num = paddle::empty(
|
||||
{1}, seq_lens_this_time.dtype(), seq_lens_this_time.place());
|
||||
ComputeOrderKernel<<<1, 1>>>(seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
base_model_seq_lens_this_time.data<int>(),
|
||||
base_model_seq_lens_encoder.data<int>(),
|
||||
accept_nums.data<int>(),
|
||||
position_map.data<int>(),
|
||||
output_token_num.data<int>(),
|
||||
bsz,
|
||||
actual_draft_token_num,
|
||||
input_token_num);
|
||||
auto output_token_num =
|
||||
paddle::empty({1}, seq_lens_this_time.dtype(), input.place());
|
||||
|
||||
int output_token_num_cpu =
|
||||
output_token_num.copy_to(paddle::CPUPlace(), false).data<int>()[0];
|
||||
|
||||
auto out = paddle::empty(
|
||||
{output_token_num_cpu, dim_embed}, input.dtype(), input.place());
|
||||
// Pre-allocate output with max possible size (input_token_num)
|
||||
auto out =
|
||||
paddle::empty({input_token_num, dim_embed}, input.dtype(), input.place());
|
||||
|
||||
constexpr int packSize = VEC_16B / (sizeof(DataType_));
|
||||
int elem_cnt = input_token_num * dim_embed;
|
||||
|
||||
assert(elem_cnt % packSize == 0);
|
||||
|
||||
int pack_num = elem_cnt / packSize;
|
||||
|
||||
int grid_size = 1;
|
||||
|
||||
GetNumBlocks(pack_num, &grid_size);
|
||||
|
||||
// Grid size linearly related to bsz for cooperative launch efficiency
|
||||
// and CUDA graph capture friendliness
|
||||
constexpr int thread_per_block = 128;
|
||||
constexpr int DESIRED_BLOCKS_PER_BATCH = 4;
|
||||
int dynamic_smem_size = 4 * bsz * static_cast<int>(sizeof(int));
|
||||
|
||||
rebuildHiddenStatesKernel<DataType_, packSize>
|
||||
<<<grid_size, thread_per_block>>>(
|
||||
reinterpret_cast<const DataType_*>(input.data<data_t>()),
|
||||
position_map.data<int>(),
|
||||
reinterpret_cast<DataType_*>(out.data<data_t>()),
|
||||
dim_embed,
|
||||
elem_cnt);
|
||||
// Cooperative launch limit: use conservative smem upper bound for caching
|
||||
static const int max_grid_size = [&]() {
|
||||
int blocks_per_sm = 0;
|
||||
constexpr int smem_upper_bound = 4 * 512 * sizeof(int); // 8KB
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&blocks_per_sm,
|
||||
rebuildHiddenStatesKernel<DataType_, packSize>,
|
||||
thread_per_block,
|
||||
smem_upper_bound);
|
||||
int dev = 0;
|
||||
cudaGetDevice(&dev);
|
||||
int sms = 0;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
|
||||
return blocks_per_sm * sms;
|
||||
}();
|
||||
|
||||
return {out};
|
||||
int blocks_per_batch =
|
||||
std::min(DESIRED_BLOCKS_PER_BATCH, max_grid_size / std::max(bsz, 1));
|
||||
blocks_per_batch = std::max(blocks_per_batch, 1);
|
||||
int grid_size = std::min(bsz * blocks_per_batch, max_grid_size);
|
||||
grid_size = std::max(grid_size, 1);
|
||||
|
||||
const DataType_* input_ptr =
|
||||
reinterpret_cast<const DataType_*>(input.data<data_t>());
|
||||
const int* seq_lens_this_time_ptr = seq_lens_this_time.data<int>();
|
||||
const int* seq_lens_encoder_ptr = seq_lens_encoder.data<int>();
|
||||
const int* base_model_seq_lens_this_time_ptr =
|
||||
base_model_seq_lens_this_time.data<int>();
|
||||
const int* base_model_seq_lens_encoder_ptr =
|
||||
base_model_seq_lens_encoder.data<int>();
|
||||
const int* accept_nums_ptr = accept_nums.data<int>();
|
||||
int* position_map_ptr = position_map.data<int>();
|
||||
int* output_token_num_ptr = output_token_num.data<int>();
|
||||
DataType_* out_ptr = reinterpret_cast<DataType_*>(out.data<data_t>());
|
||||
int dim_embed_int = static_cast<int>(dim_embed);
|
||||
int input_token_num_int = static_cast<int>(input_token_num);
|
||||
|
||||
void* kernel_args[] = {&input_ptr,
|
||||
&seq_lens_this_time_ptr,
|
||||
&seq_lens_encoder_ptr,
|
||||
&base_model_seq_lens_this_time_ptr,
|
||||
&base_model_seq_lens_encoder_ptr,
|
||||
&accept_nums_ptr,
|
||||
&position_map_ptr,
|
||||
&output_token_num_ptr,
|
||||
&out_ptr,
|
||||
&bsz,
|
||||
&dim_embed_int,
|
||||
&input_token_num_int};
|
||||
|
||||
cudaLaunchCooperativeKernel(
|
||||
(void*)rebuildHiddenStatesKernel<DataType_, packSize>,
|
||||
dim3(grid_size),
|
||||
dim3(thread_per_block),
|
||||
kernel_args,
|
||||
dynamic_smem_size,
|
||||
input.stream());
|
||||
|
||||
return {out, output_token_num};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> EagleGetHiddenStates(
|
||||
@@ -248,5 +291,5 @@ PD_BUILD_STATIC_OP(eagle_get_hidden_states)
|
||||
"base_model_seq_lens_this_time",
|
||||
"base_model_seq_lens_encoder"})
|
||||
.Attrs({"actual_draft_token_num: int"})
|
||||
.Outputs({"out"})
|
||||
.Outputs({"out", "output_token_num"})
|
||||
.SetKernelFn(PD_KERNEL(EagleGetHiddenStates));
|
||||
|
||||
+172
-117
@@ -1,4 +1,3 @@
|
||||
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -13,174 +12,230 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
// #define DEBUG_EAGLE_KERNEL
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
__global__ void computeOrderKernel(const int* last_seq_lens_this_time,
|
||||
const int* seq_lens_this_time,
|
||||
const int64_t* step_idx,
|
||||
int* src_map,
|
||||
int* output_token_num,
|
||||
int bsz) {
|
||||
int in_offset = 0;
|
||||
int out_offset = 0;
|
||||
if (threadIdx.x == 0) {
|
||||
for (int i = 0; i < bsz; ++i) {
|
||||
int cur_seq_lens_this_time = seq_lens_this_time[i];
|
||||
int cur_last_seq_lens_this_time = last_seq_lens_this_time[i];
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf(
|
||||
"batch %d: cur_seq_lens_this_time:%d. "
|
||||
"cur_last_seq_lens_this_time:%d\n",
|
||||
i,
|
||||
cur_seq_lens_this_time,
|
||||
cur_last_seq_lens_this_time);
|
||||
#endif
|
||||
// Fused kernel: block 0 computes position_map and output_token_num in parallel
|
||||
// (one thread per batch element), then all blocks synchronize via
|
||||
// cooperative_groups grid sync, and finally all threads perform the hidden
|
||||
// states rebuild in parallel.
|
||||
template <typename T, int VecSize>
|
||||
__global__ void rebuildSelfHiddenStatesKernel(
|
||||
const T* input,
|
||||
const int* last_seq_lens_this_time,
|
||||
const int* seq_lens_this_time,
|
||||
const int* seq_lens_encoder,
|
||||
int* position_map,
|
||||
int* output_token_num,
|
||||
T* out,
|
||||
const int bsz,
|
||||
const int dim_embed,
|
||||
const int input_token_num) {
|
||||
cg::grid_group grid = cg::this_grid();
|
||||
|
||||
// Dynamic shared memory layout: [in_count|out_count|in_offsets|out_offsets]
|
||||
extern __shared__ int smem[];
|
||||
int* in_count = smem;
|
||||
int* out_count = smem + bsz;
|
||||
int* in_offsets = smem + 2 * bsz;
|
||||
int* out_offsets = smem + 3 * bsz;
|
||||
|
||||
// Phase 1: compute position_map (parallelized across threads in block 0)
|
||||
if (blockIdx.x == 0) {
|
||||
// Phase 1a: each thread computes counts for its batch elements
|
||||
for (int t = threadIdx.x; t < bsz; t += blockDim.x) {
|
||||
int cur_seq_lens_this_time = seq_lens_this_time[t];
|
||||
int cur_last_seq_lens_this_time = last_seq_lens_this_time[t];
|
||||
// 1. encoder
|
||||
if (step_idx[i] == 1 && cur_seq_lens_this_time > 0) {
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d last_step is encoder \n", i);
|
||||
#endif
|
||||
in_offset += 1;
|
||||
src_map[out_offset++] = in_offset - 1;
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d finish. src_map[%d]=%d \n",
|
||||
i,
|
||||
out_offset - 1,
|
||||
in_offset - 1);
|
||||
#endif
|
||||
if (seq_lens_encoder[t] > 0 && cur_seq_lens_this_time > 0) {
|
||||
in_count[t] = 1;
|
||||
out_count[t] = 1;
|
||||
// 2. decoder
|
||||
} else if (cur_seq_lens_this_time > 0) /* =1 */ {
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d is decoder\n", i);
|
||||
#endif
|
||||
in_offset += cur_last_seq_lens_this_time;
|
||||
src_map[out_offset++] = in_offset - 1;
|
||||
} else if (cur_seq_lens_this_time > 0) {
|
||||
in_count[t] = cur_last_seq_lens_this_time;
|
||||
out_count[t] = 1;
|
||||
// 3. stop
|
||||
} else {
|
||||
// first token end
|
||||
if (step_idx[i] == 1) {
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d finished in first token \n", i);
|
||||
#endif
|
||||
in_offset += cur_last_seq_lens_this_time > 0 ? 1 : 0;
|
||||
// normal end
|
||||
if (seq_lens_encoder[t] > 0) {
|
||||
in_count[t] = cur_last_seq_lens_this_time > 0 ? 1 : 0;
|
||||
} else {
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d finished in non-first token \n", i);
|
||||
#endif
|
||||
in_offset += cur_last_seq_lens_this_time;
|
||||
in_count[t] = cur_last_seq_lens_this_time;
|
||||
}
|
||||
out_count[t] = 0;
|
||||
}
|
||||
}
|
||||
output_token_num[0] = out_offset;
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("position map output_token_num%d:\n", output_token_num[0]);
|
||||
for (int i = 0; i < output_token_num[0]; i++) {
|
||||
printf("%d ", src_map[i]);
|
||||
}
|
||||
printf("\n");
|
||||
#endif
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
template <typename T, int PackSize>
|
||||
__global__ void rebuildSelfHiddenStatesKernel(
|
||||
const T* input, int* src_map, T* output, int dim_embed, int elem_cnt) {
|
||||
using LoadT = AlignedVector<T, PackSize>;
|
||||
// Phase 1b: prefix sum (thread 0 computes exclusive prefix sums)
|
||||
if (threadIdx.x == 0) {
|
||||
int in_acc = 0, out_acc = 0;
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
in_offsets[i] = in_acc;
|
||||
out_offsets[i] = out_acc;
|
||||
in_acc += in_count[i];
|
||||
out_acc += out_count[i];
|
||||
}
|
||||
output_token_num[0] = out_acc;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Phase 1c: each thread fills position_map for its batch elements
|
||||
for (int t = threadIdx.x; t < bsz; t += blockDim.x) {
|
||||
int in_off = in_offsets[t];
|
||||
int out_off = out_offsets[t];
|
||||
int cur_seq_lens_this_time = seq_lens_this_time[t];
|
||||
int cur_last_seq_lens_this_time = last_seq_lens_this_time[t];
|
||||
// 1. encoder
|
||||
if (seq_lens_encoder[t] > 0 && cur_seq_lens_this_time > 0) {
|
||||
position_map[in_off] = out_off;
|
||||
// 2. decoder
|
||||
} else if (cur_seq_lens_this_time > 0) {
|
||||
position_map[in_off + cur_last_seq_lens_this_time - 1] = out_off;
|
||||
}
|
||||
// 3. stop: no writes needed (position_map stays -1 from memset)
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 2: grid-wide sync to ensure position_map is ready
|
||||
grid.sync();
|
||||
|
||||
// Phase 3: rebuild hidden states in parallel
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT src_vec;
|
||||
|
||||
int global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
for (int elem_id = global_thread_idx * PackSize; elem_id < elem_cnt;
|
||||
elem_id += blockDim.x * gridDim.x * PackSize) {
|
||||
int output_token_idx = elem_id / dim_embed;
|
||||
int input_token_idx = src_map[output_token_idx];
|
||||
int offset = elem_id % dim_embed;
|
||||
Load<T, PackSize>(input + input_token_idx * dim_embed + offset, &src_vec);
|
||||
Store<T, PackSize>(src_vec, output + output_token_idx * dim_embed + offset);
|
||||
int elem_cnt = input_token_num * dim_embed;
|
||||
int global_thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (int elem_idx = global_thread_idx * VecSize; elem_idx < elem_cnt;
|
||||
elem_idx += blockDim.x * gridDim.x * VecSize) {
|
||||
int ori_token_idx = elem_idx / dim_embed;
|
||||
int token_idx = position_map[ori_token_idx];
|
||||
if (token_idx >= 0) {
|
||||
int offset = elem_idx % dim_embed;
|
||||
Load<T, VecSize>(input + ori_token_idx * dim_embed + offset, &src_vec);
|
||||
Store<T, VecSize>(src_vec, out + token_idx * dim_embed + offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <paddle::DataType D>
|
||||
std::vector<paddle::Tensor> DispatchDtype(
|
||||
const paddle::Tensor input,
|
||||
const paddle::Tensor last_seq_lens_this_time,
|
||||
const paddle::Tensor seq_lens_this_time,
|
||||
const paddle::Tensor step_idx) {
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& last_seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
int input_token_num = input.shape()[0];
|
||||
int dim_embed = input.shape()[1];
|
||||
auto input_token_num = input.shape()[0];
|
||||
auto dim_embed = input.shape()[1];
|
||||
int bsz = seq_lens_this_time.shape()[0];
|
||||
auto src_map = paddle::full({input_token_num},
|
||||
-1,
|
||||
seq_lens_this_time.dtype(),
|
||||
seq_lens_this_time.place());
|
||||
auto output_token_num = paddle::full(
|
||||
{1}, 0, seq_lens_this_time.dtype(), seq_lens_this_time.place());
|
||||
|
||||
computeOrderKernel<<<1, 1, 0, seq_lens_this_time.stream()>>>(
|
||||
last_seq_lens_this_time.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
step_idx.data<int64_t>(),
|
||||
src_map.data<int>(),
|
||||
output_token_num.data<int>(),
|
||||
bsz);
|
||||
auto position_map = paddle::empty(
|
||||
{input_token_num}, seq_lens_this_time.dtype(), input.place());
|
||||
cudaMemsetAsync(position_map.data<int>(),
|
||||
0xFF,
|
||||
input_token_num * sizeof(int),
|
||||
input.stream());
|
||||
|
||||
int output_token_num_cpu =
|
||||
output_token_num.copy_to(paddle::CPUPlace(), false).data<int>()[0];
|
||||
auto output_token_num =
|
||||
paddle::empty({1}, seq_lens_this_time.dtype(), input.place());
|
||||
|
||||
auto out = paddle::full(
|
||||
{output_token_num_cpu, dim_embed}, -1, input.type(), input.place());
|
||||
// Pre-allocate output with max possible size (input_token_num)
|
||||
auto out =
|
||||
paddle::empty({input_token_num, dim_embed}, input.dtype(), input.place());
|
||||
|
||||
constexpr int packSize = VEC_16B / (sizeof(DataType_));
|
||||
int elem_cnt = output_token_num_cpu * dim_embed;
|
||||
// printf("output_token_num: %d, dim_embed: %d, cnt: %d. packSize: %d\n",
|
||||
// output_token_num_cpu, dim_embed,elem_cnt, packSize);
|
||||
int elem_cnt = input_token_num * dim_embed;
|
||||
assert(elem_cnt % packSize == 0);
|
||||
|
||||
int pack_num = elem_cnt / packSize;
|
||||
// Grid size linearly related to bsz for cooperative launch efficiency
|
||||
// and CUDA graph capture friendliness
|
||||
constexpr int thread_per_block = 128;
|
||||
constexpr int DESIRED_BLOCKS_PER_BATCH = 4;
|
||||
int dynamic_smem_size = 4 * bsz * static_cast<int>(sizeof(int));
|
||||
|
||||
int grid_size = 1;
|
||||
// Cooperative launch limit: use conservative smem upper bound for caching
|
||||
static const int max_grid_size = [&]() {
|
||||
int blocks_per_sm = 0;
|
||||
constexpr int smem_upper_bound = 4 * 512 * sizeof(int); // 8KB
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&blocks_per_sm,
|
||||
rebuildSelfHiddenStatesKernel<DataType_, packSize>,
|
||||
thread_per_block,
|
||||
smem_upper_bound);
|
||||
int dev = 0;
|
||||
cudaGetDevice(&dev);
|
||||
int sms = 0;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
|
||||
return blocks_per_sm * sms;
|
||||
}();
|
||||
|
||||
GetNumBlocks(pack_num, &grid_size);
|
||||
int blocks_per_batch =
|
||||
std::min(DESIRED_BLOCKS_PER_BATCH, max_grid_size / std::max(bsz, 1));
|
||||
blocks_per_batch = std::max(blocks_per_batch, 1);
|
||||
int grid_size = std::min(bsz * blocks_per_batch, max_grid_size);
|
||||
grid_size = std::max(grid_size, 1);
|
||||
|
||||
constexpr int threadPerBlock = 128;
|
||||
const DataType_* input_ptr =
|
||||
reinterpret_cast<const DataType_*>(input.data<data_t>());
|
||||
const int* last_seq_lens_this_time_ptr = last_seq_lens_this_time.data<int>();
|
||||
const int* seq_lens_this_time_ptr = seq_lens_this_time.data<int>();
|
||||
const int* seq_lens_encoder_ptr = seq_lens_encoder.data<int>();
|
||||
int* position_map_ptr = position_map.data<int>();
|
||||
int* output_token_num_ptr = output_token_num.data<int>();
|
||||
DataType_* out_ptr = reinterpret_cast<DataType_*>(out.data<data_t>());
|
||||
int dim_embed_int = static_cast<int>(dim_embed);
|
||||
int input_token_num_int = static_cast<int>(input_token_num);
|
||||
|
||||
rebuildSelfHiddenStatesKernel<DataType_, packSize>
|
||||
<<<grid_size, threadPerBlock, 0, input.stream()>>>(
|
||||
reinterpret_cast<const DataType_*>(input.data<data_t>()),
|
||||
src_map.data<int>(),
|
||||
reinterpret_cast<DataType_*>(out.data<data_t>()),
|
||||
dim_embed,
|
||||
elem_cnt);
|
||||
void* kernel_args[] = {&input_ptr,
|
||||
&last_seq_lens_this_time_ptr,
|
||||
&seq_lens_this_time_ptr,
|
||||
&seq_lens_encoder_ptr,
|
||||
&position_map_ptr,
|
||||
&output_token_num_ptr,
|
||||
&out_ptr,
|
||||
&bsz,
|
||||
&dim_embed_int,
|
||||
&input_token_num_int};
|
||||
|
||||
return {out};
|
||||
cudaLaunchCooperativeKernel(
|
||||
(void*)rebuildSelfHiddenStatesKernel<DataType_, packSize>,
|
||||
dim3(grid_size),
|
||||
dim3(thread_per_block),
|
||||
kernel_args,
|
||||
dynamic_smem_size,
|
||||
input.stream());
|
||||
|
||||
return {out, output_token_num};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> EagleGetSelfHiddenStates(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& last_seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& step_idx) {
|
||||
const paddle::Tensor& seq_lens_encoder) {
|
||||
switch (input.dtype()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
return DispatchDtype<paddle::DataType::BFLOAT16>(
|
||||
input, last_seq_lens_this_time, seq_lens_this_time, step_idx);
|
||||
input, last_seq_lens_this_time, seq_lens_this_time, seq_lens_encoder);
|
||||
case paddle::DataType::FLOAT16:
|
||||
return DispatchDtype<paddle::DataType::FLOAT16>(
|
||||
input, last_seq_lens_this_time, seq_lens_this_time, step_idx);
|
||||
input, last_seq_lens_this_time, seq_lens_this_time, seq_lens_encoder);
|
||||
default:
|
||||
PD_THROW("Not support this data type");
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(eagle_get_self_hidden_states)
|
||||
.Inputs(
|
||||
{"input", "last_seq_lens_this_time", "seq_lens_this_time", "step_idx"})
|
||||
.Outputs({"out"})
|
||||
.Inputs({"input",
|
||||
"last_seq_lens_this_time",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_encoder"})
|
||||
.Outputs({"out", "output_token_num"})
|
||||
.SetKernelFn(PD_KERNEL(EagleGetSelfHiddenStates));
|
||||
|
||||
Reference in New Issue
Block a user