Files
FastDeploy/fastdeploy/spec_decode/mtp.py
T
GoldPancake 650d1e49aa [Cherry-Pick][Speculative Decoding] Add MTP logprob support for PD disaggregation (#7442) (#7464)
* support mtp logprob in pd

* fix

* fix

* fix

* fix xpu bugs
2026-04-17 21:37:42 +08:00

1321 lines
62 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import os
import time
from typing import TYPE_CHECKING, List
import numpy as np
import paddle
from paddleformers.utils.log import logger
from fastdeploy import envs
from fastdeploy.engine.request import Request, RequestType
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.layers.attention import get_attention_backend
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend,
)
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.sampler import MTPSampler
from fastdeploy.model_executor.model_loader import get_model_loader
from fastdeploy.model_executor.models import ModelForCasualLM
from fastdeploy.platforms import current_platform
if current_platform.is_xpu():
from fastdeploy.model_executor.ops.xpu import (
draft_model_postprocess,
draft_model_preprocess,
draft_model_update,
eagle_get_hidden_states,
eagle_get_self_hidden_states,
mtp_save_first_token,
mtp_step_paddle,
set_data_ipc,
share_external_data,
update_attn_mask_offsets,
)
from fastdeploy.model_executor.xpu_pre_and_post_process import (
xpu_pre_process,
xpu_process_output,
)
else:
from fastdeploy.model_executor.ops.gpu import (
draft_model_postprocess,
draft_model_preprocess,
draft_model_update,
eagle_get_hidden_states,
eagle_get_self_hidden_states,
eagle_gather_hidden_states,
hybrid_mtp_ngram,
mtp_step_paddle,
share_external_data,
speculate_get_logits,
speculate_save_output_topk,
update_attn_mask_offsets,
set_data_ipc,
unset_data_ipc,
)
from fastdeploy.model_executor.pre_and_post_process import async_set_value, pre_process
from fastdeploy.worker.input_batch import (
ProposerInputBatch,
recover_batch_index_for_output,
recover_batch_index_for_sampler_output,
reorder_split_prefill_and_decode_form_index_to_batch_id,
)
from .base import Proposer
if TYPE_CHECKING:
from fastdeploy.config import FDConfig
class MTPProposer(Proposer):
"""
Proposer for Multi-Token-Prediction(MTP)
"""
def __init__(
self,
fd_config: "FDConfig",
main_model: ModelForCasualLM,
local_rank: int,
device_id: int, # physical device id
target_model_inputs, # main model share inputs
):
super().__init__(fd_config)
self.num_main_model_layers = self.model_config.num_hidden_layers
self.local_rank = local_rank
self.device_id = device_id
self.use_attn_mask_offset = self.enable_mm
self._update_mtp_config(main_model)
self._load_model()
self.target_model_inputs = target_model_inputs
self.mtp_strategy = self.speculative_config.mtp_strategy
self.hybrid_mode = self.mtp_strategy == "with_ngram" and self.max_draft_token_num > self.num_model_steps
self.enable_logprob = self.model_config.enable_logprob
self.enable_draft_logprob = self.speculative_config.enable_draft_logprob
self.cache_kvs_map = {}
# [mixed, prefill, decoder]
self.role = self.scheduler_config.splitwise_role
self.pd_disaggregation_mode = fd_config.parallel_config.pd_disaggregation_mode
if current_platform.is_xpu():
self._prepare_inputs = self._prepare_inputs_xpu
self._propose = self._propose_xpu
elif current_platform.is_cuda() or current_platform.is_maca():
self._prepare_inputs = self._prepare_inputs_cuda
self._propose = self._propose_cuda
else:
raise RuntimeError(
f"Unsupported platform for MTP: {current_platform}. " f"Supported platforms: CUDA, MACA, XPU"
)
self.sampler = MTPSampler(fd_config)
self.model_inputs = ProposerInputBatch(self.fd_config, self.target_model_inputs)
self.model_inputs.init_share_inputs()
# CUDA Graph
self.draft_model_use_cudagraph = self.graph_opt_config.draft_model_use_cudagraph
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes
self.attn_backends: list[AttentionBackend] = []
self._initialize_attn_backend()
# Forward meta store the global meta information of the forward
self.forward_meta = None
self.exist_prefill_flag = False
def _update_mtp_config(self, main_model):
"""
Update config for MTP from global config
"""
self.forward_meta: ForwardMeta = None
self.model_config.architectures[0] = self.model_config.architectures[0].replace("Moe", "MTP")
self.speculative_config.sharing_model = main_model
# TODO (wangyanpeng): The number of MTP layers should be read from model config
self.model_config.num_hidden_layers = 1
self.model_config.model = self.speculative_config.model
if "Ernie" in self.model_config.architectures[0]:
self.model_config.pretrained_config.prefix_name = "ernie.mtp_block"
self.model_config.prefix_layer_name = "mtp_block"
if self.speculative_config.quantization != "":
self.model_config.quantization = self.speculative_config.quantization
self.model_config.start_layer_index = self.num_main_model_layers
self.speculative_config.model_type = "mtp"
if not self.use_attn_mask_offset:
self.model_config.causal = True
def _load_model(self):
"""
Load MTP Layer
"""
model_loader = get_model_loader(load_config=self.fd_config.load_config)
self.model = model_loader.load_model(fd_config=self.fd_config)
def dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int):
"""Set dummy prefill inputs to model_inputs"""
max_dec_len = expected_decode_len + 1
input_length = min(
num_tokens // batch_size,
self.model_config.max_model_len - max_dec_len,
)
# TODO(wanglongzhi): Figure out the accurate buffer size of DeepEP.
if self.fd_config.parallel_config.enable_expert_parallel:
input_length = min(input_length, 32)
block_num = (
input_length + self.cache_config.block_size - 1
) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
for i in range(batch_size):
idx = i
self.model_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
self.model_inputs["eos_token_id"][:] = np.array([2], dtype="int64").reshape(-1, 1)
self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = input_length
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = input_length
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = 0
self.model_inputs["step_idx"][idx : idx + 1] = 0
self.model_inputs["max_dec_len"][idx : idx + 1] = max_dec_len
self.model_inputs["stop_flags"][idx : idx + 1] = False
self.model_inputs["encoder_block_lens"][idx : idx + 1] = block_num
self.model_inputs["block_tables"][idx : idx + 1, :block_num] = np.arange(
idx * block_num, (idx + 1) * block_num, 1
)
self.model_inputs.seq_lens_this_time = self.model_inputs["seq_lens_this_time_buffer"]
def initialize_kv_cache(self, main_model_num_blocks, profile: bool = False):
"""
Initialize kv cache
"""
self.num_gpu_blocks = int(main_model_num_blocks * self.speculative_config.num_gpu_block_expand_ratio)
self.cache_kvs = {}
# Get kv cache dtype
cache_type = self.model_config.dtype
kv_cache_quant_type = None
if (
self.quant_config
and hasattr(self.quant_config, "kv_cache_quant_type")
and self.quant_config.kv_cache_quant_type is not None
):
cache_type = self._get_cache_type()
kv_cache_quant_type = self.quant_config.kv_cache_quant_type
# Get kv cache shape
key_cache_shape, value_cache_shape = self.attn_backends[0].get_kv_cache_shape(
max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
)
if kv_cache_quant_type == "block_wise_fp8":
kv_cache_scale_shape = [key_cache_shape[0], key_cache_shape[1], key_cache_shape[2]]
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
cache_ready_signal_data = np.zeros(shape=[self.parallel_config.tensor_parallel_size], dtype=np.int32)
cache_ready_signal = IPCSignal(
name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=self.parallel_config.local_engine_worker_queue_port,
create=False,
)
# Check if gpu runner needs to create kv cache
# 1. During profiling, it creates its own kv cache.
# 2. If no need to profile, create kv cache if cache managers do not exist.
create_cache_tensor = profile or not (
self.fd_config.cache_config.num_cpu_blocks > 0
or self.fd_config.cache_config.kvcache_storage_backend
or self.fd_config.scheduler_config.splitwise_role != "mixed"
)
if not create_cache_tensor:
logger.info(f"Waiting for cache managers to create kv cache.. {cache_ready_signal.value}")
while cache_ready_signal.value[local_rank] != 1:
time.sleep(1)
logger.info(f"OK! Stop waiting. {cache_ready_signal.value}")
logger.info(f"Initializing kv cache for all layers. {cache_ready_signal.value}")
if not create_cache_tensor:
cache_kvs_list = []
for i in range(
self.num_main_model_layers,
self.num_main_model_layers + self.model_config.num_hidden_layers,
):
logger.info(
f"..attaching kv cache for mtp layer {i}: key:{key_cache_shape}, value:{value_cache_shape}"
)
key_cache = paddle.empty(shape=[], dtype=cache_type)
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
key_cache = self._share_external_data(key_cache, key_cache_name, key_cache_shape)
self.cache_kvs_map[key_cache_name] = key_cache
cache_kvs_list.append(key_cache)
value_cache = paddle.empty(shape=[], dtype=cache_type)
value_cache = self._share_external_data(value_cache, val_cache_name, value_cache_shape)
self.cache_kvs_map[val_cache_name] = value_cache
cache_kvs_list.append(value_cache)
if kv_cache_quant_type == "block_wise_fp8":
scale_key_cache_name = f"key_cache_scales_{i}_rank{local_rank}.device{self.device_id}"
scale_val_cache_name = f"value_cache_scales_{i}_rank{local_rank}.device{self.device_id}"
key_scale_cache = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
key_scale_cache = self._share_external_data(
key_scale_cache, scale_key_cache_name, kv_cache_scale_shape
)
self.cache_kvs_map[scale_key_cache_name] = key_scale_cache
cache_kvs_list.append(key_scale_cache)
value_scale_cache = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
value_scale_cache = self._share_external_data(
value_scale_cache, scale_val_cache_name, kv_cache_scale_shape
)
self.cache_kvs_map[scale_val_cache_name] = value_scale_cache
cache_kvs_list.append(value_scale_cache)
self.model_inputs["caches"] = cache_kvs_list
else:
cache_kvs_list = []
for i in range(
self.num_main_model_layers,
self.num_main_model_layers + self.model_config.num_hidden_layers,
):
logger.info(f"..creating kv cache for mtp layer {i}: key:{key_cache_shape}, value:{value_cache_shape}")
key_cache = paddle.full(
shape=key_cache_shape,
fill_value=0,
dtype=cache_type,
)
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
set_data_ipc(key_cache, key_cache_name)
self.cache_kvs_map[key_cache_name] = key_cache
cache_kvs_list.append(key_cache)
val_cache = paddle.full(
shape=value_cache_shape,
fill_value=0,
dtype=cache_type,
)
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
set_data_ipc(val_cache, val_cache_name)
self.cache_kvs_map[val_cache_name] = val_cache
cache_kvs_list.append(val_cache)
if kv_cache_quant_type == "block_wise_fp8":
key_cache_scales = paddle.full(
shape=kv_cache_scale_shape,
fill_value=0,
dtype=paddle.get_default_dtype(),
)
key_cache_scales_name = f"key_cache_scales_{i}_rank{local_rank}.device{self.device_id}"
set_data_ipc(key_cache_scales, key_cache_scales_name)
self.cache_kvs_map[key_cache_scales_name] = key_cache_scales
cache_kvs_list.append(key_cache_scales)
val_cache_scales = paddle.full(
shape=kv_cache_scale_shape,
fill_value=0,
dtype=paddle.get_default_dtype(),
)
val_cache_scales_name = f"value_cache_scales_{i}_rank{local_rank}.device{self.device_id}"
set_data_ipc(val_cache_scales, val_cache_scales_name)
self.cache_kvs_map[val_cache_scales_name] = val_cache_scales
cache_kvs_list.append(val_cache_scales)
self.model_inputs["caches"] = cache_kvs_list
self._empty_cache()
def _initialize_attn_backend(
self,
) -> None:
"""
Initialize attention backends and forward metadata
"""
assert len(self.attn_backends) == 0
num_heads = self.model_config.num_attention_heads // self.parallel_config.tensor_parallel_size
self.model_config.kv_num_heads = max(
1,
int(self.model_config.num_key_value_heads) // self.parallel_config.tensor_parallel_size,
)
head_dim = self.model_config.head_dim
# Initialize AttentionBackend buffers
encoder_block_shape_q = 64
decoder_block_shape_q = 16
self.model_inputs["decoder_batch_ids"] = paddle.zeros_like(self.target_model_inputs["decoder_batch_ids"])
self.model_inputs["decoder_tile_ids_per_batch"] = paddle.zeros_like(
self.target_model_inputs["decoder_tile_ids_per_batch"]
)
if current_platform.is_xpu() or current_platform.is_maca():
self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like(
self.target_model_inputs["decoder_num_blocks_cpu"]
).cpu()
else:
self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like(
self.target_model_inputs["decoder_num_blocks_cpu"]
).pin_memory()
self.model_inputs["decoder_num_blocks_device"] = paddle.zeros_like(
self.target_model_inputs["decoder_num_blocks_device"]
)
self.model_inputs["decoder_chunk_size_device"] = paddle.zeros_like(
self.target_model_inputs["decoder_chunk_size_device"]
)
self.model_inputs["max_len_tensor_cpu"] = paddle.zeros_like(
self.target_model_inputs["max_len_tensor_cpu"]
).cpu()
self.model_inputs["encoder_batch_ids"] = paddle.zeros_like(self.target_model_inputs["encoder_batch_ids"])
self.model_inputs["encoder_tile_ids_per_batch"] = paddle.zeros_like(
self.target_model_inputs["encoder_tile_ids_per_batch"]
)
self.model_inputs["encoder_num_blocks_x_cpu"] = paddle.zeros_like(
self.target_model_inputs["encoder_num_blocks_x_cpu"]
).cpu()
self.model_inputs["kv_batch_ids"] = paddle.zeros_like(self.target_model_inputs["kv_batch_ids"])
self.model_inputs["kv_tile_ids_per_batch"] = paddle.zeros_like(
self.target_model_inputs["kv_tile_ids_per_batch"]
)
self.model_inputs["kv_num_blocks_x_cpu"] = paddle.zeros_like(
self.target_model_inputs["kv_num_blocks_x_cpu"]
).cpu()
# Get the attention backend
attn_cls = get_attention_backend()
attn_backend = attn_cls(
self.fd_config,
kv_num_heads=self.model_config.kv_num_heads,
num_heads=num_heads,
head_dim=head_dim,
encoder_block_shape_q=encoder_block_shape_q,
decoder_block_shape_q=decoder_block_shape_q,
)
if attn_backend is None:
raise NotImplementedError(
"Attention backend which you specified is not supported, please set FD_ATTENTION_BACKEND correctly."
)
self.attn_backends.append(attn_backend)
def clear_mtp_cache(self, profile=False):
"""
Clear allocated cacheKV
"""
create_cache_tensor = profile or not (
self.fd_config.cache_config.num_cpu_blocks > 0
or self.fd_config.cache_config.kvcache_storage_backend
or self.fd_config.scheduler_config.splitwise_role != "mixed"
)
if not create_cache_tensor:
for name, tensor in self.cache_kvs_map.items():
unset_data_ipc(tensor, name, True, False)
self.cache_kvs_map.clear()
del self.model_inputs["caches"]
if self.forward_meta is not None:
del self.forward_meta.caches
def update_mtp_block_num(self, num_gpu_blocks) -> None:
"""
Update MTP block num by theoretical calculation
"""
# Reset block table and kv cache with global block num
self.main_model_num_gpu_blocks = num_gpu_blocks
self.initialize_kv_cache(main_model_num_blocks=self.main_model_num_gpu_blocks)
# Reset free list
free_list = list(
range(
self.num_gpu_blocks - 1,
int(self.main_model_num_gpu_blocks * self.cache_config.kv_cache_ratio) - 1,
-1,
)
)
self.free_list_len = len(free_list)
self.model_inputs.update(
{
"free_list": paddle.to_tensor(free_list, dtype="int32"),
"free_list_len": paddle.full([1], self.free_list_len, dtype="int32"),
}
)
def insert_tasks_v1(
self, req_dicts: List[Request], num_running_requests: int, target_model_index_to_batch_id: dict = {}
):
if "caches" not in self.model_inputs:
self.initialize_kv_cache()
req_len = len(req_dicts)
self.model_inputs["num_running_requests"] = num_running_requests
self.model_inputs["running_requests_ids"] = range(num_running_requests)
if target_model_index_to_batch_id:
self.model_inputs.index_to_batch_id = dict(target_model_index_to_batch_id)
for i in range(req_len):
request = req_dicts[i]
logger.debug(f"{i}th request-{request.request_id}: {request}")
idx = self.model_inputs.get_index_by_batch_id(request.idx)
if request.task_type.value == RequestType.PREFILL.value: # prefill task
prefill_start_index = request.prefill_start_index
prefill_end_index = request.prefill_end_index
length = prefill_end_index - prefill_start_index
input_ids = request.prompt_token_ids + request.output_token_ids
self.model_inputs["input_ids_len"][idx] = length - 1
self.model_inputs["pre_ids"][idx : idx + 1] = -1
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs["input_ids"][
idx : idx + 1, 1:length
]
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = self.target_model_inputs[
"input_ids"
][idx : idx + 1, 1:length].cpu()
encoder_block_num = len(request.block_tables)
self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
request.block_tables, dtype="int32"
)
self.model_inputs["stop_flags"][idx : idx + 1] = False
self.model_inputs["batch_drop"][idx : idx + 1] = False
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = length
self.exist_prefill_flag = True
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index
self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = length
self.model_inputs["step_idx"][idx : idx + 1] = (
len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0
)
if self.use_attn_mask_offset:
inputs = request.multimodal_inputs
self.model_inputs["attn_mask_offsets_full"][idx][0 : prefill_end_index - prefill_start_index] = (
paddle.to_tensor(
inputs["attention_mask_offset"][prefill_start_index:prefill_end_index], dtype="int32"
)
)
# GPU don't need it anymore
# NOTE: XPU backend needs decoder attention mask offset; GPU backend does not use it
if current_platform.is_xpu():
self.model_inputs["attn_mask_offsets_decoder"][idx : idx + 1] = (
inputs["attention_mask_offset"][prefill_end_index - 1] + 1
)
if (
self.fd_config.scheduler_config.splitwise_role == "decode"
): # In PD, we continue to decode after P generates first token
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0
self.exist_prefill_flag = False
self.model_inputs["recompute_token_num"][idx : idx + 1] = 0
self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = length + 1
# NOTE(liuzichang):
# extra 1 : P-D split need rollback one step
self.model_inputs["mask_rollback"][idx : idx + 1] = 1
# has_prefill_task = True
elif request.task_type.value == RequestType.DECODE.value: # decode task
encoder_block_num = len(request.block_tables)
self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
if current_platform.is_cuda():
async_set_value(
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num], request.block_tables
)
else:
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
request.block_tables, dtype="int32"
)
# if self.model_inputs["is_block_step"][idx]: # has tasks to continue to decode
# has_decode_task = True
# continue
else:
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
self.model_inputs["stop_flags"][idx : idx + 1] = True
self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = 0
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = 0
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0
self.model_inputs["is_block_step"][idx : idx + 1] = False
continue
# TODO(liuzichang): Solve splitewise-p bug to restore
# self.model_inputs["seq_lens_this_time"] = self.model_inputs["seq_lens_this_time_buffer"][:num_running_requests]
self.model_inputs.seq_lens_this_time = self.model_inputs["seq_lens_this_time_buffer"]
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int):
"""
Process inputs for prefill tasks and insert it to model_inputs buffer
"""
# TODO:Init role in initialize process
if req_dicts[-1].disaggregate_info is not None:
if req_dicts[-1].disaggregate_info["role"] == "prefill":
self.role = "prefill"
os.environ["PREFILL_NODE_ONE_STEP_STOP"] = "1"
elif req_dicts[-1].disaggregate_info["role"] == "decode":
self.role = "decode"
else:
self.role = "mixed"
req_len = len(req_dicts)
for i in range(req_len):
request = req_dicts[i]
idx = request.idx
length = len(request.prompt_token_ids)
self.model_inputs.input_ids_len[idx] = length - 1
if req_dicts[i].disaggregate_info is not None and req_dicts[i].disaggregate_info["role"] == "decode":
length = len(request.prompt_token_ids)
if length > 1:
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs[
"input_ids"
][idx : idx + 1, 1:length]
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = np.array(
request.prompt_token_ids
)[1:]
self.model_inputs["pre_ids"][idx : idx + 1] = request.prompt_token_ids[-1]
prefill_token_num = self.max_draft_token_num + 1
self.model_inputs["draft_tokens"][idx : idx + 1, 0:1] = paddle.to_tensor(
request.draft_token_ids[1:2], dtype="int64"
)
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = length
self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = prefill_token_num
self.model_inputs["stop_flags"][idx : idx + 1] = False
self.model_inputs["batch_drop"][idx : idx + 1] = False
self.model_inputs["step_idx"][idx : idx + 1] = 1
encoder_block_num = len(request.block_tables)
self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
request.block_tables, dtype="int32"
)
else:
length = len(request.prompt_token_ids)
if length > 1:
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs[
"input_ids"
][idx : idx + 1, 1:length]
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = np.array(
request.prompt_token_ids
)[1:]
self.model_inputs["pre_ids"][idx : idx + 1] = -1
self.model_inputs["step_idx"][idx : idx + 1] = 0
if self.cache_config.enable_chunked_prefill:
token_chunk_size = request.prefill_chunk_info[0]
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size
self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = token_chunk_size
else:
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = length
self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = length
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
self.model_inputs["stop_flags"][idx : idx + 1] = False
self.model_inputs["batch_drop"][idx : idx + 1] = False
encoder_block_num = len(request.get("block_tables"))
self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
request.get("block_tables"), dtype="int32"
)
self.model_inputs.seq_lens_this_time = self.model_inputs["seq_lens_this_time_buffer"]
def _initialize_forward_meta(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, substep: int = 0):
"""
Initialize forward meta and attention meta data
"""
# Initialize forward meta
self.forward_meta = ForwardMeta(
ids_remove_padding=self.model_inputs["ids_remove_padding"],
rotary_embs=self.model_inputs["rope_emb"],
attn_backend=self.attn_backends[0],
decoder_batch_ids=self.model_inputs["decoder_batch_ids"],
decoder_tile_ids_per_batch=self.model_inputs["decoder_tile_ids_per_batch"],
decoder_num_blocks_cpu=self.model_inputs["decoder_num_blocks_cpu"],
decoder_num_blocks_device=self.model_inputs["decoder_num_blocks_device"],
decoder_chunk_size_device=self.model_inputs["decoder_chunk_size_device"],
max_len_tensor_cpu=self.model_inputs["max_len_tensor_cpu"],
seq_lens_encoder=self.model_inputs["seq_lens_encoder"],
seq_lens_decoder=self.model_inputs["seq_lens_decoder"],
seq_lens_this_time=self.model_inputs["seq_lens_this_time"],
batch_id_per_token=self.model_inputs["batch_id_per_token"],
cu_seqlens_q=self.model_inputs["cu_seqlens_q"],
cu_seqlens_k=self.model_inputs["cu_seqlens_k"],
block_tables=self.model_inputs["block_tables"],
caches=self.model_inputs["caches"],
encoder_batch_ids=self.model_inputs["encoder_batch_ids"],
encoder_tile_ids_per_batch=self.model_inputs["encoder_tile_ids_per_batch"],
encoder_num_blocks_x_cpu=self.model_inputs["encoder_num_blocks_x_cpu"],
kv_batch_ids=self.model_inputs["kv_batch_ids"],
kv_tile_ids_per_batch=self.model_inputs["kv_tile_ids_per_batch"],
kv_num_blocks_x_cpu=self.model_inputs["kv_num_blocks_x_cpu"],
attn_mask_offsets=self.model_inputs["attn_mask_offsets"] if self.use_attn_mask_offset else None,
)
# Initialzie attention meta data
for attn_backend in self.attn_backends:
attn_backend.init_attention_metadata(self.forward_meta)
# Notes(liuzichang):
# 1. CUDA Graph capture sizes must be recorded in descending order (large → small).
# 2. In multi-step execution, only the first step should be captured.
self.forward_meta.step_use_cudagraph = (
step_use_cudagraph and self.draft_model_use_cudagraph and not (substep > 0 and is_dummy_run)
)
def _initialize_forward_meta_xpu(self):
self.forward_meta.decoder_batch_ids = (self.model_inputs["decoder_batch_ids"],)
self.forward_meta.decoder_tile_ids_per_batch = (self.model_inputs["decoder_tile_ids_per_batch"],)
self.forward_meta.decoder_num_blocks_cpu = (self.model_inputs["decoder_num_blocks_cpu"],)
self.forward_meta.decoder_num_blocks_device = (self.model_inputs["decoder_num_blocks_device"],)
self.forward_meta.decoder_chunk_size_device = (self.model_inputs["decoder_chunk_size_device"],)
self.forward_meta.max_len_tensor_cpu = (self.model_inputs["max_len_tensor_cpu"],)
self.forward_meta.encoder_batch_ids = (self.model_inputs["encoder_batch_ids"],)
self.forward_meta.encoder_tile_ids_per_batch = (self.model_inputs["encoder_tile_ids_per_batch"],)
self.forward_meta.encoder_num_blocks_x_cpu = (self.model_inputs["encoder_num_blocks_x_cpu"],)
self.forward_meta.kv_batch_ids = (self.model_inputs["kv_batch_ids"],)
self.forward_meta.kv_tile_ids_per_batch = (self.model_inputs["kv_tile_ids_per_batch"],)
self.forward_meta.kv_num_blocks_x_cpu = (self.model_inputs["kv_num_blocks_x_cpu"],)
self.forward_meta.attn_backend = self.attn_backends[0]
if self.pd_disaggregation_mode == "per_chunk" or self.pd_disaggregation_mode == "per_query":
self.forward_meta.kv_signal_sender = self.target_model_inputs["kv_signal_sender"]
self.forward_meta.is_draft = True
# Initialzie attention meta data
for attn_backend in self.attn_backends:
attn_backend.init_attention_metadata(self.forward_meta)
def exist_prefill(self):
"""
check whether prefill stage exist
"""
return self.exist_prefill_flag
def _prepare_inputs_cuda(self, full_hidden_states):
"""
Prepare MTP inputs
MTP state (seq_lens_decoder, step_idx) is "shadow state":
- Initialized from target model state each round
- Used for MTP forward, but not committed until verify
- No rollback needed since it's always re-initialized
"""
draft_model_preprocess(
self.model_inputs["draft_tokens"],
self.model_inputs["input_ids"],
self.model_inputs["stop_flags"],
self.model_inputs["seq_lens_this_time"],
self.model_inputs["seq_lens_encoder"],
self.model_inputs["seq_lens_decoder"],
self.model_inputs["step_idx"],
self.model_inputs["not_need_stop_device"],
self.model_inputs["pre_ids"],
self.target_model_inputs["accept_tokens"],
self.target_model_inputs["accept_num"],
self.target_model_inputs["seq_lens_encoder"],
self.target_model_inputs["seq_lens_decoder"],
self.target_model_inputs["step_idx"],
self.target_model_inputs["stop_flags"],
self.model_inputs["max_dec_len"],
self.target_model_inputs["draft_tokens"],
self.num_model_steps,
self.role == "prefill", # is_splitwise_prefill
)
target_hidden_states, _ = eagle_get_hidden_states(
full_hidden_states,
self.model_inputs["seq_lens_this_time"],
self.model_inputs["seq_lens_encoder"],
self.model_inputs["seq_lens_decoder"],
self.model_inputs["stop_flags"],
self.target_model_inputs["accept_num"],
self.target_model_inputs["seq_lens_this_time"],
self.target_model_inputs["seq_lens_encoder"],
self.num_model_steps,
)
self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False)
def _prepare_inputs_xpu(self, full_hidden_states):
use_v1_cache_scheduler = bool(envs.ENABLE_V1_KVCACHE_SCHEDULER)
draft_model_preprocess(
self.model_inputs["draft_tokens"],
self.model_inputs["input_ids"],
self.model_inputs["stop_flags"],
self.model_inputs["seq_lens_this_time"],
self.model_inputs["seq_lens_encoder"],
self.model_inputs["seq_lens_decoder"],
self.model_inputs["step_idx"],
self.model_inputs["not_need_stop"],
self.model_inputs["batch_drop"],
self.model_inputs["is_block_step"],
self.model_inputs["pre_ids"],
self.model_inputs["mask_rollback"],
self.model_inputs["recompute_token_num"],
self.target_model_inputs["accept_tokens"],
self.target_model_inputs["accept_num"],
self.target_model_inputs["seq_lens_this_time"],
self.target_model_inputs["seq_lens_encoder"],
self.target_model_inputs["seq_lens_decoder"],
self.target_model_inputs["step_idx"],
self.target_model_inputs["stop_flags"],
self.target_model_inputs["is_block_step"],
self.target_model_inputs["draft_tokens"],
self.num_model_steps,
True,
self.role == "prefill",
use_v1_cache_scheduler,
)
target_hidden_states = eagle_get_hidden_states(
full_hidden_states,
self.model_inputs["seq_lens_this_time"],
self.model_inputs["seq_lens_encoder"],
self.model_inputs["seq_lens_decoder"],
self.model_inputs["stop_flags"],
self.target_model_inputs["accept_num"],
self.target_model_inputs["seq_lens_this_time"],
self.target_model_inputs["seq_lens_encoder"],
self.num_model_steps,
)
self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False)
def _post_process(self, sampled_token_ids):
"""
PostProcess for generation
"""
draft_model_update(
sampled_token_ids,
self.model_inputs["draft_tokens"],
self.model_inputs["pre_ids"],
self.model_inputs["seq_lens_this_time"],
self.model_inputs["seq_lens_encoder"],
self.model_inputs["seq_lens_decoder"],
self.model_inputs["step_idx"],
# Note(ZKK):
# I strongly advise xpu student delete the fuck `output_cum_offsets` name in XPU backend
# like my pr https://github.com/PaddlePaddle/FastDeploy/pull/6358
self.model_inputs["cu_seqlens_q_output"],
self.model_inputs["stop_flags"],
(
self.model_inputs["not_need_stop_device"]
if current_platform.is_cuda()
else self.model_inputs["not_need_stop"]
),
self.model_inputs["max_dec_len"],
self.model_inputs["eos_token_id"],
self.model_inputs["base_model_draft_tokens"],
self.max_model_len,
self.model_inputs["substep"],
)
if self.role == "prefill" and self.parallel_config.tensor_parallel_rank == 0:
if current_platform.is_xpu():
# Note(wangyanpeng): mtp_save_first_token for GPU platforms has been moved to model_runner.
# Only XPU platform is retained here.
skip_save = bool(int(envs.ENABLE_V1_KVCACHE_SCHEDULER))
recover_model_output_map = recover_batch_index_for_output(
self.model_inputs,
self.model_inputs.index_to_batch_id,
self.model_inputs.enable_pd_reorder,
["base_model_draft_tokens", "seq_lens_decoder", "prompt_lens", "step_idx"],
)
mtp_save_first_token(
recover_model_output_map["base_model_draft_tokens"],
self.model_inputs["not_need_stop"],
recover_model_output_map["seq_lens_decoder"],
recover_model_output_map["prompt_lens"],
recover_model_output_map["step_idx"],
self.local_rank,
self.parallel_config.use_ep,
skip_save,
)
# Ensure only save first token once.
paddle.assign(
paddle.where(
self.model_inputs["stop_flags"],
paddle.zeros_like(self.model_inputs["step_idx"]),
self.model_inputs["step_idx"],
),
self.model_inputs["step_idx"],
)
def _propose_cuda(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, real_bsz: int = 0):
"""
Main process for MTP inference.
Args:
step_use_cudagraph: bool
Whether to use cuda graph. Use the target model flag to avoid hanging problems with EP.
"""
is_blocking = (
(not self.fd_config.scheduler_config.enable_overlap_schedule)
or is_dummy_run
or self.exist_prefill()
or real_bsz == 0
)
for substep in range(self.num_model_steps):
if is_blocking:
token_num_cpu = self.model_inputs["seq_lens_this_time"].numpy().sum().item()
else:
if substep == 0:
token_num_cpu = real_bsz * (self.max_draft_token_num + 1)
else:
token_num_cpu = real_bsz
if token_num_cpu > 0:
self.model_inputs["substep"] = substep
# Remove padding
(
ids_remove_padding,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
cu_seqlens_q_output,
batch_id_per_token_output,
real_output_token_num,
) = pre_process(
token_num_cpu,
self.model_inputs["input_ids"],
self.model_inputs["seq_lens_this_time"],
True,
self.model_inputs["draft_tokens"],
self.model_inputs["seq_lens_encoder"],
self.model_inputs["seq_lens_decoder"],
)
if self.use_attn_mask_offset:
attn_mask_offsets = update_attn_mask_offsets(
ids_remove_padding,
getattr(
self.model_inputs, "seq_lens_this_time", self.model_inputs["seq_lens_this_time_buffer"]
),
self.model_inputs["seq_lens_encoder"],
self.model_inputs["seq_lens_decoder"],
cu_seqlens_q,
self.model_inputs["attn_mask_offsets_full"],
self.model_inputs["is_block_step"],
self.model_inputs["decode_states"],
)
self.model_inputs["attn_mask_offsets"].copy_(attn_mask_offsets, False)
# Initialize forward meta data
self.model_inputs["ids_remove_padding"].copy_(ids_remove_padding, False)
self.model_inputs["batch_id_per_token"][:] = -1
self.model_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False)
self.model_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False)
# For speculative decoding
self.model_inputs["cu_seqlens_q_output"].copy_(cu_seqlens_q_output, False)
self.model_inputs["batch_id_per_token_output"].copy_(batch_id_per_token_output, False)
# Initialize forward meta data
self._initialize_forward_meta(
step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run, substep=substep
)
self.forward_meta.batch_id_per_token.copy_(batch_id_per_token, False)
self.forward_meta.real_bsz = real_bsz
# Padding inputs for cuda graph
self.padding_cudagraph_inputs()
# Get sampling metadata
self.sampling_metadata = SamplingMetadata(
temperature=self.model_inputs["temperature"],
top_p=self.model_inputs["top_p"],
top_k=self.model_inputs["top_k"],
seed=self.model_inputs["infer_seed"],
step_idx=self.model_inputs["step_idx"],
token_ids_all=self.model_inputs["token_ids_all"],
pre_token_ids=self.model_inputs["pre_ids"],
prompt_lens=self.model_inputs["prompt_lens"],
fake_prompt_lens=self.model_inputs["fake_prompt_lens"],
frequency_penalties=self.model_inputs["frequency_score"],
presence_penalties=self.model_inputs["presence_score"],
repetition_penalties=self.model_inputs["penalty_score"],
min_dec_lens=self.model_inputs["min_dec_len"],
bad_words_token_ids=self.model_inputs["bad_tokens"],
bad_words_token_len=self.model_inputs["bad_tokens_len"],
eos_token_ids=self.model_inputs["eos_token_id"],
max_num_logprobs=20 if self.enable_logprob else None,
temp_scaled_logprobs=self.model_inputs["temp_scaled_logprobs"],
top_p_normalized_logprobs=self.model_inputs["top_p_normalized_logprobs"],
share_inputs=self.model_inputs,
)
real_num = self.model_inputs["ids_remove_padding"].shape[0]
target_hidden_states = self.model_inputs["target_hidden_states"][:real_num]
model_output = self.model(
ids_remove_padding=self.model_inputs["ids_remove_padding"],
previous_hidden_states=target_hidden_states,
forward_meta=self.forward_meta,
)
if self.forward_meta.step_use_cudagraph:
model_output = model_output[: self.real_token_num]
hidden_states, _ = eagle_gather_hidden_states(
model_output,
self.model_inputs["cu_seqlens_q"],
self.model_inputs["seq_lens_this_time"],
self.model_inputs["seq_lens_decoder"],
self.model_inputs["seq_lens_encoder"],
self.model_inputs["batch_id_per_token_output"],
self.model_inputs["cu_seqlens_q_output"],
real_output_token_num,
)
# 4. Compute logits, Sample
logits = self.model.compute_logits(hidden_states, forward_meta=self.forward_meta)
if self.enable_logprob and self.enable_draft_logprob and substep == 0:
first_token_logits = self.model.compute_logits(
self.model_inputs["first_token_hidden_states"], forward_meta=self.forward_meta
)
speculate_get_logits(
self.model_inputs["draft_logits"],
self.model_inputs["next_token_num"],
self.model_inputs["batch_token_num"],
self.model_inputs["cu_next_token_offset"],
self.model_inputs["cu_batch_token_offset"],
logits,
first_token_logits,
self.model_inputs["seq_lens_this_time"],
self.model_inputs["seq_lens_encoder"],
)
sampled_token_ids, sampler_output = self.sampler(
logits,
self.sampling_metadata,
self.max_model_len,
self.model_inputs,
)
if (
not is_dummy_run
and self.parallel_config.tensor_parallel_rank == 0
and substep == 0
and sampler_output.logprobs_tensors is not None
):
real_bsz = self.model_inputs["seq_lens_this_time"].shape[0]
recover_batch_index_for_sampler_output(sampler_output, self.model_inputs.index_to_batch_id)
recover_model_output_map = recover_batch_index_for_output(
self.model_inputs,
self.model_inputs.index_to_batch_id,
self.model_inputs.enable_pd_reorder,
["batch_token_num", "cu_batch_token_offset", "seq_lens_decoder", "prompt_lens"],
)
speculate_save_output_topk(
sampler_output.sampled_token_ids,
sampler_output.logprobs_tensors.logprob_token_ids,
sampler_output.logprobs_tensors.logprobs,
sampler_output.logprobs_tensors.selected_token_ranks,
recover_model_output_map["batch_token_num"][:real_bsz],
recover_model_output_map["cu_batch_token_offset"][:real_bsz],
self.model_inputs["not_need_stop"],
recover_model_output_map["seq_lens_decoder"],
recover_model_output_map["prompt_lens"],
4, # mtype
self.local_rank,
self.parallel_config.use_ep,
)
if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.broadcast(
sampled_token_ids,
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
group=self.parallel_config.tp_group,
)
self._post_process(sampled_token_ids)
self.model_inputs["target_hidden_states"].copy_(hidden_states, False)
else:
if hasattr(self.model, "empty_input_forward") and not is_dummy_run:
self.model.empty_input_forward(forward_meta=self.forward_meta)
self.exist_prefill_flag = False
def _propose_xpu(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, real_bsz: int = 0):
"""
Main process for MTP inference.
Args:
step_use_cudagraph: bool
Whether to use cuda graph. Use the target model flag to avoid hanging problems with EP.
"""
for substep in range(self.num_model_steps):
if self.model_inputs["not_need_stop"]:
self.model_inputs["substep"] = substep
# Remove padding
self.forward_meta = xpu_pre_process(
self.model_inputs["input_ids"],
self.model_inputs["seq_lens_this_time"],
self.model_inputs,
True,
self.cache_config.block_size,
self.model_inputs["draft_tokens"],
self.model_inputs["seq_lens_encoder"],
self.model_inputs["seq_lens_decoder"],
)
if self.enable_mm:
attn_mask_offsets = update_attn_mask_offsets(
self.model_inputs["ids_remove_padding"],
getattr(
self.model_inputs, "seq_lens_this_time", self.model_inputs["seq_lens_this_time_buffer"]
),
self.model_inputs["seq_lens_encoder"],
self.model_inputs["seq_lens_decoder"],
self.model_inputs["cu_seqlens_q"],
self.model_inputs["attn_mask_offsets_full"],
self.model_inputs["attn_mask_offsets_decoder"],
self.model_inputs["is_block_step"],
self.model_inputs["decode_states"],
self.model_inputs["mask_rollback"],
)
self.model_inputs["attn_mask_offsets"].copy_(attn_mask_offsets, False)
self._initialize_forward_meta_xpu()
# Get sampling metadata
self.sampling_metadata = SamplingMetadata(
temperature=self.model_inputs["temperature"],
top_p=self.model_inputs["top_p"],
top_k=self.model_inputs["top_k"],
seed=self.model_inputs["infer_seed"],
step_idx=self.model_inputs["step_idx"],
pre_token_ids=self.model_inputs["pre_ids"],
frequency_penalties=self.model_inputs["frequency_score"],
presence_penalties=self.model_inputs["presence_score"],
repetition_penalties=self.model_inputs["penalty_score"],
min_dec_lens=self.model_inputs["min_dec_len"],
bad_words_token_ids=self.model_inputs["bad_tokens"],
eos_token_ids=self.model_inputs["eos_token_id"],
max_num_logprobs=20 if self.enable_logprob else None,
temp_scaled_logprobs=self.model_inputs["temp_scaled_logprobs"],
top_p_normalized_logprobs=self.model_inputs["top_p_normalized_logprobs"],
share_inputs=self.model_inputs,
)
if self.num_model_steps > 1:
self.model_inputs.last_seq_lens_this_time = paddle.clone(self.model_inputs["seq_lens_this_time"])
model_output = self.model(
ids_remove_padding=self.model_inputs["ids_remove_padding"],
previous_hidden_states=self.model_inputs["target_hidden_states"],
forward_meta=self.forward_meta,
)
hidden_states = xpu_process_output(model_output, self.forward_meta, self.model_inputs)
# 4. Compute logits, Sample
logits = self.model.compute_logits(hidden_states, forward_meta=self.forward_meta)
sampled_token_ids, sampler_output = self.sampler(
logits,
self.sampling_metadata,
self.max_model_len,
self.model_inputs,
)
if substep == 0 and sampler_output.logprobs_tensors is not None:
real_bsz = self.model_inputs["seq_lens_this_time"].shape[0]
recover_batch_index_for_sampler_output(sampler_output, self.model_inputs.index_to_batch_id)
recover_model_output_map = recover_batch_index_for_output(
self.model_inputs,
self.model_inputs.index_to_batch_id,
self.model_inputs.enable_pd_reorder,
["batch_token_num", "cu_batch_token_offset"],
)
speculate_save_output_topk(
sampler_output.sampled_token_ids,
sampler_output.logprobs_tensors.logprob_token_ids,
sampler_output.logprobs_tensors.logprobs,
sampler_output.logprobs_tensors.selected_token_ranks,
recover_model_output_map["batch_token_num"][:real_bsz],
recover_model_output_map["cu_batch_token_offset"][:real_bsz],
self.model_inputs["not_need_stop"],
4, # mtype
self.local_rank,
)
if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.broadcast(
sampled_token_ids,
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
group=self.parallel_config.tp_group,
)
self._post_process(sampled_token_ids)
if substep != self.num_model_steps - 1:
self._get_self_hidden_states_xpu(hidden_states)
else:
if hasattr(self.model, "empty_input_forward") and not is_dummy_run:
self.model.empty_input_forward(self.forward_meta)
def _get_self_hidden_states_xpu(self, hidden_states):
target_hidden_states = eagle_get_self_hidden_states(
hidden_states,
self.model_inputs.last_seq_lens_this_time,
self.model_inputs["seq_lens_this_time"],
self.model_inputs["step_idx"],
)
self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False)
def update_task_chunk_prefill(self, task):
"""
Update single task's chunk_prefill info
"""
idx = self.model_inputs.get_index_by_batch_id(task.idx)
start_idx = sum(task.prefill_chunk_info[: task.chunk_idx])
if task.chunk_idx == len(task.prefill_chunk_info):
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0
self.model_inputs["step_idx"][idx : idx + 1] = 1
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = start_idx + task.get("seq_lens_decoder", 0)
else:
token_chunk_size = task.prefill_chunk_info[task.chunk_idx]
if task.chunk_idx < len(task.prefill_chunk_info) - 1:
self.model_inputs["input_ids"][idx, :token_chunk_size] = np.array(
task.prompt_token_ids[start_idx + 1 : start_idx + token_chunk_size + 1]
)
# Last prefill
else:
self.model_inputs["input_ids"][idx, : token_chunk_size - 1] = np.array(
task.prompt_token_ids[start_idx + 1 : start_idx + token_chunk_size]
)
self.model_inputs["seq_lens_this_time"][idx : idx + 1] = token_chunk_size
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size
self.model_inputs["step_idx"][idx : idx + 1] = 0
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = start_idx + task.get("seq_lens_decoder", 0)
def _update_status(self):
"""
Update main-model's forward info in next step.
Allocate/Free block of MPT.
"""
draft_model_postprocess(
self.target_model_inputs["draft_tokens"],
self.target_model_inputs["seq_lens_this_time"],
self.target_model_inputs["seq_lens_encoder"],
self.target_model_inputs["stop_flags"],
)
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
mtp_step_paddle(
self.target_model_inputs["stop_flags"],
self.model_inputs["stop_flags"],
self.model_inputs["batch_drop"],
self.model_inputs["seq_lens_this_time"],
self.model_inputs["seq_lens_encoder"],
self.model_inputs["seq_lens_decoder"],
self.model_inputs["block_tables"],
self.model_inputs["encoder_block_lens"],
self.model_inputs["used_list_len"],
self.model_inputs["free_list"],
self.model_inputs["free_list_len"],
self.cache_config.block_size,
self.max_draft_token_num,
)
def _extend_draft_token_with_ngram_match(self):
# TODO(liuzichang): Optimize this Kernel to CUDA Kernel to reduce lantency
device = paddle.CUDAPinnedPlace()
draft_tokens = self.target_model_inputs["draft_tokens"].cpu()
seq_lens_this_time = self.target_model_inputs["seq_lens_this_time"].cpu()
seq_lens_decoder = self.model_inputs["seq_lens_decoder"].cpu()
hybrid_mtp_ngram(
self.model_inputs["input_ids_cpu"],
self.model_inputs["input_ids_len"],
self.model_inputs["pre_ids"]._copy_to(device, True),
self.model_inputs["step_idx"].cpu(),
self.target_model_inputs["actual_draft_token_num"].cpu(),
draft_tokens,
seq_lens_this_time,
seq_lens_decoder,
self.model_inputs["max_dec_len"].cpu(),
self.max_ngram_size,
self.min_ngram_size,
self.max_draft_token_num,
)
self.target_model_inputs["draft_tokens"][:] = draft_tokens.cuda()
self.target_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda()
def _run_impl(
self,
full_hidden_states: paddle.Tensor,
step_use_cudagraph: bool = False,
is_dummy_run: bool = False,
real_bsz: int = 0,
):
"""Execute Draft Model"""
self._prepare_inputs(full_hidden_states)
self._propose(step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run, real_bsz=real_bsz)
self._update_status()
if self.hybrid_mode:
self._extend_draft_token_with_ngram_match()
def is_chunk_prefill_enabled(self):
""""""
return True
def padding_cudagraph_inputs(self) -> None:
"""
Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch.
In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch.
"""
# In init_attention_metadata, the decode buffer has already been cleared
# To adapt to CUDA Graph, keep the forward pass at the maximum batch size.
if self.forward_meta.step_use_cudagraph:
self.forward_meta.seq_lens_this_time = self.model_inputs["seq_lens_this_time_buffer"]
self.real_token_num = self.forward_meta.ids_remove_padding.shape[0]
return
def _empty_cache(self):
if current_platform.is_cuda():
paddle.device.cuda.empty_cache()
elif current_platform.is_xpu():
paddle.device.xpu.empty_cache()
else:
paddle.device.empty_cache()
def _get_cache_type(self):
cache_type = None
if current_platform.is_cuda():
cache_type = "uint8"
elif current_platform.is_xpu():
cache_type = "int8"
else:
raise NotImplementedError
return cache_type
def reorder_inputs(self, target_model_input_batch):
"""
Reorder inputs to split prefill and decode.
"""
reorder_split_prefill_and_decode_form_index_to_batch_id(self.model_inputs, target_model_input_batch)
def _share_external_data(self, cache, cache_name, cache_shape):
if current_platform.is_xpu():
return share_external_data(cache, cache_name, cache_shape, False)
else:
return share_external_data(cache, cache_name, cache_shape)