diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 880c66f7c1..d5504cd4e3 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -428,9 +428,9 @@ paddle::Tensor RebuildPaddingFunc( const paddle::Tensor& seq_len_this_time, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_encoder, - const paddle::optional& output_padding_offset, + const paddle::optional& batch_id_per_token_output, + const paddle::optional& cu_seqlens_q_output, const paddle::optional& first_token_out, - int max_input_length, bool enable_logprob); void GetStopFlagsMulti(const paddle::Tensor& topk_ids, @@ -747,28 +747,23 @@ std::vector SpeculateGetSeqLensOutput( const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder); -std::vector SpeculateGetOutputPaddingOffset( - const paddle::Tensor& output_cum_offsets_tmp, - const paddle::Tensor& out_token_num, - const paddle::Tensor& seq_lens_output, +void SpecTokenPenaltyMultiScores( + const paddle::Tensor& pre_ids, + const paddle::Tensor& logits, + const paddle::Tensor& penalty_scores, + const paddle::Tensor& frequency_scores, + const paddle::Tensor& presence_scores, + const paddle::Tensor& temperatures, + const paddle::Tensor& bad_tokens, + const paddle::Tensor& bad_tokens_len, + const paddle::Tensor& cur_len, + const paddle::Tensor& min_len, + const paddle::Tensor& eos_token_id, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token_output, + const paddle::Tensor& cu_seqlens_q_output, const int max_seq_len); -void SpecTokenPenaltyMultiScores(const paddle::Tensor& pre_ids, - const paddle::Tensor& logits, - const paddle::Tensor& penalty_scores, - const paddle::Tensor& frequency_scores, - const paddle::Tensor& presence_scores, - const paddle::Tensor& temperatures, - const paddle::Tensor& bad_tokens, - const paddle::Tensor& bad_tokens_len, - const paddle::Tensor& cur_len, - const paddle::Tensor& min_len, - const paddle::Tensor& eos_token_id, - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& output_padding_offset, - const paddle::Tensor& output_cum_offsets, - const int max_seq_len); - void SpecGetStopFlagsMultiSeqs(const paddle::Tensor& accept_tokens, const paddle::Tensor& accept_num, const paddle::Tensor& pre_ids, @@ -794,7 +789,7 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids, const paddle::Tensor& max_dec_len, const paddle::Tensor& end_tokens, const paddle::Tensor& is_block_step, - const paddle::Tensor& output_cum_offsets, + const paddle::Tensor& cu_seqlens_q_output, const paddle::Tensor& actual_candidate_len, const paddle::Tensor& actual_draft_token_nums, const paddle::Tensor& topp, @@ -922,7 +917,7 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens, const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& step_idx, - const paddle::Tensor& output_cum_offsets, + const paddle::Tensor& cu_seqlens_q_output, const paddle::Tensor& stop_flags, const paddle::Tensor& not_need_stop, const paddle::Tensor& max_dec_len, @@ -1102,19 +1097,20 @@ std::vector FusedNeoxRopeEmbedding( std::vector GeluTanh(paddle::Tensor& input); -void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits, - const paddle::Tensor& pre_ids, - const paddle::Tensor& stop_flags, - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& step_idx, - const paddle::Tensor& allowed_tokens, - const paddle::Tensor& reasoning_status, - const paddle::Tensor& output_padding_offset, - const paddle::Tensor& output_cum_offsets, - const paddle::Tensor& enable_thinking, - int64_t think_end_id, - int64_t line_break_id); +void ReasoningPhaseTokenConstraint( + const paddle::Tensor& logits, + const paddle::Tensor& pre_ids, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& step_idx, + const paddle::Tensor& allowed_tokens, + const paddle::Tensor& reasoning_status, + const paddle::Tensor& batch_id_per_token_output, + const paddle::Tensor& cu_seqlens_q_output, + const paddle::Tensor& enable_thinking, + int64_t think_end_id, + int64_t line_break_id); std::vector get_attn_mask_q( const paddle::Tensor& cu_seqlens_q, @@ -1612,10 +1608,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) { &SpeculateGetSeqLensOutput, "speculate_get_seq_lens_output function"); - m.def("speculate_get_output_padding_offset", - &SpeculateGetOutputPaddingOffset, - "speculate_get_output_padding_offset function"); - m.def("speculate_get_token_penalty_multi_scores", &SpecTokenPenaltyMultiScores, "speculate_get_token_penalty_multi_scores function"); diff --git a/custom_ops/gpu_ops/reasoning_phase_token_constraint.cu b/custom_ops/gpu_ops/reasoning_phase_token_constraint.cu index cb09dd24e8..ddfd7113d4 100644 --- a/custom_ops/gpu_ops/reasoning_phase_token_constraint.cu +++ b/custom_ops/gpu_ops/reasoning_phase_token_constraint.cu @@ -125,8 +125,8 @@ __global__ void apply_token_enforce_generation_scores_kernel( T* __restrict__ logits_dst, // logits (output) const int64_t* __restrict__ allowed_tokens, // [allowed_len] const int32_t* __restrict__ reasoning_status, - const int* output_padding_offset, - const int* output_cum_offsets, + const int* batch_id_per_token_output, + const int* cu_seqlens_q_output, const int max_bsz, const int max_seq_len, const int vocab_size, @@ -134,10 +134,8 @@ __global__ void apply_token_enforce_generation_scores_kernel( int token_idx = blockIdx.x; int tid = threadIdx.x; - const int bs_idx = - (token_idx + output_padding_offset[token_idx]) / max_seq_len; - const int query_start_token_idx = - bs_idx * max_seq_len - output_cum_offsets[bs_idx]; + const int bs_idx = batch_id_per_token_output[token_idx]; + const int query_start_token_idx = cu_seqlens_q_output[bs_idx]; bool is_batch_first_token = (token_idx == query_start_token_idx); if (allowed_tokens_len == 0 || !is_batch_first_token) { @@ -177,8 +175,8 @@ void reasoning_phase_token_constraint( const paddle::Tensor& step_idx, const paddle::Tensor& allowed_tokens, const paddle::Tensor& reasoning_status, - const paddle::Tensor& output_padding_offset, - const paddle::Tensor& output_cum_offsets, + const paddle::Tensor& batch_id_per_token_output, + const paddle::Tensor& cu_seqlens_q_output, const paddle::Tensor& enable_thinking, int64_t think_end_id, int64_t line_break_id) { @@ -233,27 +231,28 @@ void reasoning_phase_token_constraint( reinterpret_cast(const_cast(logits.data())), allowed_tokens.data(), reasoning_status.data(), - output_padding_offset.data(), - output_cum_offsets.data(), + batch_id_per_token_output.data(), + cu_seqlens_q_output.data(), bs, max_seq_len, vocab_size, allowed_tokens_len); } -void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits, - const paddle::Tensor& pre_ids, - const paddle::Tensor& stop_flags, - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& step_idx, - const paddle::Tensor& allowed_tokens, - const paddle::Tensor& reasoning_status, - const paddle::Tensor& output_padding_offset, - const paddle::Tensor& output_cum_offsets, - const paddle::Tensor& enable_thinking, - int64_t think_end_id, - int64_t line_break_id) { +void ReasoningPhaseTokenConstraint( + const paddle::Tensor& logits, + const paddle::Tensor& pre_ids, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& step_idx, + const paddle::Tensor& allowed_tokens, + const paddle::Tensor& reasoning_status, + const paddle::Tensor& batch_id_per_token_output, + const paddle::Tensor& cu_seqlens_q_output, + const paddle::Tensor& enable_thinking, + int64_t think_end_id, + int64_t line_break_id) { switch (logits.type()) { case paddle::DataType::FLOAT16: return reasoning_phase_token_constraint( @@ -265,8 +264,8 @@ void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits, step_idx, allowed_tokens, reasoning_status, - output_padding_offset, - output_cum_offsets, + batch_id_per_token_output, + cu_seqlens_q_output, enable_thinking, think_end_id, line_break_id); @@ -280,8 +279,8 @@ void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits, step_idx, allowed_tokens, reasoning_status, - output_padding_offset, - output_cum_offsets, + batch_id_per_token_output, + cu_seqlens_q_output, enable_thinking, think_end_id, line_break_id); @@ -295,8 +294,8 @@ void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits, step_idx, allowed_tokens, reasoning_status, - output_padding_offset, - output_cum_offsets, + batch_id_per_token_output, + cu_seqlens_q_output, enable_thinking, think_end_id, line_break_id); @@ -317,8 +316,8 @@ PD_BUILD_STATIC_OP(reasoning_phase_token_constraint) "step_idx", "allowed_tokens", "reasoning_status", - "output_padding_offset", - "output_cum_offsets", + "batch_id_per_token_output", + "cu_seqlens_q_output", "enable_thinking"}) .Outputs({"logits_out", "reasoning_status_out"}) .Attrs({"think_end_id: int64_t", "line_break_id: int64_t"}) diff --git a/custom_ops/gpu_ops/rebuild_padding.cu b/custom_ops/gpu_ops/rebuild_padding.cu index 31c8b004a0..b06b410bcb 100644 --- a/custom_ops/gpu_ops/rebuild_padding.cu +++ b/custom_ops/gpu_ops/rebuild_padding.cu @@ -21,7 +21,6 @@ __global__ void RebuildPaddingKernel(T *output_data, const int *seq_len_this_time, const int *seq_len_decoder, const int *seq_len_encoder, - const int max_input_length, const int dim_embed, const int elem_nums) { using LoadT = AlignedVector; @@ -51,8 +50,8 @@ __global__ void RebuildAppendPaddingKernel(T *output_data, const int *seq_len_this_time, const int *seq_len_decoder, const int *seq_len_encoder, - const int *output_padding_offset, - const int max_input_length, + const int *batch_id_per_token_output, + const int *cu_seqlens_q_output, const int dim_embed, const int64_t output_elem_nums, const int bsz, @@ -62,17 +61,18 @@ __global__ void RebuildAppendPaddingKernel(T *output_data, for (int64_t i = global_idx * VecSize; i < output_elem_nums; i += gridDim.x * blockDim.x * VecSize) { const int out_token_id = i / dim_embed; - const int ori_token_id = out_token_id + output_padding_offset[out_token_id]; - - const int bi = ori_token_id / max_input_length; - - int seq_id = 0; + const int bi = batch_id_per_token_output[out_token_id]; if (seq_len_this_time[bi] == 0) continue; if (seq_len_decoder[bi] == 0 && seq_len_encoder[bi] == 0) continue; - if (seq_len_encoder[bi] > 0) seq_id = seq_len_encoder[bi] - 1; - const int cum_offset_bi = bi * max_input_length - cu_seqlens_q[bi]; - const int input_token_id = ori_token_id - cum_offset_bi + seq_id; + int seq_id = 0; + if (seq_len_encoder[bi] > 0) { + seq_id = seq_len_encoder[bi] - 1; + } else { + seq_id = out_token_id - cu_seqlens_q_output[bi]; + } + + const int input_token_id = cu_seqlens_q[bi] + seq_id; const int bias_idx = i % dim_embed; Load(&input_data[input_token_id * dim_embed + bias_idx], @@ -95,9 +95,10 @@ std::vector rebuild_padding( const paddle::Tensor &seq_len_this_time, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_encoder, - const paddle::optional &output_padding_offset, + const paddle::optional &batch_id_per_token_output, + const paddle::optional &cu_seqlens_q_output, + const paddle::optional &first_token_out, - int max_input_length, bool enable_logprob) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; @@ -117,8 +118,8 @@ std::vector rebuild_padding( paddle::Tensor out; int output_token_num; - if (output_padding_offset) { - output_token_num = output_padding_offset.get().shape()[0]; + if (batch_id_per_token_output) { + output_token_num = batch_id_per_token_output.get().shape()[0]; } else { output_token_num = bsz; } @@ -131,7 +132,7 @@ std::vector rebuild_padding( int pack_num = elem_nums / PackSize; const int blocksize = 128; const int grid_size = (pack_num + blocksize - 1) / blocksize; - if (output_padding_offset) { + if (batch_id_per_token_output) { RebuildAppendPaddingKernel <<>>( reinterpret_cast(out.data()), @@ -144,8 +145,8 @@ std::vector rebuild_padding( seq_len_this_time.data(), seq_lens_decoder.data(), seq_lens_encoder.data(), - output_padding_offset.get_ptr()->data(), - max_input_length, + batch_id_per_token_output.get_ptr()->data(), + cu_seqlens_q_output.get_ptr()->data(), dim_embed, elem_nums, bsz, @@ -160,7 +161,6 @@ std::vector rebuild_padding( seq_len_this_time.data(), seq_lens_decoder.data(), seq_lens_encoder.data(), - max_input_length, dim_embed, elem_nums); } @@ -173,43 +173,46 @@ paddle::Tensor RebuildPaddingFunc( const paddle::Tensor &seq_len_this_time, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_encoder, - const paddle::optional &output_padding_offset, + const paddle::optional &batch_id_per_token_output, + const paddle::optional &cu_seqlens_q_output, const paddle::optional &first_token_out, - int max_input_length, bool enable_logprob) { switch (tmp_out.type()) { case paddle::DataType::BFLOAT16: { - return rebuild_padding(tmp_out, - cu_seqlens_q, - seq_len_this_time, - seq_lens_decoder, - seq_lens_encoder, - output_padding_offset, - first_token_out, - max_input_length, - enable_logprob)[0]; + return rebuild_padding( + tmp_out, + cu_seqlens_q, + seq_len_this_time, + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token_output, + cu_seqlens_q_output, + first_token_out, + enable_logprob)[0]; } case paddle::DataType::FLOAT16: { - return rebuild_padding(tmp_out, - cu_seqlens_q, - seq_len_this_time, - seq_lens_decoder, - seq_lens_encoder, - output_padding_offset, - first_token_out, - max_input_length, - enable_logprob)[0]; + return rebuild_padding( + tmp_out, + cu_seqlens_q, + seq_len_this_time, + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token_output, + cu_seqlens_q_output, + first_token_out, + enable_logprob)[0]; } case paddle::DataType::FLOAT32: { - return rebuild_padding(tmp_out, - cu_seqlens_q, - seq_len_this_time, - seq_lens_decoder, - seq_lens_encoder, - output_padding_offset, - first_token_out, - max_input_length, - enable_logprob)[0]; + return rebuild_padding( + tmp_out, + cu_seqlens_q, + seq_len_this_time, + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token_output, + cu_seqlens_q_output, + first_token_out, + enable_logprob)[0]; } default: { PD_THROW( @@ -226,18 +229,18 @@ std::vector RebuildPadding( const paddle::Tensor &seq_len_this_time, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_encoder, - const paddle::optional &output_padding_offset, + const paddle::optional &batch_id_per_token_output, + const paddle::optional &cu_seqlens_q_output, const paddle::optional &first_token_out, - int max_input_length, bool enable_logprob) { return {RebuildPaddingFunc(tmp_out, cu_seqlens_q, seq_len_this_time, seq_lens_decoder, seq_lens_encoder, - output_padding_offset, + batch_id_per_token_output, + cu_seqlens_q_output, first_token_out, - max_input_length, enable_logprob)}; } @@ -247,10 +250,12 @@ std::vector> RebuildPaddingInferShape( const std::vector &seq_len_this_time_shape, const std::vector &seq_lens_decoder_shape, const std::vector &seq_lens_encoder_shape, - const paddle::optional> &output_padding_offset_shape) { + const paddle::optional> + &batch_id_per_token_output_shape, + const paddle::optional> &cu_seqlens_q_output_shape) { int64_t dim_embed = tmp_out_shape[1]; // whether speculative decoding - if (output_padding_offset_shape) { + if (batch_id_per_token_output_shape) { return {{-1, dim_embed}}; } else { int64_t bsz = cu_seqlens_q_shape[0] - 1; @@ -264,7 +269,8 @@ std::vector RebuildPaddingInferDtype( const paddle::DataType &seq_len_this_time_dtype, const paddle::DataType &seq_lens_decoder_dtype, const paddle::DataType &seq_lens_encoder_dtype, - const paddle::optional &output_padding_offset_dtype) { + const paddle::optional &batch_id_per_token_output_dtype, + const paddle::optional &cu_seqlens_q_output_dtype) { return {tmp_out_dtype}; } @@ -274,10 +280,11 @@ PD_BUILD_STATIC_OP(rebuild_padding) "seq_len_this_time", "seq_lens_decoder", "seq_lens_encoder", - paddle::Optional("output_padding_offset"), + paddle::Optional("batch_id_per_token_output"), + paddle::Optional("cu_seqlens_q_output"), paddle::Optional("first_token_out")}) .Outputs({"out"}) - .Attrs({"max_input_length: int", "enable_logprob: bool"}) + .Attrs({"enable_logprob: bool"}) .SetKernelFn(PD_KERNEL(RebuildPadding)) .SetInferShapeFn(PD_INFER_SHAPE(RebuildPaddingInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(RebuildPaddingInferDtype)); diff --git a/custom_ops/gpu_ops/speculate_decoding/draft_model/draft_model_update.cu b/custom_ops/gpu_ops/speculate_decoding/draft_model/draft_model_update.cu index 390ee831af..f217879362 100644 --- a/custom_ops/gpu_ops/speculate_decoding/draft_model/draft_model_update.cu +++ b/custom_ops/gpu_ops/speculate_decoding/draft_model/draft_model_update.cu @@ -23,7 +23,7 @@ __global__ void draft_model_update_kernel(const int64_t* inter_next_tokens, int* seq_lens_encoder, int* seq_lens_decoder, int64_t* step_idx, - const int* output_cum_offsets, + const int* cu_seqlens_q_output, bool* stop_flags, bool* not_need_stop, const int64_t* max_dec_len, @@ -47,8 +47,7 @@ __global__ void draft_model_update_kernel(const int64_t* inter_next_tokens, auto* pre_ids_now = pre_ids + tid * pre_id_length; auto* base_model_draft_tokens_now = base_model_draft_tokens + tid * max_base_model_draft_token; - const int next_tokens_start_id = - tid * max_seq_len - output_cum_offsets[tid]; + const int next_tokens_start_id = cu_seqlens_q_output[tid]; auto* next_tokens_start = inter_next_tokens + next_tokens_start_id; auto seq_len_this_time = seq_lens_this_time[tid]; auto seq_len_encoder = seq_lens_encoder[tid]; @@ -66,7 +65,6 @@ __global__ void draft_model_update_kernel(const int64_t* inter_next_tokens, step_idx[tid] += seq_len_this_time; pre_ids_now[step_idx[tid]] = token_this_time; - } else { token_this_time = next_tokens_start[0]; @@ -80,7 +78,8 @@ __global__ void draft_model_update_kernel(const int64_t* inter_next_tokens, } // multi_end - if (is_in_end(token_this_time, end_ids, end_ids_len) || prefill_one_step_stop) { + if (is_in_end(token_this_time, end_ids, end_ids_len) || + prefill_one_step_stop) { stop_flags[tid] = true; stop_flag_now_int = 1; // max_dec_len @@ -112,7 +111,6 @@ __global__ void draft_model_update_kernel(const int64_t* inter_next_tokens, } } - void DraftModelUpdate(const paddle::Tensor& inter_next_tokens, const paddle::Tensor& draft_tokens, const paddle::Tensor& pre_ids, @@ -120,7 +118,7 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens, const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& step_idx, - const paddle::Tensor& output_cum_offsets, + const paddle::Tensor& cu_seqlens_q_output, const paddle::Tensor& stop_flags, const paddle::Tensor& not_need_stop, const paddle::Tensor& max_dec_len, @@ -140,11 +138,11 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens, constexpr int BlockSize = 512; bool prefill_one_step_stop = false; - if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) { - // std::cout << "Your PATH is: " << env_p << '\n'; - if (env_p[0] == '1') { - prefill_one_step_stop = true; - } + if (const char* env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) { + // std::cout << "Your PATH is: " << env_p << '\n'; + if (env_p[0] == '1') { + prefill_one_step_stop = true; + } } draft_model_update_kernel<<<1, BlockSize, 0, cu_stream>>>( @@ -155,7 +153,7 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens, const_cast(seq_lens_encoder.data()), const_cast(seq_lens_decoder.data()), const_cast(step_idx.data()), - output_cum_offsets.data(), + cu_seqlens_q_output.data(), const_cast(stop_flags.data()), not_need_stop_gpu.data(), max_dec_len.data(), @@ -170,14 +168,12 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens, substep, prefill_one_step_stop); - auto not_need_stop_cpu = not_need_stop_gpu.copy_to(not_need_stop.place(), false); bool* not_need_stop_data = const_cast(not_need_stop.data()); not_need_stop_data[0] = not_need_stop_cpu.data()[0]; } - PD_BUILD_STATIC_OP(draft_model_update) .Inputs({"inter_next_tokens", "draft_tokens", @@ -186,7 +182,7 @@ PD_BUILD_STATIC_OP(draft_model_update) "seq_lens_encoder", "seq_lens_decoder", "step_idx", - "output_cum_offsets", + "cu_seqlens_q_output", "stop_flags", "not_need_stop", "max_dec_len", diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_padding_offset.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_padding_offset.cu deleted file mode 100644 index aea33fff0f..0000000000 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_padding_offset.cu +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/extension.h" - -#ifndef PD_BUILD_STATIC_OP -#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) -#endif - -__global__ void SpeculateGetOutputPaddingOffsetKernel( - int* output_padding_offset, - int* output_cum_offsets, - const int* output_cum_offsets_tmp, - const int* seq_lens_output, - const int max_seq_len) { - // get padding offset of each batch - const int bi = blockIdx.x; - const int ti = threadIdx.x; - int cum_offset = bi == 0 ? 0 : output_cum_offsets_tmp[bi - 1]; - for (int i = ti; i < seq_lens_output[bi]; i += blockDim.x) { - output_padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset; - } - if (ti == 0) { - output_cum_offsets[bi] = cum_offset; - } -} - -std::vector SpeculateGetOutputPaddingOffset( - const paddle::Tensor& output_cum_offsets_tmp, - const paddle::Tensor& out_token_num, - const paddle::Tensor& seq_lens_output, - const int max_seq_len) { - auto cu_stream = output_cum_offsets_tmp.stream(); - std::vector output_cum_offsets_tmp_shape = - output_cum_offsets_tmp.shape(); - const int bsz = output_cum_offsets_tmp_shape[0]; - auto cpu_out_token_num = out_token_num.copy_to(paddle::CPUPlace(), false); - - auto output_padding_offset = paddle::full({cpu_out_token_num}, - 0, - paddle::DataType::INT32, - output_cum_offsets_tmp.place()); - auto output_cum_offsets = - output_cum_offsets_tmp.copy_to(output_cum_offsets_tmp.place(), false); - - SpeculateGetOutputPaddingOffsetKernel<<>>( - output_padding_offset.data(), - output_cum_offsets.data(), - output_cum_offsets_tmp.data(), - seq_lens_output.data(), - max_seq_len); - - return {output_padding_offset, output_cum_offsets}; -} - -std::vector> SpeculateGetOutputPaddingOffsetInferShape( - const std::vector& output_cum_offsets_tmp_shape, - const std::vector& out_token_num_shape, - const std::vector& seq_lens_output_shape) { - int64_t bsz = output_cum_offsets_tmp_shape[0]; - return {{-1}, {bsz}}; -} - -std::vector SpeculateGetOutputPaddingOffsetInferDtype( - const paddle::DataType& output_cum_offsets_tmp_dtype, - const paddle::DataType& out_token_num_dtype, - const paddle::DataType& seq_lens_output_dtype) { - return {output_cum_offsets_tmp_dtype, output_cum_offsets_tmp_dtype}; -} - -PD_BUILD_STATIC_OP(speculate_get_output_padding_offset) - .Inputs({"output_cum_offsets_tmp", "out_token_num", "seq_lens_output"}) - .Outputs({"output_padding_offset", "output_cum_offsets"}) - .Attrs({"max_seq_len: int"}) - .SetKernelFn(PD_KERNEL(SpeculateGetOutputPaddingOffset)) - .SetInferShapeFn(PD_INFER_SHAPE(SpeculateGetOutputPaddingOffsetInferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(SpeculateGetOutputPaddingOffsetInferDtype)); diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_get_token_penalty_multi_scores.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_get_token_penalty_multi_scores.cu index 8e5f63a5d2..615462ab93 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_get_token_penalty_multi_scores.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_get_token_penalty_multi_scores.cu @@ -20,8 +20,8 @@ __global__ inline void min_length_logits_process( const int64_t *cur_len, const int64_t *min_len, const int64_t *eos_token_id, - const int *output_padding_offset, - const int *output_cum_offsets, + const int *batch_id_per_token_output, + const int *cu_seqlens_q_output, const int64_t token_num, const int64_t bs, const int64_t length, @@ -29,9 +29,9 @@ __global__ inline void min_length_logits_process( const int max_seq_len) { const int token_idx = threadIdx.x; if (token_idx >= token_num) return; - const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len; + const int bi = batch_id_per_token_output[token_idx]; if (bi >= bs) return; - const int query_start_token_idx = bi * max_seq_len - output_cum_offsets[bi]; + const int query_start_token_idx = cu_seqlens_q_output[bi]; if (cur_len[bi] < 0) { return; @@ -49,8 +49,8 @@ __global__ inline void min_length_logits_process( const int64_t *cur_len, const int64_t *min_len, const int64_t *eos_token_id, - const int *output_padding_offset, - const int *output_cum_offsets, + const int *batch_id_per_token_output, + const int *cu_seqlens_q_output, const int64_t token_num, const int64_t bs, const int64_t length, @@ -58,9 +58,9 @@ __global__ inline void min_length_logits_process( const int max_seq_len) { const int token_idx = threadIdx.x; if (token_idx >= token_num) return; - const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len; + const int bi = batch_id_per_token_output[token_idx]; if (bi >= bs) return; - const int query_start_token_idx = bi * max_seq_len - output_cum_offsets[bi]; + const int query_start_token_idx = cu_seqlens_q_output[bi]; if (cur_len[bi] < 0) { return; @@ -75,7 +75,7 @@ __global__ inline void min_length_logits_process( __global__ void update_repeat_times(const int64_t *pre_ids, const int64_t *cur_len, int *repeat_times, - const int *output_padding_offset, + const int *batch_id_per_token_output, const int64_t token_num, const int64_t bs, const int64_t length, @@ -83,7 +83,7 @@ __global__ void update_repeat_times(const int64_t *pre_ids, const int max_seq_len) { const int token_idx = blockIdx.x; if (token_idx >= token_num) return; - const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len; + const int bi = batch_id_per_token_output[token_idx]; if (bi >= bs) return; if (cur_len[bi] < 0) { return; @@ -99,20 +99,21 @@ __global__ void update_repeat_times(const int64_t *pre_ids, } template -__global__ void update_value_by_repeat_times(const int *repeat_times, - const T *penalty_scores, - const T *frequency_score, - const T *presence_score, - const float *temperatures, - T *logits, - const int *output_padding_offset, - const int64_t token_num, - const int64_t bs, - const int64_t length, - const int max_seq_len) { +__global__ void update_value_by_repeat_times( + const int *repeat_times, + const T *penalty_scores, + const T *frequency_score, + const T *presence_score, + const float *temperatures, + T *logits, + const int *batch_id_per_token_output, + const int64_t token_num, + const int64_t bs, + const int64_t length, + const int max_seq_len) { const int token_idx = blockIdx.x; if (token_idx >= token_num) return; - const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len; + const int bi = batch_id_per_token_output[token_idx]; if (bi >= bs) return; int tid = threadIdx.x; T *logits_now = logits + token_idx * length; @@ -135,7 +136,7 @@ template __global__ void ban_bad_words(T *logits, const int64_t *bad_tokens, const int64_t *bad_tokens_len, - const int *output_padding_offset, + const int *batch_id_per_token_output, const int64_t token_num, const int64_t bs, const int64_t length, @@ -143,7 +144,7 @@ __global__ void ban_bad_words(T *logits, const int max_seq_len) { const int token_idx = blockIdx.x; if (token_idx >= token_num) return; - const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len; + const int bi = batch_id_per_token_output[token_idx]; if (bi >= bs) return; int tid = threadIdx.x; @@ -172,8 +173,8 @@ void token_penalty_multi_scores_kernel( const paddle::Tensor &min_len, const paddle::Tensor &eos_token_id, const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &output_padding_offset, - const paddle::Tensor &output_cum_offsets, + const paddle::Tensor &batch_id_per_token_output, + const paddle::Tensor &cu_seqlens_q_output, const int max_seq_len) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; @@ -196,8 +197,8 @@ void token_penalty_multi_scores_kernel( cur_len.data(), min_len.data(), eos_token_id.data(), - output_padding_offset.data(), - output_cum_offsets.data(), + batch_id_per_token_output.data(), + cu_seqlens_q_output.data(), token_num, bs, length, @@ -210,7 +211,7 @@ void token_penalty_multi_scores_kernel( pre_ids.data(), cur_len.data(), repeat_times.data(), - output_padding_offset.data(), + batch_id_per_token_output.data(), token_num, bs, length, @@ -231,7 +232,7 @@ void token_penalty_multi_scores_kernel( temperatures.data(), reinterpret_cast( const_cast(logits.data())), - output_padding_offset.data(), + batch_id_per_token_output.data(), token_num, bs, length, @@ -243,7 +244,7 @@ void token_penalty_multi_scores_kernel( const_cast(logits.data())), bad_tokens.data(), bad_tokens_len.data(), - output_padding_offset.data(), + batch_id_per_token_output.data(), token_num, bs, length, @@ -251,21 +252,22 @@ void token_penalty_multi_scores_kernel( max_seq_len); } -void SpecTokenPenaltyMultiScores(const paddle::Tensor &pre_ids, - const paddle::Tensor &logits, - const paddle::Tensor &penalty_scores, - const paddle::Tensor &frequency_scores, - const paddle::Tensor &presence_scores, - const paddle::Tensor &temperatures, - const paddle::Tensor &bad_tokens, - const paddle::Tensor &bad_tokens_len, - const paddle::Tensor &cur_len, - const paddle::Tensor &min_len, - const paddle::Tensor &eos_token_id, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &output_padding_offset, - const paddle::Tensor &output_cum_offsets, - const int max_seq_len) { +void SpecTokenPenaltyMultiScores( + const paddle::Tensor &pre_ids, + const paddle::Tensor &logits, + const paddle::Tensor &penalty_scores, + const paddle::Tensor &frequency_scores, + const paddle::Tensor &presence_scores, + const paddle::Tensor &temperatures, + const paddle::Tensor &bad_tokens, + const paddle::Tensor &bad_tokens_len, + const paddle::Tensor &cur_len, + const paddle::Tensor &min_len, + const paddle::Tensor &eos_token_id, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &batch_id_per_token_output, + const paddle::Tensor &cu_seqlens_q_output, + const int max_seq_len) { switch (logits.type()) { case paddle::DataType::BFLOAT16: { return token_penalty_multi_scores_kernel( @@ -281,8 +283,8 @@ void SpecTokenPenaltyMultiScores(const paddle::Tensor &pre_ids, min_len, eos_token_id, seq_lens_this_time, - output_padding_offset, - output_cum_offsets, + batch_id_per_token_output, + cu_seqlens_q_output, max_seq_len); } case paddle::DataType::FLOAT16: { @@ -299,8 +301,8 @@ void SpecTokenPenaltyMultiScores(const paddle::Tensor &pre_ids, min_len, eos_token_id, seq_lens_this_time, - output_padding_offset, - output_cum_offsets, + batch_id_per_token_output, + cu_seqlens_q_output, max_seq_len); } case paddle::DataType::FLOAT32: { @@ -317,8 +319,8 @@ void SpecTokenPenaltyMultiScores(const paddle::Tensor &pre_ids, min_len, eos_token_id, seq_lens_this_time, - output_padding_offset, - output_cum_offsets, + batch_id_per_token_output, + cu_seqlens_q_output, max_seq_len); } default: { @@ -343,8 +345,8 @@ PD_BUILD_STATIC_OP(speculate_get_token_penalty_multi_scores) "min_len", "eos_token_id", "seq_lens_this_time", - "output_padding_offset", - "output_cum_offsets"}) + "batch_id_per_token_output", + "cu_seqlens_q_output"}) .Outputs({"logits_out"}) .Attrs({"max_seq_len: int"}) .SetInplaceMap({{"logits", "logits_out"}}) diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu index 36d2f4484b..9e1807c349 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu @@ -82,7 +82,7 @@ __global__ void speculate_verify(const int64_t *sampled_token_ids, const int64_t *max_dec_len, const int64_t *end_tokens, const bool *is_block_step, - const int *output_cum_offsets, + const int *cu_seqlens_q_output, const int *actual_candidate_len, const int *reasoning_status, const int real_bsz, @@ -101,7 +101,7 @@ __global__ void speculate_verify(const int64_t *sampled_token_ids, int stop_flag_now_int = 0; if (!(is_block_step[bid] || bid >= real_bsz)) { - const int start_token_id = bid * max_seq_len - output_cum_offsets[bid]; + const int start_token_id = cu_seqlens_q_output[bid]; if (stop_flags[bid]) { stop_flag_now_int = 1; @@ -312,7 +312,7 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids, const paddle::Tensor &max_dec_len, const paddle::Tensor &end_tokens, const paddle::Tensor &is_block_step, - const paddle::Tensor &output_cum_offsets, + const paddle::Tensor &cu_seqlens_q_output, const paddle::Tensor &actual_candidate_len, const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp, @@ -374,7 +374,7 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids, max_dec_len.data(), end_tokens.data(), is_block_step.data(), - output_cum_offsets.data(), + cu_seqlens_q_output.data(), actual_candidate_len.data(), reasoning_status.data(), real_bsz, @@ -407,7 +407,7 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids, max_dec_len.data(), end_tokens.data(), is_block_step.data(), - output_cum_offsets.data(), + cu_seqlens_q_output.data(), actual_candidate_len.data(), reasoning_status.data(), real_bsz, @@ -442,7 +442,7 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids, max_dec_len.data(), end_tokens.data(), is_block_step.data(), - output_cum_offsets.data(), + cu_seqlens_q_output.data(), actual_candidate_len.data(), reasoning_status.data(), real_bsz, @@ -475,7 +475,7 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids, max_dec_len.data(), end_tokens.data(), is_block_step.data(), - output_cum_offsets.data(), + cu_seqlens_q_output.data(), actual_candidate_len.data(), reasoning_status.data(), real_bsz, @@ -509,7 +509,7 @@ PD_BUILD_STATIC_OP(speculate_verify) "max_dec_len", "end_tokens", "is_block_step", - "output_cum_offsets", + "cu_seqlens_q_output", "actual_candidate_len", "actual_draft_token_nums", "topp", diff --git a/custom_ops/gpu_ops/speculate_decoding/top_p_candidates.cu b/custom_ops/gpu_ops/speculate_decoding/top_p_candidates.cu index e5ba39a1f3..1a3bd72e46 100644 --- a/custom_ops/gpu_ops/speculate_decoding/top_p_candidates.cu +++ b/custom_ops/gpu_ops/speculate_decoding/top_p_candidates.cu @@ -407,7 +407,7 @@ template __global__ void KeMatrixTopPBeamTopKFt( const T* src, const T* top_ps, - const int* output_padding_offset, + const int* batch_id_per_token_output, int64_t* out_id, // [max_cadidate_len, 1] T* out_val, // [max_cadidate_len, 1] int* actual_candidates_lens, @@ -418,8 +418,7 @@ __global__ void KeMatrixTopPBeamTopKFt( const int wid = tid / 32; const int lane = tid % 32; const int token_id = blockIdx.x; - const int ori_token_id = token_id + output_padding_offset[token_id]; - const int bid = ori_token_id / max_seq_len; + const int bid = batch_id_per_token_output[token_id]; int top_num = TopPBeamTopK; float top_p_value = static_cast(top_ps[bid]); @@ -479,7 +478,7 @@ __global__ void KeMatrixTopPBeamTopKFt( template void DispatchTopK(const T* src, const T* top_ps, - const int* output_padding_offset, + const int* batch_id_per_token_output, int64_t* out_id, // topk id T* out_val, // topk val int* actual_candidates_lens_data, @@ -495,7 +494,7 @@ void DispatchTopK(const T* src, KeMatrixTopPBeamTopKFt <<>>(src, top_ps, - output_padding_offset, + batch_id_per_token_output, out_id, out_val, actual_candidates_lens_data, @@ -516,7 +515,7 @@ template std::vector LaunchTopPCandidates( const paddle::Tensor& probs, // [token_num, vocab_size] const paddle::Tensor& top_p, // [token_num] - const paddle::Tensor& output_padding_offset, + const paddle::Tensor& batch_id_per_token_output, const int candidates_len, const int max_seq_len) { typedef PDTraits traits_; @@ -540,7 +539,7 @@ std::vector LaunchTopPCandidates( DispatchTopK( reinterpret_cast(probs.data()), reinterpret_cast(top_p.data()), - output_padding_offset.data(), + batch_id_per_token_output.data(), verify_tokens.data(), reinterpret_cast(verify_scores.data()), actual_candidate_lens.data(), @@ -556,21 +555,21 @@ std::vector LaunchTopPCandidates( std::vector DispatchTopPCandidatesWithDtype( const paddle::Tensor& probs, const paddle::Tensor& top_p, - const paddle::Tensor& output_padding_offset, + const paddle::Tensor& batch_id_per_token_output, int candidates_len, int max_seq_len) { switch (probs.type()) { case paddle::DataType::BFLOAT16: return LaunchTopPCandidates( - probs, top_p, output_padding_offset, candidates_len, max_seq_len); + probs, top_p, batch_id_per_token_output, candidates_len, max_seq_len); break; case paddle::DataType::FLOAT16: return LaunchTopPCandidates( - probs, top_p, output_padding_offset, candidates_len, max_seq_len); + probs, top_p, batch_id_per_token_output, candidates_len, max_seq_len); break; case paddle::DataType::FLOAT32: return LaunchTopPCandidates( - probs, top_p, output_padding_offset, candidates_len, max_seq_len); + probs, top_p, batch_id_per_token_output, candidates_len, max_seq_len); break; default: PD_THROW( @@ -583,17 +582,17 @@ std::vector DispatchTopPCandidatesWithDtype( std::vector TopPCandidates( const paddle::Tensor& probs, const paddle::Tensor& top_p, - const paddle::Tensor& output_padding_offset, + const paddle::Tensor& batch_id_per_token_output, int candidates_len, int max_seq_len) { return DispatchTopPCandidatesWithDtype( - probs, top_p, output_padding_offset, candidates_len, max_seq_len); + probs, top_p, batch_id_per_token_output, candidates_len, max_seq_len); } std::vector> TopPCandidatesInferShape( const std::vector& probs_shape, const std::vector& top_p_shape, - const std::vector& output_padding_offset_shape, + const std::vector& batch_id_per_token_output_shape, int max_candidates_len) { int token_num = probs_shape[0]; return {{token_num, max_candidates_len}, @@ -604,12 +603,12 @@ std::vector> TopPCandidatesInferShape( std::vector TopPCandidatesInferDtype( const paddle::DataType& probs_dtype, const paddle::DataType& top_p_dtype, - const paddle::DataType& output_padding_offset_dtype) { + const paddle::DataType& batch_id_per_token_output_dtype) { return {probs_dtype, paddle::DataType::INT64, paddle::DataType::INT32}; } PD_BUILD_STATIC_OP(top_p_candidates) - .Inputs({"probs", "top_p", "output_padding_offset"}) + .Inputs({"probs", "top_p", "batch_id_per_token_output"}) .Outputs({"verify_scores", "verify_tokens", "actual_candidate_lens"}) .Attrs({"candidates_len: int", "max_seq_len: int"}) .SetKernelFn(PD_KERNEL(TopPCandidates)) diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index 357a7a6a78..86e7a2a1ae 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -261,9 +261,6 @@ class FlashAttentionBackend(AttentionBackend): ) # Note(ZKK): here must be consistent with append_attn_backend.py self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 1024)) - self.zero_seq_enc_lens_for_decode = paddle.zeros( - shape=[fd_config.scheduler_config.max_num_seqs, 1], dtype=paddle.int32 - ) def get_attention_meta(self): """get_attention_meta""" diff --git a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py index 4c35556ffe..6dbcf3eed6 100644 --- a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py @@ -123,9 +123,6 @@ class FlashMaskAttentionBackend(AttentionBackend): if fd_config.speculative_config.model_type != "main": self.rope_3d = False self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", "32768")) - self.zero_seq_enc_lens_for_decode = paddle.zeros( - shape=[fd_config.scheduler_config.max_num_seqs, 1], dtype=paddle.int32 - ) def get_kv_cache_shape( self, diff --git a/fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py b/fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py index bc0f0692ec..4f090297d3 100644 --- a/fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py +++ b/fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py @@ -175,8 +175,8 @@ def apply_speculative_penalty_multi_scores( min_dec_lens: paddle.Tensor, eos_token_ids: paddle.Tensor, seq_lens_this_time: paddle.Tensor, - output_padding_offset: paddle.Tensor, - output_cum_offsets: paddle.Tensor, + batch_id_per_token_output: paddle.Tensor, + cu_seqlens_q_output: paddle.Tensor, max_len: int, ): """ @@ -200,8 +200,8 @@ def apply_speculative_penalty_multi_scores( min_dec_lens, eos_token_ids, seq_lens_this_time, - output_padding_offset, - output_cum_offsets, + batch_id_per_token_output, + cu_seqlens_q_output, max_len, ) elif current_platform.is_xpu(): @@ -221,8 +221,8 @@ def apply_speculative_penalty_multi_scores( min_dec_lens, eos_token_ids, seq_lens_this_time, - output_padding_offset, - output_cum_offsets, + batch_id_per_token_output, + cu_seqlens_q_output, max_len, ) @@ -242,8 +242,8 @@ def reasoning_phase_token_constraint( step_idx: paddle.Tensor, reasoning_allowed_tokens: paddle.Tensor, reasoning_status: paddle.Tensor, - output_padding_offset: paddle.Tensor, - output_cum_offsets: paddle.Tensor, + batch_id_per_token_output: paddle.Tensor, + cu_seqlens_q_output: paddle.Tensor, enable_thinking: paddle.Tensor, think_end_id: int, line_break_id: int, @@ -263,8 +263,8 @@ def reasoning_phase_token_constraint( step_idx, reasoning_allowed_tokens, reasoning_status, - output_padding_offset, - output_cum_offsets, + batch_id_per_token_output, + cu_seqlens_q_output, enable_thinking, think_end_id, line_break_id, diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 19b1fcdfb1..4e85647807 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -758,8 +758,8 @@ class SpeculativeSampler(nn.Layer): sampling_metadata.min_dec_lens, sampling_metadata.eos_token_ids, share_inputs["seq_lens_this_time"], - share_inputs["output_padding_offset"], - share_inputs["output_cum_offsets"], + share_inputs["batch_id_per_token_output"], + share_inputs["cu_seqlens_q_output"], max_model_len, ) @@ -773,8 +773,8 @@ class SpeculativeSampler(nn.Layer): share_inputs["step_idx"], share_inputs["reasoning_allowed_tokens"], share_inputs["reasoning_status"], - share_inputs["output_padding_offset"], - share_inputs["output_cum_offsets"], + share_inputs["batch_id_per_token_output"], + share_inputs["cu_seqlens_q_output"], share_inputs["enable_thinking"], self.think_end_id, self.line_break_id, @@ -794,7 +794,7 @@ class SpeculativeSampler(nn.Layer): verify_scores, verify_tokens, actual_candidate_len = top_p_candidates( probs, sampling_metadata.top_p, - share_inputs["output_padding_offset"], + share_inputs["batch_id_per_token_output"], self.speculative_max_candidate_len, max_model_len, ) @@ -816,7 +816,7 @@ class SpeculativeSampler(nn.Layer): share_inputs["max_dec_len"], sampling_metadata.eos_token_ids, share_inputs["is_block_step"], - share_inputs["output_cum_offsets"], + share_inputs["cu_seqlens_q_output"], actual_candidate_len, share_inputs["actual_draft_token_num"], sampling_metadata.top_p, @@ -1135,8 +1135,8 @@ class MTPSampler(nn.Layer): sampling_metadata.min_dec_lens, sampling_metadata.eos_token_ids, share_inputs["seq_lens_this_time"], - share_inputs["output_padding_offset"], - share_inputs["output_cum_offsets"], + share_inputs["batch_id_per_token_output"], + share_inputs["cu_seqlens_q_output"], max_model_len, ) probs = F.softmax(logits) diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index af269987ce..382f55d30e 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -63,7 +63,6 @@ elif current_platform.is_maca(): save_output, save_output_topk, set_stop_value_multi_ends, - speculate_get_output_padding_offset, speculate_get_seq_lens_output, speculate_limit_thinking_content_length_v1, speculate_limit_thinking_content_length_v2, @@ -89,7 +88,6 @@ else: save_output, save_output_topk, set_stop_value_multi_ends, - speculate_get_output_padding_offset, speculate_get_seq_lens_output, speculate_save_output, speculate_save_output_topk, @@ -240,10 +238,6 @@ def pre_process( None, ) # Remove padding - max_len = input_ids.shape[1] - cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32") - output_padding_offset = None - output_cum_offsets = None if speculative_decoding: ( ids_remove_padding, @@ -251,6 +245,8 @@ def pre_process( cu_seqlens_q, cu_seqlens_k, ) = get_padding_offset(input_ids, seq_lens_this_time, draft_tokens, seq_lens_encoder, token_num_cpu) + + # compute each batch's output token num seq_lens_output = speculate_get_seq_lens_output( seq_lens_this_time, seq_lens_encoder, @@ -259,28 +255,23 @@ def pre_process( if isinstance(seq_lens_output, list): seq_lens_output = seq_lens_output[0] output_token_num = paddle.sum(seq_lens_output) - output_cum_offsets_tmp = paddle.cumsum(max_len - seq_lens_output, dtype="int32") - output_padding_offset, output_cum_offsets = speculate_get_output_padding_offset( - output_cum_offsets_tmp, - output_token_num, + + useless_input_ids = input_ids + _, batch_id_per_token_output, cu_seqlens_q_output, _ = get_padding_offset( + useless_input_ids, seq_lens_output, - max_len, + None, + None, + output_token_num.item(), ) - else: - token_num = paddle.sum(seq_lens_this_time) - ( - ids_remove_padding, - batch_id_per_token, - cu_seqlens_q, - cu_seqlens_k, - ) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time) + return ( ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k, - output_cum_offsets, - output_padding_offset, + cu_seqlens_q_output, + batch_id_per_token_output, ) @@ -841,8 +832,8 @@ def rebuild_padding( seq_len_this_time: paddle.Tensor, seq_lens_decoder: paddle.Tensor, seq_lens_encoder: paddle.Tensor, - output_padding_offset: Optional[paddle.Tensor] = None, - max_input_length: Optional[int] = None, + batch_id_per_token_output: Optional[paddle.Tensor] = None, + cu_seqlens_q_output: Optional[paddle.Tensor] = None, first_token_out: Optional[paddle.Tensor] = None, enable_logprob: Optional[bool] = False, ): @@ -859,9 +850,9 @@ def rebuild_padding( seq_len_this_time, seq_lens_decoder, seq_lens_encoder, - output_padding_offset, + batch_id_per_token_output, + cu_seqlens_q_output, first_token_out, - max_input_length, enable_logprob, ) elif current_platform.is_dcu(): @@ -873,8 +864,7 @@ def rebuild_padding( seq_len_this_time, seq_lens_decoder, seq_lens_encoder, - output_padding_offset, - max_input_length, + batch_id_per_token_output, ) elif current_platform.is_iluvatar(): from fastdeploy.model_executor.ops.iluvatar import rebuild_padding @@ -885,9 +875,8 @@ def rebuild_padding( seq_len_this_time, seq_lens_decoder, seq_lens_encoder, - output_padding_offset, + batch_id_per_token_output, first_token_out, - max_input_length, enable_logprob, ) elif current_platform.is_gcu(): @@ -899,8 +888,7 @@ def rebuild_padding( seq_len_this_time, seq_lens_decoder, seq_lens_encoder, - output_padding_offset, - max_input_length, + batch_id_per_token_output, ) elif current_platform.is_cpu(): from fastdeploy.model_executor.ops.cpu import rebuild_padding_cpu @@ -911,8 +899,7 @@ def rebuild_padding( seq_len_this_time, seq_lens_decoder, seq_lens_encoder, - output_padding_offset, - max_input_length, + batch_id_per_token_output, ) elif current_platform.is_maca(): from fastdeploy.model_executor.ops.gpu import rebuild_padding @@ -923,9 +910,8 @@ def rebuild_padding( seq_len_this_time, seq_lens_decoder, seq_lens_encoder, - output_padding_offset, + batch_id_per_token_output, first_token_out, - max_input_length, enable_logprob, ) else: diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 8f2df69f00..3a6e5e7316 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -750,7 +750,14 @@ class MTPProposer(Proposer): self.model_inputs["seq_lens_encoder"], self.model_inputs["seq_lens_decoder"], self.model_inputs["step_idx"], - self.model_inputs["output_cum_offsets"], + # Note(ZKK): + # I strongly advise xpu student delete the fuck `output_cum_offsets` name in XPU backend + # like my pr https://github.com/PaddlePaddle/FastDeploy/pull/6358 + ( + self.model_inputs["cu_seqlens_q_output"] + if current_platform.is_cuda() + else self.model_inputs["output_cum_offsets"] + ), self.model_inputs["stop_flags"], self.model_inputs["not_need_stop"], self.model_inputs["max_dec_len"], @@ -805,8 +812,8 @@ class MTPProposer(Proposer): batch_id_per_token, cu_seqlens_q, cu_seqlens_k, - output_cum_offsets, - output_padding_offset, + cu_seqlens_q_output, + batch_id_per_token_output, ) = pre_process( token_num_cpu, self.model_inputs["input_ids"], @@ -841,8 +848,8 @@ class MTPProposer(Proposer): self.model_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False) # For speculative decoding - self.model_inputs["output_cum_offsets"].copy_(output_cum_offsets, False) - self.model_inputs["output_padding_offset"].copy_(output_padding_offset, False) + self.model_inputs["cu_seqlens_q_output"].copy_(cu_seqlens_q_output, False) + self.model_inputs["batch_id_per_token_output"].copy_(batch_id_per_token_output, False) # Initialize forward meta data self._initialize_forward_meta( @@ -891,8 +898,8 @@ class MTPProposer(Proposer): self.model_inputs["seq_lens_this_time"], self.model_inputs["seq_lens_decoder"], self.model_inputs["seq_lens_encoder"], - self.model_inputs["output_padding_offset"], - self.model_config.max_model_len, + self.model_inputs["batch_id_per_token_output"], + self.model_inputs["cu_seqlens_q_output"], self.model_inputs["first_token_hidden_states"], self.enable_logprob if substep == 0 else False, ) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 94813313a3..7b138d2abb 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1301,8 +1301,8 @@ class GPUModelRunner(ModelRunnerBase): batch_id_per_token, cu_seqlens_q, cu_seqlens_k, - output_cum_offsets, - output_padding_offset, + cu_seqlens_q_output, + batch_id_per_token_output, ) = pre_process( token_num, self.share_inputs["input_ids"], @@ -1321,8 +1321,8 @@ class GPUModelRunner(ModelRunnerBase): # For speculative decoding if self.speculative_decoding: - self.share_inputs["output_cum_offsets"].copy_(output_cum_offsets, False) - self.share_inputs["output_padding_offset"].copy_(output_padding_offset, False) + self.share_inputs["cu_seqlens_q_output"].copy_(cu_seqlens_q_output, False) + self.share_inputs["batch_id_per_token_output"].copy_(batch_id_per_token_output, False) # Initialize forward meta data self.initialize_forward_meta(is_dummy_or_profile_run=is_dummy_or_profile_run) @@ -1939,10 +1939,8 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["seq_lens_this_time"], self.share_inputs["seq_lens_decoder"], self.share_inputs["seq_lens_encoder"], - ( - self.share_inputs["output_padding_offset"] if self.speculative_decoding else None - ), # speculative decoding requires - self.model_config.max_model_len, + (self.share_inputs["batch_id_per_token_output"] if self.speculative_decoding else None), + (self.share_inputs["cu_seqlens_q_output"] if self.speculative_decoding else None), ) self._dummy_sampler_run(hidden_states, model_output, accept_all_drafts, reject_all_drafts) @@ -2401,8 +2399,8 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["seq_lens_this_time"], self.share_inputs["seq_lens_decoder"], self.share_inputs["seq_lens_encoder"], - (self.share_inputs["output_padding_offset"] if self.speculative_decoding else None), - self.model_config.max_model_len, + (self.share_inputs["batch_id_per_token_output"] if self.speculative_decoding else None), + (self.share_inputs["cu_seqlens_q_output"] if self.speculative_decoding else None), ) # 4. Compute logits, Sample diff --git a/fastdeploy/worker/input_batch.py b/fastdeploy/worker/input_batch.py index a1ef260b81..82397be8f9 100644 --- a/fastdeploy/worker/input_batch.py +++ b/fastdeploy/worker/input_batch.py @@ -272,12 +272,20 @@ class InputBatch: fill_value=max_draft_token_num, dtype="int32", ) - self.output_cum_offsets = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") - self.output_padding_offset = paddle.full( - shape=[max_num_seqs * (max_draft_token_num + 1)], - fill_value=0, - dtype="int32", - ) + if current_platform.is_cuda(): + self.cu_seqlens_q_output = paddle.full(shape=[max_num_seqs + 1, 1], fill_value=0, dtype="int32") + self.batch_id_per_token_output = paddle.full( + shape=[max_num_seqs * (max_draft_token_num + 1)], + fill_value=0, + dtype="int32", + ) + else: + self.output_cum_offsets = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") + self.output_padding_offset = paddle.full( + shape=[max_num_seqs * (max_draft_token_num + 1)], + fill_value=0, + dtype="int32", + ) # For V1_KVCACHE_SCHEDULER self.step_draft_tokens = paddle.full( shape=[max_num_seqs, max_draft_token_num + 1], @@ -404,7 +412,10 @@ class InputBatch: swap_data(self.accept_num, i1, i2) swap_data(self.draft_tokens, i1, i2) swap_data(self.actual_draft_token_num, i1, i2) - swap_data(self.output_cum_offsets, i1, i2) + if current_platform.is_cuda(): + swap_data(self.cu_seqlens_q_output, i1, i2) + else: + swap_data(self.output_cum_offsets, i1, i2) swap_data(self.step_draft_tokens, i1, i2) swap_data(self.step_seq_lens_this_time, i1, i2) swap_data(self.draft_logits, i1, i2) @@ -512,8 +523,12 @@ class ProposerInputBatch(InputBatch): self.stop_flags = paddle.clone(self.target_model_input_batch["stop_flags"]) self.not_need_stop = paddle.to_tensor([False], dtype="bool", place="cpu") self.pre_ids = paddle.clone(self.target_model_input_batch["pre_ids"]) - self.output_cum_offsets = paddle.clone(self.target_model_input_batch["output_cum_offsets"]) - self.output_padding_offset = paddle.clone(self.target_model_input_batch["output_padding_offset"]) + if current_platform.is_cuda(): + self.cu_seqlens_q_output = paddle.clone(self.target_model_input_batch["cu_seqlens_q_output"]) + self.batch_id_per_token_output = paddle.clone(self.target_model_input_batch["batch_id_per_token_output"]) + else: + self.output_cum_offsets = paddle.clone(self.target_model_input_batch["output_cum_offsets"]) + self.output_padding_offset = paddle.clone(self.target_model_input_batch["output_padding_offset"]) self.ids_remove_padding = paddle.clone(self.target_model_input_batch["ids_remove_padding"]) self.batch_id_per_token = paddle.clone(self.target_model_input_batch["batch_id_per_token"]) self.cu_seqlens_q = paddle.clone(self.target_model_input_batch["cu_seqlens_q"]) @@ -659,8 +674,12 @@ class ProposerInputBatch(InputBatch): swap_data(self.stop_flags, i1, i2) swap_data(self.not_need_stop, i1, i2) swap_data(self.pre_ids, i1, i2) - swap_data(self.output_cum_offsets, i1, i2) - swap_data(self.output_padding_offset, i1, i2) + if current_platform.is_cuda(): + swap_data(self.cu_seqlens_q_output, i1, i2) + swap_data(self.batch_id_per_token_output, i1, i2) + else: + swap_data(self.output_cum_offsets, i1, i2) + swap_data(self.output_padding_offset, i1, i2) swap_data(self.ids_remove_padding, i1, i2) swap_data(self.batch_id_per_token, i1, i2) swap_data(self.cu_seqlens_q, i1, i2) diff --git a/tests/layers/test_speculative_sampler.py b/tests/layers/test_speculative_sampler.py index 841e1099e3..47de1aad06 100644 --- a/tests/layers/test_speculative_sampler.py +++ b/tests/layers/test_speculative_sampler.py @@ -107,10 +107,11 @@ def _create_fd_config(max_model_len): def _create_share_inputs(max_num_seqs, max_draft_token_num, max_model_len, vocab_size): share_inputs = {} share_inputs["seq_lens_this_time"] = paddle.full([max_num_seqs, 1], 2, dtype="int32") - share_inputs["output_cum_offsets"] = paddle.concat( - [(max_model_len - share_inputs["seq_lens_this_time"][i]) * i for i in range(max_num_seqs)] - ) - share_inputs["output_padding_offset"] = paddle.repeat_interleave(share_inputs["output_cum_offsets"], 2) + + cu_seqlens_q_output = [0] + paddle.cumsum(share_inputs["seq_lens_this_time"]).numpy().tolist() + share_inputs["cu_seqlens_q_output"] = paddle.to_tensor(cu_seqlens_q_output).cast("int32") + share_inputs["batch_id_per_token_output"] = paddle.arange(max_num_seqs, dtype="int32") * 2 + share_inputs["accept_tokens"] = paddle.full( shape=[max_num_seqs, max_draft_token_num + 1], fill_value=0, dtype="int64" ) diff --git a/tests/operators/test_draft_model_update.py b/tests/operators/test_draft_model_update.py index e69f201f93..e2e2f53522 100644 --- a/tests/operators/test_draft_model_update.py +++ b/tests/operators/test_draft_model_update.py @@ -43,7 +43,7 @@ def draft_model_update_kernel( seq_lens_encoder, seq_lens_decoder, step_idx, - output_cum_offsets, + cu_seqlens_q_output, stop_flags, not_need_stop, max_dec_len, @@ -64,7 +64,7 @@ def draft_model_update_kernel( draft_token_now = draft_tokens[tid] pre_ids_now = pre_ids[tid] base_model_draft_tokens_now = base_model_draft_tokens[tid] - next_tokens_start_id = tid * max_seq_len - output_cum_offsets[tid] + next_tokens_start_id = cu_seqlens_q_output[tid] # next_tokens_start = seq_len_this_time = seq_lens_this_time[tid] seq_len_encoder = seq_lens_encoder[tid] @@ -130,7 +130,7 @@ def draft_model_update_ref( seq_lens_encoder, seq_lens_decoder, step_idx, - output_cum_offsets, + cu_seqlens_q_output, stop_flags, not_need_stop, max_dec_len, @@ -161,7 +161,7 @@ def draft_model_update_ref( seq_lens_encoder, seq_lens_decoder, step_idx, - output_cum_offsets, + cu_seqlens_q_output, stop_flags, not_need_stop, max_dec_len, @@ -200,8 +200,8 @@ class TestDraftModelUpdate(unittest.TestCase): seq_lens_encoder = paddle.randint(1, 10, shape=(max_bsz,), dtype="int32") seq_lens_decoder = paddle.randint(1, 10, shape=(max_bsz,), dtype="int32") step_idx = paddle.randint(1, 10, shape=(max_bsz,), dtype="int64") - output_cum_offsets = paddle.randint(0, 2, shape=(max_bsz,), dtype="int32") - output_cum_offsets[0] = 0 + cu_seqlens_q_output = paddle.randint(0, 2, shape=(max_bsz,), dtype="int32") + cu_seqlens_q_output[0] = 0 stop_flags = paddle.zeros([max_bsz], dtype="bool") not_need_stop = paddle.zeros([1], dtype="bool").to(device=paddle.CPUPlace()) max_dec_len = paddle.randint(100, 102, shape=(max_bsz,), dtype="int64") @@ -216,7 +216,7 @@ class TestDraftModelUpdate(unittest.TestCase): seq_lens_encoder, seq_lens_decoder, step_idx, - output_cum_offsets, + cu_seqlens_q_output, stop_flags, not_need_stop, max_dec_len, diff --git a/tests/operators/test_reasoning_phase_token_constraint.py b/tests/operators/test_reasoning_phase_token_constraint.py index 486a6341aa..58b32545d1 100644 --- a/tests/operators/test_reasoning_phase_token_constraint.py +++ b/tests/operators/test_reasoning_phase_token_constraint.py @@ -4,8 +4,8 @@ import numpy as np import paddle from fastdeploy.model_executor.ops.gpu import ( + get_padding_offset, reasoning_phase_token_constraint, - speculate_get_output_padding_offset, ) @@ -76,13 +76,14 @@ class TestReasoningPhaseTokenConstraint(unittest.TestCase): seq_lens_output = paddle.to_tensor([2, 2], dtype="int32") output_token_num = paddle.sum(seq_lens_output) - output_cum_offsets_tmp = paddle.cumsum(self.max_seq_len - seq_lens_output, dtype="int32") - self.output_padding_offset, self.output_cum_offsets = speculate_get_output_padding_offset( - output_cum_offsets_tmp, - output_token_num, + useless_inputs = paddle.zeros([self.bs, self.max_seq_len], dtype="int64") + _, self.output_padding_offset, self.output_cum_offsets, _ = get_padding_offset( + useless_inputs, seq_lens_output, - self.max_seq_len, + None, + None, + output_token_num.item(), ) # self.output_padding_offset = paddle.zeros([self.token_num], dtype="int32") @@ -445,13 +446,14 @@ class TestReasoningPhaseTokenConstraint(unittest.TestCase): seq_lens_output = paddle.full(bs, 2, dtype="int32") output_token_num = paddle.sum(seq_lens_output) - output_cum_offsets_tmp = paddle.cumsum(max_seq_len - seq_lens_output, dtype="int32") - output_padding_offset, output_cum_offsets = speculate_get_output_padding_offset( - output_cum_offsets_tmp, - output_token_num, + useless_inputs = paddle.zeros([self.bs, self.max_seq_len], dtype="int64") + _, output_padding_offset, output_cum_offsets, _ = get_padding_offset( + useless_inputs, seq_lens_output, - max_seq_len, + None, + None, + output_token_num.item(), ) # ------------------------ diff --git a/tests/operators/test_rebuild_padding.py b/tests/operators/test_rebuild_padding.py index 6b8db57016..ad80b587f1 100644 --- a/tests/operators/test_rebuild_padding.py +++ b/tests/operators/test_rebuild_padding.py @@ -49,21 +49,21 @@ def RebuildAppendPaddingKernel( seq_len_this_time, seq_lens_decoder, seq_lens_encoder, - output_padding_offset, - max_input_length, + batch_id_per_token_output, + cu_seqlens_q_output, token_num, need_delete_token_num, ): for token_id in range(token_num - need_delete_token_num): - bi = int(token_id / max_input_length) + bi = batch_id_per_token_output[token_id] if seq_len_this_time[bi] == 0 or (seq_lens_decoder[bi] == 0 and seq_lens_encoder[bi] == 0): continue - ori_token_id = token_id + output_padding_offset[token_id] seq_id = 0 if seq_lens_encoder[bi] > 0: seq_id = seq_lens_encoder[bi] - 1 - cum_offset_bi = bi * max_input_length - cu_seqlens_q[bi] - input_token_id = ori_token_id - cum_offset_bi + seq_id + else: + seq_id = token_id - cu_seqlens_q_output[bi] + input_token_id = cu_seqlens_q[bi] + seq_id out[token_id] = tmp_out[input_token_id][:] @@ -73,8 +73,8 @@ def rebuild_padding_ref( seq_len_this_time, seq_lens_decoder, seq_lens_encoder, - output_padding_offset, - max_input_length, + batch_id_per_token_output, + cu_seqlens_q_output, ): tmp_out_shape = tmp_out.shape @@ -83,7 +83,7 @@ def rebuild_padding_ref( bsz = cu_seqlens_q.shape[0] - 1 out = np.zeros([bsz, dim_embed]) - if output_padding_offset is not None: + if batch_id_per_token_output is not None: need_delete_token_num = 0 for i in range(bsz): if seq_lens_encoder[i] > 0: @@ -92,7 +92,7 @@ def rebuild_padding_ref( else: out = np.zeros([bsz, dim_embed]) - if output_padding_offset is not None: + if batch_id_per_token_output is not None: RebuildAppendPaddingKernel( out, tmp_out, @@ -100,8 +100,8 @@ def rebuild_padding_ref( seq_len_this_time, seq_lens_decoder, seq_lens_encoder, - output_padding_offset, - max_input_length, + batch_id_per_token_output, + cu_seqlens_q_output, token_num, need_delete_token_num, ) @@ -124,7 +124,6 @@ class TestRebuildPadding(unittest.TestCase): token_num = 100 dim_embed = 256 # bsz = 4 - max_input_length = 512 # tmp_out: [token_num, dim_embed] tmp_out = np.random.randn(token_num, dim_embed).astype(np.float32) # cu_seqlens_q: [bsz + 1],accumulate the number of tokens for each batch. @@ -142,8 +141,8 @@ class TestRebuildPadding(unittest.TestCase): seq_len_this_time=seq_len_this_time, seq_lens_decoder=seq_lens_decoder, seq_lens_encoder=seq_lens_encoder, - output_padding_offset=None, - max_input_length=max_input_length, + batch_id_per_token_output=None, + cu_seqlens_q_output=None, ) tmp_out = paddle.to_tensor(tmp_out) @@ -160,7 +159,7 @@ class TestRebuildPadding(unittest.TestCase): seq_lens_encoder, None, None, - max_input_length, + None, False, ) np.testing.assert_allclose(out_no_offset.numpy(), out_no_offset_ref) @@ -171,7 +170,6 @@ class TestRebuildPadding(unittest.TestCase): token_num = 84 dim_embed = 256 # bsz = 4 - max_input_length = 512 # tmp_out: [token_num, dim_embed] tmp_out = np.random.randn(token_num, dim_embed).astype(np.float32) # cu_seqlens_q: [bsz + 1],accumulate the number of tokens for each batch. @@ -184,16 +182,19 @@ class TestRebuildPadding(unittest.TestCase): seq_lens_encoder = np.array([0, 20, 0, 20, 0, 20, 0, 20], dtype=np.int32) seq_lens_decoder = np.array([21, 0, 21, 0, 21, 0, 21, 0], dtype=np.int32) - num_output_tokens = 8 - output_padding_offset = np.random.randint(0, 10, [num_output_tokens], dtype=np.int32) + batch_id_per_token_output = np.array([0, 1, 2, 3, 4, 5, 6, 7], dtype=np.int32) + batch_id_per_token_output = paddle.to_tensor(batch_id_per_token_output) + cu_seqlens_q_output = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8], dtype=np.int32) + cu_seqlens_q_output = paddle.to_tensor(cu_seqlens_q_output) + out_with_offset_ref = rebuild_padding_ref( tmp_out=tmp_out, cu_seqlens_q=cu_seqlens_q, seq_len_this_time=seq_len_this_time, seq_lens_decoder=seq_lens_decoder, seq_lens_encoder=seq_lens_encoder, - output_padding_offset=output_padding_offset, - max_input_length=max_input_length, + batch_id_per_token_output=batch_id_per_token_output, + cu_seqlens_q_output=cu_seqlens_q_output, ) tmp_out = paddle.to_tensor(tmp_out) @@ -201,16 +202,16 @@ class TestRebuildPadding(unittest.TestCase): seq_len_this_time = paddle.to_tensor(seq_len_this_time) seq_lens_decoder = paddle.to_tensor(seq_lens_decoder) seq_lens_encoder = paddle.to_tensor(seq_lens_encoder) - output_padding_offset = paddle.to_tensor(output_padding_offset) + batch_id_per_token_output = paddle.to_tensor(batch_id_per_token_output) out_with_offset = rebuild_padding( tmp_out, cu_seqlens_q, seq_len_this_time, seq_lens_decoder, seq_lens_encoder, - output_padding_offset, + batch_id_per_token_output, + cu_seqlens_q_output, None, - max_input_length, False, ) np.testing.assert_allclose(out_with_offset.numpy(), out_with_offset_ref) diff --git a/tests/operators/test_speculate_get_output_padding_offset.py b/tests/operators/test_speculate_get_output_padding_offset.py deleted file mode 100644 index fc973b75e3..0000000000 --- a/tests/operators/test_speculate_get_output_padding_offset.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -import paddle - -from fastdeploy.model_executor.ops.gpu import speculate_get_output_padding_offset - - -class TestSpeculateGetOutputPaddingOffset(unittest.TestCase): - def test_speculate_get_output_padding_offset(self): - bsz = 256 - max_seq_len = 8192 - - seq_lens_output = np.random.randint(0, 4, size=bsz) - output_token_num = np.sum(seq_lens_output) - - seq_lens_output = paddle.to_tensor(seq_lens_output, dtype="int32") - out_token_num = paddle.sum(seq_lens_output).astype("int32") - output_cum_offsets_tmp = paddle.cumsum(max_seq_len - seq_lens_output).astype("int32") - - output_padding_offset_gpu, output_cum_offsets_gpu = speculate_get_output_padding_offset( - output_cum_offsets_tmp, out_token_num, seq_lens_output, max_seq_len - ) - - output_padding_offset_ref = [-1] * output_token_num - output_cum_offsets_ref = [-1] * bsz - - for bi in range(bsz): - cum_offset = 0 if bi == 0 else output_cum_offsets_tmp[bi - 1] - output_cum_offsets_ref[bi] = cum_offset - for token_i in range(seq_lens_output[bi]): - output_padding_offset_ref[bi * max_seq_len - cum_offset + token_i] = cum_offset - - np.testing.assert_allclose(output_padding_offset_gpu, output_padding_offset_ref) - np.testing.assert_allclose(output_cum_offsets_gpu, output_cum_offsets_ref) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/operators/test_speculate_get_token_penalty_multi_scores.py b/tests/operators/test_speculate_get_token_penalty_multi_scores.py index e638da1d3c..83aca759db 100644 --- a/tests/operators/test_speculate_get_token_penalty_multi_scores.py +++ b/tests/operators/test_speculate_get_token_penalty_multi_scores.py @@ -25,8 +25,8 @@ def min_length_logits_process( cur_len, min_len, eos_token_id, - output_padding_offset, - output_cum_offsets, + batch_id_per_token_output, + cu_seqlens_q_output, token_num, bs, length, @@ -34,11 +34,11 @@ def min_length_logits_process( max_seq_len, ): for token_idx in range(token_num): - bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len + bi = batch_id_per_token_output[token_idx] bi = bi.astype(paddle.int32) if bi >= bs: continue - query_start_token_idx = bi * max_seq_len - output_cum_offsets[bi] + query_start_token_idx = cu_seqlens_q_output[bi] if cur_len[bi] < 0: continue @@ -48,10 +48,10 @@ def min_length_logits_process( def update_repeat_times( - pre_ids, cur_len, repeat_times, output_padding_offset, token_num, bs, length, length_id, max_seq_len + pre_ids, cur_len, repeat_times, batch_id_per_token_output, token_num, bs, length, length_id, max_seq_len ): for token_idx in range(token_num): - bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len + bi = batch_id_per_token_output[token_idx] bi = bi.astype(paddle.int32) if bi >= bs: continue @@ -75,14 +75,14 @@ def update_value_by_repeat_times( presence_score, temperatures, logits, - output_padding_offset, + batch_id_per_token_output, token_num, bs, length, max_seq_len, ): for token_idx in range(token_num): - bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len + bi = batch_id_per_token_output[token_idx] bi = bi.astype(paddle.int32) if bi >= bs: continue @@ -102,10 +102,18 @@ def update_value_by_repeat_times( def ban_bad_words( - logits, bad_words_list, bad_words_len, output_padding_offset, token_num, bs, length, bad_words_length, max_seq_len + logits, + bad_words_list, + bad_words_len, + batch_id_per_token_output, + token_num, + bs, + length, + bad_words_length, + max_seq_len, ): for token_idx in range(token_num): - bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len + bi = batch_id_per_token_output[token_idx] bi = bi.astype(paddle.int32) if bi >= bs: continue @@ -133,8 +141,8 @@ def speculate_get_token_penalty_multi_scores_ref( min_len, eos_token_id, seq_lens_this_time, - output_padding_offset, - output_cum_offsets, + batch_id_per_token_output, + cu_seqlens_q_output, max_seq_len, ): shape = logits.shape @@ -152,8 +160,8 @@ def speculate_get_token_penalty_multi_scores_ref( cur_len, min_len, eos_token_id, - output_padding_offset, - output_cum_offsets, + batch_id_per_token_output, + cu_seqlens_q_output, token_num, bs, length, @@ -162,7 +170,7 @@ def speculate_get_token_penalty_multi_scores_ref( ) update_repeat_times( - pre_ids, cur_len, repeat_times, output_padding_offset, token_num, bs, length, length_id, max_seq_len + pre_ids, cur_len, repeat_times, batch_id_per_token_output, token_num, bs, length, length_id, max_seq_len ) update_value_by_repeat_times( @@ -172,7 +180,7 @@ def speculate_get_token_penalty_multi_scores_ref( presence_score, temperatures, logits, - output_padding_offset, + batch_id_per_token_output, token_num, bs, length, @@ -180,7 +188,15 @@ def speculate_get_token_penalty_multi_scores_ref( ) ban_bad_words( - logits, bad_tokens, bad_tokens_len, output_padding_offset, token_num, bs, length, length_bad_words, max_seq_len + logits, + bad_tokens, + bad_tokens_len, + batch_id_per_token_output, + token_num, + bs, + length, + length_bad_words, + max_seq_len, ) @@ -193,21 +209,21 @@ class TestSpeculateGetTokenPenaltyMultiScores(unittest.TestCase): max_seq_len = 1024 # 1024 #2048 #8192 data_type = "float32" - # prepare output_padding_offset and output_cum_offsets + # prepare batch_id_per_token_output and cu_seqlens_q_output tokens = [1] * bs token_num = np.sum(tokens) - output_padding_offset = [] - output_cum_offsets = [0] + batch_id_per_token_output = [] + cu_seqlens_q_output = [0] opo_offset = 0 for bid in range(bs): ts = tokens[bid] for i in range(ts): - output_padding_offset.append(opo_offset) + batch_id_per_token_output.append(opo_offset) opo_offset += max_seq_len - ts - output_cum_offsets.append(opo_offset) - output_cum_offsets = output_cum_offsets[:-1] - output_padding_offset = paddle.to_tensor(output_padding_offset, "int32") - output_cum_offsets = paddle.to_tensor(output_cum_offsets, "int32") + cu_seqlens_q_output.append(opo_offset) + cu_seqlens_q_output = cu_seqlens_q_output[:-1] + batch_id_per_token_output = paddle.to_tensor(batch_id_per_token_output, "int32") + cu_seqlens_q_output = paddle.to_tensor(cu_seqlens_q_output, "int32") # prepare pre_ids and logits pre_ids_len = 122 @@ -245,8 +261,8 @@ class TestSpeculateGetTokenPenaltyMultiScores(unittest.TestCase): min_len, eos_token_id, seq_len_this_time, - output_padding_offset, - output_cum_offsets, + batch_id_per_token_output, + cu_seqlens_q_output, max_seq_len, ) # inplace modify, not return data diff --git a/tests/operators/test_speculate_verify.py b/tests/operators/test_speculate_verify.py index 26e96f3085..87bf82579f 100644 --- a/tests/operators/test_speculate_verify.py +++ b/tests/operators/test_speculate_verify.py @@ -62,7 +62,7 @@ def speculate_verify_ref( max_dec_len, end_tokens, is_block_step, - output_cum_offsets, + cu_seqlens_q_output, actual_candidate_len, actual_draft_token_nums, topp, @@ -106,7 +106,7 @@ def speculate_verify_ref( verify_tokens_flat = verify_tokens.reshape(-1) verify_scores_flat = verify_scores.reshape(-1) for bid in range(real_bsz): - start_token_id = bid * max_seq_len - output_cum_offsets[bid] + start_token_id = cu_seqlens_q_output[bid] accept_num_now = 1 stop_flag_now_int = 0 @@ -274,11 +274,11 @@ def gen_speculate_verify_inputs( end_tokens = rng.integers(1, 1000, size=end_length, dtype=np.int64) is_block_step = rng.integers(0, 2, size=real_bsz, dtype=bool) - # output_cum_offsets = np.zeros_like(seq_lens_this_time) - # output_cum_offsets[1:] = np.cumsum(seq_lens_this_time[:-1]) + # cu_seqlens_q_output = np.zeros_like(seq_lens_this_time) + # cu_seqlens_q_output[1:] = np.cumsum(seq_lens_this_time[:-1]) blank_lengths = max_seq_len - seq_lens_this_time - output_cum_offsets = np.concatenate([[0], np.cumsum(blank_lengths[:-1])]) - output_cum_offsets = output_cum_offsets.astype("int32") + cu_seqlens_q_output = np.concatenate([[0], np.cumsum(blank_lengths[:-1])]) + cu_seqlens_q_output = cu_seqlens_q_output.astype("int32") actual_candidate_len = rng.integers(1, max_candidate_len + 1, size=sum_seq_this_time, dtype=np.int32) topp = ( @@ -309,7 +309,7 @@ def gen_speculate_verify_inputs( "max_dec_len": max_dec_len, "end_tokens": end_tokens, "is_block_step": is_block_step, - "output_cum_offsets": output_cum_offsets, + "cu_seqlens_q_output": cu_seqlens_q_output, "actual_candidate_len": actual_candidate_len, "actual_draft_token_nums": actual_draft_token_nums, "topp": topp, diff --git a/tests/spec_decode/test_mtp_proposer.py b/tests/spec_decode/test_mtp_proposer.py index 4d1bd51bed..6679ce4286 100644 --- a/tests/spec_decode/test_mtp_proposer.py +++ b/tests/spec_decode/test_mtp_proposer.py @@ -71,8 +71,8 @@ class TestMTPProposer(unittest.TestCase): "step_idx": paddle.zeros([2, 1], dtype="int64"), "stop_flags": paddle.zeros([2, 1], dtype="bool"), "pre_ids": paddle.zeros([2, 2048], dtype="int64"), - "output_cum_offsets": paddle.zeros([2], dtype="int32"), - "output_padding_offset": paddle.zeros([2], dtype="int32"), + "cu_seqlens_q_output": paddle.zeros([3], dtype="int32"), + "batch_id_per_token_output": paddle.zeros([2], dtype="int32"), "ids_remove_padding": paddle.zeros([2], dtype="int64"), "batch_id_per_token": paddle.zeros([2], dtype="int32"), "cu_seqlens_q": paddle.zeros([3], dtype="int32"),