[Speculative Decoding] Refactor Eagle MTP hidden states copy (#6812)

* reformat eagle_get_hidden_states & eagle_get_self_hidden_states

* readibility

* fix xpu bug

* fix coverage failure

* change luanch params & parallelize position_map compute

* Fix MTP-related bugs in FastDeploy centralized inference

* fix

* refactor mtp hidden_states process

* fix

* add unittest & optimize kernel

* remove useless code

* fix
This commit is contained in:
huicongyao
2026-03-26 13:54:31 +08:00
committed by GitHub
parent 4fd877ed43
commit 25d64efdc4
12 changed files with 1309 additions and 383 deletions
@@ -1290,35 +1290,7 @@ class MTPSampler(nn.Layer):
elif self.logprobs_mode == "raw_logits":
raw_logprobs = share_inputs["draft_logits"][:real_token_num, :].clone()
if sampling_metadata.token_ids_all is not None:
token_ids_all = sampling_metadata.token_ids_all
prompt_lens = sampling_metadata.prompt_lens
else:
token_ids_all = sampling_metadata.pre_token_ids
prompt_lens = sampling_metadata.fake_prompt_lens
logits = apply_speculative_penalty_multi_scores(
token_ids_all,
prompt_lens,
logits,
sampling_metadata.repetition_penalties,
sampling_metadata.frequency_penalties,
sampling_metadata.presence_penalties,
sampling_metadata.temperature,
sampling_metadata.bad_words_token_ids,
sampling_metadata.bad_words_token_len,
sampling_metadata.step_idx,
sampling_metadata.min_dec_lens,
sampling_metadata.eos_token_ids,
share_inputs["seq_lens_this_time"],
share_inputs["batch_id_per_token_output"],
share_inputs["cu_seqlens_q_output"],
max_model_len,
sampling_metadata.pre_token_ids,
)
probs = F.softmax(logits)
next_tokens = paddle.argmax(probs, axis=-1)
next_tokens = paddle.argmax(logits, axis=-1)
token_ids = None
logprobs_tensors = None