diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index f6eaf97764..58b346c87e 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -407,9 +407,12 @@ void GetBlockShapeAndSplitKVBlock( const int group_size, const int block_size); -std::vector GetPaddingOffset(const paddle::Tensor& input_ids, - const paddle::Tensor& seq_len, - const int64_t token_num_cpu); +std::vector GetPaddingOffset( + const paddle::Tensor& input_ids, + const paddle::Tensor& seq_len, + const paddle::optional& draft_tokens, + const paddle::optional& seq_lens_encoder, + const int64_t token_num_cpu); void SetValueByFlagsAndIdx(const paddle::Tensor& pre_ids_all, const paddle::Tensor& input_ids, @@ -739,15 +742,6 @@ void free_shared_buffer(int64_t buffer); void clear_ipc_handles(int64_t _fa); -// speculative decoding Kernel -std::vector SpeculateGetPaddingOffset( - const paddle::Tensor& input_ids, - const paddle::Tensor& draft_tokens, - const paddle::Tensor& cum_offsets, - const paddle::Tensor& seq_len, - const paddle::Tensor& seq_lens_encoder, - const int64_t token_num_cpu); - std::vector SpeculateGetSeqLensOutput( const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& seq_lens_encoder, @@ -1596,11 +1590,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) { &get_graph_buffer_ipc_meta, "get_graph_buffer_ipc_meta"); - // speculative decoding Kernel - m.def("speculate_get_padding_offset", - &SpeculateGetPaddingOffset, - "speculate_get_padding_offset function"); - m.def("speculate_get_seq_lens_output", &SpeculateGetSeqLensOutput, "speculate_get_seq_lens_output function"); diff --git a/custom_ops/gpu_ops/get_padding_offset.cu b/custom_ops/gpu_ops/get_padding_offset.cu index c5d676365e..dff137032c 100644 --- a/custom_ops/gpu_ops/get_padding_offset.cu +++ b/custom_ops/gpu_ops/get_padding_offset.cu @@ -25,7 +25,10 @@ __global__ void PrefixSumKernel(int64_t *ids_remove_padding, int *cu_seqlens_k, const int64_t *input_data, const int *seq_lens, - const int max_seq_len) { + const int max_seq_len, + const int64_t *draft_tokens, + const int *seq_lens_encoder, + const int max_draft_tokens_per_batch) { const int bi = blockIdx.x; const int tid = threadIdx.x; #ifdef PADDLE_WITH_COREX @@ -62,15 +65,26 @@ __global__ void PrefixSumKernel(int64_t *ids_remove_padding, for (int i = tid; i < seq_lens[bi]; i += blockDim.x) { const int tgt_seq_id = cum_seq_len - seq_lens[bi] + i; - const int src_seq_id = bi * max_seq_len + i; - ids_remove_padding[tgt_seq_id] = input_data[src_seq_id]; + if (max_draft_tokens_per_batch > 0 && seq_lens_encoder[bi] <= 0) { + // speculative decoding + const int src_seq_id = bi * max_draft_tokens_per_batch + i; + ids_remove_padding[tgt_seq_id] = draft_tokens[src_seq_id]; + } else { + // Non-speculative decoding + const int src_seq_id = bi * max_seq_len + i; + ids_remove_padding[tgt_seq_id] = input_data[src_seq_id]; + } + batch_id_per_token[tgt_seq_id] = bi; } } -std::vector GetPaddingOffset(const paddle::Tensor &input_ids, - const paddle::Tensor &seq_len, - const int64_t cpu_token_num) { +std::vector GetPaddingOffset( + const paddle::Tensor &input_ids, + const paddle::Tensor &seq_len, + const paddle::optional &draft_tokens, + const paddle::optional &seq_lens_encoder, + const int64_t cpu_token_num) { #ifdef PADDLE_WITH_CUSTOM_DEVICE auto dev_ctx = static_cast( paddle::experimental::DeviceContextPool::Instance().Get( @@ -98,6 +112,12 @@ std::vector GetPaddingOffset(const paddle::Tensor &input_ids, int blockSize = min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128); #endif + + int max_draft_tokens_per_batch = -1; + if (draft_tokens) { + max_draft_tokens_per_batch = draft_tokens.get().shape()[1]; + } + PrefixSumKernel<<>>( x_remove_padding.data(), batch_id_per_token.data(), @@ -105,7 +125,10 @@ std::vector GetPaddingOffset(const paddle::Tensor &input_ids, cu_seqlens_k.data(), input_ids.data(), seq_len.data(), - max_seq_len); + max_seq_len, + draft_tokens ? draft_tokens.get().data() : nullptr, + seq_lens_encoder ? seq_lens_encoder.get().data() : nullptr, + max_draft_tokens_per_batch); return {x_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k}; } @@ -127,7 +150,10 @@ std::vector GetPaddingOffsetInferDtype( } PD_BUILD_STATIC_OP(get_padding_offset) - .Inputs({"input_ids", "seq_len"}) + .Inputs({"input_ids", + "seq_len", + paddle::Optional("draft_tokens"), + paddle::Optional("seq_lens_encoder")}) .Outputs({"x_remove_padding", "batch_id_per_token", "cu_seqlens_q", diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_get_padding_offset.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_get_padding_offset.cu deleted file mode 100644 index d644a4fa32..0000000000 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_get_padding_offset.cu +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/extension.h" - -#ifndef PD_BUILD_STATIC_OP -#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) -#endif - -__global__ void SpeculateRemovePadding(int64_t* output_data, - const int64_t* input_data, - const int64_t* draft_tokens, - const int* seq_lens, - const int* seq_lens_encoder, - const int* cum_offsets, - const int sequence_length, - const int max_draft_tokens) { - const int bi = blockIdx.x; - const int tid = threadIdx.x; - - for (int i = tid; i < seq_lens[bi]; i += blockDim.x) { - const int tgt_seq_id = bi * sequence_length - cum_offsets[bi] + i; - if (seq_lens_encoder[bi] > 0) { - const int src_seq_id = bi * sequence_length + i; - output_data[tgt_seq_id] = input_data[src_seq_id]; - } else { - const int src_seq_id = bi * max_draft_tokens + i; - output_data[tgt_seq_id] = draft_tokens[src_seq_id]; - } - } -} - -__global__ void SpeculateGetPaddingOffsetKernel(int* batch_id_per_token, - int* cum_offsets_out, - int* cu_seqlens_q, - int* cu_seqlens_k, - const int* cum_offsets, - const int* seq_lens, - const int max_seq_len) { - // get padding offset of each batch - const int bi = blockIdx.x; - const int ti = threadIdx.x; - int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1]; - for (int i = ti; i < seq_lens[bi]; i += blockDim.x) { - batch_id_per_token[bi * max_seq_len - cum_offset + i] = bi; - } - if (ti == 0) { - cum_offsets_out[bi] = cum_offset; - int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi]; - cu_seqlens_q[bi + 1] = cum_seq_len; - cu_seqlens_k[bi + 1] = cum_seq_len; - } -} - -std::vector SpeculateGetPaddingOffset( - const paddle::Tensor& input_ids, - const paddle::Tensor& draft_tokens, - const paddle::Tensor& cum_offsets, - const paddle::Tensor& seq_len, - const paddle::Tensor& seq_lens_encoder, - const int64_t cpu_token_num) { - auto cu_stream = input_ids.stream(); - std::vector input_ids_shape = input_ids.shape(); - const int bsz = seq_len.shape()[0]; - const int seq_length = input_ids_shape[1]; - const int max_draft_tokens = draft_tokens.shape()[1]; - auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false); - const int token_num_data = cpu_token_num; - auto x_remove_padding = paddle::full( - {token_num_data}, 0, paddle::DataType::INT64, input_ids.place()); - auto batch_id_per_token = paddle::full( - {token_num_data}, 0, paddle::DataType::INT32, input_ids.place()); - auto cu_seqlens_q = - paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); - auto cu_seqlens_k = - paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); - int blockSize = min((token_num_data + 32 - 1) / 32 * 32, 128); - SpeculateGetPaddingOffsetKernel<<>>( - batch_id_per_token.data(), - cum_offsets_out.data(), - cu_seqlens_q.data(), - cu_seqlens_k.data(), - cum_offsets.data(), - seq_len.data(), - seq_length); - SpeculateRemovePadding<<>>( - x_remove_padding.data(), - input_ids.data(), - draft_tokens.data(), - seq_len.data(), - seq_lens_encoder.data(), - cum_offsets_out.data(), - seq_length, - max_draft_tokens); - return {x_remove_padding, - batch_id_per_token, - cu_seqlens_q, - cu_seqlens_k}; // , enc_token_num, dec_token_num}; -} - -std::vector> SpeculateGetPaddingOffsetInferShape( - const std::vector& input_ids_shape, - const std::vector& draft_tokens_shape, - const std::vector& cum_offsets_shape, - const std::vector& token_num_shape, - const std::vector& seq_len_shape, - const std::vector& seq_lens_encoder_shape) { - int64_t bsz = seq_len_shape[0]; - int64_t seq_len = input_ids_shape[1]; - return {{-1}, {-1}, {bsz + 1}, {bsz + 1}}; -} - -std::vector SpeculateGetPaddingOffsetInferDtype( - const paddle::DataType& input_ids_dtype, - const paddle::DataType& draft_tokens_dtype, - const paddle::DataType& cum_offsets_dtype, - const paddle::DataType& token_num_dtype, - const paddle::DataType& seq_len_dtype, - const paddle::DataType& seq_lens_encoder_dtype) { - return {input_ids_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype}; -} - -PD_BUILD_STATIC_OP(speculate_get_padding_offset) - .Inputs({ - "input_ids", - "draft_tokens", - "cum_offsets", - "seq_len", - "seq_lens_encoder", - }) - .Outputs({"x_remove_padding", - "batch_id_per_token", - "cu_seqlens_q", - "cu_seqlens_k"}) - .Attrs({"cpu_token_num: int64_t"}) - .SetKernelFn(PD_KERNEL(SpeculateGetPaddingOffset)) - .SetInferShapeFn(PD_INFER_SHAPE(SpeculateGetPaddingOffsetInferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(SpeculateGetPaddingOffsetInferDtype)); diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 7c3ed751fe..43b27ce532 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -59,7 +59,6 @@ elif current_platform.is_maca(): save_output_topk, set_stop_value_multi_ends, speculate_get_output_padding_offset, - speculate_get_padding_offset, speculate_get_seq_lens_output, speculate_limit_thinking_content_length_v1, speculate_limit_thinking_content_length_v2, @@ -86,7 +85,6 @@ else: save_output_topk, set_stop_value_multi_ends, speculate_get_output_padding_offset, - speculate_get_padding_offset, speculate_get_seq_lens_output, speculate_save_output, speculate_save_output_topk, @@ -226,7 +224,7 @@ def pre_process( if specific_platform and not speculative_decoding: # Note(ZKK): This case's code is very simple! ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k = get_padding_offset( - input_ids, seq_lens_this_time, token_num_cpu + input_ids, seq_lens_this_time, None, None, token_num_cpu ) return ( ids_remove_padding, @@ -247,9 +245,7 @@ def pre_process( batch_id_per_token, cu_seqlens_q, cu_seqlens_k, - ) = speculate_get_padding_offset( - input_ids, draft_tokens, cum_offsets_now, seq_lens_this_time, seq_lens_encoder, token_num_cpu - ) + ) = get_padding_offset(input_ids, seq_lens_this_time, draft_tokens, seq_lens_encoder, token_num_cpu) seq_lens_output = speculate_get_seq_lens_output( seq_lens_this_time, seq_lens_encoder, diff --git a/tests/layers/test_attention_layer.py b/tests/layers/test_attention_layer.py index f37248a83a..507ca0556c 100644 --- a/tests/layers/test_attention_layer.py +++ b/tests/layers/test_attention_layer.py @@ -274,7 +274,7 @@ class TestAttentionPerformance(unittest.TestCase): input_ids = paddle.zeros([batch_size, max_model_len], dtype="int64") token_num = np.sum(seq_lens_this_time) ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k = get_padding_offset( - input_ids, seq_lens_this_time, token_num + input_ids, seq_lens_this_time, None, None, token_num ) forward_meta = ForwardMeta( diff --git a/tests/operators/test_get_padding_offset.py b/tests/operators/test_get_padding_offset.py index cfa7760d84..13c86c422d 100644 --- a/tests/operators/test_get_padding_offset.py +++ b/tests/operators/test_get_padding_offset.py @@ -32,7 +32,7 @@ class TestGetPaddingOffset(unittest.TestCase): batch_id_per_token, cu_seqlens_q, cu_seqlens_k, - ) = get_padding_offset(paddle.to_tensor(input_ids), paddle.to_tensor(seq_lens), token_num_cpu) + ) = get_padding_offset(paddle.to_tensor(input_ids), paddle.to_tensor(seq_lens), None, None, token_num_cpu) ref_x_remove_padding = np.array([8, 7, 8, 2, 4, 5, 5, 7, 6, 1, 7, 2, 6], "int64") ref_batch_id_per_token = np.array([0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 2], "int32") diff --git a/tests/operators/test_speculate_get_padding_offset.py b/tests/operators/test_speculate_get_padding_offset.py deleted file mode 100644 index a8aac690bf..0000000000 --- a/tests/operators/test_speculate_get_padding_offset.py +++ /dev/null @@ -1,142 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -import paddle - -from fastdeploy.model_executor.ops.gpu import speculate_get_padding_offset - - -def ref_speculate_get_padding_offset(cum_offsets, seq_lens, max_seq_len, token_num_data): - bsz = seq_lens.shape[0] - - padding_offset = np.zeros([token_num_data], dtype=np.int32) - batch_id_per_token = np.zeros([token_num_data], dtype=np.int32) - cum_offsets_out = np.zeros([bsz], dtype=np.int32) - cu_seqlens_q = np.zeros([bsz + 1], dtype=np.int32) - cu_seqlens_k = np.zeros([bsz + 1], dtype=np.int32) - - modified_indices = { - "padding_offset": [], - "cum_offsets_out": [], - "cu_seqlens_q": [], - "cu_seqlens_k": [], - } - - cu_seqlens_q[0] = 0 - cu_seqlens_k[0] = 0 - modified_indices["cu_seqlens_q"].append(0) - modified_indices["cu_seqlens_k"].append(0) - - for bi in range(bsz): - cum_offset = 0 if bi == 0 else cum_offsets[bi - 1] - cum_offsets_out[bi] = cum_offset - modified_indices["cum_offsets_out"].append(bi) - - for i in range(seq_lens[bi]): - idx = bi * max_seq_len - cum_offset + i - if idx >= 0 and idx < token_num_data: - if idx == 0: - print(idx, bi, cum_offset) - padding_offset[idx] = cum_offset - batch_id_per_token[idx] = bi - modified_indices["padding_offset"].append(idx) - - cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi] - cu_seqlens_q[bi + 1] = cum_seq_len - cu_seqlens_k[bi + 1] = cum_seq_len - modified_indices["cu_seqlens_q"].append(bi + 1) - modified_indices["cu_seqlens_k"].append(bi + 1) - - return ( - padding_offset, - cum_offsets_out, - cu_seqlens_q, - cu_seqlens_k, - modified_indices, - batch_id_per_token, - ) - - -class TestSpeculateGetPaddingOffset(unittest.TestCase): - def test_speculate_get_padding_offset(self): - test_case = { - "bsz": 4, - "max_seq_len": 10, - "token_num_data": 32, - "cum_offsets": np.array([2, 5, 8, 12], dtype=np.int32), - "seq_lens": np.array([8, 5, 7, 6], dtype=np.int32), - "seq_lens_encoder": np.array([1, 0, 1, 0], dtype=np.int32), - } - - max_draft_tokens = 4 - - input_ids = np.random.randint(0, 1000, (test_case["bsz"], test_case["max_seq_len"]), dtype=np.int64) - draft_tokens = np.random.randint(0, 1000, (test_case["bsz"], max_draft_tokens), dtype=np.int64) - token_num_cpu = np.array([test_case["token_num_data"]], dtype=np.int64).item() - - input_ids_tensor = paddle.to_tensor(input_ids) - draft_tokens_tensor = paddle.to_tensor(draft_tokens) - cum_offsets_tensor = paddle.to_tensor(test_case["cum_offsets"]) - seq_lens_tensor = paddle.to_tensor(test_case["seq_lens"]) - seq_lens_encoder_tensor = paddle.to_tensor(test_case["seq_lens_encoder"]) - - ( - x_remove_padding, - batch_id_per_token, - cu_seqlens_q, - cu_seqlens_k, - ) = speculate_get_padding_offset( - input_ids_tensor, - draft_tokens_tensor, - cum_offsets_tensor, - seq_lens_tensor, - seq_lens_encoder_tensor, - token_num_cpu, - ) - - ( - ref_padding_offset, - ref_cum_offsets_out, - ref_cu_seqlens_q, - ref_cu_seqlens_k, - modified_indices, - ref_batch_id_per_token, - ) = ref_speculate_get_padding_offset( - test_case["cum_offsets"], - test_case["seq_lens"], - test_case["max_seq_len"], - test_case["token_num_data"], - ) - - output_arrays = { - "batch_id_per_token": batch_id_per_token.numpy(), - "cu_seqlens_q": cu_seqlens_q.numpy(), - "cu_seqlens_k": cu_seqlens_k.numpy(), - } - - ref_arrays = { - "batch_id_per_token": ref_batch_id_per_token, - "cu_seqlens_q": ref_cu_seqlens_q, - "cu_seqlens_k": ref_cu_seqlens_k, - } - - for key in output_arrays: - np.testing.assert_allclose(output_arrays[key], ref_arrays[key]) - - -if __name__ == "__main__": - unittest.main()