[MetaxGPU] adapt to the latest fastdeploy on metax gpu (#3492)

This commit is contained in:
Kane2011
2025-08-25 17:44:20 +08:00
committed by GitHub
parent c13c904971
commit 2ae7ab28d2
8 changed files with 338 additions and 115 deletions
+88 -51
View File
@@ -23,8 +23,11 @@ import paddle
from paddle import nn
from paddleformers.utils.log import logger
from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request, RequestType
from fastdeploy.input.mm_processor import DataProcessor
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.graph_optimization.utils import (
profile_run_guard,
sot_warmup_guard,
@@ -41,6 +44,7 @@ from fastdeploy.model_executor.layers.rotary_embedding import get_rope, get_rope
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.sampler import Sampler, SpeculativeSampler
from fastdeploy.model_executor.model_loader import get_model_loader
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
from fastdeploy.model_executor.ops.gpu import (
recover_decode_task,
set_value_by_flags_and_idx,
@@ -52,15 +56,7 @@ from fastdeploy.model_executor.pre_and_post_process import (
rebuild_padding,
step_cuda,
)
from fastdeploy.platforms import current_platform
if not current_platform.is_dcu():
from fastdeploy.spec_decode import MTPProposer, NgramProposer
from fastdeploy import envs
from fastdeploy.input.mm_processor import DataProcessor
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
from fastdeploy.spec_decode import MTPProposer, NgramProposer
from fastdeploy.worker.model_runner_base import ModelRunnerBase
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput
@@ -130,7 +126,7 @@ class MetaxModelRunner(ModelRunnerBase):
shape=[self.parallel_config.max_num_seqs, 1],
fill_value=4,
dtype="int64",
)
).cpu()
self.restore_chunked_prefill_request = dict()
# Initialize attention Backend
@@ -164,6 +160,7 @@ class MetaxModelRunner(ModelRunnerBase):
if self.speculative_method == "ngram":
self.proposer = NgramProposer(self.fd_config)
elif self.speculative_method == "mtp":
self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer
self.proposer = MTPProposer(
self.fd_config,
self.get_model(),
@@ -193,21 +190,23 @@ class MetaxModelRunner(ModelRunnerBase):
return self.guided_backend.get_logits_processor(schemata_key=schemata_key), schemata_key
def insert_tasks_v1(self, req_dicts: List[Request]):
def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = None):
"""
Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1
req_dict: A list of Request dict
num_running_requests: batch_size
"""
# NOTE(luotingdan): Lazy initialize kv cache
# Lazy initialize kv cache
if "caches" not in self.share_inputs:
self.initialize_kv_cache()
req_len = len(req_dicts)
has_prefill_task = False
has_decode_task = False
for i in range(req_len):
request = req_dicts[i]
idx = request.idx
if request.task_type.value == RequestType.PREFILL.value: # prefill task
logger.debug(f"Handle prefill request {request} at idx {idx}")
prefill_start_index = request.prefill_start_index
prefill_end_index = request.prefill_end_index
length = prefill_end_index - prefill_start_index
@@ -253,6 +252,11 @@ class MetaxModelRunner(ModelRunnerBase):
)
input_ids = request.prompt_token_ids + request.output_token_ids
logger.debug(
f"Handle prefill request {request} at idx {idx}, "
f"{prefill_start_index=}, {prefill_end_index=}, "
f"need_prefilled_token_num={len(input_ids)}"
)
self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array(
input_ids[prefill_start_index:prefill_end_index]
)
@@ -264,7 +268,7 @@ class MetaxModelRunner(ModelRunnerBase):
)
self.share_inputs["stop_flags"][idx : idx + 1] = False
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length
self.seq_lens_this_time_buffer[idx : idx + 1] = length
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = 0
self.share_inputs["prompt_lens"][idx : idx + 1] = len(input_ids)
@@ -281,22 +285,27 @@ class MetaxModelRunner(ModelRunnerBase):
self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
request.block_tables, dtype="int32"
)
if self.share_inputs["is_block_step"][idx]: # has tasks to continue to decode
has_decode_task = True
continue
else: # preempted task
logger.debug(f"Handle preempted request {request} at idx {idx}")
self.share_inputs["block_tables"][idx : idx + 1, :] = -1
self.share_inputs["stop_flags"][idx : idx + 1] = True
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = 0
self.seq_lens_this_time_buffer[idx : idx + 1] = 0
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0
self.share_inputs["is_block_step"][idx : idx + 1] = False
continue
if len(request.eos_token_ids) < self.parallel_config.eos_tokens_lens:
request.eos_token_ids.append(request.eos_token_ids[0])
assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens
self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1)
self.share_inputs["top_p"][idx : idx + 1] = request.get("top_p", 0.7)
self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0)
self.share_inputs["top_k_list"][idx] = request.get("top_k", 0)
self.share_inputs["min_p"][idx : idx + 1] = request.get("min_p", 0.0)
self.share_inputs["min_p_list"][idx] = request.get("min_p", 0.0)
self.share_inputs["temperature"][idx : idx + 1] = request.get("temperature", 0.95)
self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0)
self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0)
@@ -326,12 +335,15 @@ class MetaxModelRunner(ModelRunnerBase):
else:
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0
if has_prefill_task:
if has_prefill_task or has_decode_task:
self.share_inputs["not_need_stop"][0] = True
self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
def insert_prefill_inputs(self, req_dicts: List[Request]):
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int = None):
"""
Process inputs for prefill tasks and insert it to share_inputs buffer
req_dict: A list of Request dict
num_running_requests: batch_size
TODO(gongshaotian): Refactor this func
"""
@@ -365,7 +377,7 @@ class MetaxModelRunner(ModelRunnerBase):
self.share_inputs["prompt_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids)
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = length
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = 1
self.seq_lens_this_time_buffer[idx : idx + 1] = 1
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = 0
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = length
self.share_inputs["prompt_lens"][idx : idx + 1] = length
@@ -377,7 +389,7 @@ class MetaxModelRunner(ModelRunnerBase):
request.draft_token_ids[0:num_prefill_send_token],
dtype="int64",
)
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = num_prefill_send_token
self.seq_lens_this_time_buffer[idx : idx + 1] = num_prefill_send_token
else:
self.share_inputs["pre_ids"][idx : idx + 1] = -1
self.share_inputs["step_idx"][idx : idx + 1] = 0
@@ -412,7 +424,7 @@ class MetaxModelRunner(ModelRunnerBase):
)
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = token_chunk_size
self.seq_lens_this_time_buffer[idx : idx + 1] = token_chunk_size
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = token_chunk_size
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size
self.share_inputs["prompt_lens"][idx : idx + 1] = token_chunk_size
@@ -430,7 +442,7 @@ class MetaxModelRunner(ModelRunnerBase):
else:
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length
self.seq_lens_this_time_buffer[idx : idx + 1] = length
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = length
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length
self.share_inputs["prompt_lens"][idx : idx + 1] = length
@@ -453,12 +465,13 @@ class MetaxModelRunner(ModelRunnerBase):
else:
return default_value
if len(request.eos_token_ids) < self.parallel_config.eos_tokens_lens:
request.eos_token_ids.append(request.eos_token_ids[0])
assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens
self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1)
self.share_inputs["top_p"][idx : idx + 1] = get_attr_from_request(request, "top_p", 0.7)
self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0)
self.share_inputs["top_k_list"][idx] = request.get("top_k", 0)
self.share_inputs["min_p"][idx : idx + 1] = request.get("min_p", 0.0)
self.share_inputs["min_p_list"][idx] = request.get("min_p", 0.0)
self.share_inputs["temperature"][idx : idx + 1] = get_attr_from_request(request, "temperature", 0.95)
self.share_inputs["penalty_score"][idx : idx + 1] = get_attr_from_request(
@@ -489,13 +502,15 @@ class MetaxModelRunner(ModelRunnerBase):
request.block_tables, dtype="int32"
)
if request.get("bad_words_token_ids") is not None:
if request.get("bad_words_token_ids") is not None and len(request.get("bad_words_token_ids")) > 0:
bad_words_len = len(request.get("bad_words_token_ids"))
if bad_words_len > 0:
self.share_inputs["bad_tokens_len"][idx : idx + 1] = bad_words_len
self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array(
request.get("bad_words_token_ids"), dtype="int64"
)
self.share_inputs["bad_tokens_len"][idx : idx + 1] = bad_words_len
self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array(
request.get("bad_words_token_ids"), dtype="int64"
)
else:
self.share_inputs["bad_tokens_len"][idx : idx + 1] = 1
self.share_inputs["bad_tokens"][idx : idx + 1, :] = np.array([-1], dtype="int64")
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
stop_seqs_num = len(request.get("stop_seqs_len"))
@@ -514,8 +529,10 @@ class MetaxModelRunner(ModelRunnerBase):
self.share_inputs["not_need_stop"][0] = True
self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
if self.speculative_method in ["mtp"]:
self.proposer.insert_prefill_inputs(req_dicts)
self.proposer.insert_prefill_inputs(req_dicts, num_running_requests)
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int):
"""Set dummy prefill inputs to share_inputs"""
@@ -525,6 +542,12 @@ class MetaxModelRunner(ModelRunnerBase):
num_tokens // batch_size,
self.parallel_config.max_model_len - max_dec_len,
)
# When the full length is too large, DeepEP's buffer size will not be enough to cause the result to appear nan.
# Figure out the accurate buffer size of DeepEP.
if self.fd_config.parallel_config.enable_expert_parallel:
full_length = min(full_length, 32)
input_length = int(full_length * self.cache_config.kv_cache_ratio)
block_num = (
input_length + self.cache_config.block_size - 1
@@ -534,8 +557,10 @@ class MetaxModelRunner(ModelRunnerBase):
idx = i
self.share_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
self.share_inputs["prompt_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
self.share_inputs["eos_token_id"][:] = np.array([2], dtype="int64").reshape(-1, 1)
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = input_length
self.share_inputs["eos_token_id"][:] = np.array(
[2] * self.model_config.eos_tokens_lens, dtype="int64"
).reshape(-1, 1)
self.seq_lens_this_time_buffer[idx : idx + 1] = input_length
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = input_length
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = input_length
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
@@ -553,6 +578,7 @@ class MetaxModelRunner(ModelRunnerBase):
self.share_inputs["block_tables"][idx : idx + 1, :block_num] = np.arange(
idx * block_num, (idx + 1) * block_num, 1
)
self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer
def _init_share_inputs(self, max_num_seqs: int):
"""
@@ -568,18 +594,20 @@ class MetaxModelRunner(ModelRunnerBase):
)
self.share_inputs["input_ids"] = paddle.full(
[max_num_seqs, self.parallel_config.max_model_len],
self.parallel_config.pad_token_id,
self.model_config.pad_token_id,
dtype="int64",
)
self.share_inputs["prompt_ids"] = paddle.full(
[max_num_seqs, self.parallel_config.max_model_len],
self.parallel_config.pad_token_id,
self.model_config.pad_token_id,
dtype="int64",
)
self.share_inputs["eos_token_id"] = paddle.full([self.parallel_config.eos_tokens_lens, 1], 0, dtype="int64")
self.share_inputs["eos_token_id"] = paddle.full([self.model_config.eos_tokens_lens, 1], 0, dtype="int64")
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32")
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
self.share_inputs["top_k_list"] = [0] * max_num_seqs
self.share_inputs["min_p"] = paddle.full([max_num_seqs, 1], 0.0, dtype="float32")
self.share_inputs["min_p_list"] = [0.0] * max_num_seqs
self.share_inputs["temperature"] = paddle.full(
[max_num_seqs, 1], self.model_config.temperature, dtype="float32"
)
@@ -603,7 +631,9 @@ class MetaxModelRunner(ModelRunnerBase):
self.share_inputs["max_length"] = paddle.full(
[max_num_seqs, 1], self.model_config.max_model_len, dtype="int64"
)
self.share_inputs["seq_lens_this_time"] = paddle.full(max_num_seqs, 0, dtype="int32")
self.seq_lens_this_time_buffer = paddle.full([max_num_seqs, 1], 0, dtype="int32")
if self.fd_config.parallel_config.enable_expert_parallel:
self.share_inputs["seq_lens_this_time"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["step_seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
@@ -626,7 +656,7 @@ class MetaxModelRunner(ModelRunnerBase):
self.share_inputs["need_block_list"] = paddle.full([max_num_seqs], -1, dtype="int32")
self.share_inputs["need_block_len"] = paddle.full([1], 0, dtype="int32")
self.share_inputs["used_list_len"] = paddle.full([max_num_seqs], 0, dtype="int32")
self.share_inputs["infer_seed"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
self.share_inputs["infer_seed"] = paddle.full([max_num_seqs, 1], 0, dtype="int64").cpu()
self.share_inputs["first_token_ids"] = paddle.full([max_num_seqs, 1], -1, dtype="int64")
self.share_inputs["ori_seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["system_lens"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
@@ -637,10 +667,11 @@ class MetaxModelRunner(ModelRunnerBase):
0,
dtype="int64",
)
self.share_inputs["cum_offsets"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["batch_id_per_token"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["batch_id_per_token"] = paddle.full(
[max_num_seqs * self.parallel_config.max_model_len, 1], 0, dtype="int32"
)
self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs + 1, 1], 0, dtype="int32")
self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs + 1, 1], 0, dtype="int32")
# Declare AttentionBackend buffers
self.share_inputs["decoder_batch_ids"] = None
@@ -758,7 +789,6 @@ class MetaxModelRunner(ModelRunnerBase):
# Remove padding
(
ids_remove_padding,
cum_offsets,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
@@ -774,7 +804,6 @@ class MetaxModelRunner(ModelRunnerBase):
)
self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding, False)
self.share_inputs["cum_offsets"].copy_(cum_offsets, False)
self.share_inputs["batch_id_per_token"].copy_(batch_id_per_token, False)
self.share_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False)
self.share_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False)
@@ -795,7 +824,10 @@ class MetaxModelRunner(ModelRunnerBase):
temperature=self.share_inputs["temperature"],
top_p=self.share_inputs["top_p"],
top_k=self.share_inputs["top_k"],
top_k_list=self.share_inputs["top_k_list"],
min_p=self.share_inputs["min_p"],
min_p_list=self.share_inputs["min_p_list"],
seed=self.share_inputs["infer_seed"],
step_idx=self.share_inputs["step_idx"],
pre_token_ids=self.share_inputs["pre_ids"],
prompt_ids=self.share_inputs["prompt_ids"],
@@ -933,7 +965,7 @@ class MetaxModelRunner(ModelRunnerBase):
self.share_inputs["caches"] = list(cache_kvs.values())
for value in cache_kvs.values():
del value
paddle.device.cuda.empty_cache()
# paddle.device.empty_cache()
def initialize_attn_backend(self) -> None:
"""
@@ -1023,7 +1055,7 @@ class MetaxModelRunner(ModelRunnerBase):
hidden_states = rebuild_padding(
model_output,
self.share_inputs["cum_offsets"],
self.share_inputs["cu_seqlens_q"],
self.share_inputs["seq_lens_this_time"],
self.share_inputs["seq_lens_decoder"],
self.share_inputs["seq_lens_encoder"],
@@ -1247,6 +1279,7 @@ class MetaxModelRunner(ModelRunnerBase):
def execute_model(
self,
model_forward_batch: Optional[List[Request]] = None,
num_running_requests: int = None,
) -> Optional[ModelRunnerOutput]:
"""
The Entrance of model execute.
@@ -1255,6 +1288,7 @@ class MetaxModelRunner(ModelRunnerBase):
class at the server level, which is too granular for ModelRunner.
We plan to replace it with 'ModelForwardBatch'.
intermediate_tensors:
num_running_requests: batch_size
"""
# 1. Prepare inputs of model and sampler.
skip_idx_list = self._get_skip_idx(model_forward_batch)
@@ -1286,7 +1320,7 @@ class MetaxModelRunner(ModelRunnerBase):
)
hidden_states = rebuild_padding(
model_output,
self.share_inputs["cum_offsets"],
self.share_inputs["cu_seqlens_q"],
self.share_inputs["seq_lens_this_time"],
self.share_inputs["seq_lens_decoder"],
self.share_inputs["seq_lens_encoder"],
@@ -1356,8 +1390,8 @@ class MetaxModelRunner(ModelRunnerBase):
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None),
think_end_id=(self.model_config.think_end_id if self.enable_mm else -1),
need_think_end=(self.share_inputs["need_think_end"] if self.enable_mm else None),
reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None),
need_think_end=(self.share_inputs["need_think_end"][:num_running_requests] if self.enable_mm else None),
reasoning_index=(self.share_inputs["reasoning_index"][:num_running_requests] if self.enable_mm else None),
stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"],
)
@@ -1397,6 +1431,9 @@ class MetaxModelRunner(ModelRunnerBase):
self._update_chunked_prefill(model_forward_batch)
self._add_cache(model_forward_batch)
self.seq_lens_this_time_buffer[:num_running_requests].copy_(
self.share_inputs["seq_lens_this_time"][:num_running_requests], False
)
return None
def _add_cache(self, model_forward_batch) -> None:
@@ -1528,7 +1565,7 @@ class MetaxModelRunner(ModelRunnerBase):
""" " Dynamic model loader use to clear parameters use for RL"""
self.dynamic_weight_manager.clear_parameters(pid)
self.clear_cache()
paddle.device.cuda.empty_cache()
# paddle.device.empty_cache()
self.dynamic_weight_manager._log_memory("dynamic weight manager clear all memory")
def update_parameters(self, pid):