From 2e63d88f7a20cfce9ca73752b0d166aad1c63c32 Mon Sep 17 00:00:00 2001 From: huicongyao Date: Thu, 12 Mar 2026 20:05:15 +0800 Subject: [PATCH] [Optimization][Speculative Decoding]Fuse padding sampling params (#6765) * optimize speculate pre process unit test * Add CUDA kernel for building sampling params in speculative decoding * init infer seed in device * format code * add unittest & fix * fix * format-code * format-code * fix rebase * . * fix unitest --- custom_ops/gpu_ops/cpp_extensions.cc | 13 + .../build_sampling_params.cu | 95 +++++++ .../model_executor/layers/sample/sampler.py | 15 +- fastdeploy/worker/gpu_model_runner.py | 18 +- fastdeploy/worker/input_batch.py | 2 +- tests/layers/test_speculative_sampler.py | 10 +- tests/operators/test_build_sampling_params.py | 237 ++++++++++++++++++ tests/operators/test_speculate_pre_process.py | 10 + 8 files changed, 389 insertions(+), 11 deletions(-) create mode 100644 custom_ops/gpu_ops/speculate_decoding/build_sampling_params.cu create mode 100644 tests/operators/test_build_sampling_params.py diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index ecc13c2c9b..cfac8abfd2 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -772,6 +772,15 @@ std::vector SpeculatePreProcess( const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder); +std::vector BuildSamplingParams( + const paddle::Tensor& top_p, + const paddle::Tensor& top_k, + paddle::Tensor& infer_seed, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& cu_seqlens_q_output, + const int64_t token_num_output_cpu, + const int64_t increment_value); + void SpecTokenPenaltyMultiScores( const paddle::Tensor& token_ids_all, const paddle::Tensor& prompt_lens, @@ -1727,6 +1736,10 @@ PYBIND11_MODULE(fastdeploy_ops, m) { &SpeculatePreProcess, "speculate_pre_process function"); + m.def("build_sampling_params", + &BuildSamplingParams, + "build_sampling_params function"); + m.def("speculate_get_token_penalty_multi_scores", &SpecTokenPenaltyMultiScores, "speculate_get_token_penalty_multi_scores function"); diff --git a/custom_ops/gpu_ops/speculate_decoding/build_sampling_params.cu b/custom_ops/gpu_ops/speculate_decoding/build_sampling_params.cu new file mode 100644 index 0000000000..5541abb371 --- /dev/null +++ b/custom_ops/gpu_ops/speculate_decoding/build_sampling_params.cu @@ -0,0 +1,95 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" +#include "paddle/extension.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +constexpr int64_t MAX_INFER_SEED = 9223372036854775806; + +__global__ void BuildSamplingParamsKernel(float *top_p_padding, + int64_t *top_k_padding, + int64_t *topp_seed, + const float *top_p, + const int64_t *top_k, + int64_t *infer_seed, + const int *cu_seqlens_q_output, + const int64_t increment_value) { + const int tid = threadIdx.x; + const int bi = blockIdx.x; + int cur_seq_len_q_output_start = cu_seqlens_q_output[bi]; + int cur_seq_len_q_output_end = cu_seqlens_q_output[bi + 1]; + const float bi_top_p = top_p[bi]; + const int64_t bi_top_k = top_k[bi]; + int64_t bi_infer_seed = (infer_seed[bi] + tid * 4) % MAX_INFER_SEED; + + for (int i = tid; i < cur_seq_len_q_output_end - cur_seq_len_q_output_start; + i += blockDim.x) { + int pad_idx = cur_seq_len_q_output_start + i; + top_p_padding[pad_idx] = bi_top_p; + top_k_padding[pad_idx] = bi_top_k; + topp_seed[pad_idx] = bi_infer_seed; + bi_infer_seed = (bi_infer_seed + blockDim.x * 4) % MAX_INFER_SEED; + } + + if (tid == 0) { + infer_seed[bi] = (infer_seed[bi] + increment_value) % MAX_INFER_SEED; + } +} + +std::vector BuildSamplingParams( + const paddle::Tensor &top_p, + const paddle::Tensor &top_k, + paddle::Tensor &infer_seed, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &cu_seqlens_q_output, + const int64_t token_num_output_cpu, + const int64_t increment_value) { + auto cu_stream = seq_lens_this_time.stream(); + int real_bsz = seq_lens_this_time.shape()[0]; + paddle::Tensor top_p_padding = paddle::empty({token_num_output_cpu, 1}, + paddle::DataType::FLOAT32, + seq_lens_this_time.place()); + paddle::Tensor top_k_padding = paddle::empty({token_num_output_cpu, 1}, + paddle::DataType::INT64, + seq_lens_this_time.place()); + paddle::Tensor topp_seed = paddle::empty({token_num_output_cpu, 1}, + paddle::DataType::INT64, + seq_lens_this_time.place()); + + BuildSamplingParamsKernel<<>>( + top_p_padding.data(), + top_k_padding.data(), + topp_seed.data(), + top_p.data(), + top_k.data(), + infer_seed.data(), + cu_seqlens_q_output.data(), + increment_value); + + return {top_p_padding, top_k_padding, topp_seed}; +} + +PD_BUILD_STATIC_OP(build_sampling_params) + .Inputs({"top_p", + "top_k", + "infer_seed", + "seq_lens_this_time", + "cu_seqlens_q_output"}) + .Outputs({"top_p_padding", "top_k_padding", "topp_seed"}) + .Attrs({"token_num_output_cpu: int64_t", "increment_value: int64_t"}) + .SetKernelFn(PD_KERNEL(BuildSamplingParams)); diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 9997d8c50f..8a45f9d260 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -50,6 +50,9 @@ from fastdeploy.reasoning import ReasoningParser from fastdeploy.spec_decode import SpecMethod, VerifyStrategy from fastdeploy.worker.output import LogprobsTensors, SamplerOutput +if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import build_sampling_params + def top_p_normalize_probs_paddle( probs: paddle.Tensor, @@ -772,6 +775,8 @@ class SpeculativeSampler(nn.Layer): sampling_metadata: SamplingMetadata, max_model_len: int, share_inputs: List[paddle.Tensor], + token_num_output_cpu: int, + increment_value: int, accept_all_drafts: bool = False, reject_all_drafts: bool = False, ) -> SamplerOutput: @@ -806,12 +811,14 @@ class SpeculativeSampler(nn.Layer): if self.verify_strategy == VerifyStrategy.TARGET_MATCH: # Only TARGET_MATCH needs stochastic sampling - top_p, top_k, topp_seed = padding_sampling_params( + top_p, top_k, topp_seed = build_sampling_params( sampling_metadata.top_p, sampling_metadata.top_k, sampling_metadata.seed, share_inputs["seq_lens_this_time"], - share_inputs["seq_lens_encoder"], + share_inputs["cu_seqlens_q_output"], + token_num_output_cpu, + increment_value, ) _, target_tokens = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, topp_seed=topp_seed) elif self.verify_strategy == VerifyStrategy.GREEDY: @@ -922,6 +929,8 @@ class SpeculativeSampler(nn.Layer): sampling_metadata: SamplingMetadata, max_model_len: int, share_inputs: List[paddle.Tensor], + token_num_output_cpu: int, + increment_value: int, accept_all_drafts: bool = False, reject_all_drafts: bool = False, ) -> SamplerOutput: @@ -1001,6 +1010,8 @@ class SpeculativeSampler(nn.Layer): sampling_metadata, max_model_len, share_inputs, + token_num_output_cpu, + increment_value, accept_all_drafts, reject_all_drafts, ) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index d0275fc3e6..554e464bf5 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -198,11 +198,11 @@ class GPUModelRunner(ModelRunnerBase): # Initialize input batch self.share_inputs = InputBatch(self.fd_config) self.share_inputs.init_share_inputs() - increment_value = ( + self.increment_value = ( 4 if not self.speculative_decoding else (self.speculative_config.num_speculative_tokens + 1) * 4 ) self.infer_seed_increment = paddle.full( - shape=[self.scheduler_config.max_num_seqs, 1], fill_value=increment_value, dtype="int64", device="cpu" + shape=[self.scheduler_config.max_num_seqs, 1], fill_value=self.increment_value, dtype="int64", device="cpu" ) self.restore_chunked_prefill_request = dict() @@ -1667,6 +1667,8 @@ class GPUModelRunner(ModelRunnerBase): self.sampling_metadata, self.model_config.max_model_len, self.share_inputs, + int(self._real_output_token_num_host), + self.increment_value, accept_all_drafts, reject_all_drafts, ) @@ -1836,8 +1838,9 @@ class GPUModelRunner(ModelRunnerBase): self._dummy_sampler_run(hidden_states, model_output, batch_size, accept_all_drafts, reject_all_drafts) # 7. Updata 'infer_seed' and step_cuda() - self.share_inputs["infer_seed"].add_(self.infer_seed_increment) - self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED + if not self.speculative_decoding: + self.share_inputs["infer_seed"].add_(self.infer_seed_increment) + self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED if int((self.share_inputs["seq_lens_this_time"] > 0).sum()) == 0: break @@ -2270,6 +2273,8 @@ class GPUModelRunner(ModelRunnerBase): self.sampling_metadata, self.model_config.max_model_len, self.share_inputs, + int(self._real_output_token_num_host), + self.increment_value, ) if self.parallel_config.tensor_parallel_size > 1: paddle.distributed.broadcast( @@ -2369,8 +2374,9 @@ class GPUModelRunner(ModelRunnerBase): self.proposer.run(share_inputs=self.share_inputs) # 7. Update 'infer_seed' and step_cuda() - self.share_inputs["infer_seed"].add_(self.infer_seed_increment) - self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED + if not self.speculative_decoding: + self.share_inputs["infer_seed"].add_(self.infer_seed_increment) + self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED if self.speculative_decoding: speculate_schedule_cache( self.share_inputs["draft_tokens"], diff --git a/fastdeploy/worker/input_batch.py b/fastdeploy/worker/input_batch.py index f2b3e9a88c..4c6c646d27 100644 --- a/fastdeploy/worker/input_batch.py +++ b/fastdeploy/worker/input_batch.py @@ -171,7 +171,7 @@ class InputBatch: self.need_block_list = paddle.full([max_num_seqs], -1, dtype="int32") self.need_block_len = paddle.full([1], 0, dtype="int32") self.used_list_len = paddle.full([max_num_seqs], 0, dtype="int32") - self.infer_seed = paddle.full([max_num_seqs, 1], 0, dtype="int64", device="cpu") + self.infer_seed = paddle.full([max_num_seqs, 1], 0, dtype="int64") self.first_token_ids = paddle.full([max_num_seqs, 1], -1, dtype="int64") self.ori_seq_lens_encoder = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.system_lens = paddle.full([max_num_seqs, 1], 0, dtype="int32") diff --git a/tests/layers/test_speculative_sampler.py b/tests/layers/test_speculative_sampler.py index 7abd94286d..ef75fe5d4e 100644 --- a/tests/layers/test_speculative_sampler.py +++ b/tests/layers/test_speculative_sampler.py @@ -192,8 +192,11 @@ def test_speculative_sampler(): logits = _create_fake_logits(batch_size * (max_draft_token_num + 1), vocab_size) share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size) + token_num_output_cpu = int(share_inputs["cu_seqlens_q_output"][-1]) + increment_value = (max_draft_token_num + 1) * 4 + sampler = SpeculativeSampler(fd_config) - sampler(logits, sampling_metadata, max_model_len, share_inputs) + sampler(logits, sampling_metadata, max_model_len, share_inputs, token_num_output_cpu, increment_value) def test_speculative_sampler_logprobs(): @@ -211,11 +214,14 @@ def test_speculative_sampler_logprobs(): sampling_metadata.share_inputs = share_inputs logits = _create_fake_logits(batch_size * (max_draft_token_num + 1), vocab_size) + token_num_output_cpu = int(share_inputs["cu_seqlens_q_output"][-1]) + increment_value = (max_draft_token_num + 1) * 4 + logprobs_mode_list = ["raw_logprobs", "raw_logits"] for logprobs_mode in logprobs_mode_list: fd_config.model_config.logprobs_mode = logprobs_mode sampler = SpeculativeSampler(fd_config) - sampler(logits, sampling_metadata, max_model_len, share_inputs) + sampler(logits, sampling_metadata, max_model_len, share_inputs, token_num_output_cpu, increment_value) def test_mtp_sampler(): diff --git a/tests/operators/test_build_sampling_params.py b/tests/operators/test_build_sampling_params.py new file mode 100644 index 0000000000..31b2aa745f --- /dev/null +++ b/tests/operators/test_build_sampling_params.py @@ -0,0 +1,237 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.gpu import build_sampling_params + +MAX_INFER_SEED = 9223372036854775806 +BLOCK_DIM = 64 + + +def build_sampling_params_ref( + top_p, + top_k, + infer_seed, + cu_seq_lens_q_output, + token_num_output_cpu, + increment_value, +): + """ + Python reference implementation for BuildSamplingParamsKernel. + + Returns: + top_p_padding: float32[token_num_output_cpu, 1] + top_k_padding: int64[token_num_output_cpu, 1] + topp_seed: int64[token_num_output_cpu, 1] + infer_seed: int64[real_bsz] (updated in-place) + """ + real_bsz = len(top_p) + top_p_padding = np.zeros((token_num_output_cpu, 1), dtype=np.float32) + top_k_padding = np.zeros((token_num_output_cpu, 1), dtype=np.int64) + topp_seed = np.zeros((token_num_output_cpu, 1), dtype=np.int64) + infer_seed = infer_seed.copy() + + for bi in range(real_bsz): + cur_start = cu_seq_lens_q_output[bi] + cur_end = cu_seq_lens_q_output[bi + 1] + bi_top_p = top_p[bi] + bi_top_k = top_k[bi] + + for tid in range(BLOCK_DIM): + bi_infer_seed = (infer_seed[bi] + tid * 4) % MAX_INFER_SEED + i = tid + while i < cur_end - cur_start: + pad_idx = cur_start + i + top_p_padding[pad_idx, 0] = bi_top_p + top_k_padding[pad_idx, 0] = bi_top_k + topp_seed[pad_idx, 0] = bi_infer_seed + bi_infer_seed = (bi_infer_seed + BLOCK_DIM * 4) % MAX_INFER_SEED + i += BLOCK_DIM + + infer_seed[bi] = (infer_seed[bi] + increment_value) % MAX_INFER_SEED + + return top_p_padding, top_k_padding, topp_seed, infer_seed + + +def build_inputs(real_bsz, seq_lens_this_time_list, seq_lens_encoder_list, seed=42): + """ + Helper to build test inputs. + + For prefill requests (seq_lens_encoder > 0), the output length is 1. + For decode requests (seq_lens_encoder == 0), the output length equals seq_lens_this_time. + seq_lens_this_time == 0 means the slot is empty, output length is 0. + """ + rng = np.random.default_rng(seed) + + top_p = rng.uniform(0.0, 1.0, size=(real_bsz,)).astype(np.float32) + top_k = rng.integers(1, 100, size=(real_bsz,)).astype(np.int64) + infer_seed = rng.integers(0, MAX_INFER_SEED, size=(real_bsz,)).astype(np.int64) + + seq_lens_this_time = np.array(seq_lens_this_time_list, dtype=np.int32) + seq_lens_encoder = np.array(seq_lens_encoder_list, dtype=np.int32) + + seq_lens_output = np.zeros(real_bsz, dtype=np.int32) + for bid in range(real_bsz): + if seq_lens_this_time[bid] == 0: + seq_lens_output[bid] = 0 + elif seq_lens_encoder[bid] > 0: + seq_lens_output[bid] = 1 + else: + seq_lens_output[bid] = seq_lens_this_time[bid] + + cu_seq_lens_q_output = np.zeros(real_bsz + 1, dtype=np.int32) + for i in range(real_bsz): + cu_seq_lens_q_output[i + 1] = cu_seq_lens_q_output[i] + seq_lens_output[i] + + token_num_output_cpu = int(cu_seq_lens_q_output[-1]) + + return { + "top_p": top_p, + "top_k": top_k, + "infer_seed": infer_seed, + "seq_lens_this_time": seq_lens_this_time, + "cu_seq_lens_q_output": cu_seq_lens_q_output, + "token_num_output_cpu": token_num_output_cpu, + } + + +def run_and_compare(tc, inputs, increment_value): + """ + Call GPU op and Python reference, compare all outputs. + """ + t_top_p = paddle.to_tensor(inputs["top_p"], dtype="float32") + t_top_k = paddle.to_tensor(inputs["top_k"], dtype="int64") + t_infer_seed = paddle.to_tensor(inputs["infer_seed"], dtype="int64") + t_seq_lens_this_time = paddle.to_tensor(inputs["seq_lens_this_time"], dtype="int32") + t_cu_seq_lens_q_output = paddle.to_tensor(inputs["cu_seq_lens_q_output"], dtype="int32") + token_num_output_cpu = inputs["token_num_output_cpu"] + + gpu_outs = build_sampling_params( + t_top_p, + t_top_k, + t_infer_seed, + t_seq_lens_this_time, + t_cu_seq_lens_q_output, + token_num_output_cpu, + increment_value, + ) + + ref_outs = build_sampling_params_ref( + inputs["top_p"], + inputs["top_k"], + inputs["infer_seed"], + inputs["cu_seq_lens_q_output"], + token_num_output_cpu, + increment_value, + ) + + np.testing.assert_allclose(gpu_outs[0].numpy(), ref_outs[0], rtol=1e-6, err_msg="Mismatch in top_p_padding") + np.testing.assert_allclose(gpu_outs[1].numpy(), ref_outs[1], err_msg="Mismatch in top_k_padding") + np.testing.assert_allclose(gpu_outs[2].numpy(), ref_outs[2], err_msg="Mismatch in topp_seed") + np.testing.assert_allclose(t_infer_seed.numpy(), ref_outs[3], err_msg="Mismatch in infer_seed (in-place update)") + + +class TestBuildSamplingParams(unittest.TestCase): + """Unit tests for build_sampling_params custom operator.""" + + # ---------------------------------------------------------------- + # Test 1: exact golden values — mixed prefill and decode + # bid=0: decode, seq_lens_this_time=2 => output=2 + # bid=1: prefill, seq_lens_this_time=10 => output=1 + # ---------------------------------------------------------------- + def test_exact_golden_values(self): + top_p = np.array([0.9, 0.5], dtype=np.float32) + top_k = np.array([50, 10], dtype=np.int64) + infer_seed = np.array([100, 200], dtype=np.int64) + cu_seq_lens_q_output = np.array([0, 2, 3], dtype=np.int32) + seq_lens_this_time = np.array([2, 10], dtype=np.int32) + + t_top_p = paddle.to_tensor(top_p, dtype="float32") + t_top_k = paddle.to_tensor(top_k, dtype="int64") + t_infer_seed = paddle.to_tensor(infer_seed, dtype="int64") + t_seq_lens_this_time = paddle.to_tensor(seq_lens_this_time, dtype="int32") + t_cu_seq_lens_q_output = paddle.to_tensor(cu_seq_lens_q_output, dtype="int32") + + gpu_outs = build_sampling_params( + t_top_p, + t_top_k, + t_infer_seed, + t_seq_lens_this_time, + t_cu_seq_lens_q_output, + 3, + 1, + ) + + np.testing.assert_allclose(gpu_outs[0].numpy().flatten(), [0.9, 0.9, 0.5], rtol=1e-6) + np.testing.assert_allclose(gpu_outs[1].numpy().flatten(), [50, 50, 10]) + # topp_seed: bi=0 tid=0 => 100, bi=0 tid=1 => 104; bi=1 tid=0 => 200 + np.testing.assert_allclose(gpu_outs[2].numpy().flatten(), [100, 104, 200]) + np.testing.assert_allclose(t_infer_seed.numpy(), [101, 201]) + + # ---------------------------------------------------------------- + # Test 2: mixed prefill/decode batch with reference comparison + # bid=0: decode, seq_lens_this_time=3 => output=3 + # bid=1: prefill, seq_lens_this_time=50 => output=1 + # bid=2: decode, seq_lens_this_time=5 => output=5 + # bid=3: prefill, seq_lens_this_time=100 => output=1 + # bid=4: empty slot => output=0 + # ---------------------------------------------------------------- + def test_mixed_prefill_decode(self): + inputs = build_inputs( + real_bsz=5, + seq_lens_this_time_list=[3, 50, 5, 100, 0], + seq_lens_encoder_list=[0, 50, 0, 100, 0], + seed=300, + ) + self.assertEqual(inputs["token_num_output_cpu"], 10) + run_and_compare(self, inputs, increment_value=5) + + # ---------------------------------------------------------------- + # Test 3: random stress test with mixed prefill/decode configs + # ---------------------------------------------------------------- + def test_random_configs(self): + configs = [ + {"real_bsz": 8, "max_seq_len": 4, "increment_value": 1, "seed": 700}, + {"real_bsz": 32, "max_seq_len": 16, "increment_value": 16, "seed": 800}, + ] + for cfg in configs: + with self.subTest(**cfg): + rng = np.random.default_rng(cfg["seed"]) + real_bsz = cfg["real_bsz"] + max_seq_len = cfg["max_seq_len"] + seq_lens_this_time_list = rng.integers(0, max_seq_len + 1, size=real_bsz).tolist() + seq_lens_encoder_list = [] + for s in seq_lens_this_time_list: + if s > 0 and rng.random() < 0.3: + seq_lens_encoder_list.append(s) + else: + seq_lens_encoder_list.append(0) + + inputs = build_inputs( + real_bsz=real_bsz, + seq_lens_this_time_list=seq_lens_this_time_list, + seq_lens_encoder_list=seq_lens_encoder_list, + seed=cfg["seed"], + ) + if inputs["token_num_output_cpu"] == 0: + continue + run_and_compare(self, inputs, increment_value=cfg["increment_value"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/operators/test_speculate_pre_process.py b/tests/operators/test_speculate_pre_process.py index c64f0fcf48..3b80044174 100644 --- a/tests/operators/test_speculate_pre_process.py +++ b/tests/operators/test_speculate_pre_process.py @@ -249,6 +249,16 @@ class TestSpeculatePreProcess(unittest.TestCase): 0, t_input_ids, t_seq_lens, t_draft_tokens, t_seq_lens_encoder, t_seq_lens_decoder ) self.assertEqual(len(gpu_outs), 7) + self.assertIsNotNone(gpu_outs[-3]) + self.assertIsNotNone(gpu_outs[-2]) + self.assertIsNotNone(gpu_outs[-1]) + # test copy + fake_cu_seqlens_q_output = paddle.empty([real_bsz + 1], dtype="int32") + fake_batch_id_per_token_output = paddle.empty([real_bsz], dtype="int32") + fake_cu_seqlens_q_output.copy_(gpu_outs[-3]) + fake_batch_id_per_token_output.copy_(gpu_outs[-2]) + # test slice + fake_batch_id_per_token_output[: gpu_outs[-1].item()] # ---------------------------------------------------------------- # Test 3: exact token values — manually verify ids_remove_padding