mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
@@ -20,6 +20,7 @@ from paddleformers.utils.log import logger
|
||||
from fastdeploy.config import CacheConfig, FDConfig, ModelConfig, SpeculativeConfig
|
||||
from fastdeploy.model_executor.layers.rotary_embedding import get_rope
|
||||
from fastdeploy.model_executor.logits_processor import build_logits_processors
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
|
||||
class InputBatch:
|
||||
@@ -134,23 +135,29 @@ class InputBatch:
|
||||
self.seq_lens_this_time_buffer = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
if self.enable_expert_parallel:
|
||||
self.seq_lens_this_time = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.seq_lens_this_time_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int32").pin_memory()
|
||||
self.seq_lens_encoder = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.seq_lens_decoder = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.step_seq_lens_encoder = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.step_seq_lens_decoder = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.prompt_lens = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
||||
self.step_idx = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
||||
self.not_need_stop = paddle.full([1], False, dtype="bool").pin_memory()
|
||||
if current_platform.is_maca():
|
||||
self.not_need_stop = paddle.full([1], False, dtype="bool").cpu()
|
||||
self.sampled_token_ids = paddle.full([max_num_seqs, 1], -1, dtype="int64").cpu()
|
||||
self.seq_lens_this_time_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int32").cpu()
|
||||
self.is_block_step_cpu = paddle.full([max_num_seqs], False, dtype="bool").cpu()
|
||||
else:
|
||||
self.not_need_stop = paddle.full([1], False, dtype="bool").pin_memory()
|
||||
self.sampled_token_ids = paddle.full([max_num_seqs, 1], -1, dtype="int64").pin_memory()
|
||||
self.seq_lens_this_time_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int32").pin_memory()
|
||||
self.is_block_step_cpu = paddle.full([max_num_seqs], False, dtype="bool").pin_memory()
|
||||
self.not_need_stop_device = paddle.full([1], False, dtype="bool")
|
||||
self.sampled_token_ids = paddle.full([max_num_seqs, 1], -1, dtype="int64").pin_memory()
|
||||
self.stop_flags = paddle.full([max_num_seqs, 1], True, dtype="bool")
|
||||
|
||||
self.bad_tokens = paddle.full([max_num_seqs, self.model_config.vocab_size], -1, dtype="int64")
|
||||
self.bad_tokens_len = paddle.full([max_num_seqs], 1, dtype="int64")
|
||||
self.next_tokens = paddle.full([max_num_seqs, 1], -1, dtype="int64")
|
||||
self.is_block_step = paddle.full([max_num_seqs], False, dtype="bool")
|
||||
self.is_block_step_cpu = paddle.full([max_num_seqs], False, dtype="bool").pin_memory()
|
||||
self.is_chunk_step = paddle.full([max_num_seqs], False, dtype="bool").cpu()
|
||||
self.encoder_block_lens = paddle.full([max_num_seqs], 0, dtype="int32")
|
||||
self.step_block_list = paddle.full([max_num_seqs], -1, dtype="int32")
|
||||
|
||||
@@ -55,10 +55,9 @@ from fastdeploy.model_executor.layers.moe.routing_indices_cache import (
|
||||
RoutingReplayManager,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.pool.metadata import PoolingMetadata
|
||||
from fastdeploy.model_executor.layers.rotary_embedding import get_rope, get_rope_3d
|
||||
from fastdeploy.model_executor.layers.rotary_embedding import get_rope_3d
|
||||
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
||||
from fastdeploy.model_executor.layers.sample.sampler import Sampler, SpeculativeSampler
|
||||
from fastdeploy.model_executor.logits_processor import build_logits_processors
|
||||
from fastdeploy.model_executor.model_loader import get_model_loader
|
||||
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
|
||||
from fastdeploy.model_executor.models.interfaces_base import FdModelForPooling
|
||||
@@ -81,6 +80,7 @@ from fastdeploy.model_executor.pre_and_post_process import (
|
||||
)
|
||||
from fastdeploy.output.pooler import PoolerOutput
|
||||
from fastdeploy.spec_decode import MTPProposer, NgramProposer
|
||||
from fastdeploy.worker.input_batch import InputBatch, reorder_split_prefill_and_decode
|
||||
from fastdeploy.worker.model_runner_base import (
|
||||
DistributedOut,
|
||||
DistributedStatus,
|
||||
@@ -99,6 +99,7 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
local_rank: int,
|
||||
):
|
||||
super().__init__(fd_config=fd_config, device=device)
|
||||
self.MAX_INFER_SEED = 9223372036854775806
|
||||
self.enable_mm = self.model_config.enable_mm
|
||||
self.rank = rank
|
||||
self.local_rank = local_rank
|
||||
@@ -173,12 +174,12 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes
|
||||
self.cudagraph_only_prefill = self.graph_opt_config.cudagraph_only_prefill
|
||||
|
||||
# Initialize share inputs
|
||||
self._init_share_inputs(self.scheduler_config.max_num_seqs)
|
||||
# Initialize input batch
|
||||
self.share_inputs = InputBatch(self.fd_config)
|
||||
self.share_inputs.init_share_inputs()
|
||||
increment_value = (
|
||||
4 if not self.speculative_decoding else (self.speculative_config.num_speculative_tokens + 1) * 4
|
||||
)
|
||||
|
||||
self.infer_seed_increment = paddle.full(
|
||||
shape=[self.scheduler_config.max_num_seqs, 1],
|
||||
fill_value=increment_value,
|
||||
@@ -354,7 +355,7 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
if self.speculative_method == "ngram":
|
||||
self.proposer = NgramProposer(self.fd_config)
|
||||
elif self.speculative_method == "mtp":
|
||||
self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer
|
||||
self.share_inputs["seq_lens_this_time"] = self.share_inputs["seq_lens_this_time_buffer"]
|
||||
self.proposer = MTPProposer(
|
||||
self.fd_config,
|
||||
self.get_model(),
|
||||
@@ -398,8 +399,9 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
"""
|
||||
if not self.enable_mm:
|
||||
return
|
||||
|
||||
self.share_inputs["image_features"] = None
|
||||
self.share_inputs["image_features_list"] = [-1] * self.scheduler_config.max_num_seqs
|
||||
img_index = 0
|
||||
req_idx_img_index_map = {}
|
||||
multi_vision_inputs = {
|
||||
"images_lst": [],
|
||||
"grid_thw_lst": [],
|
||||
@@ -407,6 +409,8 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
"cu_seqlens": [0],
|
||||
"encoder_cache_info": [],
|
||||
"feature_position_list": [],
|
||||
"grid_thw_lst_batches": [],
|
||||
"feature_position_list_batches": [],
|
||||
}
|
||||
rope_3d_position_ids = {
|
||||
"position_ids_idx": [],
|
||||
@@ -426,7 +430,9 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
self.encoder_cache.pop(mm_hash, None)
|
||||
|
||||
position_ids = request.multimodal_inputs["position_ids"]
|
||||
rope_3d_position_ids["position_ids_idx"].append(request.idx)
|
||||
idx = self.share_inputs.get_index_by_batch_id(request.idx)
|
||||
rope_3d_position_ids["position_ids_idx"].append(idx)
|
||||
req_idx_img_index_map[idx] = -1
|
||||
rope_3d_position_ids["position_ids_lst"].append(position_ids)
|
||||
rope_3d_position_ids["position_ids_offset"].append(
|
||||
position_ids.shape[0] + rope_3d_position_ids["position_ids_offset"][-1]
|
||||
@@ -438,6 +444,8 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
rope_3d_position_ids["max_tokens_lst"].append(request.get("max_tokens", 2048))
|
||||
|
||||
if request.with_image:
|
||||
req_idx_img_index_map[idx] = img_index
|
||||
img_index = img_index + 1
|
||||
inputs = request.multimodal_inputs
|
||||
if self.encoder_cache is not None:
|
||||
if envs.FD_ENABLE_MAX_PREFILL:
|
||||
@@ -460,21 +468,24 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
f"request {request.request_id} start process encoder info, image_start_idx: {image_start_idx} "
|
||||
f"grid_thw_list: {grid_thw_list}, feature_positions: {feature_positions}, mm_hashes_list: {mm_hashes_list}"
|
||||
)
|
||||
encoder_cache_info_per_req = []
|
||||
grid_thw_lst_per_req = []
|
||||
for i, mm_hash in enumerate(mm_hashes_list):
|
||||
image_offset = np.prod(grid_thw_list[i])
|
||||
logger.debug(
|
||||
f"run idx {i} with mm_hash {mm_hash} image_offset: {image_offset} grid_thw: {grid_thw_list[i]}"
|
||||
)
|
||||
if mm_hash in self.encoder_cache:
|
||||
multi_vision_inputs["encoder_cache_info"].append((mm_hash, feature_positions[i], True))
|
||||
encoder_cache_info_per_req.append((mm_hash, feature_positions[i], True))
|
||||
continue
|
||||
|
||||
multi_vision_inputs["encoder_cache_info"].append((mm_hash, feature_positions[i], False))
|
||||
encoder_cache_info_per_req.append((mm_hash, feature_positions[i], False))
|
||||
if envs.FD_ENABLE_MAX_PREFILL:
|
||||
multi_vision_inputs["images_lst"].append(
|
||||
inputs["images"][image_start_idx : image_start_idx + image_offset].to(self.device)
|
||||
)
|
||||
multi_vision_inputs["grid_thw_lst"].append(paddle.to_tensor(grid_thw_list[i]))
|
||||
grid_thw_lst_per_req.append(paddle.to_tensor(grid_thw_list[i], dtype=paddle.int64))
|
||||
multi_vision_inputs["cu_seqlens"].append(vit_seqlen_list[i])
|
||||
multi_vision_inputs["vit_position_ids_lst"].append(vit_position_ids_list[i])
|
||||
else:
|
||||
@@ -487,7 +498,10 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
multi_vision_inputs["grid_thw_lst"].append(
|
||||
paddle.to_tensor(grid_thw_list[i], dtype=paddle.int64)
|
||||
)
|
||||
grid_thw_lst_per_req.append(paddle.to_tensor(grid_thw_list[i], dtype=paddle.int64))
|
||||
image_start_idx += image_offset
|
||||
multi_vision_inputs["grid_thw_lst_batches"].append(grid_thw_lst_per_req)
|
||||
multi_vision_inputs["encoder_cache_info"].append(encoder_cache_info_per_req)
|
||||
else:
|
||||
if envs.FD_ENABLE_MAX_PREFILL:
|
||||
multi_vision_inputs["images_lst"].append(
|
||||
@@ -496,6 +510,9 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
multi_vision_inputs["grid_thw_lst"].extend(
|
||||
paddle.to_tensor(inputs["grid_thw"][request.num_image_start : request.num_image_end])
|
||||
)
|
||||
multi_vision_inputs["grid_thw_lst_batches"].append(
|
||||
paddle.to_tensor(inputs["grid_thw"][request.num_image_start : request.num_image_end])
|
||||
)
|
||||
multi_vision_inputs["cu_seqlens"].extend(
|
||||
inputs["vit_seqlen"][request.num_image_start : request.num_image_end]
|
||||
)
|
||||
@@ -515,7 +532,12 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
dtype=paddle.int64,
|
||||
)
|
||||
)
|
||||
|
||||
multi_vision_inputs["grid_thw_lst_batches"].append(
|
||||
paddle.to_tensor(
|
||||
inputs["grid_thw"][request.num_image_start : request.num_image_end],
|
||||
dtype=paddle.int64,
|
||||
)
|
||||
)
|
||||
multi_vision_inputs["feature_position_list"].extend(
|
||||
self._get_feature_positions(
|
||||
mm_positions=inputs["mm_positions"][request.num_image_start : request.num_image_end],
|
||||
@@ -523,7 +545,13 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
prefill_end_index=request.prefill_end_index,
|
||||
)
|
||||
)
|
||||
|
||||
multi_vision_inputs["feature_position_list_batches"].append(
|
||||
self._get_feature_positions(
|
||||
mm_positions=inputs["mm_positions"][request.num_image_start : request.num_image_end],
|
||||
prefill_start_index=request.prefill_start_index,
|
||||
prefill_end_index=request.prefill_end_index,
|
||||
)
|
||||
)
|
||||
if self.encoder_cache is not None:
|
||||
if len(multi_vision_inputs["images_lst"]) > 0 or len(multi_vision_inputs["encoder_cache_info"]) > 0:
|
||||
image_features_output = None
|
||||
@@ -531,50 +559,56 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
image_features_output = self.extract_vision_features(multi_vision_inputs)
|
||||
|
||||
logger.debug(f"encoder_cache_info: {multi_vision_inputs['encoder_cache_info']}")
|
||||
merge_image_features, feature_idx, thw_idx = [], 0, 0
|
||||
for mm_hash, feature_position, use_cache in multi_vision_inputs["encoder_cache_info"]:
|
||||
if use_cache:
|
||||
assert mm_hash in self.encoder_cache, f"{mm_hash} not in encoder cache"
|
||||
mm_feature = self.encoder_cache[mm_hash].to(self.device)
|
||||
else:
|
||||
assert (
|
||||
image_features_output is not None
|
||||
), f"image_features_output is None, images_lst length: {len(multi_vision_inputs['images_lst'])}"
|
||||
grid_thw = multi_vision_inputs["grid_thw_lst"][thw_idx]
|
||||
mm_token_lenght = inputs["mm_num_token_func"](grid_thw=grid_thw)
|
||||
mm_feature = image_features_output[feature_idx : feature_idx + mm_token_lenght]
|
||||
feature_idx = 0
|
||||
image_features_list = []
|
||||
for index, encoder_cache_info in enumerate(multi_vision_inputs["encoder_cache_info"]):
|
||||
merge_image_features, thw_idx = [], 0
|
||||
for mm_hash, feature_position, use_cache in encoder_cache_info:
|
||||
if use_cache:
|
||||
assert mm_hash in self.encoder_cache, f"{mm_hash} not in encoder cache"
|
||||
mm_feature = self.encoder_cache[mm_hash].to(self.device)
|
||||
else:
|
||||
assert (
|
||||
image_features_output is not None
|
||||
), f"image_features_output is None, images_lst length: {len(multi_vision_inputs['images_lst'])}"
|
||||
grid_thw = multi_vision_inputs["grid_thw_lst_batches"][index][thw_idx]
|
||||
mm_token_lenght = inputs["mm_num_token_func"](grid_thw=grid_thw)
|
||||
mm_feature = image_features_output[feature_idx : feature_idx + mm_token_lenght]
|
||||
|
||||
# add feature to encoder cache
|
||||
self.encoder_cache[mm_hash] = mm_feature.detach().cpu()
|
||||
feature_idx += mm_token_lenght
|
||||
thw_idx += 1
|
||||
# add feature to encoder cache
|
||||
self.encoder_cache[mm_hash] = mm_feature.detach().cpu()
|
||||
feature_idx += mm_token_lenght
|
||||
thw_idx += 1
|
||||
|
||||
feature_start = feature_position.offset
|
||||
feature_end = feature_position.offset + feature_position.length
|
||||
merge_image_features.append(mm_feature[feature_start:feature_end])
|
||||
image_features_list.append(paddle.concat(merge_image_features, axis=0))
|
||||
for idx, index in req_idx_img_index_map.items():
|
||||
if index != -1:
|
||||
self.share_inputs["image_features_list"][idx] = image_features_list[index]
|
||||
elif len(multi_vision_inputs["images_lst"]) > 0:
|
||||
image_features_output = self.extract_vision_features(multi_vision_inputs)
|
||||
image_features_list = []
|
||||
feature_idx = 0
|
||||
for index, feature_position_item in enumerate(multi_vision_inputs["feature_position_list_batches"]):
|
||||
grid_thw_lst = multi_vision_inputs["grid_thw_lst_batches"][index]
|
||||
assert len(feature_position_item) == len(grid_thw_lst), f"{feature_position_item} != {grid_thw_lst}"
|
||||
merge_image_features, thw_idx = [], 0
|
||||
for feature_position in feature_position_item:
|
||||
grid_thw = grid_thw_lst[thw_idx]
|
||||
mm_token_lenght = inputs["mm_num_token_func"](grid_thw=grid_thw)
|
||||
mm_feature = image_features_output[feature_idx : feature_idx + mm_token_lenght]
|
||||
|
||||
feature_start = feature_position.offset
|
||||
feature_end = feature_position.offset + feature_position.length
|
||||
merge_image_features.append(mm_feature[feature_start:feature_end])
|
||||
|
||||
self.share_inputs["image_features"] = paddle.concat(merge_image_features, axis=0)
|
||||
logger.debug(
|
||||
f"merge_image_features length: {len(merge_image_features)}, features shape: {self.share_inputs['image_features'].shape}"
|
||||
)
|
||||
elif len(multi_vision_inputs["images_lst"]) > 0:
|
||||
assert len(multi_vision_inputs["feature_position_list"]) == len(
|
||||
multi_vision_inputs["grid_thw_lst"]
|
||||
), f"{multi_vision_inputs['feature_position_list']} != {multi_vision_inputs['grid_thw_lst']}"
|
||||
|
||||
merge_image_features, feature_idx, thw_idx = [], 0, 0
|
||||
image_features_output = self.extract_vision_features(multi_vision_inputs)
|
||||
for feature_position in multi_vision_inputs["feature_position_list"]:
|
||||
grid_thw = multi_vision_inputs["grid_thw_lst"][thw_idx]
|
||||
mm_token_lenght = inputs["mm_num_token_func"](grid_thw=grid_thw)
|
||||
mm_feature = image_features_output[feature_idx : feature_idx + mm_token_lenght]
|
||||
|
||||
feature_start = feature_position.offset
|
||||
feature_end = feature_position.offset + feature_position.length
|
||||
merge_image_features.append(mm_feature[feature_start:feature_end])
|
||||
feature_idx += mm_token_lenght
|
||||
thw_idx += 1
|
||||
self.share_inputs["image_features"] = paddle.concat(merge_image_features, axis=0)
|
||||
feature_idx += mm_token_lenght
|
||||
thw_idx += 1
|
||||
image_features_list.append(paddle.concat(merge_image_features, axis=0))
|
||||
for idx, index in req_idx_img_index_map.items():
|
||||
if index != -1:
|
||||
self.share_inputs["image_features_list"][idx] = image_features_list[index]
|
||||
|
||||
if len(rope_3d_position_ids["position_ids_idx"]) > 0:
|
||||
packed_position_ids = paddle.to_tensor(
|
||||
@@ -644,10 +678,12 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
has_decode_task = False
|
||||
|
||||
batch_pooling_params = []
|
||||
self.share_inputs["num_running_requests"] = num_running_requests
|
||||
self.share_inputs["running_requests_ids"] = range(num_running_requests)
|
||||
for i in range(req_len):
|
||||
request = req_dicts[i]
|
||||
# assert isinstance(request, Request)
|
||||
idx = request.idx
|
||||
idx = self.share_inputs.get_index_by_batch_id(request.idx)
|
||||
self.share_inputs["req_ids"][idx] = str(request.request_id)
|
||||
|
||||
if hasattr(request, "pooling_params") and request.pooling_params is not None:
|
||||
batch_pooling_params.append(request.pooling_params)
|
||||
@@ -715,7 +751,7 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
)
|
||||
self.share_inputs["stop_flags"][idx : idx + 1] = False
|
||||
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index
|
||||
self.seq_lens_this_time_buffer[idx : idx + 1] = length
|
||||
self.share_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = length
|
||||
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length
|
||||
self.exist_prefill_flag = True
|
||||
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = 0
|
||||
@@ -759,7 +795,7 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["preempted_idx"][idx : idx + 1, :] = 1
|
||||
self.share_inputs["block_tables"][idx : idx + 1, :] = -1
|
||||
self.share_inputs["stop_flags"][idx : idx + 1] = True
|
||||
self.seq_lens_this_time_buffer[idx : idx + 1] = 0
|
||||
self.share_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = 0
|
||||
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
|
||||
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0
|
||||
self.exist_prefill_flag = False
|
||||
@@ -833,7 +869,8 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
self._process_mm_features(req_dicts)
|
||||
if has_prefill_task or has_decode_task:
|
||||
set_stop(self.share_inputs["not_need_stop"], True)
|
||||
self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
|
||||
|
||||
self.share_inputs["seq_lens_this_time"] = self.share_inputs["seq_lens_this_time_buffer"][:num_running_requests]
|
||||
if self.speculative_method in ["mtp"]:
|
||||
self.proposer.insert_tasks_v1(req_dicts, num_running_requests)
|
||||
|
||||
@@ -852,7 +889,7 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
req_len = len(req_dicts)
|
||||
for i in range(req_len):
|
||||
request = req_dicts[i]
|
||||
idx = request.idx
|
||||
idx = self.share_inputs.get_index_by_batch_id(request.idx)
|
||||
length = len(request.prompt_token_ids)
|
||||
assert length > 0, "The prompt requested must not be empty."
|
||||
|
||||
@@ -875,7 +912,7 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["prompt_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids)
|
||||
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0
|
||||
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = length
|
||||
self.seq_lens_this_time_buffer[idx : idx + 1] = 1
|
||||
self.share_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = 1
|
||||
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = 0
|
||||
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = length
|
||||
self.share_inputs["prompt_lens"][idx : idx + 1] = length
|
||||
@@ -887,7 +924,7 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
request.draft_token_ids[0:num_prefill_send_token],
|
||||
dtype="int64",
|
||||
)
|
||||
self.seq_lens_this_time_buffer[idx : idx + 1] = num_prefill_send_token
|
||||
self.share_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = num_prefill_send_token
|
||||
if self.enable_mm:
|
||||
# Fix for V0 mode: Add position encoding for decode nodes in multimodal models
|
||||
# to prevent garbled output. Position_ids are transmitted from prefill nodes.
|
||||
@@ -936,7 +973,7 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
)
|
||||
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
|
||||
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
|
||||
self.seq_lens_this_time_buffer[idx : idx + 1] = token_chunk_size
|
||||
self.share_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = token_chunk_size
|
||||
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = token_chunk_size
|
||||
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size
|
||||
self.share_inputs["prompt_lens"][idx : idx + 1] = token_chunk_size
|
||||
@@ -954,7 +991,7 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
else:
|
||||
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
|
||||
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
|
||||
self.seq_lens_this_time_buffer[idx : idx + 1] = length
|
||||
self.share_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = length
|
||||
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = length
|
||||
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length
|
||||
self.share_inputs["prompt_lens"][idx : idx + 1] = length
|
||||
@@ -1053,7 +1090,7 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
|
||||
set_stop(self.share_inputs["not_need_stop"], True)
|
||||
|
||||
self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
|
||||
self.share_inputs["seq_lens_this_time"] = self.share_inputs["seq_lens_this_time_buffer"][:num_running_requests]
|
||||
|
||||
if self.speculative_method in ["mtp"]:
|
||||
self.proposer.insert_prefill_inputs(req_dicts, num_running_requests)
|
||||
@@ -1157,7 +1194,7 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["eos_token_id"][:] = np.array(
|
||||
[2] * self.model_config.eos_tokens_lens, dtype="int64"
|
||||
).reshape(-1, 1)
|
||||
self.seq_lens_this_time_buffer[idx : idx + 1] = input_length
|
||||
self.share_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = input_length
|
||||
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = input_length
|
||||
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = input_length
|
||||
self.exist_prefill_flag = True
|
||||
@@ -1174,250 +1211,15 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["block_tables"][idx : idx + 1, :block_num] = np.arange(
|
||||
idx * block_num, (idx + 1) * block_num, 1
|
||||
)
|
||||
self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer
|
||||
self.share_inputs["seq_lens_this_time"] = self.share_inputs["seq_lens_this_time_buffer"]
|
||||
|
||||
def _init_share_inputs(self, max_num_seqs: int):
|
||||
"""
|
||||
Initialize all share buffers for model inputs.
|
||||
"""
|
||||
self.MAX_INFER_SEED = 9223372036854775806
|
||||
self.share_inputs = {}
|
||||
|
||||
self.share_inputs["pre_ids"] = paddle.full(
|
||||
[max_num_seqs, self.model_config.max_model_len],
|
||||
-1,
|
||||
dtype="int64",
|
||||
)
|
||||
self.share_inputs["input_ids"] = paddle.full(
|
||||
[max_num_seqs, self.model_config.max_model_len],
|
||||
self.model_config.pad_token_id,
|
||||
dtype="int64",
|
||||
)
|
||||
self.share_inputs["prompt_ids"] = paddle.full(
|
||||
[max_num_seqs, self.model_config.max_model_len],
|
||||
self.model_config.pad_token_id,
|
||||
dtype="int64",
|
||||
)
|
||||
self.share_inputs["eos_token_id"] = paddle.full([self.model_config.eos_tokens_lens, 1], 0, dtype="int64")
|
||||
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32")
|
||||
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
||||
self.share_inputs["top_k_list"] = [0] * max_num_seqs
|
||||
self.share_inputs["min_p"] = paddle.full([max_num_seqs, 1], 0.0, dtype="float32")
|
||||
self.share_inputs["min_p_list"] = [0.0] * max_num_seqs
|
||||
self.share_inputs["temperature"] = paddle.full(
|
||||
[max_num_seqs, 1], self.model_config.temperature, dtype="float32"
|
||||
)
|
||||
self.share_inputs["penalty_score"] = paddle.full(
|
||||
[max_num_seqs, 1], self.model_config.penalty_score, dtype="float32"
|
||||
)
|
||||
self.share_inputs["frequency_score"] = paddle.full(
|
||||
[max_num_seqs, 1],
|
||||
self.model_config.frequency_score,
|
||||
dtype="float32",
|
||||
)
|
||||
self.share_inputs["presence_score"] = paddle.full(
|
||||
[max_num_seqs, 1], self.model_config.presence_score, dtype="float32"
|
||||
)
|
||||
self.share_inputs["temp_scaled_logprobs"] = paddle.full([max_num_seqs, 1], False, dtype="bool")
|
||||
self.share_inputs["top_p_normalized_logprobs"] = paddle.full([max_num_seqs, 1], False, dtype="bool")
|
||||
|
||||
self.share_inputs["min_dec_len"] = paddle.full([max_num_seqs, 1], self.model_config.min_length, dtype="int64")
|
||||
self.share_inputs["max_dec_len"] = paddle.full(
|
||||
[max_num_seqs, 1], self.model_config.max_model_len, dtype="int64"
|
||||
)
|
||||
self.seq_lens_this_time_buffer = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
if self.fd_config.parallel_config.enable_expert_parallel:
|
||||
self.share_inputs["seq_lens_this_time"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.share_inputs["seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.share_inputs["seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.share_inputs["step_seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.share_inputs["step_seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.share_inputs["prompt_lens"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
||||
self.share_inputs["step_idx"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
||||
self.share_inputs["not_need_stop"] = paddle.full([1], False, dtype="bool").cpu()
|
||||
self.share_inputs["not_need_stop_device"] = paddle.full([1], False, dtype="bool")
|
||||
self.share_inputs["sampled_token_ids"] = paddle.full([max_num_seqs, 1], -1, dtype="int64").cpu()
|
||||
self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool")
|
||||
|
||||
self.share_inputs["bad_tokens"] = paddle.full([max_num_seqs, self.model_config.vocab_size], -1, dtype="int64")
|
||||
self.share_inputs["bad_tokens_len"] = paddle.full([max_num_seqs], 1, dtype="int64")
|
||||
self.share_inputs["next_tokens"] = paddle.full([max_num_seqs, 1], -1, dtype="int64")
|
||||
self.share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool")
|
||||
self.share_inputs["is_chunk_step"] = paddle.full([max_num_seqs], False, dtype="bool").cpu()
|
||||
self.share_inputs["encoder_block_lens"] = paddle.full([max_num_seqs], 0, dtype="int32")
|
||||
self.share_inputs["step_block_list"] = paddle.full([max_num_seqs], -1, dtype="int32")
|
||||
self.share_inputs["step_lens"] = paddle.full([1], 0, dtype="int32")
|
||||
self.share_inputs["recover_block_list"] = paddle.full([max_num_seqs], -1, dtype="int32")
|
||||
self.share_inputs["recover_lens"] = paddle.full([1], 0, dtype="int32")
|
||||
self.share_inputs["need_block_list"] = paddle.full([max_num_seqs], -1, dtype="int32")
|
||||
self.share_inputs["need_block_len"] = paddle.full([1], 0, dtype="int32")
|
||||
self.share_inputs["used_list_len"] = paddle.full([max_num_seqs], 0, dtype="int32")
|
||||
self.share_inputs["infer_seed"] = paddle.full([max_num_seqs, 1], 0, dtype="int64").cpu()
|
||||
self.share_inputs["first_token_ids"] = paddle.full([max_num_seqs, 1], -1, dtype="int64")
|
||||
self.share_inputs["system_lens"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.share_inputs["system_ids"] = paddle.full([max_num_seqs, 1], -1, dtype="int32")
|
||||
|
||||
self.share_inputs["ids_remove_padding"] = paddle.full(
|
||||
[max_num_seqs * self.model_config.max_model_len],
|
||||
0,
|
||||
dtype="int64",
|
||||
)
|
||||
self.share_inputs["batch_id_per_token"] = paddle.full(
|
||||
[max_num_seqs * self.model_config.max_model_len, 1], 0, dtype="int32"
|
||||
)
|
||||
self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs + 1, 1], 0, dtype="int32")
|
||||
self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs + 1, 1], 0, dtype="int32")
|
||||
|
||||
# Declare AttentionBackend buffers
|
||||
self.share_inputs["decoder_batch_ids"] = None
|
||||
self.share_inputs["decoder_tile_ids_per_batch"] = None
|
||||
self.share_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
|
||||
self.share_inputs["decoder_num_blocks_device"] = None
|
||||
self.share_inputs["decoder_chunk_size_device"] = None
|
||||
self.share_inputs["max_len_tensor_cpu"] = None # CPU
|
||||
self.share_inputs["encoder_batch_ids"] = None
|
||||
self.share_inputs["encoder_tile_ids_per_batch"] = None
|
||||
self.share_inputs["encoder_num_blocks_x_cpu"] = None # CPU
|
||||
self.share_inputs["kv_batch_ids"] = None
|
||||
self.share_inputs["kv_tile_ids_per_batch"] = None
|
||||
self.share_inputs["kv_num_blocks_x_cpu"] = None # CPU
|
||||
|
||||
# Initialize thinking related buffers
|
||||
self.share_inputs["enable_thinking"] = paddle.full(shape=[max_num_seqs, 1], fill_value=True, dtype="bool")
|
||||
self.share_inputs["max_think_lens"] = paddle.full(shape=[max_num_seqs, 1], fill_value=-1, dtype="int32")
|
||||
self.share_inputs["limit_think_status"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
|
||||
|
||||
# NOTE(liuzichang): token after \n</think>\n\n must be <tool_call> 100973 or <response> 100975
|
||||
# It is a hard code to cover up model's performance
|
||||
# Detailed notes can be found in FastDeploy/custom_ops/gpu_ops/reasoning_phase_token_constraint.cu
|
||||
self.share_inputs["reasoning_status"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
|
||||
self.share_inputs["reasoning_allowed_tokens"] = paddle.to_tensor([100973, 100975], dtype="int64")
|
||||
|
||||
# Initialize rotary position embedding
|
||||
if not self.enable_mm:
|
||||
self.share_inputs["rope_emb"] = get_rope(
|
||||
rotary_dim=self.model_config.head_dim,
|
||||
position_ids=paddle.arange(self.model_config.max_model_len).reshape((1, -1)),
|
||||
base=self.model_config.rope_theta,
|
||||
model_config=self.model_config,
|
||||
partial_rotary_factor=self.model_config.partial_rotary_factor,
|
||||
)
|
||||
|
||||
# Set block tables
|
||||
pre_max_block_num = (
|
||||
self.model_config.max_model_len + self.cache_config.block_size - 1
|
||||
) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
|
||||
self.share_inputs["block_tables"] = paddle.full([max_num_seqs, pre_max_block_num], -1, dtype="int32")
|
||||
|
||||
# Initialize free list
|
||||
free_list = list(
|
||||
range(
|
||||
self.cache_config.total_block_num - 1,
|
||||
int(self.cache_config.total_block_num * self.cache_config.kv_cache_ratio) - 1,
|
||||
-1,
|
||||
)
|
||||
)
|
||||
self.free_list_len = len(free_list)
|
||||
self.share_inputs["free_list"] = paddle.to_tensor(free_list, dtype="int32")
|
||||
self.share_inputs["free_list_len"] = paddle.full([1], self.free_list_len, dtype="int32")
|
||||
|
||||
# Initialize stop seqs
|
||||
self.share_inputs["stop_seqs_len"] = paddle.full(
|
||||
[max_num_seqs, self.model_config.max_stop_seqs_num], 0, dtype="int32"
|
||||
)
|
||||
self.share_inputs["stop_seqs"] = paddle.full(
|
||||
[
|
||||
max_num_seqs,
|
||||
self.model_config.max_stop_seqs_num,
|
||||
self.model_config.stop_seqs_max_len,
|
||||
],
|
||||
-1,
|
||||
dtype="int64",
|
||||
)
|
||||
self.share_inputs["req_ids"] = [""] * max_num_seqs
|
||||
self.share_inputs["entropy_list"] = [[] for _ in range(max_num_seqs)]
|
||||
|
||||
if self.speculative_decoding:
|
||||
max_draft_token_num = self.speculative_config.num_speculative_tokens
|
||||
self.share_inputs["input_ids_cpu"] = paddle.full(
|
||||
shape=[max_num_seqs, self.model_config.max_model_len],
|
||||
fill_value=1,
|
||||
dtype="int64",
|
||||
).cpu()
|
||||
self.share_inputs["accept_tokens"] = paddle.full(
|
||||
shape=[max_num_seqs, max_draft_token_num + 1],
|
||||
fill_value=0,
|
||||
dtype="int64",
|
||||
)
|
||||
self.share_inputs["accept_num"] = paddle.full(shape=[max_num_seqs], fill_value=0, dtype="int32")
|
||||
self.share_inputs["draft_tokens"] = paddle.full(
|
||||
shape=[max_num_seqs, max_draft_token_num + 1],
|
||||
fill_value=0,
|
||||
dtype="int64",
|
||||
)
|
||||
|
||||
self.share_inputs["actual_draft_token_num"] = paddle.full(
|
||||
shape=[max_num_seqs],
|
||||
fill_value=max_draft_token_num,
|
||||
dtype="int32",
|
||||
)
|
||||
self.share_inputs["output_cum_offsets"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
|
||||
self.share_inputs["output_padding_offset"] = paddle.full(
|
||||
shape=[max_num_seqs * (max_draft_token_num + 1)],
|
||||
fill_value=0,
|
||||
dtype="int32",
|
||||
)
|
||||
# For V1_KVCACHE_SCHEDULER
|
||||
self.share_inputs["step_draft_tokens"] = paddle.full(
|
||||
shape=[max_num_seqs, max_draft_token_num + 1],
|
||||
fill_value=0,
|
||||
dtype="int64",
|
||||
)
|
||||
self.share_inputs["step_seq_lens_this_time"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
# For MTP Logprob
|
||||
self.share_inputs["draft_logits"] = paddle.full(
|
||||
[max_num_seqs * (self.speculative_config.num_speculative_tokens + 1), self.model_config.vocab_size],
|
||||
-1,
|
||||
dtype="float32",
|
||||
)
|
||||
self.share_inputs["cu_batch_token_offset"] = paddle.full(
|
||||
shape=[max_num_seqs + 1], fill_value=0, dtype="int32"
|
||||
)
|
||||
|
||||
if self.enable_mm:
|
||||
head_dim = self.model_config.head_dim
|
||||
if (
|
||||
"qwen" in self.model_config.model_type or "paddleocr" in self.model_config.model_type
|
||||
): # neox style = True
|
||||
rope_head_dim = head_dim
|
||||
else: # neox style = False
|
||||
rope_head_dim = head_dim // 2
|
||||
|
||||
self.share_inputs["rope_emb"] = paddle.full(
|
||||
shape=[
|
||||
max_num_seqs,
|
||||
2,
|
||||
1,
|
||||
self.model_config.max_model_len,
|
||||
1,
|
||||
rope_head_dim,
|
||||
],
|
||||
fill_value=0,
|
||||
dtype="float32",
|
||||
)
|
||||
self.share_inputs["image_features"] = None
|
||||
|
||||
# For logits processors
|
||||
self.share_inputs["logits_processors"] = build_logits_processors(self.fd_config)
|
||||
self.share_inputs["logits_processors_args"] = [{} for _ in range(max_num_seqs)]
|
||||
logger.info(f"Enabled logits processors: {self.share_inputs['logits_processors']}")
|
||||
|
||||
self.share_inputs["mask_rollback"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
|
||||
self.share_inputs["preempted_idx"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32").cpu()
|
||||
|
||||
def _prepare_inputs(self, is_dummy_or_profile_run=False) -> None:
|
||||
def _prepare_inputs(self, is_dummy_or_profile_run: bool = False) -> None:
|
||||
"""Prepare the model inputs"""
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
if self.enable_mm and self.share_inputs["image_features_list"] is not None:
|
||||
tensor_feats = [t for t in self.share_inputs["image_features_list"] if isinstance(t, paddle.Tensor)]
|
||||
if tensor_feats:
|
||||
self.share_inputs["image_features"] = paddle.concat(tensor_feats, axis=0)
|
||||
recover_decode_task(
|
||||
self.share_inputs["stop_flags"],
|
||||
self.share_inputs["seq_lens_this_time"],
|
||||
@@ -1473,7 +1275,6 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["seq_lens_encoder"],
|
||||
self.share_inputs["seq_lens_decoder"],
|
||||
)
|
||||
|
||||
self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding, False)
|
||||
# NOTE: (changwenbin) Initialized to max_num_seq '-1' before copying, marking illegal positions
|
||||
self.share_inputs["batch_id_per_token"][:] = -1
|
||||
@@ -1520,6 +1321,22 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
share_inputs=self.share_inputs,
|
||||
)
|
||||
|
||||
def _process_reorder(self) -> None:
|
||||
if self.attn_backends and getattr(self.attn_backends[0], "enable_ids_reorder", False):
|
||||
if (
|
||||
self.enable_mm
|
||||
and not envs.ENABLE_V1_KVCACHE_SCHEDULER
|
||||
and self.share_inputs["image_features_list"] is not None
|
||||
):
|
||||
logger.info("Multimodal models skip reordering if v1 scheduling is not enabled.")
|
||||
else:
|
||||
self.share_inputs.enable_pd_reorder = True
|
||||
self.share_inputs.condense()
|
||||
reorder_split_prefill_and_decode(input_batch=self.share_inputs)
|
||||
if self.speculative_decoding:
|
||||
if self.speculative_method == "mtp":
|
||||
self.proposer.reorder_inputs()
|
||||
|
||||
def load_model(self) -> None:
|
||||
"""load or download model"""
|
||||
logger.info(f"Starting to load model {self.model_config.architectures[0]}")
|
||||
@@ -1873,6 +1690,8 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
stop_seqs_len=self.share_inputs["stop_seqs_len"],
|
||||
min_tokens=self.share_inputs["min_dec_len"],
|
||||
prompt_lens=self.share_inputs["prompt_lens"],
|
||||
index_to_batch_id=self.share_inputs["index_to_batch_id"],
|
||||
enable_pd_reorder=getattr(self.share_inputs, "enable_pd_reorder", False),
|
||||
)
|
||||
|
||||
post_process(
|
||||
@@ -1976,6 +1795,8 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
min_tokens=self.share_inputs["min_dec_len"],
|
||||
prompt_lens=self.share_inputs["prompt_lens"],
|
||||
mask_rollback=self.share_inputs["mask_rollback"],
|
||||
index_to_batch_id=self.share_inputs["index_to_batch_id"],
|
||||
enable_pd_reorder=getattr(self.share_inputs, "enable_pd_reorder", False),
|
||||
)
|
||||
|
||||
post_process(
|
||||
@@ -2045,7 +1866,6 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
while True:
|
||||
# 1. Initialize forward meta and attention meta data
|
||||
self._prepare_inputs(is_dummy_or_profile_run=True)
|
||||
|
||||
# 2. Padding inputs for cuda graph
|
||||
self.forward_meta.step_use_cudagraph = in_capturing and self.forward_meta.step_use_cudagraph
|
||||
self.padding_cudagraph_inputs()
|
||||
@@ -2120,7 +1940,7 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
self.restore_chunked_prefill_request[task.request_id] = task
|
||||
|
||||
for id, task in list(self.restore_chunked_prefill_request.items()):
|
||||
idx = task.idx
|
||||
idx = self.share_inputs.get_index_by_batch_id(task.idx)
|
||||
logger.debug(f"{task.request_id} chunked prefill {task.chunk_idx}/{len(task.prefill_chunk_info)}")
|
||||
if not self.enable_mm:
|
||||
start_idx = sum(task.prefill_chunk_info[: task.chunk_idx])
|
||||
@@ -2316,8 +2136,9 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
|
||||
prefill_done_idxs = []
|
||||
for idx in range(0, num_running_requests):
|
||||
if self.share_inputs["step_idx"][idx] == 0:
|
||||
prefill_done_idxs.append(idx)
|
||||
batch_id = self.share_inputs.get_index_by_batch_id(idx)
|
||||
if self.share_inputs["step_idx"][batch_id] == 0:
|
||||
prefill_done_idxs.append(batch_id)
|
||||
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
if model_forward_batch is None:
|
||||
@@ -2329,8 +2150,9 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
# in chunk prefill
|
||||
if self.cache_config.enable_chunked_prefill:
|
||||
if hasattr(task, "prefill_end_index") and hasattr(task, "prompt_token_ids"):
|
||||
if len(task.prompt_token_ids) > task.prefill_end_index and task.idx in prefill_done_idxs:
|
||||
prefill_done_idxs.remove(task.idx)
|
||||
task_idx = self.share_inputs.get_index_by_batch_id(task.idx)
|
||||
if len(task.prompt_token_ids) > task.prefill_end_index and task_idx in prefill_done_idxs:
|
||||
prefill_done_idxs.remove(task_idx)
|
||||
|
||||
return prefill_done_idxs
|
||||
|
||||
@@ -2343,12 +2165,13 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
self.restore_chunked_prefill_request[task.request_id] = task
|
||||
|
||||
for id, task in list(self.restore_chunked_prefill_request.items()):
|
||||
task_idx = self.share_inputs.get_index_by_batch_id(task.idx)
|
||||
# unfinished, remove
|
||||
if task.chunk_idx < len(task.prefill_chunk_info) and task.idx in prefill_done_idxs:
|
||||
prefill_done_idxs.remove(task.idx)
|
||||
if task.chunk_idx < len(task.prefill_chunk_info) and task_idx in prefill_done_idxs:
|
||||
prefill_done_idxs.remove(task_idx)
|
||||
# finished, add
|
||||
if task.chunk_idx == len(task.prefill_chunk_info) and task.idx not in prefill_done_idxs:
|
||||
prefill_done_idxs.append(task.idx)
|
||||
if task.chunk_idx == len(task.prefill_chunk_info) and task_idx not in prefill_done_idxs:
|
||||
prefill_done_idxs.append(task_idx)
|
||||
|
||||
return prefill_done_idxs
|
||||
|
||||
@@ -2411,6 +2234,9 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
# 1. Prepare inputs of model and sampler.
|
||||
p_done_idxs = self._get_p_done_idxs_gd(model_forward_batch, num_running_requests)
|
||||
|
||||
# Reorder inputs to split prefill and decode tokens
|
||||
self._process_reorder()
|
||||
|
||||
self._prepare_inputs()
|
||||
self.sampler.pre_process(p_done_idxs)
|
||||
|
||||
@@ -2454,7 +2280,6 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
) -> None:
|
||||
|
||||
prompt_logprobs_list = self._get_prompt_logprobs_list(model_output)
|
||||
|
||||
if self.is_pooling_model:
|
||||
pooler_output = self._pool(model_output, num_running_requests)
|
||||
|
||||
@@ -2486,6 +2311,8 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
stop_seqs_len=self.share_inputs["stop_seqs_len"],
|
||||
min_tokens=self.share_inputs["min_dec_len"],
|
||||
prompt_lens=self.share_inputs["prompt_lens"],
|
||||
index_to_batch_id=self.share_inputs["index_to_batch_id"],
|
||||
enable_pd_reorder=getattr(self.share_inputs, "enable_pd_reorder", False),
|
||||
)
|
||||
|
||||
post_process(
|
||||
@@ -2611,6 +2438,8 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
prompt_lens=self.share_inputs["prompt_lens"],
|
||||
mask_rollback=self.share_inputs["mask_rollback"],
|
||||
prompt_logprobs_list=prompt_logprobs_list,
|
||||
index_to_batch_id=self.share_inputs["index_to_batch_id"],
|
||||
enable_pd_reorder=getattr(self.share_inputs, "enable_pd_reorder", False),
|
||||
)
|
||||
|
||||
if self.speculative_config.method in ["mtp"] and self.scheduler_config.splitwise_role == "prefill":
|
||||
@@ -2632,6 +2461,7 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
line_break_id=self.model_config.line_break_id,
|
||||
enable_entropy=self.enable_entropy and self.parallel_config.tensor_parallel_rank == 0,
|
||||
)
|
||||
|
||||
if self.guided_backend is not None and sampler_output is not None:
|
||||
self.sampler.post_process(sampler_output.sampled_token_ids)
|
||||
|
||||
@@ -3197,6 +3027,7 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
start_idx = request.prefill_start_index
|
||||
start_tok = start_idx + 1
|
||||
num_remaining_tokens = num_prompt_tokens - start_tok
|
||||
batch_id = self.share_inputs.get_index_by_batch_id(request.idx)
|
||||
if num_tokens <= num_remaining_tokens:
|
||||
# This is a chunk, more tokens remain.
|
||||
# In the == case, there are no more prompt logprobs to produce
|
||||
@@ -3207,13 +3038,13 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
# This is the last chunk of prompt tokens to return.
|
||||
num_logits = num_remaining_tokens
|
||||
completed_prefill_reqs.append(request)
|
||||
prompt_logprobs_list[request.idx] = logprobs_tensors
|
||||
prompt_logprobs_list[batch_id] = logprobs_tensors
|
||||
if num_logits <= 0:
|
||||
# This can happen for the final chunk if we prefilled exactly
|
||||
# (num_prompt_tokens - 1) tokens for this request in the prior
|
||||
# step. There are no more prompt logprobs to produce.
|
||||
continue
|
||||
offset = self.share_inputs["cu_seqlens_q"][request.idx]
|
||||
offset = self.share_inputs["cu_seqlens_q"][batch_id]
|
||||
prompt_hidden_states = hidden_states[offset : offset + num_logits]
|
||||
logits = self.model.compute_logits(prompt_hidden_states)
|
||||
prompt_token_ids = request.prompt_token_ids[start_tok : start_tok + num_logits]
|
||||
|
||||
Reference in New Issue
Block a user