[Cherry-Pick][CI]Support different inferseed in speculate decoding(#5568) (#5597)

* fix mtp entropy drop in RL

* optimize usage and fix unit test

* optimize padding_sampling_params speed(vectorized)
This commit is contained in:
freeliuzc
2025-12-17 16:53:47 +08:00
committed by GitHub
parent c19af496cb
commit a7359d1c1d
3 changed files with 99 additions and 14 deletions
+59 -1
View File
@@ -30,6 +30,7 @@ from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.sampler import (
MTPSampler,
SpeculativeSampler,
padding_sampling_params,
)
@@ -72,7 +73,7 @@ def _create_default_sampling_metadata(
bad_words_token_ids=paddle.full(shape=[batch_size], fill_value=-1, dtype="int64"),
eos_token_ids=paddle.full(shape=[batch_size], fill_value=-2, dtype="int64"),
min_p=paddle.randn([batch_size]),
seed=paddle.to_tensor([[2025]]),
seed=paddle.full(shape=[batch_size], fill_value=0, dtype="int64"),
)
if max_num_logprobs is not None:
fake_sampling_metadata.max_num_logprobs = max_num_logprobs
@@ -143,6 +144,19 @@ def _create_share_inputs(max_num_seqs, max_draft_token_num, max_model_len, vocab
return share_inputs
def _create_padding_inputs():
# batch_size = 3
top_p = paddle.to_tensor([[0.9], [0.8], [0.7], [1.0]], dtype="float32")
top_k = paddle.to_tensor([[10], [20], [30], [40]], dtype="int32")
infer_seed = paddle.to_tensor([[100], [200], [300], [400]], dtype="int64")
# decoder, encoder, decoder
seq_lens_encoder = paddle.to_tensor([[0], [5], [0], [0]], dtype="int32")
seq_lens_this_time = paddle.to_tensor([[3], [2], [0], [2]], dtype="int32")
return top_p, top_k, infer_seed, seq_lens_this_time, seq_lens_encoder
def test_speculative_sampler():
batch_size = 32
vocab_size = 1024
@@ -220,8 +234,52 @@ def test_mtp_sampler_logprobs():
sampler(logits, sampling_metadata, max_model_len, share_inputs)
def test_padding_sampling_params_basic():
top_p, top_k, infer_seed, seq_lens_this_time, seq_lens_encoder = _create_padding_inputs()
top_p_pad, top_k_pad, seed_pad = padding_sampling_params(
top_p, top_k, infer_seed, seq_lens_this_time, seq_lens_encoder
)
# decoder(3) + encoder(1) + decoder(2) = 6
assert top_p_pad.shape == [6, 1]
assert top_k_pad.shape == [6, 1]
assert seed_pad.shape == [6, 1]
# top_p padding check
expected_top_p = [0.9, 0.9, 0.9, 0.8, 1.0, 1.0]
assert paddle.allclose(top_p_pad.squeeze(), paddle.to_tensor(expected_top_p, dtype="float32"))
# top_k padding check
expected_top_k = [10, 10, 10, 20, 40, 40]
assert paddle.equal_all(top_k_pad.squeeze(), paddle.to_tensor(expected_top_k, dtype="int32"))
def test_padding_sampling_params_seed_offset():
top_p, top_k, infer_seed, seq_lens_this_time, seq_lens_encoder = _create_padding_inputs()
_, _, seed_pad = padding_sampling_params(top_p, top_k, infer_seed, seq_lens_this_time, seq_lens_encoder)
# decoder(0): 100 + 4*k
# encoder(1): 200 (no offset)
# null
# decoder(3): 400 + 4*k
expected_seed = [
100,
104,
108, # first decoder seq (len=3)
200, # encoder
400,
404, # second decoder seq (len=2)
]
assert paddle.equal_all(seed_pad.squeeze(), paddle.to_tensor(expected_seed, dtype="int64"))
if __name__ == "__main__":
test_speculative_sampler()
test_speculative_sampler_logprobs()
test_mtp_sampler()
test_mtp_sampler_logprobs()
test_padding_sampling_params_basic()
test_padding_sampling_params_seed_offset()