[XPU] xpu support ep4tp4 (#5773)

* [XPU] xpu support ep4tp4

* Add commands to check multiprocessing and fastdeploy processes

---------

Co-authored-by: ddchenhao66 <dhaochen163.com>
Co-authored-by: Jiaxin Sui <95567040+plusNew001@users.noreply.github.com>
This commit is contained in:
ddchenhao66
2025-12-29 11:27:01 +08:00
committed by GitHub
parent 91a2b13676
commit 56a9ecccb2
3 changed files with 469 additions and 103 deletions
+116 -103
View File
@@ -18,6 +18,7 @@ import os
import queue
import random
import time
from contextlib import contextmanager
from threading import Thread
from typing import List, Optional
@@ -67,6 +68,21 @@ from fastdeploy.worker.output import LogprobsTensors, ModelOutputData, ModelRunn
logger = get_logger("xpu_model_runner", "xpu_model_runner.log")
@contextmanager
def kv_signal_sender_context_manager(pd_disaggregation_mode):
sender = None
try:
sender = (
create_kv_signal_sender()
if pd_disaggregation_mode == "per_chunk" or pd_disaggregation_mode == "per_query"
else None
)
yield sender
finally:
if sender is not None:
destroy_kv_signal_sender(sender)
class XPUModelRunner(ModelRunnerBase):
""" """
@@ -1359,115 +1375,112 @@ class XPUModelRunner(ModelRunnerBase):
"""
# 0. set debug level
# self._set_debug_level(0x1, model_forward_batch, is_dummy_run)
if self.pd_disaggregation_mode == "per_chunk" or self.pd_disaggregation_mode == "per_query":
self.kv_signal_sender = create_kv_signal_sender()
# 1. Prepare inputs of model and decoder.
self._prepare_inputs(is_dummy_run=is_dummy_run)
# NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state.
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
# when there is data on other runner, the current runner is required to execute part of the model.
if not self.not_need_stop() and not is_dummy_run:
self._execute_empty_input(self.forward_meta)
return None
with kv_signal_sender_context_manager(self.pd_disaggregation_mode) as sender:
self.kv_signal_sender = sender
# 1. Prepare inputs of model and decoder.
self._prepare_inputs(is_dummy_run=is_dummy_run)
# NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state.
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
# when there is data on other runner, the current runner is required to execute part of the model.
if not self.not_need_stop() and not is_dummy_run:
self._execute_empty_input(self.forward_meta)
return None
# 2. Padding inputs for cuda grph
# 2. Padding inputs for cuda grph
# 3. Execute model
if self.enable_mm:
model_output = self.model(
self.share_inputs["ids_remove_padding"], self.share_inputs["image_features"], self.forward_meta
)
else:
model_output = self.model(
ids_remove_padding=self.share_inputs["ids_remove_padding"],
forward_meta=self.forward_meta,
# 3. Execute model
if self.enable_mm:
model_output = self.model(
self.share_inputs["ids_remove_padding"], self.share_inputs["image_features"], self.forward_meta
)
else:
model_output = self.model(
ids_remove_padding=self.share_inputs["ids_remove_padding"],
forward_meta=self.forward_meta,
)
hidden_states = xpu_process_output(
model_output, self.share_inputs["cum_offsets"], self.forward_meta, self.share_inputs
)
# 4. Compute logits, Sample
logits = self.model.compute_logits(hidden_states)
sampler_output = None
if not self.speculative_decoding:
sampler_output = self.sampler(logits, self.sampling_metadata)
else:
self.sampler(
logits,
self.sampling_metadata,
self.model_config.max_model_len,
self.share_inputs,
)
hidden_states = xpu_process_output(
model_output, self.share_inputs["cum_offsets"], self.forward_meta, self.share_inputs
)
# 4. Compute logits, Sample
logits = self.model.compute_logits(hidden_states)
sampler_output = None
if not self.speculative_decoding:
sampler_output = self.sampler(logits, self.sampling_metadata)
else:
self.sampler(
logits,
self.sampling_metadata,
self.model_config.max_model_len,
# 5. Speculative decode
# 6. Post Process
prompt_logprobs_list = None
if not self.speculative_decoding:
prompt_logprobs_list = self._get_prompt_logprobs_list(model_output)
model_output_data = ModelOutputData(
next_tokens=self.share_inputs["next_tokens"],
stop_flags=self.share_inputs["stop_flags"],
step_idx=self.share_inputs["step_idx"],
max_dec_len=self.share_inputs["max_dec_len"],
pre_ids=self.share_inputs["pre_ids"],
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
eos_token_id=self.share_inputs["eos_token_id"],
not_need_stop=self.share_inputs["not_need_stop"],
input_ids=self.share_inputs["input_ids"],
stop_nums=self.share_inputs["stop_nums"],
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
is_block_step=self.share_inputs["is_block_step"],
# 投机解码
full_hidden_states=model_output if self.speculative_decoding else None,
msg_queue_id=self.parallel_config.msg_queue_id,
mp_rank=self.local_rank,
use_ep=self.parallel_config.use_ep,
draft_tokens=(self.share_inputs["draft_tokens"] if self.speculative_decoding else None),
actual_draft_token_num=(
self.share_inputs["actual_draft_token_num"] if self.speculative_decoding else None
),
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"],
min_tokens=self.share_inputs["min_dec_len"],
prompt_logprobs_list=prompt_logprobs_list,
)
if self.speculative_decoding:
# base model post process
xpu_post_process_specualate(model_output_data, False, is_dummy_run)
else:
xpu_post_process_normal(
sampler_output=sampler_output,
model_output=model_output_data,
share_inputs=self.share_inputs,
block_size=self.cache_config.block_size,
skip_save_output=is_dummy_run,
async_output_queue=self.async_output_queue,
think_end_id=self.model_config.think_end_id,
line_break_id=self.model_config.line_break_id,
)
# draft model propose
if self.speculative_method == "mtp":
self.proposer.run(full_hidden_states=model_output)
# 7. Updata 'infer_seed' and step_paddle()
self.share_inputs["infer_seed"].add_(self.infer_seed_increment)
self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED
step_xpu(
self.share_inputs,
self.cache_config.block_size,
self.cache_config.enc_dec_block_num,
self.speculative_decoding,
self.speculative_config.num_speculative_tokens,
)
# 5. Speculative decode
# 6. Post Process
prompt_logprobs_list = None
if not self.speculative_decoding:
prompt_logprobs_list = self._get_prompt_logprobs_list(model_output)
model_output_data = ModelOutputData(
next_tokens=self.share_inputs["next_tokens"],
stop_flags=self.share_inputs["stop_flags"],
step_idx=self.share_inputs["step_idx"],
max_dec_len=self.share_inputs["max_dec_len"],
pre_ids=self.share_inputs["pre_ids"],
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
eos_token_id=self.share_inputs["eos_token_id"],
not_need_stop=self.share_inputs["not_need_stop"],
input_ids=self.share_inputs["input_ids"],
stop_nums=self.share_inputs["stop_nums"],
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
is_block_step=self.share_inputs["is_block_step"],
# 投机解码
full_hidden_states=model_output if self.speculative_decoding else None,
msg_queue_id=self.parallel_config.msg_queue_id,
mp_rank=self.local_rank,
use_ep=self.parallel_config.use_ep,
draft_tokens=(self.share_inputs["draft_tokens"] if self.speculative_decoding else None),
actual_draft_token_num=(
self.share_inputs["actual_draft_token_num"] if self.speculative_decoding else None
),
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"],
min_tokens=self.share_inputs["min_dec_len"],
prompt_logprobs_list=prompt_logprobs_list,
)
if self.speculative_decoding:
# base model post process
xpu_post_process_specualate(model_output_data, False, is_dummy_run)
else:
xpu_post_process_normal(
sampler_output=sampler_output,
model_output=model_output_data,
share_inputs=self.share_inputs,
block_size=self.cache_config.block_size,
skip_save_output=is_dummy_run,
async_output_queue=self.async_output_queue,
think_end_id=self.model_config.think_end_id,
line_break_id=self.model_config.line_break_id,
)
# draft model propose
if self.speculative_method == "mtp":
self.proposer.run(full_hidden_states=model_output)
# 7. Updata 'infer_seed' and step_paddle()
self.share_inputs["infer_seed"].add_(self.infer_seed_increment)
self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED
step_xpu(
self.share_inputs,
self.cache_config.block_size,
self.cache_config.enc_dec_block_num,
self.speculative_decoding,
self.speculative_config.num_speculative_tokens,
)
if self.pd_disaggregation_mode == "per_chunk" or self.pd_disaggregation_mode == "per_query":
destroy_kv_signal_sender(self.kv_signal_sender)
return None
def _execute_empty_input(self, forward_meta) -> None: