mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 09:44:10 +08:00
[Speculative Decoding] Refactor Eagle MTP hidden states copy (#6812)
* reformat eagle_get_hidden_states & eagle_get_self_hidden_states * readibility * fix xpu bug * fix coverage failure * change luanch params & parallelize position_map compute * Fix MTP-related bugs in FastDeploy centralized inference * fix * refactor mtp hidden_states process * fix * add unittest & optimize kernel * remove useless code * fix
This commit is contained in:
@@ -1290,35 +1290,7 @@ class MTPSampler(nn.Layer):
|
||||
elif self.logprobs_mode == "raw_logits":
|
||||
raw_logprobs = share_inputs["draft_logits"][:real_token_num, :].clone()
|
||||
|
||||
if sampling_metadata.token_ids_all is not None:
|
||||
token_ids_all = sampling_metadata.token_ids_all
|
||||
prompt_lens = sampling_metadata.prompt_lens
|
||||
else:
|
||||
token_ids_all = sampling_metadata.pre_token_ids
|
||||
prompt_lens = sampling_metadata.fake_prompt_lens
|
||||
|
||||
logits = apply_speculative_penalty_multi_scores(
|
||||
token_ids_all,
|
||||
prompt_lens,
|
||||
logits,
|
||||
sampling_metadata.repetition_penalties,
|
||||
sampling_metadata.frequency_penalties,
|
||||
sampling_metadata.presence_penalties,
|
||||
sampling_metadata.temperature,
|
||||
sampling_metadata.bad_words_token_ids,
|
||||
sampling_metadata.bad_words_token_len,
|
||||
sampling_metadata.step_idx,
|
||||
sampling_metadata.min_dec_lens,
|
||||
sampling_metadata.eos_token_ids,
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["batch_id_per_token_output"],
|
||||
share_inputs["cu_seqlens_q_output"],
|
||||
max_model_len,
|
||||
sampling_metadata.pre_token_ids,
|
||||
)
|
||||
probs = F.softmax(logits)
|
||||
|
||||
next_tokens = paddle.argmax(probs, axis=-1)
|
||||
next_tokens = paddle.argmax(logits, axis=-1)
|
||||
|
||||
token_ids = None
|
||||
logprobs_tensors = None
|
||||
|
||||
Reference in New Issue
Block a user