remove speculate_get_padding_offset op (#6308)

This commit is contained in:
周周周
2026-02-03 15:18:12 +08:00
committed by GitHub
parent 39dc4b0c2e
commit 8277b95fa6
7 changed files with 44 additions and 324 deletions
+6 -17
View File
@@ -407,9 +407,12 @@ void GetBlockShapeAndSplitKVBlock(
const int group_size,
const int block_size);
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor& input_ids,
const paddle::Tensor& seq_len,
const int64_t token_num_cpu);
std::vector<paddle::Tensor> GetPaddingOffset(
const paddle::Tensor& input_ids,
const paddle::Tensor& seq_len,
const paddle::optional<paddle::Tensor>& draft_tokens,
const paddle::optional<paddle::Tensor>& seq_lens_encoder,
const int64_t token_num_cpu);
void SetValueByFlagsAndIdx(const paddle::Tensor& pre_ids_all,
const paddle::Tensor& input_ids,
@@ -739,15 +742,6 @@ void free_shared_buffer(int64_t buffer);
void clear_ipc_handles(int64_t _fa);
// speculative decoding Kernel
std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
const paddle::Tensor& input_ids,
const paddle::Tensor& draft_tokens,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& seq_len,
const paddle::Tensor& seq_lens_encoder,
const int64_t token_num_cpu);
std::vector<paddle::Tensor> SpeculateGetSeqLensOutput(
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
@@ -1596,11 +1590,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
&get_graph_buffer_ipc_meta,
"get_graph_buffer_ipc_meta");
// speculative decoding Kernel
m.def("speculate_get_padding_offset",
&SpeculateGetPaddingOffset,
"speculate_get_padding_offset function");
m.def("speculate_get_seq_lens_output",
&SpeculateGetSeqLensOutput,
"speculate_get_seq_lens_output function");
+34 -8
View File
@@ -25,7 +25,10 @@ __global__ void PrefixSumKernel(int64_t *ids_remove_padding,
int *cu_seqlens_k,
const int64_t *input_data,
const int *seq_lens,
const int max_seq_len) {
const int max_seq_len,
const int64_t *draft_tokens,
const int *seq_lens_encoder,
const int max_draft_tokens_per_batch) {
const int bi = blockIdx.x;
const int tid = threadIdx.x;
#ifdef PADDLE_WITH_COREX
@@ -62,15 +65,26 @@ __global__ void PrefixSumKernel(int64_t *ids_remove_padding,
for (int i = tid; i < seq_lens[bi]; i += blockDim.x) {
const int tgt_seq_id = cum_seq_len - seq_lens[bi] + i;
const int src_seq_id = bi * max_seq_len + i;
ids_remove_padding[tgt_seq_id] = input_data[src_seq_id];
if (max_draft_tokens_per_batch > 0 && seq_lens_encoder[bi] <= 0) {
// speculative decoding
const int src_seq_id = bi * max_draft_tokens_per_batch + i;
ids_remove_padding[tgt_seq_id] = draft_tokens[src_seq_id];
} else {
// Non-speculative decoding
const int src_seq_id = bi * max_seq_len + i;
ids_remove_padding[tgt_seq_id] = input_data[src_seq_id];
}
batch_id_per_token[tgt_seq_id] = bi;
}
}
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
const paddle::Tensor &seq_len,
const int64_t cpu_token_num) {
std::vector<paddle::Tensor> GetPaddingOffset(
const paddle::Tensor &input_ids,
const paddle::Tensor &seq_len,
const paddle::optional<paddle::Tensor> &draft_tokens,
const paddle::optional<paddle::Tensor> &seq_lens_encoder,
const int64_t cpu_token_num) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto dev_ctx = static_cast<const phi::CustomContext *>(
paddle::experimental::DeviceContextPool::Instance().Get(
@@ -98,6 +112,12 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
int blockSize =
min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128);
#endif
int max_draft_tokens_per_batch = -1;
if (draft_tokens) {
max_draft_tokens_per_batch = draft_tokens.get().shape()[1];
}
PrefixSumKernel<<<bsz, blockSize, 0, cu_stream>>>(
x_remove_padding.data<int64_t>(),
batch_id_per_token.data<int>(),
@@ -105,7 +125,10 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
cu_seqlens_k.data<int>(),
input_ids.data<int64_t>(),
seq_len.data<int>(),
max_seq_len);
max_seq_len,
draft_tokens ? draft_tokens.get().data<int64_t>() : nullptr,
seq_lens_encoder ? seq_lens_encoder.get().data<int32_t>() : nullptr,
max_draft_tokens_per_batch);
return {x_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k};
}
@@ -127,7 +150,10 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
}
PD_BUILD_STATIC_OP(get_padding_offset)
.Inputs({"input_ids", "seq_len"})
.Inputs({"input_ids",
"seq_len",
paddle::Optional("draft_tokens"),
paddle::Optional("seq_lens_encoder")})
.Outputs({"x_remove_padding",
"batch_id_per_token",
"cu_seqlens_q",
@@ -1,149 +0,0 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
__global__ void SpeculateRemovePadding(int64_t* output_data,
const int64_t* input_data,
const int64_t* draft_tokens,
const int* seq_lens,
const int* seq_lens_encoder,
const int* cum_offsets,
const int sequence_length,
const int max_draft_tokens) {
const int bi = blockIdx.x;
const int tid = threadIdx.x;
for (int i = tid; i < seq_lens[bi]; i += blockDim.x) {
const int tgt_seq_id = bi * sequence_length - cum_offsets[bi] + i;
if (seq_lens_encoder[bi] > 0) {
const int src_seq_id = bi * sequence_length + i;
output_data[tgt_seq_id] = input_data[src_seq_id];
} else {
const int src_seq_id = bi * max_draft_tokens + i;
output_data[tgt_seq_id] = draft_tokens[src_seq_id];
}
}
}
__global__ void SpeculateGetPaddingOffsetKernel(int* batch_id_per_token,
int* cum_offsets_out,
int* cu_seqlens_q,
int* cu_seqlens_k,
const int* cum_offsets,
const int* seq_lens,
const int max_seq_len) {
// get padding offset of each batch
const int bi = blockIdx.x;
const int ti = threadIdx.x;
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
for (int i = ti; i < seq_lens[bi]; i += blockDim.x) {
batch_id_per_token[bi * max_seq_len - cum_offset + i] = bi;
}
if (ti == 0) {
cum_offsets_out[bi] = cum_offset;
int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi];
cu_seqlens_q[bi + 1] = cum_seq_len;
cu_seqlens_k[bi + 1] = cum_seq_len;
}
}
std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
const paddle::Tensor& input_ids,
const paddle::Tensor& draft_tokens,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& seq_len,
const paddle::Tensor& seq_lens_encoder,
const int64_t cpu_token_num) {
auto cu_stream = input_ids.stream();
std::vector<int64_t> input_ids_shape = input_ids.shape();
const int bsz = seq_len.shape()[0];
const int seq_length = input_ids_shape[1];
const int max_draft_tokens = draft_tokens.shape()[1];
auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false);
const int token_num_data = cpu_token_num;
auto x_remove_padding = paddle::full(
{token_num_data}, 0, paddle::DataType::INT64, input_ids.place());
auto batch_id_per_token = paddle::full(
{token_num_data}, 0, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_q =
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_k =
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
int blockSize = min((token_num_data + 32 - 1) / 32 * 32, 128);
SpeculateGetPaddingOffsetKernel<<<bsz, 128, 0, cu_stream>>>(
batch_id_per_token.data<int>(),
cum_offsets_out.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
cum_offsets.data<int>(),
seq_len.data<int>(),
seq_length);
SpeculateRemovePadding<<<bsz, blockSize, 0, cu_stream>>>(
x_remove_padding.data<int64_t>(),
input_ids.data<int64_t>(),
draft_tokens.data<int64_t>(),
seq_len.data<int>(),
seq_lens_encoder.data<int>(),
cum_offsets_out.data<int>(),
seq_length,
max_draft_tokens);
return {x_remove_padding,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k}; // , enc_token_num, dec_token_num};
}
std::vector<std::vector<int64_t>> SpeculateGetPaddingOffsetInferShape(
const std::vector<int64_t>& input_ids_shape,
const std::vector<int64_t>& draft_tokens_shape,
const std::vector<int64_t>& cum_offsets_shape,
const std::vector<int64_t>& token_num_shape,
const std::vector<int64_t>& seq_len_shape,
const std::vector<int64_t>& seq_lens_encoder_shape) {
int64_t bsz = seq_len_shape[0];
int64_t seq_len = input_ids_shape[1];
return {{-1}, {-1}, {bsz + 1}, {bsz + 1}};
}
std::vector<paddle::DataType> SpeculateGetPaddingOffsetInferDtype(
const paddle::DataType& input_ids_dtype,
const paddle::DataType& draft_tokens_dtype,
const paddle::DataType& cum_offsets_dtype,
const paddle::DataType& token_num_dtype,
const paddle::DataType& seq_len_dtype,
const paddle::DataType& seq_lens_encoder_dtype) {
return {input_ids_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype};
}
PD_BUILD_STATIC_OP(speculate_get_padding_offset)
.Inputs({
"input_ids",
"draft_tokens",
"cum_offsets",
"seq_len",
"seq_lens_encoder",
})
.Outputs({"x_remove_padding",
"batch_id_per_token",
"cu_seqlens_q",
"cu_seqlens_k"})
.Attrs({"cpu_token_num: int64_t"})
.SetKernelFn(PD_KERNEL(SpeculateGetPaddingOffset))
.SetInferShapeFn(PD_INFER_SHAPE(SpeculateGetPaddingOffsetInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(SpeculateGetPaddingOffsetInferDtype));
@@ -59,7 +59,6 @@ elif current_platform.is_maca():
save_output_topk,
set_stop_value_multi_ends,
speculate_get_output_padding_offset,
speculate_get_padding_offset,
speculate_get_seq_lens_output,
speculate_limit_thinking_content_length_v1,
speculate_limit_thinking_content_length_v2,
@@ -86,7 +85,6 @@ else:
save_output_topk,
set_stop_value_multi_ends,
speculate_get_output_padding_offset,
speculate_get_padding_offset,
speculate_get_seq_lens_output,
speculate_save_output,
speculate_save_output_topk,
@@ -226,7 +224,7 @@ def pre_process(
if specific_platform and not speculative_decoding:
# Note(ZKK): This case's code is very simple!
ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k = get_padding_offset(
input_ids, seq_lens_this_time, token_num_cpu
input_ids, seq_lens_this_time, None, None, token_num_cpu
)
return (
ids_remove_padding,
@@ -247,9 +245,7 @@ def pre_process(
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
) = speculate_get_padding_offset(
input_ids, draft_tokens, cum_offsets_now, seq_lens_this_time, seq_lens_encoder, token_num_cpu
)
) = get_padding_offset(input_ids, seq_lens_this_time, draft_tokens, seq_lens_encoder, token_num_cpu)
seq_lens_output = speculate_get_seq_lens_output(
seq_lens_this_time,
seq_lens_encoder,
+1 -1
View File
@@ -274,7 +274,7 @@ class TestAttentionPerformance(unittest.TestCase):
input_ids = paddle.zeros([batch_size, max_model_len], dtype="int64")
token_num = np.sum(seq_lens_this_time)
ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k = get_padding_offset(
input_ids, seq_lens_this_time, token_num
input_ids, seq_lens_this_time, None, None, token_num
)
forward_meta = ForwardMeta(
+1 -1
View File
@@ -32,7 +32,7 @@ class TestGetPaddingOffset(unittest.TestCase):
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
) = get_padding_offset(paddle.to_tensor(input_ids), paddle.to_tensor(seq_lens), token_num_cpu)
) = get_padding_offset(paddle.to_tensor(input_ids), paddle.to_tensor(seq_lens), None, None, token_num_cpu)
ref_x_remove_padding = np.array([8, 7, 8, 2, 4, 5, 5, 7, 6, 1, 7, 2, 6], "int64")
ref_batch_id_per_token = np.array([0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 2], "int32")
@@ -1,142 +0,0 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from fastdeploy.model_executor.ops.gpu import speculate_get_padding_offset
def ref_speculate_get_padding_offset(cum_offsets, seq_lens, max_seq_len, token_num_data):
bsz = seq_lens.shape[0]
padding_offset = np.zeros([token_num_data], dtype=np.int32)
batch_id_per_token = np.zeros([token_num_data], dtype=np.int32)
cum_offsets_out = np.zeros([bsz], dtype=np.int32)
cu_seqlens_q = np.zeros([bsz + 1], dtype=np.int32)
cu_seqlens_k = np.zeros([bsz + 1], dtype=np.int32)
modified_indices = {
"padding_offset": [],
"cum_offsets_out": [],
"cu_seqlens_q": [],
"cu_seqlens_k": [],
}
cu_seqlens_q[0] = 0
cu_seqlens_k[0] = 0
modified_indices["cu_seqlens_q"].append(0)
modified_indices["cu_seqlens_k"].append(0)
for bi in range(bsz):
cum_offset = 0 if bi == 0 else cum_offsets[bi - 1]
cum_offsets_out[bi] = cum_offset
modified_indices["cum_offsets_out"].append(bi)
for i in range(seq_lens[bi]):
idx = bi * max_seq_len - cum_offset + i
if idx >= 0 and idx < token_num_data:
if idx == 0:
print(idx, bi, cum_offset)
padding_offset[idx] = cum_offset
batch_id_per_token[idx] = bi
modified_indices["padding_offset"].append(idx)
cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi]
cu_seqlens_q[bi + 1] = cum_seq_len
cu_seqlens_k[bi + 1] = cum_seq_len
modified_indices["cu_seqlens_q"].append(bi + 1)
modified_indices["cu_seqlens_k"].append(bi + 1)
return (
padding_offset,
cum_offsets_out,
cu_seqlens_q,
cu_seqlens_k,
modified_indices,
batch_id_per_token,
)
class TestSpeculateGetPaddingOffset(unittest.TestCase):
def test_speculate_get_padding_offset(self):
test_case = {
"bsz": 4,
"max_seq_len": 10,
"token_num_data": 32,
"cum_offsets": np.array([2, 5, 8, 12], dtype=np.int32),
"seq_lens": np.array([8, 5, 7, 6], dtype=np.int32),
"seq_lens_encoder": np.array([1, 0, 1, 0], dtype=np.int32),
}
max_draft_tokens = 4
input_ids = np.random.randint(0, 1000, (test_case["bsz"], test_case["max_seq_len"]), dtype=np.int64)
draft_tokens = np.random.randint(0, 1000, (test_case["bsz"], max_draft_tokens), dtype=np.int64)
token_num_cpu = np.array([test_case["token_num_data"]], dtype=np.int64).item()
input_ids_tensor = paddle.to_tensor(input_ids)
draft_tokens_tensor = paddle.to_tensor(draft_tokens)
cum_offsets_tensor = paddle.to_tensor(test_case["cum_offsets"])
seq_lens_tensor = paddle.to_tensor(test_case["seq_lens"])
seq_lens_encoder_tensor = paddle.to_tensor(test_case["seq_lens_encoder"])
(
x_remove_padding,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
) = speculate_get_padding_offset(
input_ids_tensor,
draft_tokens_tensor,
cum_offsets_tensor,
seq_lens_tensor,
seq_lens_encoder_tensor,
token_num_cpu,
)
(
ref_padding_offset,
ref_cum_offsets_out,
ref_cu_seqlens_q,
ref_cu_seqlens_k,
modified_indices,
ref_batch_id_per_token,
) = ref_speculate_get_padding_offset(
test_case["cum_offsets"],
test_case["seq_lens"],
test_case["max_seq_len"],
test_case["token_num_data"],
)
output_arrays = {
"batch_id_per_token": batch_id_per_token.numpy(),
"cu_seqlens_q": cu_seqlens_q.numpy(),
"cu_seqlens_k": cu_seqlens_k.numpy(),
}
ref_arrays = {
"batch_id_per_token": ref_batch_id_per_token,
"cu_seqlens_q": ref_cu_seqlens_q,
"cu_seqlens_k": ref_cu_seqlens_k,
}
for key in output_arrays:
np.testing.assert_allclose(output_arrays[key], ref_arrays[key])
if __name__ == "__main__":
unittest.main()