mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
1204 lines
58 KiB
Python
1204 lines
58 KiB
Python
"""
|
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
import paddle
|
|
from paddleformers.utils.log import logger
|
|
|
|
from fastdeploy.config import (
|
|
CacheConfig,
|
|
DeployModality,
|
|
FDConfig,
|
|
ModelConfig,
|
|
SpeculativeConfig,
|
|
)
|
|
from fastdeploy.model_executor.layers.rotary_embedding import get_rope
|
|
from fastdeploy.model_executor.logits_processor import build_logits_processors
|
|
from fastdeploy.platforms import current_platform
|
|
|
|
|
|
class InputBatch:
|
|
def __getitem__(self, key):
|
|
"""Support dictionary-style attribute access"""
|
|
if hasattr(self, key):
|
|
return getattr(self, key)
|
|
raise KeyError(f"'{key}' is not a valid attribute of InputBatch")
|
|
|
|
def __setitem__(self, key, value):
|
|
"""Support dictionary-style attribute setting, overwrite if exists, add if not exists"""
|
|
setattr(self, key, value)
|
|
|
|
def __contains__(self, key):
|
|
"""Support 'in' operator to check if attribute exists"""
|
|
return hasattr(self, key)
|
|
|
|
def update(self, values: dict):
|
|
"""Batch update attributes, similar to dict's update method"""
|
|
for key, value in values.items():
|
|
setattr(self, key, value)
|
|
|
|
def get(self, key, default=None):
|
|
if hasattr(self, key):
|
|
return getattr(self, key)
|
|
return default
|
|
|
|
def pop(self, key, default=None):
|
|
"""
|
|
Pop an attribute, similar to dict's pop method
|
|
|
|
Args:
|
|
key: Name of the attribute to pop
|
|
default: Default value to return if attribute does not exist
|
|
|
|
Returns:
|
|
Popped attribute value, or default if attribute doesn't exist and default is provided
|
|
"""
|
|
if hasattr(self, key):
|
|
value = getattr(self, key)
|
|
delattr(self, key)
|
|
return value
|
|
elif default is not None:
|
|
return default
|
|
else:
|
|
raise KeyError(f"'{key}' is not a valid attribute of InputBatch")
|
|
|
|
def __delitem__(self, key):
|
|
"""
|
|
Delete an attribute using del operator
|
|
|
|
Args:
|
|
key: Name of the attribute to delete
|
|
|
|
Raises:
|
|
KeyError: If attribute does not exist
|
|
"""
|
|
if hasattr(self, key):
|
|
delattr(self, key)
|
|
else:
|
|
raise KeyError(f"'{key}' is not a valid attribute of InputBatch")
|
|
|
|
def __init__(self, fd_config: FDConfig) -> None:
|
|
"""
|
|
Initialize all share buffers for model inputs.
|
|
"""
|
|
self.num_running_requests = 0
|
|
self.running_requests_ids = []
|
|
self.fd_config: FDConfig = fd_config
|
|
self.model_config: ModelConfig = fd_config.model_config
|
|
self.cache_config: CacheConfig = fd_config.cache_config
|
|
self.scheduler_config = fd_config.scheduler_config
|
|
self.speculative_config: SpeculativeConfig = fd_config.speculative_config
|
|
self.speculative_decoding = self.speculative_config.method is not None
|
|
self.enable_mm = self.model_config.enable_mm
|
|
self.enable_expert_parallel = fd_config.parallel_config.enable_expert_parallel
|
|
self.index_to_batch_id = {}
|
|
self.enable_pd_reorder = False
|
|
# Qwen vl etc. do not support mm_max_tokens_per_item now
|
|
if self.enable_mm and self.model_config.mm_max_tokens_per_item is None:
|
|
self.max_chunk_tokens = self.model_config.max_model_len
|
|
else:
|
|
self.max_chunk_tokens = self.fd_config.get_max_chunk_tokens(self.model_config.mm_max_tokens_per_item)
|
|
|
|
def init_share_inputs(self):
|
|
max_num_seqs = self.scheduler_config.max_num_seqs
|
|
|
|
self.token_ids_all = paddle.full(
|
|
[max_num_seqs, self.model_config.max_model_len],
|
|
-1,
|
|
dtype="int64",
|
|
)
|
|
self.input_ids = paddle.full(
|
|
[max_num_seqs, self.max_chunk_tokens],
|
|
self.model_config.pad_token_id,
|
|
dtype="int64",
|
|
)
|
|
self.eos_token_id = paddle.full([self.model_config.eos_tokens_lens, 1], 0, dtype="int64")
|
|
self.top_p = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32")
|
|
self.top_k = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
|
self.top_k_list = [0] * max_num_seqs
|
|
self.min_p = paddle.full([max_num_seqs, 1], 0.0, dtype="float32")
|
|
self.min_p_list = [0.0] * max_num_seqs
|
|
self.temperature = paddle.full([max_num_seqs, 1], self.model_config.temperature, dtype="float32")
|
|
self.penalty_score = paddle.full([max_num_seqs, 1], self.model_config.penalty_score, dtype="float32")
|
|
self.frequency_score = paddle.full(
|
|
[max_num_seqs, 1],
|
|
self.model_config.frequency_score,
|
|
dtype="float32",
|
|
)
|
|
self.presence_score = paddle.full([max_num_seqs, 1], self.model_config.presence_score, dtype="float32")
|
|
self.temp_scaled_logprobs = paddle.full([max_num_seqs, 1], False, dtype="bool")
|
|
self.top_p_normalized_logprobs = paddle.full([max_num_seqs, 1], False, dtype="bool")
|
|
|
|
self.min_dec_len = paddle.full([max_num_seqs, 1], self.model_config.min_length, dtype="int64")
|
|
self.max_dec_len = paddle.full([max_num_seqs, 1], self.model_config.max_model_len, dtype="int64")
|
|
self.seq_lens_this_time_buffer = paddle.full([max_num_seqs], 0, dtype="int32")
|
|
self.seq_lens_this_time = paddle.full([max_num_seqs], 0, dtype="int32")
|
|
self.seq_lens_encoder = paddle.full([max_num_seqs], 0, dtype="int32")
|
|
self.seq_lens_decoder = paddle.full([max_num_seqs], 0, dtype="int32")
|
|
self.step_seq_lens_encoder = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
|
self.step_seq_lens_decoder = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
|
self.prompt_lens = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
|
self.step_idx = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
|
if current_platform.is_maca():
|
|
self.not_need_stop = paddle.full([1], False, dtype="bool", device="cpu")
|
|
self.sampled_token_ids = paddle.full([max_num_seqs, 1], -1, dtype="int64", device="cpu")
|
|
self.seq_lens_this_time_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int32", device="cpu")
|
|
self.is_block_step_cpu = paddle.full([max_num_seqs], False, dtype="bool", device="cpu")
|
|
else:
|
|
self.not_need_stop = paddle.full([1], False, dtype="bool").cpu()
|
|
self.sampled_token_ids = paddle.full([max_num_seqs, 1], -1, dtype="int64").pin_memory()
|
|
self.seq_lens_this_time_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int32").pin_memory()
|
|
self.is_block_step_cpu = paddle.full([max_num_seqs], False, dtype="bool").pin_memory()
|
|
self.not_need_stop_device = paddle.full([1], False, dtype="bool")
|
|
self.stop_flags = paddle.full([max_num_seqs, 1], True, dtype="bool")
|
|
|
|
self.bad_tokens = paddle.full([max_num_seqs, self.model_config.vocab_size], -1, dtype="int64")
|
|
self.bad_tokens_len = paddle.full([max_num_seqs], 1, dtype="int64")
|
|
self.next_tokens = paddle.full([max_num_seqs, 1], -1, dtype="int64")
|
|
self.is_block_step = paddle.full([max_num_seqs], False, dtype="bool")
|
|
self.is_chunk_step = paddle.full([max_num_seqs], False, dtype="bool", device="cpu")
|
|
self.encoder_block_lens = paddle.full([max_num_seqs], 0, dtype="int32")
|
|
self.step_block_list = paddle.full([max_num_seqs], -1, dtype="int32")
|
|
self.step_lens = paddle.full([1], 0, dtype="int32")
|
|
self.recover_block_list = paddle.full([max_num_seqs], -1, dtype="int32")
|
|
self.recover_lens = paddle.full([1], 0, dtype="int32")
|
|
self.need_block_list = paddle.full([max_num_seqs], -1, dtype="int32")
|
|
self.need_block_len = paddle.full([1], 0, dtype="int32")
|
|
self.used_list_len = paddle.full([max_num_seqs], 0, dtype="int32")
|
|
self.infer_seed = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
|
self.first_token_ids = paddle.full([max_num_seqs, 1], -1, dtype="int64")
|
|
self.ori_seq_lens_encoder = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
|
self.system_lens = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
|
self.system_ids = paddle.full([max_num_seqs, 1], -1, dtype="int32")
|
|
self.generated_modality = paddle.full([max_num_seqs], -1, dtype="int32")
|
|
|
|
self.ids_remove_padding = paddle.full(
|
|
[max_num_seqs * self.max_chunk_tokens],
|
|
0,
|
|
dtype="int64",
|
|
)
|
|
self.batch_id_per_token = paddle.full([max_num_seqs * self.max_chunk_tokens, 1], 0, dtype="int32")
|
|
self.cu_seqlens_q = paddle.full([max_num_seqs + 1], 0, dtype="int32")
|
|
self.cu_seqlens_k = paddle.full([max_num_seqs + 1], 0, dtype="int32")
|
|
|
|
# Declare AttentionBackend buffers
|
|
self.decoder_batch_ids = None
|
|
self.decoder_tile_ids_per_batch = None
|
|
self.decoder_num_blocks_cpu = None # Pinning Memory
|
|
self.decoder_num_blocks_device = None
|
|
self.decoder_chunk_size_device = None
|
|
self.max_len_tensor_cpu = None # CPU
|
|
self.encoder_batch_ids = None
|
|
self.encoder_tile_ids_per_batch = None
|
|
self.encoder_num_blocks_x_cpu = None # CPU
|
|
self.kv_batch_ids = None
|
|
self.kv_tile_ids_per_batch = None
|
|
self.kv_num_blocks_x_cpu = None # CPU
|
|
|
|
# Initialize thinking related buffers
|
|
self.enable_thinking = paddle.full(shape=[max_num_seqs, 1], fill_value=True, dtype="bool")
|
|
self.max_think_lens = paddle.full(shape=[max_num_seqs, 1], fill_value=-1, dtype="int32")
|
|
self.max_reply_lens = paddle.full(shape=[max_num_seqs, 1], fill_value=-1, dtype="int32")
|
|
self.limit_think_status = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
|
|
self.inject_token_ids = paddle.to_tensor(self.model_config.think_truncate_prompt_ids, dtype="int64").reshape(
|
|
[-1, 1]
|
|
)
|
|
|
|
# NOTE(liuzichang): token after \n</think>\n\n must be <tool_call> 100973 or <response> 100975
|
|
# It is a hard code to cover up model's performance
|
|
# Detailed notes can be found in FastDeploy/custom_ops/gpu_ops/reasoning_phase_token_constraint.cu
|
|
self.reasoning_status = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
|
|
self.reasoning_allowed_tokens = paddle.to_tensor([100973, 100975], dtype="int64")
|
|
|
|
# Initialize rotary position embedding
|
|
if not self.enable_mm:
|
|
self.rope_emb = get_rope(
|
|
rotary_dim=self.model_config.head_dim,
|
|
position_ids=paddle.arange(self.model_config.max_model_len).reshape((1, -1)),
|
|
base=self.model_config.rope_theta,
|
|
model_config=self.model_config,
|
|
partial_rotary_factor=self.model_config.partial_rotary_factor,
|
|
)
|
|
|
|
# Set block tables
|
|
pre_max_block_num = (
|
|
self.model_config.max_model_len + self.cache_config.block_size - 1
|
|
) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
|
|
self.block_tables = paddle.full([max_num_seqs, pre_max_block_num], -1, dtype="int32")
|
|
|
|
# Initialize free list
|
|
free_list = list(
|
|
range(
|
|
self.cache_config.total_block_num - 1,
|
|
int(self.cache_config.total_block_num * self.cache_config.kv_cache_ratio) - 1,
|
|
-1,
|
|
)
|
|
)
|
|
self.free_list_len = len(free_list)
|
|
self.free_list = paddle.to_tensor(free_list, dtype="int32")
|
|
self.free_list_len = paddle.full([1], self.free_list_len, dtype="int32")
|
|
|
|
# Initialize stop seqs
|
|
self.stop_seqs_len = paddle.full([max_num_seqs, self.model_config.max_stop_seqs_num], 0, dtype="int32")
|
|
self.stop_seqs = paddle.full(
|
|
[
|
|
max_num_seqs,
|
|
self.model_config.max_stop_seqs_num,
|
|
self.model_config.stop_seqs_max_len,
|
|
],
|
|
-1,
|
|
dtype="int64",
|
|
)
|
|
self.req_ids = [""] * max_num_seqs
|
|
self.entropy_list = [[] for _ in range(max_num_seqs)]
|
|
if self.speculative_decoding:
|
|
max_draft_token_num = self.speculative_config.num_speculative_tokens
|
|
self.input_ids_cpu = paddle.full(
|
|
shape=[max_num_seqs, self.model_config.max_model_len],
|
|
fill_value=-1,
|
|
dtype="int64",
|
|
device="cpu",
|
|
)
|
|
self.accept_tokens = paddle.full(
|
|
shape=[max_num_seqs, max_draft_token_num + 1],
|
|
fill_value=0,
|
|
dtype="int64",
|
|
)
|
|
self.accept_num = paddle.full(shape=[max_num_seqs], fill_value=0, dtype="int32")
|
|
self.draft_tokens = paddle.full(
|
|
shape=[max_num_seqs, max_draft_token_num + 1],
|
|
fill_value=0,
|
|
dtype="int64",
|
|
)
|
|
|
|
self.actual_draft_token_num = paddle.full(
|
|
shape=[max_num_seqs],
|
|
fill_value=max_draft_token_num,
|
|
dtype="int32",
|
|
)
|
|
if current_platform.is_cuda():
|
|
self.cu_seqlens_q_output = paddle.full(shape=[max_num_seqs + 1, 1], fill_value=0, dtype="int32")
|
|
self.batch_id_per_token_output = paddle.full(
|
|
shape=[max_num_seqs * (max_draft_token_num + 1)],
|
|
fill_value=0,
|
|
dtype="int32",
|
|
)
|
|
else:
|
|
self.output_cum_offsets = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
|
|
self.output_padding_offset = paddle.full(
|
|
shape=[max_num_seqs * (max_draft_token_num + 1)],
|
|
fill_value=0,
|
|
dtype="int32",
|
|
)
|
|
# For V1_KVCACHE_SCHEDULER
|
|
self.step_draft_tokens = paddle.full(
|
|
shape=[max_num_seqs, max_draft_token_num + 1],
|
|
fill_value=-1,
|
|
dtype="int64",
|
|
)
|
|
self.step_seq_lens_this_time = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
|
# For MTP Logprob
|
|
self.draft_logits = paddle.full(
|
|
[max_num_seqs * (self.speculative_config.num_speculative_tokens + 1), self.model_config.vocab_size],
|
|
-1,
|
|
dtype="float32",
|
|
)
|
|
self.cu_batch_token_offset = paddle.full(shape=[max_num_seqs + 1], fill_value=0, dtype="int32")
|
|
# For mtp overlap
|
|
self.seq_lens_decoder_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int32").pin_memory()
|
|
self.prompt_lens_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int64").pin_memory()
|
|
self.accept_tokens_cpu = paddle.full(
|
|
shape=[max_num_seqs, max_draft_token_num + 1],
|
|
fill_value=0,
|
|
dtype="int64",
|
|
).pin_memory()
|
|
self.accept_num_cpu = paddle.full(shape=[max_num_seqs], fill_value=0, dtype="int32").pin_memory()
|
|
if self.enable_mm:
|
|
head_dim = self.model_config.head_dim
|
|
if (
|
|
"qwen" in self.model_config.model_type or "paddleocr" in self.model_config.model_type
|
|
): # neox style = True
|
|
rope_head_dim = head_dim
|
|
else: # neox style = False
|
|
rope_head_dim = head_dim // 2
|
|
|
|
self.rope_emb = paddle.full(
|
|
shape=[
|
|
max_num_seqs,
|
|
2,
|
|
1,
|
|
self.model_config.max_model_len,
|
|
1,
|
|
rope_head_dim,
|
|
],
|
|
fill_value=0,
|
|
dtype="float32",
|
|
)
|
|
self.image_features = None # Built before the forward
|
|
self.image_features_list = None
|
|
|
|
# For logits processors
|
|
self.logits_processors = build_logits_processors(self.fd_config)
|
|
self.logits_processors_args = [{} for _ in range(max_num_seqs)]
|
|
logger.info(f"Enabled logits processors: {self.logits_processors}")
|
|
|
|
self.mask_rollback = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
|
|
self.preempted_idx = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32", device="cpu")
|
|
self.last_preempted_idx = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32", device="cpu")
|
|
|
|
def swap_states(self, i1, i2) -> None:
|
|
"""Swap the data at indices i1 and i2 for all array-like attributes"""
|
|
|
|
def swap_data(tensor, idx1, idx2):
|
|
"""Safely swap tensor slices using clone"""
|
|
temp = tensor[idx1].clone()
|
|
tensor[idx1] = tensor[idx2].clone()
|
|
tensor[idx2] = temp
|
|
|
|
self.index_to_batch_id[i1], self.index_to_batch_id[i2] = self.index_to_batch_id[i2], self.index_to_batch_id[i1]
|
|
swap_data(self.token_ids_all, i1, i2)
|
|
swap_data(self.input_ids, i1, i2)
|
|
swap_data(self.top_p, i1, i2)
|
|
swap_data(self.top_k, i1, i2)
|
|
swap_data(self.min_p, i1, i2)
|
|
swap_data(self.temperature, i1, i2)
|
|
swap_data(self.penalty_score, i1, i2)
|
|
swap_data(self.frequency_score, i1, i2)
|
|
swap_data(self.presence_score, i1, i2)
|
|
swap_data(self.temp_scaled_logprobs, i1, i2)
|
|
swap_data(self.top_p_normalized_logprobs, i1, i2)
|
|
swap_data(self.min_dec_len, i1, i2)
|
|
swap_data(self.max_dec_len, i1, i2)
|
|
swap_data(self.seq_lens_this_time_buffer, i1, i2)
|
|
swap_data(self.seq_lens_this_time_cpu, i1, i2)
|
|
swap_data(self.seq_lens_encoder, i1, i2)
|
|
swap_data(self.seq_lens_decoder, i1, i2)
|
|
swap_data(self.step_seq_lens_encoder, i1, i2)
|
|
swap_data(self.step_seq_lens_decoder, i1, i2)
|
|
swap_data(self.prompt_lens, i1, i2)
|
|
swap_data(self.step_idx, i1, i2)
|
|
swap_data(self.sampled_token_ids, i1, i2)
|
|
swap_data(self.stop_flags, i1, i2)
|
|
# swap_data(self.recompute_token_num, i1, i2)
|
|
|
|
# # Swap list-based arrays (lists don't need clone)
|
|
self.top_k_list[i1], self.top_k_list[i2] = self.top_k_list[i2], self.top_k_list[i1]
|
|
self.min_p_list[i1], self.min_p_list[i2] = self.min_p_list[i2], self.min_p_list[i1]
|
|
|
|
# Swap 1D arrays
|
|
swap_data(self.bad_tokens, i1, i2)
|
|
swap_data(self.bad_tokens_len, i1, i2)
|
|
swap_data(self.next_tokens, i1, i2)
|
|
swap_data(self.is_block_step, i1, i2)
|
|
swap_data(self.is_block_step_cpu, i1, i2)
|
|
swap_data(self.is_chunk_step, i1, i2)
|
|
swap_data(self.encoder_block_lens, i1, i2)
|
|
swap_data(self.step_block_list, i1, i2)
|
|
swap_data(self.recover_block_list, i1, i2)
|
|
swap_data(self.need_block_list, i1, i2)
|
|
swap_data(self.used_list_len, i1, i2)
|
|
swap_data(self.infer_seed, i1, i2)
|
|
swap_data(self.first_token_ids, i1, i2)
|
|
swap_data(self.ori_seq_lens_encoder, i1, i2)
|
|
swap_data(self.system_lens, i1, i2)
|
|
swap_data(self.system_ids, i1, i2)
|
|
swap_data(self.enable_thinking, i1, i2)
|
|
swap_data(self.max_think_lens, i1, i2)
|
|
swap_data(self.limit_think_status, i1, i2)
|
|
|
|
# # Swap block tables
|
|
swap_data(self.block_tables, i1, i2)
|
|
|
|
# # Swap stop sequences
|
|
swap_data(self.stop_seqs_len, i1, i2)
|
|
swap_data(self.stop_seqs, i1, i2)
|
|
|
|
swap_data(self.preempted_idx, i1, i2)
|
|
swap_data(self.last_preempted_idx, i1, i2)
|
|
swap_data(self.reasoning_status, i1, i2)
|
|
|
|
# Swap speculative decoding buffers if enabled
|
|
if self.speculative_decoding:
|
|
swap_data(self.input_ids_cpu, i1, i2)
|
|
swap_data(self.accept_tokens, i1, i2)
|
|
swap_data(self.accept_num, i1, i2)
|
|
swap_data(self.draft_tokens, i1, i2)
|
|
swap_data(self.actual_draft_token_num, i1, i2)
|
|
if current_platform.is_cuda():
|
|
swap_data(self.cu_seqlens_q_output, i1, i2)
|
|
else:
|
|
swap_data(self.output_cum_offsets, i1, i2)
|
|
swap_data(self.step_draft_tokens, i1, i2)
|
|
swap_data(self.step_seq_lens_this_time, i1, i2)
|
|
swap_data(self.draft_logits, i1, i2)
|
|
swap_data(self.cu_batch_token_offset, i1, i2)
|
|
swap_data(self.seq_lens_decoder_cpu, i1, i2)
|
|
swap_data(self.prompt_lens_cpu, i1, i2)
|
|
swap_data(self.accept_tokens_cpu, i1, i2)
|
|
swap_data(self.accept_num_cpu, i1, i2)
|
|
|
|
if self.enable_mm:
|
|
if self.image_features_list is not None:
|
|
self.image_features_list[i1], self.image_features_list[i2] = (
|
|
self.image_features_list[i2],
|
|
self.image_features_list[i1],
|
|
)
|
|
swap_data(self.share_inputs["rope_emb"], i1, i2)
|
|
# Swap mask rollback
|
|
swap_data(self.mask_rollback, i1, i2)
|
|
|
|
def condense(self) -> None:
|
|
"""
|
|
Condense the input batch by keeping only the running requests and moving their data to the front.
|
|
Running requests are identified by self.running_requests_ids.
|
|
Also updates index_to_batch_id to remove mappings for non-running requests.
|
|
"""
|
|
# Get the indices of running requests from index_to_batch_id
|
|
running_indices = [
|
|
idx for idx, batch_id in self.index_to_batch_id.items() if batch_id in self.running_requests_ids
|
|
]
|
|
|
|
# Sort the indices to maintain order
|
|
running_indices.sort()
|
|
if self.num_running_requests == len(self.index_to_batch_id):
|
|
return
|
|
# Move data of running requests to the front
|
|
for new_idx, old_idx in enumerate(running_indices):
|
|
if new_idx != old_idx:
|
|
self.swap_states(new_idx, old_idx)
|
|
|
|
# Update index_to_batch_id mapping - only keep mappings for running requests
|
|
# After swap_states, the mapping has been updated, just remove non-running ones
|
|
keys_to_remove = [
|
|
key
|
|
for key in list(self.index_to_batch_id.keys())
|
|
if self.index_to_batch_id[key] not in self.running_requests_ids
|
|
]
|
|
for key in keys_to_remove:
|
|
del self.index_to_batch_id[key]
|
|
|
|
def get_index_by_batch_id(self, batch_id):
|
|
"""
|
|
Get the index corresponding to the given batch_id
|
|
|
|
Args:
|
|
batch_id: The batch_id to look up
|
|
|
|
Returns:
|
|
The index corresponding to the batch_id, or add new key if not found
|
|
"""
|
|
for index, bid in self.index_to_batch_id.items():
|
|
if bid == batch_id:
|
|
return index
|
|
if batch_id in self.index_to_batch_id:
|
|
# In PD reordering, some req_idx that are no longer used will be removed and
|
|
# the remaining requests will be re-sorted by index.
|
|
#
|
|
# If req_idx = 2 was removed in the previous step and request 12 later occupied
|
|
# slot 2 (i.e. {2: 12}), inserting a new request with req_id = 2 may overwrite
|
|
# the existing request (req_idx = 12), leading to incorrect behavior.
|
|
#
|
|
# To avoid index collision, we always assign a new slot using the current length
|
|
# as the new index, instead of reusing a previously freed req_idx.
|
|
self.index_to_batch_id[len(self.index_to_batch_id)] = batch_id
|
|
else:
|
|
self.index_to_batch_id[batch_id] = batch_id
|
|
return batch_id
|
|
|
|
def reset_share_inputs(self):
|
|
"""
|
|
Reset all paddle tensors to their initial state.
|
|
This method clears the content of the shared input buffers while preserving
|
|
their shapes and data types.
|
|
"""
|
|
try:
|
|
logger.info("Resetting share_inputs to initial state...")
|
|
from fastdeploy.utils import fill_paddle_tensor
|
|
|
|
# Reset all paddle tensors to their initial fill values
|
|
max_num_seqs = self.scheduler_config.max_num_seqs
|
|
|
|
# Reset basic tensors to their default values
|
|
fill_paddle_tensor(self, "token_ids_all", -1)
|
|
fill_paddle_tensor(self, "input_ids", self.model_config.pad_token_id)
|
|
fill_paddle_tensor(self, "eos_token_id", 0)
|
|
fill_paddle_tensor(self, "top_p", self.model_config.top_p)
|
|
fill_paddle_tensor(self, "top_k", 0)
|
|
fill_paddle_tensor(self, "min_p", 0.0)
|
|
fill_paddle_tensor(self, "temperature", self.model_config.temperature)
|
|
fill_paddle_tensor(self, "penalty_score", self.model_config.penalty_score)
|
|
fill_paddle_tensor(self, "frequency_score", self.model_config.frequency_score)
|
|
fill_paddle_tensor(self, "presence_score", self.model_config.presence_score)
|
|
fill_paddle_tensor(self, "temp_scaled_logprobs", False)
|
|
fill_paddle_tensor(self, "top_p_normalized_logprobs", False)
|
|
|
|
# Reset list variables (not paddle tensors)
|
|
self.top_k_list = [0] * max_num_seqs
|
|
self.min_p_list = [0.0] * max_num_seqs
|
|
|
|
fill_paddle_tensor(self, "min_dec_len", self.model_config.min_length)
|
|
fill_paddle_tensor(self, "max_dec_len", self.model_config.max_model_len)
|
|
|
|
# Reset sequence length related buffers
|
|
fill_paddle_tensor(self, "seq_lens_this_time_buffer", 0)
|
|
fill_paddle_tensor(self, "seq_lens_this_time", 0)
|
|
fill_paddle_tensor(self, "seq_lens_encoder", 0)
|
|
fill_paddle_tensor(self, "seq_lens_decoder", 0)
|
|
fill_paddle_tensor(self, "step_seq_lens_encoder", 0)
|
|
fill_paddle_tensor(self, "step_seq_lens_decoder", 0)
|
|
fill_paddle_tensor(self, "prompt_lens", 0)
|
|
fill_paddle_tensor(self, "step_idx", 0)
|
|
# fill_paddle_tensor(self, "not_need_stop", False)
|
|
fill_paddle_tensor(self, "not_need_stop_device", False)
|
|
fill_paddle_tensor(self, "sampled_token_ids", -1)
|
|
fill_paddle_tensor(self, "stop_flags", True)
|
|
|
|
fill_paddle_tensor(self, "bad_tokens", -1)
|
|
fill_paddle_tensor(self, "bad_tokens_len", 1)
|
|
fill_paddle_tensor(self, "next_tokens", -1)
|
|
fill_paddle_tensor(self, "is_block_step", False)
|
|
fill_paddle_tensor(self, "is_chunk_step", False)
|
|
fill_paddle_tensor(self, "encoder_block_lens", 0)
|
|
fill_paddle_tensor(self, "step_block_list", -1)
|
|
fill_paddle_tensor(self, "step_lens", 0)
|
|
fill_paddle_tensor(self, "recover_block_list", -1)
|
|
fill_paddle_tensor(self, "recover_lens", 0)
|
|
fill_paddle_tensor(self, "need_block_list", -1)
|
|
fill_paddle_tensor(self, "need_block_len", 0)
|
|
fill_paddle_tensor(self, "used_list_len", 0)
|
|
fill_paddle_tensor(self, "infer_seed", 0)
|
|
fill_paddle_tensor(self, "first_token_ids", -1)
|
|
fill_paddle_tensor(self, "ori_seq_lens_encoder", 0)
|
|
fill_paddle_tensor(self, "system_lens", 0)
|
|
fill_paddle_tensor(self, "system_ids", -1)
|
|
|
|
fill_paddle_tensor(self, "ids_remove_padding", 0)
|
|
fill_paddle_tensor(self, "batch_id_per_token", 0)
|
|
fill_paddle_tensor(self, "cu_seqlens_q", 0)
|
|
fill_paddle_tensor(self, "cu_seqlens_k", 0)
|
|
|
|
# Reset thinking related buffers
|
|
fill_paddle_tensor(self, "enable_thinking", True)
|
|
fill_paddle_tensor(self, "max_think_lens", -1)
|
|
fill_paddle_tensor(self, "limit_think_status", 0)
|
|
|
|
# Reset reasoning buffers
|
|
fill_paddle_tensor(self, "reasoning_status", 0)
|
|
# Reset reasoning allowed tokens (not using fill_paddle_tensor since it's a fixed tensor)
|
|
self.reasoning_allowed_tokens = paddle.to_tensor([100973, 100975], dtype="int64")
|
|
|
|
# Reset block tables
|
|
fill_paddle_tensor(self, "block_tables", -1)
|
|
|
|
# Reset free list (requires special handling)
|
|
free_list = list(
|
|
range(
|
|
self.cache_config.total_block_num - 1,
|
|
int(self.cache_config.total_block_num * self.cache_config.kv_cache_ratio) - 1,
|
|
-1,
|
|
)
|
|
)
|
|
self.free_list = paddle.to_tensor(free_list, dtype="int32")
|
|
self.free_list_len = paddle.full([1], len(free_list), dtype="int32")
|
|
|
|
# Reset stop sequences
|
|
fill_paddle_tensor(self, "stop_seqs_len", 0)
|
|
fill_paddle_tensor(self, "stop_seqs", -1)
|
|
|
|
# Reset other list variables
|
|
self.req_ids = [""] * max_num_seqs
|
|
self.entropy_list = [[] for _ in range(max_num_seqs)]
|
|
self.logits_processors_args = [{} for _ in range(max_num_seqs)]
|
|
|
|
# Reset speculative decoding tensors if enabled
|
|
if self.speculative_decoding:
|
|
max_draft_token_num = self.speculative_config.num_speculative_tokens
|
|
fill_paddle_tensor(self, "input_ids_cpu", -1)
|
|
fill_paddle_tensor(self, "accept_tokens", 0)
|
|
fill_paddle_tensor(self, "accept_num", 0)
|
|
fill_paddle_tensor(self, "draft_tokens", -1)
|
|
fill_paddle_tensor(self, "actual_draft_token_num", max_draft_token_num)
|
|
fill_paddle_tensor(self, "output_cum_offsets", 0)
|
|
fill_paddle_tensor(self, "output_padding_offset", 0)
|
|
fill_paddle_tensor(self, "step_draft_tokens", 0)
|
|
fill_paddle_tensor(self, "step_seq_lens_this_time", 0)
|
|
fill_paddle_tensor(self, "draft_logits", -1)
|
|
fill_paddle_tensor(self, "cu_batch_token_offset", 0)
|
|
# for mtp overlap
|
|
self.prompt_lens_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int64").pin_memory()
|
|
self.seq_lens_decoder_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int32").pin_memory()
|
|
self.accept_num_cpu = paddle.full(shape=[max_num_seqs], fill_value=0, dtype="int32").pin_memory()
|
|
self.accept_tokens_cpu = paddle.full(
|
|
shape=[max_num_seqs, max_draft_token_num + 1],
|
|
fill_value=0,
|
|
dtype="int64",
|
|
).pin_memory()
|
|
|
|
# Reset multimodal related tensors
|
|
if self.enable_mm:
|
|
head_dim = self.model_config.head_dim
|
|
if "qwen" in self.model_config.model_type or "paddleocr" in self.model_config.model_type:
|
|
rope_head_dim = head_dim
|
|
else:
|
|
rope_head_dim = head_dim // 2
|
|
|
|
self.rope_emb = paddle.full(
|
|
shape=[
|
|
max_num_seqs,
|
|
2,
|
|
1,
|
|
self.model_config.max_model_len,
|
|
1,
|
|
rope_head_dim,
|
|
],
|
|
fill_value=0,
|
|
dtype="float32",
|
|
)
|
|
self.image_features = None
|
|
self.image_features_list = None
|
|
else:
|
|
# Reset non-multimodal rope_emb
|
|
self.rope_emb = get_rope(
|
|
rotary_dim=self.model_config.head_dim,
|
|
position_ids=paddle.arange(self.model_config.max_model_len).reshape((1, -1)),
|
|
base=self.model_config.rope_theta,
|
|
model_config=self.model_config,
|
|
partial_rotary_factor=self.model_config.partial_rotary_factor,
|
|
)
|
|
|
|
# Reset other miscellaneous tensors
|
|
fill_paddle_tensor(self, "mask_rollback", 0)
|
|
fill_paddle_tensor(self, "preempted_idx", 0)
|
|
|
|
logger.info("share_inputs reset completed")
|
|
except Exception as e:
|
|
logger.error(f"Resetting share inputs failed, skipping reset, error message is {e}")
|
|
|
|
|
|
class ProposerInputBatch(InputBatch):
|
|
def __init__(self, fd_config: FDConfig, target_model_input_batch: InputBatch) -> None:
|
|
self.enable_mm = fd_config.model_config.enable_mm
|
|
self.num_model_steps = fd_config.speculative_config.num_model_steps
|
|
self.index_to_batch_id = {}
|
|
self.target_model_input_batch = target_model_input_batch
|
|
self.fd_config: FDConfig = fd_config
|
|
self.scheduler_config = fd_config.scheduler_config
|
|
self.model_config: ModelConfig = fd_config.model_config
|
|
self.cache_config: CacheConfig = fd_config.cache_config
|
|
self.speculative_config: SpeculativeConfig = fd_config.speculative_config
|
|
self.enable_pd_reorder: bool = False
|
|
|
|
def init_share_inputs(self):
|
|
# share with targe model
|
|
self.enable_pd_reorder = getattr(self.target_model_input_batch, "enable_pd_reorder", False)
|
|
|
|
self.block_tables = paddle.clone(self.target_model_input_batch["block_tables"])
|
|
self.input_ids = paddle.clone(self.target_model_input_batch["input_ids"])
|
|
self.input_ids_cpu = paddle.full(
|
|
shape=[self.scheduler_config.max_num_seqs, self.model_config.max_model_len],
|
|
fill_value=-1,
|
|
dtype="int64",
|
|
device="cpu",
|
|
)
|
|
self.seq_lens_this_time_buffer = paddle.clone(self.target_model_input_batch["seq_lens_this_time"])
|
|
|
|
self.seq_lens_encoder = paddle.clone(self.target_model_input_batch["seq_lens_encoder"])
|
|
self.seq_lens_decoder = paddle.clone(self.target_model_input_batch["seq_lens_decoder"])
|
|
self.step_idx = paddle.clone(self.target_model_input_batch["step_idx"])
|
|
self.stop_flags = paddle.clone(self.target_model_input_batch["stop_flags"])
|
|
self.not_need_stop = paddle.to_tensor([False], dtype="bool", place="cpu")
|
|
self.not_need_stop_device = paddle.to_tensor([False], dtype="bool")
|
|
if current_platform.is_cuda():
|
|
self.cu_seqlens_q_output = paddle.clone(self.target_model_input_batch["cu_seqlens_q_output"])
|
|
self.batch_id_per_token_output = paddle.clone(self.target_model_input_batch["batch_id_per_token_output"])
|
|
if "token_ids_all" in self.target_model_input_batch:
|
|
self.token_ids_all = paddle.clone(self.target_model_input_batch["token_ids_all"])
|
|
# TODO: delete pre_ids in mtp
|
|
self.pre_ids = paddle.full(
|
|
[self.scheduler_config.max_num_seqs, self.model_config.max_model_len],
|
|
-1,
|
|
dtype="int64",
|
|
)
|
|
for bs_idx in range(self.scheduler_config.max_num_seqs):
|
|
prompt_len = self.target_model_input_batch["prompt_lens"][bs_idx]
|
|
pre_ids_len = self.model_config.max_model_len - prompt_len
|
|
self.pre_ids[bs_idx, :pre_ids_len] = self.target_model_input_batch["token_ids_all"][
|
|
bs_idx, prompt_len:
|
|
]
|
|
else:
|
|
self.pre_ids = paddle.clone(self.target_model_input_batch["pre_ids"])
|
|
self.token_ids_all = None
|
|
else:
|
|
self.output_cum_offsets = paddle.clone(self.target_model_input_batch["output_cum_offsets"])
|
|
self.output_padding_offset = paddle.clone(self.target_model_input_batch["output_padding_offset"])
|
|
self.pre_ids = paddle.clone(self.target_model_input_batch["pre_ids"])
|
|
self.ids_remove_padding = paddle.clone(self.target_model_input_batch["ids_remove_padding"])
|
|
self.batch_id_per_token = paddle.clone(self.target_model_input_batch["batch_id_per_token"])
|
|
self.cu_seqlens_q = paddle.clone(self.target_model_input_batch["cu_seqlens_q"])
|
|
self.cu_seqlens_k = paddle.clone(self.target_model_input_batch["cu_seqlens_k"])
|
|
|
|
self.target_hidden_states = paddle.full(
|
|
[
|
|
self.scheduler_config.max_num_batched_tokens + self.scheduler_config.max_extra_num_batched_tokens,
|
|
self.model_config.hidden_size,
|
|
],
|
|
0,
|
|
dtype="bfloat16",
|
|
)
|
|
|
|
tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1))
|
|
|
|
self.rope_emb = get_rope(
|
|
rotary_dim=self.model_config.head_dim,
|
|
position_ids=tmp_position_ids,
|
|
base=self.model_config.rope_theta,
|
|
model_config=self.model_config,
|
|
partial_rotary_factor=self.model_config.partial_rotary_factor,
|
|
)
|
|
|
|
# self.caches = self.cache_kvs
|
|
# Inherit generation hyperparameters from the main model for consistency
|
|
self.prompt_lens = self.target_model_input_batch["prompt_lens"]
|
|
self.fake_prompt_lens = paddle.full([self.scheduler_config.max_num_seqs, 1], 0, dtype="int64")
|
|
self.top_p = self.target_model_input_batch["top_p"]
|
|
self.top_k = self.target_model_input_batch["top_k"]
|
|
self.temperature = self.target_model_input_batch["temperature"]
|
|
self.eos_token_id = self.target_model_input_batch["eos_token_id"]
|
|
self.penalty_score = self.target_model_input_batch["penalty_score"]
|
|
self.frequency_score = self.target_model_input_batch["frequency_score"]
|
|
self.presence_score = self.target_model_input_batch["presence_score"]
|
|
self.infer_seed = self.target_model_input_batch["infer_seed"]
|
|
|
|
self.max_dec_len = self.target_model_input_batch["max_dec_len"]
|
|
self.min_dec_len = self.target_model_input_batch["min_dec_len"]
|
|
|
|
self.bad_tokens = self.target_model_input_batch["bad_tokens"]
|
|
self.bad_tokens_len = self.target_model_input_batch["bad_tokens_len"]
|
|
|
|
# Integraad_tokens"]te the updated results in model forward
|
|
self.base_model_draft_tokens = self.target_model_input_batch["draft_tokens"]
|
|
self.substep = 0
|
|
|
|
# Declare AttentionBackend buffers
|
|
self.decoder_batch_ids = None
|
|
self.decoder_tile_ids_per_batch = None
|
|
self.decoder_num_blocks_cpu = None # Pinning Memory
|
|
self.decoder_num_blocks_device = None
|
|
self.decoder_chunk_size_device = None
|
|
self.max_len_tensor_cpu = None # CPU
|
|
self.encoder_batch_ids = None
|
|
self.encoder_tile_ids_per_batch = None
|
|
self.encoder_num_blocks_x_cpu = None # CPU
|
|
self.kv_batch_ids = None
|
|
self.kv_tile_ids_per_batch = None
|
|
self.kv_num_blocks_x_cpu = None # CPU
|
|
|
|
# Input tokens
|
|
self.draft_tokens = paddle.full(
|
|
shape=[self.scheduler_config.max_num_seqs, self.speculative_config.num_speculative_tokens + 1],
|
|
fill_value=-1,
|
|
dtype="int64",
|
|
)
|
|
|
|
self.encoder_block_lens = paddle.clone(self.target_model_input_batch["encoder_block_lens"])
|
|
self.free_list = list(
|
|
range(
|
|
self.cache_config.total_block_num - 1,
|
|
int(self.cache_config.total_block_num * self.cache_config.kv_cache_ratio) - 1,
|
|
-1,
|
|
)
|
|
)
|
|
self.free_list_len = len(self.free_list)
|
|
|
|
self.free_list = paddle.to_tensor(self.free_list, dtype="int32")
|
|
self.free_list_len = paddle.full(shape=[1], fill_value=self.free_list_len, dtype="int32")
|
|
|
|
self.is_block_step = paddle.full(shape=[self.scheduler_config.max_num_seqs, 1], fill_value=False, dtype="bool")
|
|
self.batch_drop = paddle.full(shape=[self.scheduler_config.max_num_seqs, 1], fill_value=False, dtype="bool")
|
|
self.used_list_len = paddle.full(shape=[self.scheduler_config.max_num_seqs], fill_value=0, dtype="int32")
|
|
|
|
if self.num_model_steps > 1:
|
|
self.last_seq_lens_this_time = paddle.full_like(
|
|
self.target_model_input_batch["seq_lens_this_time"], fill_value=-1, dtype="int32"
|
|
)
|
|
self.input_ids_len = paddle.zeros(shape=[self.scheduler_config.max_num_seqs, 1], dtype="int64", device="cpu")
|
|
self.temp_scaled_logprobs = self.target_model_input_batch["temp_scaled_logprobs"]
|
|
self.top_p_normalized_logprobs = self.target_model_input_batch["top_p_normalized_logprobs"]
|
|
self.accept_num = self.target_model_input_batch["accept_num"]
|
|
self.accept_tokens = self.target_model_input_batch["accept_tokens"]
|
|
self.draft_logits = self.target_model_input_batch["draft_logits"]
|
|
self.first_token_hidden_states = paddle.full(
|
|
[self.scheduler_config.max_num_seqs, self.model_config.hidden_size], -1
|
|
)
|
|
self.batch_token_num = paddle.full(shape=[self.scheduler_config.max_num_seqs], fill_value=0, dtype="int32")
|
|
self.next_token_num = paddle.full(shape=[self.scheduler_config.max_num_seqs], fill_value=0, dtype="int32")
|
|
self.cu_batch_token_offset = paddle.full_like(
|
|
self.target_model_input_batch["cu_batch_token_offset"], fill_value=0, dtype="int32"
|
|
)
|
|
self.cu_next_token_offset = paddle.full(
|
|
shape=[self.scheduler_config.max_num_seqs + 1], fill_value=0, dtype="int32"
|
|
)
|
|
self.mask_rollback = paddle.full([self.scheduler_config.max_num_seqs, 1], 0, dtype="int32")
|
|
# NOTE(liuzichang): In speculative decoding, accepted tokens' KV cache is recomputed
|
|
# using the target model's hidden states.
|
|
self.recompute_token_num = paddle.full(
|
|
[self.scheduler_config.max_num_seqs, 1], self.num_model_steps - 1, dtype="int32"
|
|
)
|
|
# attn_mask
|
|
if self.enable_mm:
|
|
self.decode_states = paddle.full(
|
|
[self.scheduler_config.max_num_seqs, self.speculative_config.num_speculative_tokens + 1],
|
|
-1,
|
|
dtype="int32",
|
|
)
|
|
if self.fd_config.deploy_modality != DeployModality.TEXT:
|
|
self.attn_mask_offsets = paddle.full(
|
|
shape=[self.scheduler_config.max_num_seqs * self.model_config.max_model_len],
|
|
fill_value=-1,
|
|
dtype="int32",
|
|
)
|
|
self.attn_mask_offsets_full = paddle.full(
|
|
[self.scheduler_config.max_num_seqs, self.model_config.max_model_len], -1, dtype="int32"
|
|
)
|
|
self.attn_mask_offsets_decoder = paddle.full(
|
|
[self.scheduler_config.max_num_seqs, 1], -1, dtype="int32"
|
|
)
|
|
|
|
def swap_states(self, i1, i2) -> None:
|
|
def swap_data(tensor, idx1, idx2):
|
|
"""Safely swap tensor slices using clone"""
|
|
temp = tensor[idx1].clone()
|
|
tensor[idx1] = tensor[idx2].clone()
|
|
tensor[idx2] = temp
|
|
|
|
self.index_to_batch_id[i1], self.index_to_batch_id[i2] = self.index_to_batch_id[i2], self.index_to_batch_id[i1]
|
|
swap_data(self.block_tables, i1, i2)
|
|
swap_data(self.input_ids, i1, i2)
|
|
swap_data(self.input_ids_cpu, i1, i2)
|
|
swap_data(self.seq_lens_this_time_buffer, i1, i2)
|
|
swap_data(self.seq_lens_encoder, i1, i2)
|
|
swap_data(self.seq_lens_decoder, i1, i2)
|
|
swap_data(self.step_idx, i1, i2)
|
|
swap_data(self.pre_ids, i1, i2)
|
|
swap_data(self.encoder_block_lens, i1, i2)
|
|
swap_data(self.input_ids_len, i1, i2)
|
|
swap_data(self.mask_rollback, i1, i2)
|
|
swap_data(self.recompute_token_num, i1, i2)
|
|
if self.enable_mm and self.fd_config.deploy_modality != DeployModality.TEXT:
|
|
swap_data(self.attn_mask_offsets_full, i1, i2)
|
|
swap_data(self.attn_mask_offsets_decoder, i1, i2)
|
|
|
|
def reset_model_inputs(self) -> None:
|
|
"""
|
|
Reset all paddle tensors in self to their initial state.
|
|
This method clears the content of the model input buffers while preserving
|
|
their shapes and data types.
|
|
"""
|
|
try:
|
|
logger.info("Resetting model_inputs to initial state...")
|
|
from fastdeploy.utils import fill_paddle_tensor
|
|
|
|
# Reset all paddle tensors to their default values
|
|
# Clone the target model inputs to restore initial values
|
|
self.block_tables = paddle.clone(self.target_model_input_batch["block_tables"])
|
|
self.input_ids = paddle.clone(self.target_model_input_batch["input_ids"])
|
|
fill_paddle_tensor(self, "input_ids_cpu", -1)
|
|
# acceptance rate decline when reset seq_lens_this_time
|
|
# self.seq_lens_this_time_buffer = paddle.clone(self.target_model_input_batch["seq_lens_this_time"])
|
|
|
|
self.seq_lens_encoder = paddle.clone(self.target_model_input_batch["seq_lens_encoder"])
|
|
self.seq_lens_decoder = paddle.clone(self.target_model_input_batch["seq_lens_decoder"])
|
|
self.prompt_lens = self.target_model_input_batch["prompt_lens"]
|
|
self.fake_prompt_lens = paddle.full([self.scheduler_config.max_num_seqs, 1], 0, dtype="int64")
|
|
self.step_idx = paddle.clone(self.target_model_input_batch["step_idx"])
|
|
self.stop_flags = paddle.clone(self.target_model_input_batch["stop_flags"])
|
|
self.not_need_stop = paddle.to_tensor([False], dtype="bool", place="cpu")
|
|
self.index_to_batch_id = {}
|
|
if current_platform.is_cuda():
|
|
if "token_ids_all" in self.target_model_input_batch:
|
|
self.token_ids_all = paddle.clone(self.target_model_input_batch["token_ids_all"])
|
|
# TODO: delete pre_ids in mtp
|
|
self.pre_ids = paddle.full(
|
|
[self.scheduler_config.max_num_seqs, self.model_config.max_model_len],
|
|
-1,
|
|
dtype="int64",
|
|
)
|
|
for bs_idx in range(self.scheduler_config.max_num_seqs):
|
|
prompt_len = self.target_model_input_batch["prompt_lens"][bs_idx]
|
|
pre_ids_len = self.model_config.max_model_len - prompt_len
|
|
self.pre_ids[bs_idx, :pre_ids_len] = self.target_model_input_batch["token_ids_all"][
|
|
bs_idx, prompt_len:
|
|
]
|
|
else:
|
|
self.pre_ids = paddle.clone(self.target_model_input_batch["pre_ids"])
|
|
self.token_ids_all = None
|
|
else:
|
|
self.pre_ids = paddle.clone(self.target_model_input_batch["pre_ids"])
|
|
self.ids_remove_padding = paddle.clone(self.target_model_input_batch["ids_remove_padding"])
|
|
self.batch_id_per_token = paddle.clone(self.target_model_input_batch["batch_id_per_token"])
|
|
self.cu_seqlens_q = paddle.clone(self.target_model_input_batch["cu_seqlens_q"])
|
|
self.cu_seqlens_k = paddle.clone(self.target_model_input_batch["cu_seqlens_k"])
|
|
|
|
# Reset target hidden states
|
|
fill_paddle_tensor(self, "target_hidden_states", 0)
|
|
|
|
# Reset rope embedding by recreating with default position_ids
|
|
tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1))
|
|
self.rope_emb = get_rope(
|
|
rotary_dim=self.model_config.head_dim,
|
|
position_ids=tmp_position_ids,
|
|
base=self.model_config.rope_theta,
|
|
model_config=self.model_config,
|
|
partial_rotary_factor=self.model_config.partial_rotary_factor,
|
|
)
|
|
|
|
# Reset generation hyperparameters from the main model
|
|
self.top_p = self.target_model_input_batch["top_p"]
|
|
self.top_k = self.target_model_input_batch["top_k"]
|
|
self.temperature = self.target_model_input_batch["temperature"]
|
|
self.eos_token_id = self.target_model_input_batch["eos_token_id"]
|
|
self.penalty_score = self.target_model_input_batch["penalty_score"]
|
|
self.frequency_score = self.target_model_input_batch["frequency_score"]
|
|
self.presence_score = self.target_model_input_batch["presence_score"]
|
|
self.infer_seed = self.target_model_input_batch["infer_seed"]
|
|
|
|
self.max_dec_len = self.target_model_input_batch["max_dec_len"]
|
|
self.min_dec_len = self.target_model_input_batch["min_dec_len"]
|
|
|
|
self.bad_tokens = self.target_model_input_batch["bad_tokens"]
|
|
self.bad_tokens_len = self.target_model_input_batch["bad_tokens_len"]
|
|
|
|
# Reset speculative decoding specific tensors
|
|
self.base_model_draft_tokens = self.target_model_input_batch["draft_tokens"]
|
|
self.substep = 0
|
|
|
|
# Reset draft tokens
|
|
fill_paddle_tensor(self, "draft_tokens", -1)
|
|
|
|
# Reset encoder block lens
|
|
self.encoder_block_lens = paddle.clone(self.target_model_input_batch["encoder_block_lens"])
|
|
|
|
# Reset free list (recreate with current cache config)
|
|
free_list = list(
|
|
range(
|
|
self.cache_config.total_block_num - 1,
|
|
int(self.cache_config.total_block_num * self.cache_config.kv_cache_ratio) - 1,
|
|
-1,
|
|
)
|
|
)
|
|
self.free_list = paddle.to_tensor(free_list, dtype="int32")
|
|
self.free_list_len = paddle.full(shape=[1], fill_value=len(free_list), dtype="int32")
|
|
|
|
# Reset step and drop flags
|
|
fill_paddle_tensor(self, "is_block_step", False)
|
|
fill_paddle_tensor(self, "batch_drop", False)
|
|
fill_paddle_tensor(self, "used_list_len", 0)
|
|
|
|
# Reset last sequence lengths if applicable
|
|
if self.num_model_steps > 1:
|
|
fill_paddle_tensor(self, "last_seq_lens_this_time", -1)
|
|
|
|
# Reset input IDs length
|
|
fill_paddle_tensor(self, "input_ids_len", 0)
|
|
|
|
# Reset various scores and flags
|
|
self.temp_scaled_logprobs = self.target_model_input_batch["temp_scaled_logprobs"]
|
|
self.top_p_normalized_logprobs = self.target_model_input_batch["top_p_normalized_logprobs"]
|
|
self.accept_num = self.target_model_input_batch["accept_num"]
|
|
self.accept_tokens = self.target_model_input_batch["accept_tokens"]
|
|
self.draft_logits = self.target_model_input_batch["draft_logits"]
|
|
fill_paddle_tensor(self, "first_token_hidden_states", -1)
|
|
fill_paddle_tensor(self, "batch_token_num", 0)
|
|
fill_paddle_tensor(self, "next_token_num", 0)
|
|
fill_paddle_tensor(self, "cu_batch_token_offset", 0)
|
|
fill_paddle_tensor(self, "cu_next_token_offset", 0)
|
|
fill_paddle_tensor(self, "mask_rollback", 0)
|
|
fill_paddle_tensor(self, "recompute_token_num", self.num_model_steps - 1)
|
|
|
|
# Reset multimodal tensors if enabled
|
|
if self.enable_mm:
|
|
fill_paddle_tensor(self, "decode_states", -1)
|
|
if self.fd_config.deploy_modality != DeployModality.TEXT:
|
|
fill_paddle_tensor(self, "attn_mask_offsets", -1)
|
|
fill_paddle_tensor(self, "attn_mask_offsets_full", -1)
|
|
fill_paddle_tensor(self, "attn_mask_offsets_decoder", -1)
|
|
|
|
logger.info("model_inputs reset completed")
|
|
except Exception as e:
|
|
logger.error(f"Resetting model inputs failed, skipping reset, error message is {e}")
|
|
|
|
|
|
def reorder_split_prefill_and_decode_form_index_to_batch_id(input_batch: InputBatch, target_model_input_batch: dict):
|
|
mtp_index_2_mtp_id = {v: k for k, v in input_batch.index_to_batch_id.items()}
|
|
for target_model_id in target_model_input_batch:
|
|
target_model_index = target_model_input_batch[target_model_id]
|
|
if input_batch.index_to_batch_id[target_model_id] == target_model_index:
|
|
continue
|
|
mtp_id = mtp_index_2_mtp_id[target_model_index]
|
|
v1 = input_batch.index_to_batch_id[target_model_id]
|
|
v2 = input_batch.index_to_batch_id[mtp_id]
|
|
input_batch.swap_states(target_model_id, mtp_id)
|
|
# update mapping
|
|
mtp_index_2_mtp_id[v1] = mtp_id
|
|
mtp_index_2_mtp_id[v2] = target_model_id
|
|
|
|
keys_to_remove = input_batch.index_to_batch_id.keys() - target_model_input_batch.keys()
|
|
|
|
for key in keys_to_remove:
|
|
del input_batch.index_to_batch_id[key]
|
|
for k, v in mtp_index_2_mtp_id.items():
|
|
if v == key:
|
|
del mtp_index_2_mtp_id[k]
|
|
break
|
|
|
|
|
|
def reorder_split_prefill_and_decode(input_batch: InputBatch):
|
|
"""
|
|
Reorder input_batch data to place decode requests first and prefill requests last.
|
|
|
|
Args:
|
|
input_batch: Input batch data
|
|
|
|
Returns:
|
|
None: Modifies the input_batch data order in place
|
|
"""
|
|
# 1. Identify decode (prefill_len=0) vs prefill (prefill_len>0) requests
|
|
decode_mask = input_batch.seq_lens_encoder == 0
|
|
|
|
# Get batch size
|
|
batch_size = input_batch.num_running_requests
|
|
|
|
# 2. Use two-pointer algorithm to swap prefill to the back and decode to the front
|
|
left = 0 # Pointer for decode section start
|
|
right = batch_size - 1 # Pointer for prefill section start
|
|
while left <= right:
|
|
if decode_mask[left]: # Left position is decode request, no swap needed, move right
|
|
left += 1
|
|
elif not decode_mask[right]: # Right position is prefill request, no swap needed, move left
|
|
right -= 1
|
|
else:
|
|
# Swap: left position is prefill, right position is decode, need to swap
|
|
input_batch.swap_states(left, right)
|
|
left += 1
|
|
right -= 1
|
|
|
|
|
|
def _recover_tensor(recover_tensor, index_to_batch_id_list):
|
|
"""
|
|
Reorder recover_tensor according to index_to_batch_id_list mapping.
|
|
|
|
Args:
|
|
recover_tensor: paddle.Tensor to be reordered.
|
|
index_to_batch_id_list: List mapping current indices to original batch IDs.
|
|
|
|
Returns:
|
|
A paddle.Tensor with elements restored to the original batch order.
|
|
"""
|
|
sort_len = len(index_to_batch_id_list)
|
|
if isinstance(recover_tensor.place, paddle.CUDAPinnedPlace):
|
|
recover_tensor = recover_tensor.cpu()
|
|
recover_res_tensor = paddle.empty_like(recover_tensor, device="cpu")
|
|
else:
|
|
recover_res_tensor = paddle.empty_like(recover_tensor)
|
|
recover_res_tensor[:sort_len] = recover_tensor[:sort_len][index_to_batch_id_list]
|
|
if sort_len < recover_res_tensor.shape[0]:
|
|
recover_res_tensor[sort_len:] = recover_tensor[sort_len:]
|
|
return recover_res_tensor
|
|
|
|
|
|
def recover_batch_index_for_output(output_cls, index_to_batch_id, enable_pd_reorder, recover_list):
|
|
"""
|
|
Reorder model_output according to index_to_batch_id mapping.
|
|
|
|
Args:
|
|
model_output: Model output object containing sampled_token_ids and other attributes
|
|
index_to_batch_id: Dict mapping indices to original batch IDs
|
|
|
|
Returns:
|
|
Updated model_output object with reordered attributes
|
|
"""
|
|
res_map = {}
|
|
is_not_swapped = all(i == v for i, v in index_to_batch_id.items()) or not enable_pd_reorder
|
|
# Create a new tensor to store the reordered results
|
|
if not is_not_swapped:
|
|
src_order = [k for k, v in sorted(index_to_batch_id.items(), key=lambda x: x[1])]
|
|
for recover_name in recover_list:
|
|
if isinstance(output_cls, dict):
|
|
recover_tensor = output_cls[recover_name]
|
|
else:
|
|
recover_tensor = getattr(output_cls, recover_name)
|
|
if is_not_swapped:
|
|
res_map[recover_name] = recover_tensor
|
|
continue
|
|
|
|
if isinstance(recover_tensor, paddle.Tensor):
|
|
# Create a new tensor to store the reordered results
|
|
res_map[recover_name] = _recover_tensor(recover_tensor, src_order)
|
|
elif isinstance(recover_tensor, list):
|
|
real_recover_tensor = recover_tensor.copy()
|
|
for i1, i2 in enumerate(index_to_batch_id):
|
|
real_recover_tensor[i1], real_recover_tensor[i2] = real_recover_tensor[i2], real_recover_tensor[i1]
|
|
res_map[recover_name] = real_recover_tensor
|
|
else:
|
|
logger.info("Unsupported type of {}".format(recover_name))
|
|
|
|
return res_map
|
|
|
|
|
|
def recover_batch_index_for_sampler_output(sampler_output, index_to_batch_id, enable_pd_reorder):
|
|
"""
|
|
Reorder sampled_token_ids according to index_to_batch_id mapping.
|
|
|
|
Args:
|
|
sampler_output: Sampler output object containing sampled_token_ids and other attributes
|
|
index_to_batch_id: Dict mapping indices to original batch IDs
|
|
|
|
Returns:
|
|
Updated sampler_output object with reordered sampled_token_ids
|
|
"""
|
|
if not enable_pd_reorder or all(i == v for i, v in index_to_batch_id.items()):
|
|
return
|
|
|
|
sampled_token_ids = sampler_output.sampled_token_ids
|
|
# Create a new tensor to store the reordered results
|
|
src_order = [k for k, v in sorted(index_to_batch_id.items(), key=lambda x: x[1])]
|
|
real_token_ids = _recover_tensor(sampled_token_ids, src_order)
|
|
sampler_output.sampled_token_ids = real_token_ids
|
|
if sampler_output.logprobs_tensors is not None:
|
|
logprob_token_ids = sampler_output.logprobs_tensors.logprob_token_ids
|
|
logprobs = sampler_output.logprobs_tensors.logprobs
|
|
selected_token_ranks = sampler_output.logprobs_tensors.selected_token_ranks
|
|
real_logprob_token_ids = _recover_tensor(logprob_token_ids, src_order)
|
|
real_logprobs = _recover_tensor(logprobs, src_order)
|
|
real_selected_token_ranks = _recover_tensor(selected_token_ranks, src_order)
|
|
sampler_output.logprobs_tensors.logprob_token_ids = real_logprob_token_ids
|
|
sampler_output.logprobs_tensors.logprobs = real_logprobs
|
|
sampler_output.logprobs_tensors.sampled_token_ranks = real_selected_token_ranks
|
|
|
|
if sampler_output.token_num_per_batch is not None:
|
|
token_num_per_batch = sampler_output.token_num_per_batch
|
|
real_token_num_per_batch = _recover_tensor(token_num_per_batch, src_order)
|
|
sampler_output.token_num_per_batch = real_token_num_per_batch
|
|
|
|
if sampler_output.cu_batch_token_offset is not None:
|
|
cu_batch_token_offset = sampler_output.cu_batch_token_offset
|
|
real_cu_batch_token_offset = _recover_tensor(cu_batch_token_offset, src_order)
|
|
sampler_output.cu_batch_token_offset = real_cu_batch_token_offset
|
|
|
|
if sampler_output.logits is not None:
|
|
logits = sampler_output.logits
|
|
real_logits = _recover_tensor(logits, src_order)
|
|
sampler_output.logits = real_logits
|