mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] Support mtp overlap schedule (#7001)
This commit is contained in:
@@ -137,6 +137,7 @@ __global__ void apply_token_enforce_generation_scores_kernel(
|
||||
int tid = threadIdx.x;
|
||||
|
||||
const int bs_idx = batch_id_per_token_output[token_idx];
|
||||
if (bs_idx < 0) return;
|
||||
const int query_start_token_idx = cu_seqlens_q_output[bs_idx];
|
||||
bool is_batch_first_token = (token_idx == query_start_token_idx);
|
||||
|
||||
|
||||
@@ -62,6 +62,7 @@ __global__ void RebuildAppendPaddingKernel(T *output_data,
|
||||
i += gridDim.x * blockDim.x * VecSize) {
|
||||
const int out_token_id = i / dim_embed;
|
||||
const int bi = batch_id_per_token_output[out_token_id];
|
||||
if (bi < 0) continue;
|
||||
if (seq_len_this_time[bi] == 0) continue;
|
||||
if (seq_len_decoder[bi] == 0 && seq_len_encoder[bi] == 0) continue;
|
||||
|
||||
|
||||
@@ -176,8 +176,6 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
int pre_ids_len = pre_ids.shape()[1];
|
||||
auto cu_stream = seq_lens_this_time.stream();
|
||||
int target_model_draft_tokens_len = target_model_draft_tokens.shape()[1];
|
||||
auto not_need_stop_gpu =
|
||||
not_need_stop.copy_to(seq_lens_this_time.place(), false);
|
||||
|
||||
draft_model_preprocess_kernel<kBlockSize><<<1, kBlockSize, 0, cu_stream>>>(
|
||||
const_cast<int64_t*>(draft_tokens.data<int64_t>()),
|
||||
@@ -187,7 +185,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<bool*>(not_need_stop.data<bool>()),
|
||||
const_cast<int64_t*>(pre_ids.data<int64_t>()),
|
||||
accept_tokens.data<int64_t>(),
|
||||
accept_num.data<int>(),
|
||||
@@ -205,10 +203,6 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
target_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
is_splitwise_prefill);
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
|
||||
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(draft_model_preprocess)
|
||||
|
||||
@@ -123,8 +123,6 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
|
||||
auto seq_lens_this_time_shape = seq_lens_this_time.shape();
|
||||
auto cu_stream = seq_lens_this_time.stream();
|
||||
const int real_bsz = seq_lens_this_time_shape[0];
|
||||
auto not_need_stop_gpu =
|
||||
not_need_stop.copy_to(seq_lens_this_time.place(), false);
|
||||
const int end_ids_len = end_ids.shape()[0];
|
||||
const int max_draft_token = draft_tokens.shape()[1];
|
||||
const int pre_id_length = pre_ids.shape()[1];
|
||||
@@ -149,7 +147,7 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
cu_seqlens_q_output.data<int>(),
|
||||
const_cast<bool*>(stop_flags.data<bool>()),
|
||||
not_need_stop_gpu.data<bool>(),
|
||||
const_cast<bool*>(not_need_stop.data<bool>()),
|
||||
max_dec_len.data<int64_t>(),
|
||||
end_ids.data<int64_t>(),
|
||||
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
|
||||
@@ -161,11 +159,6 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
|
||||
max_seq_len,
|
||||
substep,
|
||||
prefill_one_step_stop);
|
||||
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
|
||||
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(draft_model_update)
|
||||
|
||||
@@ -30,6 +30,7 @@ __global__ inline void min_length_logits_process(
|
||||
const int token_idx = threadIdx.x;
|
||||
if (token_idx >= token_num) return;
|
||||
const int bi = batch_id_per_token_output[token_idx];
|
||||
if (bi < 0) return;
|
||||
if (bi >= bs) return;
|
||||
const int query_start_token_idx = cu_seqlens_q_output[bi];
|
||||
|
||||
@@ -59,6 +60,7 @@ __global__ inline void min_length_logits_process<half>(
|
||||
const int token_idx = threadIdx.x;
|
||||
if (token_idx >= token_num) return;
|
||||
const int bi = batch_id_per_token_output[token_idx];
|
||||
if (bi < 0) return;
|
||||
if (bi >= bs) return;
|
||||
const int query_start_token_idx = cu_seqlens_q_output[bi];
|
||||
|
||||
@@ -85,6 +87,7 @@ __global__ void update_repeat_times(const int64_t *token_ids_all,
|
||||
const int token_idx = blockIdx.x;
|
||||
if (token_idx >= token_num) return;
|
||||
const int bi = batch_id_per_token_output[token_idx];
|
||||
if (bi < 0) return;
|
||||
if (bi >= bs) return;
|
||||
if (cur_len[bi] < 0) {
|
||||
return;
|
||||
@@ -115,6 +118,7 @@ __global__ void update_value_by_repeat_times(
|
||||
const int token_idx = blockIdx.x;
|
||||
if (token_idx >= token_num) return;
|
||||
const int bi = batch_id_per_token_output[token_idx];
|
||||
if (bi < 0) return;
|
||||
if (bi >= bs) return;
|
||||
int tid = threadIdx.x;
|
||||
T *logits_now = logits + token_idx * length;
|
||||
@@ -146,7 +150,7 @@ __global__ void ban_bad_words(T *logits,
|
||||
const int token_idx = blockIdx.x;
|
||||
if (token_idx >= token_num) return;
|
||||
const int bi = batch_id_per_token_output[token_idx];
|
||||
|
||||
if (bi < 0) return;
|
||||
if (bi >= bs) return;
|
||||
int tid = threadIdx.x;
|
||||
T *logits_now = logits + token_idx * length;
|
||||
|
||||
@@ -146,10 +146,10 @@ std::vector<paddle::Tensor> SpeculatePreProcess(
|
||||
const int bsz = seq_len.shape()[0];
|
||||
const int max_seq_len = input_ids_shape[1];
|
||||
const int token_num_data = cpu_token_num;
|
||||
auto ids_remove_padding = paddle::empty(
|
||||
{token_num_data}, paddle::DataType::INT64, input_ids.place());
|
||||
auto batch_id_per_token = paddle::empty(
|
||||
{token_num_data}, paddle::DataType::INT32, input_ids.place());
|
||||
auto ids_remove_padding = paddle::full(
|
||||
{token_num_data}, 2, paddle::DataType::INT64, input_ids.place());
|
||||
auto batch_id_per_token = paddle::full(
|
||||
{token_num_data}, -1, paddle::DataType::INT32, input_ids.place());
|
||||
auto cu_seqlens_q =
|
||||
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
|
||||
auto cu_seqlens_k =
|
||||
@@ -170,9 +170,10 @@ std::vector<paddle::Tensor> SpeculatePreProcess(
|
||||
auto cu_seq_lens_q_output =
|
||||
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
|
||||
auto batch_id_per_token_output =
|
||||
paddle::empty({bsz * max_draft_tokens_per_batch},
|
||||
paddle::DataType::INT32,
|
||||
input_ids.place());
|
||||
paddle::full({bsz * max_draft_tokens_per_batch},
|
||||
-1,
|
||||
paddle::DataType::INT32,
|
||||
input_ids.place());
|
||||
auto real_output_token_num =
|
||||
paddle::empty({1}, paddle::DataType::INT32, input_ids.place());
|
||||
|
||||
|
||||
@@ -134,7 +134,6 @@ void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
|
||||
prefill_one_step_stop = true;
|
||||
}
|
||||
}
|
||||
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
|
||||
speculate_schedula_cache<BlockSize>
|
||||
<<<1, BlockSize, 0, seq_lens_this_time.stream()>>>(
|
||||
draft_tokens.data<int64_t>(),
|
||||
@@ -150,7 +149,7 @@ void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
|
||||
const_cast<int *>(accept_num.data<int>()),
|
||||
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<bool *>(not_need_stop.data<bool>()),
|
||||
real_bsz,
|
||||
max_bsz,
|
||||
max_next_step_tokens,
|
||||
@@ -159,11 +158,6 @@ void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
|
||||
block_size,
|
||||
block_num_per_seq,
|
||||
prefill_one_step_stop);
|
||||
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), true);
|
||||
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_schedule_cache)
|
||||
|
||||
@@ -419,6 +419,7 @@ __global__ void KeMatrixTopPBeamTopKFt(
|
||||
const int lane = tid % 32;
|
||||
const int token_id = blockIdx.x;
|
||||
const int bid = batch_id_per_token_output[token_id];
|
||||
if (bid < 0) return;
|
||||
|
||||
int top_num = TopPBeamTopK;
|
||||
float top_p_value = static_cast<float>(top_ps[bid]);
|
||||
|
||||
@@ -190,13 +190,11 @@ void UnifiedUpdateModelStatus(const paddle::Tensor &seq_lens_encoder,
|
||||
constexpr int BlockSize = 1024;
|
||||
|
||||
// has_running_seqs is CPU tensor, need to copy to GPU first
|
||||
auto has_running_seqs_gpu =
|
||||
has_running_seqs.copy_to(seq_lens_this_time.place(), false);
|
||||
unified_update_model_status_kernel<BlockSize>
|
||||
<<<1, BlockSize, 0, seq_lens_this_time.stream()>>>(
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<bool *>(has_running_seqs_gpu.data<bool>()),
|
||||
const_cast<bool *>(has_running_seqs.data<bool>()),
|
||||
const_cast<int64_t *>(step_input_ids.data<int64_t>()),
|
||||
const_cast<int64_t *>(step_output_ids.data<int64_t>()),
|
||||
const_cast<int *>(step_output_len.data<int>()),
|
||||
@@ -213,11 +211,6 @@ void UnifiedUpdateModelStatus(const paddle::Tensor &seq_lens_encoder,
|
||||
max_step_tokens,
|
||||
max_model_len,
|
||||
num_end_tokens);
|
||||
// Copy result back to CPU
|
||||
auto has_running_seqs_cpu =
|
||||
has_running_seqs_gpu.copy_to(has_running_seqs.place(), false);
|
||||
bool *out_data = const_cast<bool *>(has_running_seqs.data<bool>());
|
||||
out_data[0] = has_running_seqs_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(unified_update_model_status)
|
||||
|
||||
@@ -573,7 +573,7 @@ class EngineArgs:
|
||||
self.enable_prefix_caching = False
|
||||
if (
|
||||
not current_platform.is_cuda()
|
||||
or self.speculative_config is not None
|
||||
or (self.speculative_config is not None and self.enable_logprob)
|
||||
or self.splitwise_role == "prefill"
|
||||
or self.dynamic_load_weight
|
||||
):
|
||||
|
||||
@@ -160,6 +160,8 @@ class ForwardMeta:
|
||||
|
||||
position_ids: Optional[paddle.Tensor] = None
|
||||
|
||||
real_bsz: int = 0
|
||||
|
||||
def clear_caches(self):
|
||||
"""Safely clean up the caches"""
|
||||
if self.caches:
|
||||
|
||||
@@ -155,12 +155,15 @@ class CudaGraphPiecewiseBackend:
|
||||
|
||||
def __call__(self, **kwargs) -> List[paddle.Tensor] | paddle.Tensor:
|
||||
# Get real shape (total num tokens)
|
||||
ids_remove_padding: paddle.Tensor = kwargs["forward_meta"].ids_remove_padding
|
||||
real_shape = ids_remove_padding.shape[0]
|
||||
if self.speculative_decoding and all(self.real_bsz_to_captured_size.values()):
|
||||
seq_lens_this_time: paddle.Tensor = kwargs["forward_meta"].seq_lens_this_time
|
||||
num_running_requests = int((seq_lens_this_time.flatten() > 0).sum().item())
|
||||
real_bsz = kwargs["forward_meta"].real_bsz
|
||||
num_running_requests = real_bsz if real_bsz > 0 else int((seq_lens_this_time.flatten() > 0).sum().item())
|
||||
num_running_requests = max(1, num_running_requests)
|
||||
real_shape = self.real_bsz_to_captured_size[num_running_requests]
|
||||
else:
|
||||
ids_remove_padding: paddle.Tensor = kwargs["forward_meta"].ids_remove_padding
|
||||
real_shape = ids_remove_padding.shape[0]
|
||||
exist_prefill = kwargs["forward_meta"].exist_prefill
|
||||
# Static split graph mode: use Static + CUDAGraph for prefill/mixed phase
|
||||
static_cudagraph_for_prefill = exist_prefill and not self.full_cuda_graph and self.dy2st
|
||||
|
||||
@@ -123,7 +123,7 @@ def gather_logprobs(
|
||||
indices = token_ids
|
||||
top_logprobs = token_logprobs
|
||||
|
||||
return LogprobsTensors(indices, top_logprobs, token_ranks)
|
||||
return LogprobsTensors(indices.cpu(), top_logprobs.cpu(), token_ranks.cpu())
|
||||
|
||||
|
||||
def build_output_logprobs(
|
||||
|
||||
@@ -1041,7 +1041,7 @@ class SpeculativeSampler(nn.Layer):
|
||||
)
|
||||
sampler_output.logprobs_tensors = logprobs_tensors
|
||||
if cu_batch_token_offset is not None:
|
||||
sampler_output.cu_batch_token_offset = cu_batch_token_offset
|
||||
sampler_output.cu_batch_token_offset = cu_batch_token_offset.cpu()
|
||||
return sampler_output
|
||||
|
||||
def forward_xpu(
|
||||
|
||||
@@ -437,8 +437,6 @@ def post_process_specualate(
|
||||
model_output: ModelOutputData,
|
||||
share_inputs: InputBatch,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
save_each_rank: bool = False,
|
||||
skip_save_output: bool = False,
|
||||
think_end_id: int = -1,
|
||||
splitwise_role_is_decode: bool = False,
|
||||
enable_entropy: bool = False,
|
||||
@@ -508,7 +506,7 @@ def post_process_specualate(
|
||||
unified_update_model_status(
|
||||
model_output.seq_lens_encoder, # seq_lens_encoder
|
||||
model_output.seq_lens_decoder, # seq_lens_decoder
|
||||
model_output.not_need_stop, # has_running_seqs
|
||||
model_output.not_need_stop_device, # has_running_seqs
|
||||
model_output.draft_tokens, # step_input_ids
|
||||
model_output.accept_tokens, # step_output_ids (read-write)
|
||||
model_output.accept_num, # step_output_len (read-write)
|
||||
@@ -522,24 +520,35 @@ def post_process_specualate(
|
||||
model_output.max_dec_len, # max_dec_len
|
||||
)
|
||||
|
||||
|
||||
def save_output_specualate(
|
||||
sampler_output: SamplerOutput,
|
||||
model_output: ModelOutputData,
|
||||
share_inputs: InputBatch,
|
||||
save_each_rank: bool = False,
|
||||
skip_save_output: bool = False,
|
||||
):
|
||||
if not skip_save_output:
|
||||
if sampler_output.logprobs_tensors is None:
|
||||
recover_model_output_map = recover_batch_index_for_output(
|
||||
model_output,
|
||||
recover_share_inputs = recover_batch_index_for_output(
|
||||
share_inputs,
|
||||
model_output.index_to_batch_id,
|
||||
model_output.enable_pd_reorder,
|
||||
["accept_tokens", "accept_num", "seq_lens_decoder", "prompt_lens"],
|
||||
)
|
||||
recover_share_inputs = recover_batch_index_for_output(
|
||||
share_inputs, model_output.index_to_batch_id, model_output.enable_pd_reorder, ["preempted_idx"]
|
||||
[
|
||||
"accept_tokens_cpu",
|
||||
"accept_num_cpu",
|
||||
"seq_lens_decoder_cpu",
|
||||
"prompt_lens_cpu",
|
||||
"last_preempted_idx",
|
||||
],
|
||||
)
|
||||
speculate_save_output(
|
||||
recover_model_output_map["accept_tokens"],
|
||||
recover_model_output_map["accept_num"],
|
||||
recover_share_inputs["accept_tokens_cpu"],
|
||||
recover_share_inputs["accept_num_cpu"],
|
||||
model_output.not_need_stop,
|
||||
recover_model_output_map["seq_lens_decoder"],
|
||||
recover_model_output_map["prompt_lens"],
|
||||
recover_share_inputs["preempted_idx"],
|
||||
recover_share_inputs["seq_lens_decoder_cpu"],
|
||||
recover_share_inputs["prompt_lens_cpu"],
|
||||
recover_share_inputs["last_preempted_idx"],
|
||||
model_output.mp_rank,
|
||||
save_each_rank,
|
||||
bool(envs.ENABLE_V1_KVCACHE_SCHEDULER),
|
||||
@@ -548,30 +557,35 @@ def post_process_specualate(
|
||||
recover_batch_index_for_sampler_output(
|
||||
sampler_output, model_output.index_to_batch_id, model_output.enable_pd_reorder
|
||||
)
|
||||
recover_model_output_map = recover_batch_index_for_output(
|
||||
model_output,
|
||||
recover_share_inputs = recover_batch_index_for_output(
|
||||
share_inputs,
|
||||
model_output.index_to_batch_id,
|
||||
model_output.enable_pd_reorder,
|
||||
["seq_lens_decoder", "prompt_lens"],
|
||||
)
|
||||
recover_share_inputs = recover_batch_index_for_output(
|
||||
share_inputs, model_output.index_to_batch_id, model_output.enable_pd_reorder, ["preempted_idx"]
|
||||
[
|
||||
"sampled_token_ids",
|
||||
"accept_tokens_cpu",
|
||||
"accept_num_cpu",
|
||||
"seq_lens_decoder_cpu",
|
||||
"prompt_lens_cpu",
|
||||
"last_preempted_idx",
|
||||
],
|
||||
)
|
||||
speculate_save_output_topk(
|
||||
sampler_output.sampled_token_ids,
|
||||
recover_share_inputs["sampled_token_ids"],
|
||||
sampler_output.logprobs_tensors.logprob_token_ids,
|
||||
sampler_output.logprobs_tensors.logprobs,
|
||||
sampler_output.logprobs_tensors.selected_token_ranks,
|
||||
sampler_output.token_num_per_batch,
|
||||
recover_share_inputs["accept_num_cpu"],
|
||||
sampler_output.cu_batch_token_offset,
|
||||
model_output.not_need_stop,
|
||||
recover_model_output_map["seq_lens_decoder"],
|
||||
recover_model_output_map["prompt_lens"],
|
||||
recover_share_inputs["preempted_idx"],
|
||||
recover_share_inputs["seq_lens_decoder_cpu"],
|
||||
recover_share_inputs["prompt_lens_cpu"],
|
||||
recover_share_inputs["last_preempted_idx"],
|
||||
3, # mtype
|
||||
model_output.mp_rank,
|
||||
save_each_rank,
|
||||
)
|
||||
share_inputs["last_preempted_idx"][:] = 0
|
||||
|
||||
|
||||
def post_process(
|
||||
@@ -609,13 +623,12 @@ def post_process(
|
||||
model_output,
|
||||
share_inputs,
|
||||
sampling_metadata,
|
||||
save_each_rank,
|
||||
skip_save_output,
|
||||
think_end_id,
|
||||
splitwise_role_is_decode,
|
||||
enable_entropy,
|
||||
routing_replay_manager,
|
||||
)
|
||||
share_inputs["last_preempted_idx"].copy_(share_inputs["preempted_idx"])
|
||||
else:
|
||||
post_process_normal(
|
||||
sampler_or_pooler_output,
|
||||
|
||||
@@ -71,7 +71,7 @@ else:
|
||||
set_data_ipc,
|
||||
unset_data_ipc,
|
||||
)
|
||||
from fastdeploy.model_executor.pre_and_post_process import pre_process
|
||||
from fastdeploy.model_executor.pre_and_post_process import async_set_value, pre_process
|
||||
|
||||
from fastdeploy.worker.input_batch import (
|
||||
ProposerInputBatch,
|
||||
@@ -143,6 +143,7 @@ class MTPProposer(Proposer):
|
||||
|
||||
# Forward meta store the global meta information of the forward
|
||||
self.forward_meta = None
|
||||
self.exist_prefill_flag = False
|
||||
|
||||
def _update_mtp_config(self, main_model):
|
||||
"""
|
||||
@@ -499,6 +500,7 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["batch_drop"][idx : idx + 1] = False
|
||||
|
||||
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = length
|
||||
self.exist_prefill_flag = True
|
||||
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index
|
||||
self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = length
|
||||
self.model_inputs["step_idx"][idx : idx + 1] = (
|
||||
@@ -521,6 +523,7 @@ class MTPProposer(Proposer):
|
||||
self.fd_config.scheduler_config.splitwise_role == "decode"
|
||||
): # In PD, we continue to decode after P generates first token
|
||||
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0
|
||||
self.exist_prefill_flag = False
|
||||
self.model_inputs["recompute_token_num"][idx : idx + 1] = 0
|
||||
self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = length + 1
|
||||
# NOTE(liuzichang):
|
||||
@@ -531,9 +534,14 @@ class MTPProposer(Proposer):
|
||||
encoder_block_num = len(request.block_tables)
|
||||
self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
|
||||
request.block_tables, dtype="int32"
|
||||
)
|
||||
if current_platform.is_cuda():
|
||||
async_set_value(
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num], request.block_tables
|
||||
)
|
||||
else:
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
|
||||
request.block_tables, dtype="int32"
|
||||
)
|
||||
# if self.model_inputs["is_block_step"][idx]: # has tasks to continue to decode
|
||||
# has_decode_task = True
|
||||
# continue
|
||||
@@ -631,7 +639,6 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
|
||||
request.get("block_tables"), dtype="int32"
|
||||
)
|
||||
self.model_inputs["not_need_stop"][0] = True
|
||||
self.model_inputs.seq_lens_this_time = self.model_inputs["seq_lens_this_time_buffer"]
|
||||
|
||||
def _initialize_forward_meta(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, substep: int = 0):
|
||||
@@ -706,10 +713,7 @@ class MTPProposer(Proposer):
|
||||
"""
|
||||
check whether prefill stage exist
|
||||
"""
|
||||
if np.any(self.share_inputs["seq_lens_encoder"].numpy() > 0):
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
return self.exist_prefill_flag
|
||||
|
||||
def _prepare_inputs_cuda(self, full_hidden_states):
|
||||
"""
|
||||
@@ -729,7 +733,7 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["seq_lens_encoder"],
|
||||
self.model_inputs["seq_lens_decoder"],
|
||||
self.model_inputs["step_idx"],
|
||||
self.model_inputs["not_need_stop"],
|
||||
self.model_inputs["not_need_stop_device"],
|
||||
self.model_inputs["pre_ids"],
|
||||
self.target_model_inputs["accept_tokens"],
|
||||
self.target_model_inputs["accept_num"],
|
||||
@@ -822,7 +826,11 @@ class MTPProposer(Proposer):
|
||||
else self.model_inputs["output_cum_offsets"]
|
||||
),
|
||||
self.model_inputs["stop_flags"],
|
||||
self.model_inputs["not_need_stop"],
|
||||
(
|
||||
self.model_inputs["not_need_stop_device"]
|
||||
if current_platform.is_cuda()
|
||||
else self.model_inputs["not_need_stop"]
|
||||
),
|
||||
self.model_inputs["max_dec_len"],
|
||||
self.model_inputs["eos_token_id"],
|
||||
self.model_inputs["base_model_draft_tokens"],
|
||||
@@ -858,18 +866,30 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["step_idx"],
|
||||
)
|
||||
|
||||
def _propose_cuda(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False):
|
||||
def _propose_cuda(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, real_bsz: int = 0):
|
||||
"""
|
||||
Main process for MTP inference.
|
||||
Args:
|
||||
step_use_cudagraph: bool
|
||||
Whether to use cuda graph. Use the target model flag to avoid hanging problems with EP.
|
||||
"""
|
||||
is_blocking = (
|
||||
(not self.fd_config.scheduler_config.enable_overlap_schedule)
|
||||
or is_dummy_run
|
||||
or self.exist_prefill()
|
||||
or real_bsz == 0
|
||||
)
|
||||
for substep in range(self.num_model_steps):
|
||||
if self.model_inputs["not_need_stop"]:
|
||||
if is_blocking:
|
||||
token_num_cpu = self.model_inputs["seq_lens_this_time"].numpy().sum().item()
|
||||
else:
|
||||
if substep == 0:
|
||||
token_num_cpu = real_bsz * (self.max_draft_token_num + 1)
|
||||
else:
|
||||
token_num_cpu = real_bsz
|
||||
if token_num_cpu > 0:
|
||||
self.model_inputs["substep"] = substep
|
||||
# Remove padding
|
||||
token_num_cpu = self.model_inputs["seq_lens_this_time"].numpy().sum().item()
|
||||
(
|
||||
ids_remove_padding,
|
||||
batch_id_per_token,
|
||||
@@ -918,6 +938,7 @@ class MTPProposer(Proposer):
|
||||
step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run, substep=substep
|
||||
)
|
||||
self.forward_meta.batch_id_per_token.copy_(batch_id_per_token, False)
|
||||
self.forward_meta.real_bsz = real_bsz
|
||||
|
||||
# Padding inputs for cuda graph
|
||||
self.padding_cudagraph_inputs()
|
||||
@@ -1034,8 +1055,9 @@ class MTPProposer(Proposer):
|
||||
else:
|
||||
if hasattr(self.model, "empty_input_forward") and not is_dummy_run:
|
||||
self.model.empty_input_forward(forward_meta=self.forward_meta)
|
||||
self.exist_prefill_flag = False
|
||||
|
||||
def _propose_xpu(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False):
|
||||
def _propose_xpu(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, real_bsz: int = 0):
|
||||
"""
|
||||
Main process for MTP inference.
|
||||
Args:
|
||||
@@ -1241,11 +1263,15 @@ class MTPProposer(Proposer):
|
||||
self.target_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda()
|
||||
|
||||
def _run_impl(
|
||||
self, full_hidden_states: paddle.Tensor, step_use_cudagraph: bool = False, is_dummy_run: bool = False
|
||||
self,
|
||||
full_hidden_states: paddle.Tensor,
|
||||
step_use_cudagraph: bool = False,
|
||||
is_dummy_run: bool = False,
|
||||
real_bsz: int = 0,
|
||||
):
|
||||
"""Execute Draft Model"""
|
||||
self._prepare_inputs(full_hidden_states)
|
||||
self._propose(step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run)
|
||||
self._propose(step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run, real_bsz=real_bsz)
|
||||
self._update_status()
|
||||
if self.hybrid_mode:
|
||||
self._extend_draft_token_with_ngram_match()
|
||||
|
||||
@@ -97,6 +97,7 @@ from fastdeploy.model_executor.pre_and_post_process import (
|
||||
pre_process,
|
||||
rebuild_padding,
|
||||
save_output_normal,
|
||||
save_output_specualate,
|
||||
)
|
||||
from fastdeploy.output.pooler import PoolerOutput
|
||||
from fastdeploy.worker.model_runner_base import (
|
||||
@@ -270,9 +271,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
# Cached token count for next batch prediction in overlap scheduling.
|
||||
# Used to avoid synchronization overhead when preparing inputs for the next batch.
|
||||
self._cached_launch_token_num = -1
|
||||
self.enable_overlap_schedule = fd_config.scheduler_config.enable_overlap_schedule and (
|
||||
not self.speculative_decoding
|
||||
)
|
||||
self._cached_real_bsz = -1
|
||||
self.enable_overlap_schedule = fd_config.scheduler_config.enable_overlap_schedule
|
||||
if self.enable_overlap_schedule:
|
||||
logger.info("Using overlap schedule")
|
||||
self.current_launch_token_num = 0
|
||||
@@ -305,7 +305,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
return ((seq_lens_decoder > 0) & ~stop_flags).any().cpu().numpy().item()
|
||||
|
||||
def _resolve_current_launch_token_num(
|
||||
self, cached_token_num: int, token_num_event, is_dummy_or_profile_run: bool
|
||||
self, cached_token_num: int, cached_real_bsz: int, token_num_event, is_dummy_or_profile_run: bool
|
||||
) -> int:
|
||||
"""
|
||||
Resolve token count for current batch.
|
||||
@@ -322,10 +322,12 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
or (not self.enable_overlap_schedule)
|
||||
or self.exist_prefill()
|
||||
or cached_token_num <= 0
|
||||
or cached_real_bsz <= 0
|
||||
):
|
||||
token_num_event.synchronize()
|
||||
return self.share_inputs["seq_lens_this_time_cpu"].numpy().sum().item()
|
||||
return cached_token_num
|
||||
seq_lens_this_time_cpu = self.share_inputs["seq_lens_this_time_cpu"].numpy()
|
||||
return seq_lens_this_time_cpu.sum().item(), (seq_lens_this_time_cpu > 0).sum().item()
|
||||
return cached_token_num, cached_real_bsz
|
||||
|
||||
def _predict_next_launch_token_num(self) -> int:
|
||||
"""
|
||||
@@ -338,11 +340,15 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
Returns -1 if prediction is not applicable (non-overlap or prefill exists).
|
||||
"""
|
||||
if self.exist_prefill():
|
||||
return -1
|
||||
return (
|
||||
self.share_inputs["seq_lens_this_time_cpu"].numpy().sum().item()
|
||||
+ self.share_inputs["is_block_step_cpu"].numpy().sum().item()
|
||||
return -1, -1
|
||||
seq_lens_this_time_cpu = self.share_inputs["seq_lens_this_time_cpu"].numpy()
|
||||
is_block_step_cpu = self.share_inputs["is_block_step_cpu"].numpy()
|
||||
next_real_bsz = (seq_lens_this_time_cpu > 0).sum().item() + (is_block_step_cpu > 0).sum().item()
|
||||
token_num_one_step = (self.speculative_config.num_speculative_tokens + 1) if self.speculative_decoding else 1
|
||||
next_launch_token_num = (
|
||||
seq_lens_this_time_cpu.sum().item() + is_block_step_cpu.sum().item() * token_num_one_step
|
||||
)
|
||||
return next_launch_token_num, next_real_bsz
|
||||
|
||||
def only_prefill(self):
|
||||
"""
|
||||
@@ -1112,7 +1118,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
self.share_inputs["seq_lens_this_time"] = self.share_inputs["seq_lens_this_time_buffer"]
|
||||
|
||||
def _prepare_inputs(self, cached_token_num=-1, is_dummy_or_profile_run=False) -> None:
|
||||
def _prepare_inputs(self, cached_token_num=-1, cached_real_bsz=-1, is_dummy_or_profile_run=False) -> None:
|
||||
"""Prepare the model inputs"""
|
||||
if self.enable_mm and self.share_inputs["image_features_list"] is not None:
|
||||
tensor_feats = [t for t in self.share_inputs["image_features_list"] if isinstance(t, paddle.Tensor)]
|
||||
@@ -1160,7 +1166,9 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["is_block_step_cpu"].copy_(self.share_inputs["is_block_step"], False)
|
||||
token_num_event = paddle.device.cuda.create_event()
|
||||
token_num_event.record()
|
||||
token_num = self._resolve_current_launch_token_num(cached_token_num, token_num_event, is_dummy_or_profile_run)
|
||||
token_num, real_bsz = self._resolve_current_launch_token_num(
|
||||
cached_token_num, cached_real_bsz, token_num_event, is_dummy_or_profile_run
|
||||
)
|
||||
(
|
||||
ids_remove_padding,
|
||||
batch_id_per_token,
|
||||
@@ -1195,6 +1203,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
|
||||
# Initialize forward meta data
|
||||
self.initialize_forward_meta(is_dummy_or_profile_run=is_dummy_or_profile_run)
|
||||
self.forward_meta.real_bsz = real_bsz
|
||||
|
||||
# Get sampling metadata
|
||||
self.sampling_metadata = SamplingMetadata(
|
||||
@@ -2052,7 +2061,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
) -> None:
|
||||
model_inputs, p_done_idxs, _ = self._preprocess(model_forward_batch, num_running_requests)
|
||||
model_output = self._execute(model_inputs)
|
||||
if model_output is None or self.share_inputs["seq_lens_this_time_cpu"].numpy().sum().item() <= 0:
|
||||
real_bsz = (self.share_inputs["seq_lens_this_time_cpu"].numpy() > 0).sum().item()
|
||||
if model_output is None or real_bsz <= 0:
|
||||
if (
|
||||
self.fd_config.speculative_config.method == SpecMethod.MTP
|
||||
and hasattr(self.proposer.model, "empty_input_forward")
|
||||
@@ -2061,9 +2071,9 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self._execute_empty_mtp_input(self.forward_meta)
|
||||
return
|
||||
model_output_data, sampler_output, post_process_event = self._postprocess(
|
||||
model_output, p_done_idxs, model_forward_batch, num_running_requests
|
||||
model_output, p_done_idxs, model_forward_batch, num_running_requests, real_bsz
|
||||
)
|
||||
if model_output_data is not None and not self.speculative_decoding:
|
||||
if model_output_data is not None:
|
||||
# synchronizes the async DtoH copies of sampled_token_ids.
|
||||
post_process_event.synchronize()
|
||||
self._save_model_output(model_output_data, sampler_output)
|
||||
@@ -2075,7 +2085,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
) -> None:
|
||||
# preprocess and execute model (current batch)
|
||||
model_inputs, p_done_idxs, token_num_event = self._preprocess(
|
||||
model_forward_batch, num_running_requests, self._cached_launch_token_num
|
||||
model_forward_batch, num_running_requests, self._cached_launch_token_num, self._cached_real_bsz
|
||||
)
|
||||
model_output = self._execute(model_inputs)
|
||||
# save output (last batch)
|
||||
@@ -2091,10 +2101,11 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
# synchronizes the async DtoH copies of seq_lens_this_time_cpu and is_block_step_cpu,
|
||||
# ensuring that the token count for the current batch is ready to be computed and reused in the subsequent batch.
|
||||
token_num_event.synchronize()
|
||||
next_launch_token_num = self._predict_next_launch_token_num()
|
||||
if self.share_inputs["seq_lens_this_time_cpu"].numpy().sum().item() > 0 and model_output is not None:
|
||||
next_launch_token_num, next_real_bsz = self._predict_next_launch_token_num()
|
||||
real_bsz = (self.share_inputs["seq_lens_this_time_cpu"].numpy() > 0).sum().item()
|
||||
if real_bsz > 0 and model_output is not None:
|
||||
model_output_data, sampler_output, post_process_event = self._postprocess(
|
||||
model_output, p_done_idxs, model_forward_batch, num_running_requests
|
||||
model_output, p_done_idxs, model_forward_batch, num_running_requests, real_bsz
|
||||
)
|
||||
self._cached_model_output_data = model_output_data
|
||||
self._cached_sampler_output = sampler_output
|
||||
@@ -2104,12 +2115,14 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self._cached_sampler_output = None
|
||||
self._cached_post_process_event = None
|
||||
self._cached_launch_token_num = next_launch_token_num
|
||||
self._cached_real_bsz = next_real_bsz
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
model_forward_batch: Optional[List[Request]] = None,
|
||||
num_running_requests: int = None,
|
||||
cached_token_num: int = -1,
|
||||
cached_real_bsz: int = -1,
|
||||
) -> None:
|
||||
if self.deterministic_logger is not None:
|
||||
self.deterministic_logger.log_batch_start(model_forward_batch)
|
||||
@@ -2118,7 +2131,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self._process_reorder()
|
||||
|
||||
# Prepare inputs of model and sampler.
|
||||
current_launch_token_num, token_num_event = self._prepare_inputs(cached_token_num)
|
||||
current_launch_token_num, token_num_event = self._prepare_inputs(cached_token_num, cached_real_bsz)
|
||||
self.current_launch_token_num = current_launch_token_num
|
||||
|
||||
# NOTE(sunxin):
|
||||
@@ -2170,12 +2183,13 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
p_done_idxs: List[int],
|
||||
model_forward_batch: Optional[List[Request]] = None,
|
||||
num_running_requests: int = None,
|
||||
real_bsz: int = 0,
|
||||
) -> None:
|
||||
|
||||
if self.speculative_decoding:
|
||||
self.output_token_num_event.synchronize()
|
||||
real_num = int(self._real_output_token_num_host)
|
||||
real_batch_id_per_token_output = self.share_inputs["batch_id_per_token_output"][:real_num]
|
||||
real_output_token_num = int(self._real_output_token_num_host)
|
||||
real_batch_id_per_token_output = self.share_inputs["batch_id_per_token_output"][:real_output_token_num]
|
||||
|
||||
prompt_logprobs_list = self._get_prompt_logprobs_list(model_output)
|
||||
if self.is_pooling_model:
|
||||
@@ -2305,7 +2319,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.sampling_metadata,
|
||||
self.model_config.max_model_len,
|
||||
self.share_inputs,
|
||||
int(self._real_output_token_num_host),
|
||||
real_output_token_num,
|
||||
self.increment_value,
|
||||
)
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
@@ -2388,6 +2402,17 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
if self.guided_backend is not None and sampler_output is not None:
|
||||
self.sampler.post_process(sampler_output.sampled_token_ids)
|
||||
|
||||
# 5.1. Async cpy
|
||||
post_process_event = paddle.device.cuda.create_event()
|
||||
# if not self.speculative_decoding:
|
||||
self.share_inputs["sampled_token_ids"].copy_(sampler_output.sampled_token_ids, False)
|
||||
if self.speculative_decoding:
|
||||
self.share_inputs["accept_tokens_cpu"].copy_(self.share_inputs["accept_tokens"], False)
|
||||
self.share_inputs["accept_num_cpu"].copy_(self.share_inputs["accept_num"], False)
|
||||
self.share_inputs["seq_lens_decoder_cpu"].copy_(self.share_inputs["seq_lens_decoder"], False)
|
||||
self.share_inputs["prompt_lens_cpu"].copy_(self.share_inputs["prompt_lens"], False)
|
||||
post_process_event.record()
|
||||
|
||||
# 6. Speculative decode -- proposer run (method="naive" has proposer=None, skip)
|
||||
# For naive mode: seq_lens_this_time is already reset to 1 inside
|
||||
# unified_update_model_status kernel. For MTP/Ngram, the proposer
|
||||
@@ -2396,7 +2421,9 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
if self.speculative_decoding and self.proposer is not None:
|
||||
if self.spec_method == SpecMethod.MTP:
|
||||
self.proposer.run(
|
||||
full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph
|
||||
full_hidden_states=model_output,
|
||||
step_use_cudagraph=self.forward_meta.step_use_cudagraph,
|
||||
real_bsz=real_bsz,
|
||||
)
|
||||
elif self.spec_method == SpecMethod.NAIVE:
|
||||
pass
|
||||
@@ -2422,17 +2449,11 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["accept_num"],
|
||||
self.share_inputs["accept_tokens"],
|
||||
self.share_inputs["is_block_step"],
|
||||
self.share_inputs["not_need_stop"],
|
||||
self.share_inputs["not_need_stop_device"],
|
||||
self.cache_config.block_size,
|
||||
self.speculative_config.num_speculative_tokens,
|
||||
)
|
||||
|
||||
# 8. Async cpy
|
||||
post_process_event = paddle.device.cuda.create_event()
|
||||
if not self.speculative_decoding:
|
||||
self.share_inputs["sampled_token_ids"].copy_(sampler_output.sampled_token_ids, False)
|
||||
post_process_event.record()
|
||||
|
||||
self.exist_prefill_flag = False
|
||||
return model_output_data, sampler_output, post_process_event
|
||||
|
||||
@@ -2441,13 +2462,23 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
model_output_data,
|
||||
sampler_output,
|
||||
):
|
||||
save_output_normal(
|
||||
model_output=model_output_data,
|
||||
sampler_output=sampler_output,
|
||||
share_inputs=self.share_inputs,
|
||||
async_output_queue=self.async_output_queue,
|
||||
save_each_rank=self.parallel_config.use_ep,
|
||||
)
|
||||
if self.speculative_decoding:
|
||||
skip_save_output = self.spec_method == SpecMethod.MTP and self.scheduler_config.splitwise_role == "prefill"
|
||||
save_output_specualate(
|
||||
sampler_output=sampler_output,
|
||||
model_output=model_output_data,
|
||||
share_inputs=self.share_inputs,
|
||||
save_each_rank=self.parallel_config.use_ep,
|
||||
skip_save_output=skip_save_output,
|
||||
)
|
||||
else:
|
||||
save_output_normal(
|
||||
model_output=model_output_data,
|
||||
sampler_output=sampler_output,
|
||||
share_inputs=self.share_inputs,
|
||||
async_output_queue=self.async_output_queue,
|
||||
save_each_rank=self.parallel_config.use_ep,
|
||||
)
|
||||
|
||||
def _pool(self, hidden_states: paddle.Tensor, num_running_requests: int) -> Optional[ModelRunnerOutput]:
|
||||
num_scheduled_tokens = int(self.share_inputs["seq_lens_this_time"][:num_running_requests].sum())
|
||||
|
||||
@@ -316,6 +316,15 @@ class InputBatch:
|
||||
dtype="float32",
|
||||
)
|
||||
self.cu_batch_token_offset = paddle.full(shape=[max_num_seqs + 1], fill_value=0, dtype="int32")
|
||||
# For mtp overlap
|
||||
self.seq_lens_decoder_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int32").pin_memory()
|
||||
self.prompt_lens_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int64").pin_memory()
|
||||
self.accept_tokens_cpu = paddle.full(
|
||||
shape=[max_num_seqs, max_draft_token_num + 1],
|
||||
fill_value=0,
|
||||
dtype="int64",
|
||||
).pin_memory()
|
||||
self.accept_num_cpu = paddle.full(shape=[max_num_seqs], fill_value=0, dtype="int32").pin_memory()
|
||||
if self.enable_mm:
|
||||
head_dim = self.model_config.head_dim
|
||||
if (
|
||||
@@ -435,6 +444,10 @@ class InputBatch:
|
||||
swap_data(self.step_seq_lens_this_time, i1, i2)
|
||||
swap_data(self.draft_logits, i1, i2)
|
||||
swap_data(self.cu_batch_token_offset, i1, i2)
|
||||
swap_data(self.seq_lens_decoder_cpu, i1, i2)
|
||||
swap_data(self.prompt_lens_cpu, i1, i2)
|
||||
swap_data(self.accept_tokens_cpu, i1, i2)
|
||||
swap_data(self.accept_num_cpu, i1, i2)
|
||||
|
||||
if self.enable_mm:
|
||||
if self.image_features_list is not None:
|
||||
@@ -623,6 +636,15 @@ class InputBatch:
|
||||
fill_paddle_tensor(self, "step_seq_lens_this_time", 0)
|
||||
fill_paddle_tensor(self, "draft_logits", -1)
|
||||
fill_paddle_tensor(self, "cu_batch_token_offset", 0)
|
||||
# for mtp overlap
|
||||
self.prompt_lens_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int64").pin_memory()
|
||||
self.seq_lens_decoder_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int32").pin_memory()
|
||||
self.accept_num_cpu = paddle.full(shape=[max_num_seqs], fill_value=0, dtype="int32").pin_memory()
|
||||
self.accept_tokens_cpu = paddle.full(
|
||||
shape=[max_num_seqs, max_draft_token_num + 1],
|
||||
fill_value=0,
|
||||
dtype="int64",
|
||||
).pin_memory()
|
||||
|
||||
# Reset multimodal related tensors
|
||||
if self.enable_mm:
|
||||
@@ -697,6 +719,7 @@ class ProposerInputBatch(InputBatch):
|
||||
self.step_idx = paddle.clone(self.target_model_input_batch["step_idx"])
|
||||
self.stop_flags = paddle.clone(self.target_model_input_batch["stop_flags"])
|
||||
self.not_need_stop = paddle.to_tensor([False], dtype="bool", place="cpu")
|
||||
self.not_need_stop_device = paddle.to_tensor([False], dtype="bool")
|
||||
if current_platform.is_cuda():
|
||||
self.cu_seqlens_q_output = paddle.clone(self.target_model_input_batch["cu_seqlens_q_output"])
|
||||
self.batch_id_per_token_output = paddle.clone(self.target_model_input_batch["batch_id_per_token_output"])
|
||||
@@ -1085,6 +1108,7 @@ def _recover_tensor(recover_tensor, index_to_batch_id_list):
|
||||
"""
|
||||
sort_len = len(index_to_batch_id_list)
|
||||
if isinstance(recover_tensor.place, paddle.CUDAPinnedPlace):
|
||||
recover_tensor = recover_tensor.cpu()
|
||||
recover_res_tensor = paddle.empty_like(recover_tensor, device="cpu")
|
||||
else:
|
||||
recover_res_tensor = paddle.empty_like(recover_tensor)
|
||||
|
||||
@@ -137,7 +137,7 @@ class TestDraftModelPreprocess(unittest.TestCase):
|
||||
seq_lens_encoder = paddle.randint(0, input_ids_len, [bsz], dtype="int32")
|
||||
seq_lens_decoder = paddle.randint(0, input_ids_len, [bsz], dtype="int32")
|
||||
step_idx = paddle.randint(0, 100, [bsz], dtype="int64")
|
||||
not_need_stop = paddle.zeros([1], dtype="bool").cpu() # must be CPU: kernel writes back via raw pointer
|
||||
not_need_stop = paddle.zeros([1], dtype="bool")
|
||||
pre_ids = input_ids.clone()
|
||||
|
||||
accept_tokens = paddle.randint(0, 100, [bsz, 100], dtype="int64")
|
||||
@@ -230,7 +230,7 @@ class TestDraftModelPreprocess(unittest.TestCase):
|
||||
seq_lens_encoder = paddle.zeros([bsz], dtype="int32") # all decode for simplicity
|
||||
seq_lens_decoder = paddle.randint(max_draft_token + 1, 100, [bsz], dtype="int32")
|
||||
step_idx = paddle.randint(0, 100, [bsz], dtype="int64")
|
||||
not_need_stop = paddle.zeros([1], dtype="bool").cpu() # must be CPU: kernel writes back via raw pointer
|
||||
not_need_stop = paddle.zeros([1], dtype="bool")
|
||||
pre_ids = input_ids.clone()
|
||||
|
||||
accept_tokens = paddle.randint(0, 100, [bsz, 100], dtype="int64")
|
||||
|
||||
@@ -64,13 +64,11 @@ CPU_PLACE = paddle.CPUPlace()
|
||||
|
||||
|
||||
def to_paddle_inputs(inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Convert numpy dict → paddle tensors. not_need_stop stays on CPU."""
|
||||
"""Convert numpy dict → paddle tensors."""
|
||||
paddle_inputs = {}
|
||||
for k, v in inputs.items():
|
||||
if isinstance(v, (int, bool, float)):
|
||||
paddle_inputs[k] = v
|
||||
elif k == "not_need_stop":
|
||||
paddle_inputs[k] = paddle.to_tensor(v, place=CPU_PLACE)
|
||||
elif v is not None:
|
||||
paddle_inputs[k] = paddle.to_tensor(v, place=CUDA_PLACE)
|
||||
return paddle_inputs
|
||||
|
||||
@@ -138,7 +138,7 @@ class TestSpeculateScheduleCache(unittest.TestCase):
|
||||
self.is_block_step = paddle.zeros((self.real_bsz,), dtype=paddle.bool)
|
||||
|
||||
# not_need_stop lives on CPU in the caller; the kernel copies to device internally
|
||||
self.not_need_stop = paddle.zeros((1,), dtype=paddle.bool).cpu()
|
||||
self.not_need_stop = paddle.zeros((1,), dtype=paddle.bool)
|
||||
|
||||
# Choose threshold so with: bid0 triggers, bid1 already stopped, padding (5-3)=2 -> stop_sum = 1+1+2 = 4
|
||||
|
||||
|
||||
@@ -53,14 +53,11 @@ CPU_PLACE = paddle.CPUPlace()
|
||||
|
||||
|
||||
def to_paddle_inputs(inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Convert numpy dict → paddle tensors. has_running_seqs goes to CPU."""
|
||||
"""Convert numpy dict → paddle tensors."""
|
||||
paddle_inputs = {}
|
||||
for k, v in inputs.items():
|
||||
if isinstance(v, (int, bool, float, str)):
|
||||
paddle_inputs[k] = v
|
||||
elif k == "has_running_seqs":
|
||||
# Kernel host function: has_running_seqs.copy_to(GPU) → kernel → copy_to(CPU)
|
||||
paddle_inputs[k] = paddle.to_tensor(v, place=CPU_PLACE)
|
||||
elif v is not None:
|
||||
paddle_inputs[k] = paddle.to_tensor(v, place=CUDA_PLACE)
|
||||
else:
|
||||
|
||||
@@ -71,6 +71,7 @@ class TestMTPProposer(unittest.TestCase):
|
||||
"seq_lens_encoder": paddle.zeros([2, 1], dtype="int32"),
|
||||
"seq_lens_decoder": paddle.zeros([2, 1], dtype="int32"),
|
||||
"prompt_lens": paddle.zeros([2, 1], dtype="int64"),
|
||||
"prompt_lens_cpu": paddle.zeros([2, 1], dtype="int64").pin_memory(),
|
||||
"step_idx": paddle.zeros([2, 1], dtype="int64"),
|
||||
"stop_flags": paddle.zeros([2, 1], dtype="bool"),
|
||||
"token_ids_all": paddle.zeros([2, 2048], dtype="int64"),
|
||||
@@ -379,11 +380,11 @@ class TestMTPProposer(unittest.TestCase):
|
||||
self.assertEqual(proposer.forward_meta.pos_emb_type, "NORMAL")
|
||||
|
||||
# Test exist_prefill
|
||||
proposer.share_inputs = {"seq_lens_encoder": paddle.ones([2, 1], dtype="int32")}
|
||||
proposer.exist_prefill_flag = True
|
||||
result = proposer.exist_prefill()
|
||||
self.assertEqual(result, 1)
|
||||
|
||||
proposer.share_inputs = {"seq_lens_encoder": paddle.zeros([2, 1], dtype="int32")}
|
||||
proposer.exist_prefill_flag = False
|
||||
result = proposer.exist_prefill()
|
||||
self.assertEqual(result, 0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user