[XPU][Graph Optimization] XPU Support CUDAGraph (#6152)

* support cuda graph
This commit is contained in:
yinwei
2026-01-22 14:41:56 +08:00
committed by GitHub
parent 82057cb71f
commit 1e3c35496c
6 changed files with 143 additions and 16 deletions
+5 -3
View File
@@ -900,7 +900,7 @@ class GraphOptimizationConfig:
"""
self.sot_warmup_sizes: list[int] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 32, 64, 128]
""" Number of warmup runs for SOT warmup. """
self.use_cudagraph: bool = True
self.use_cudagraph: bool = False if paddle.is_compiled_with_xpu() else True
"""Sizes to capture cudagraph.
- None (default): capture sizes are inferred from llm config.
- list[int]: capture sizes are specified as given."""
@@ -1850,9 +1850,11 @@ class FDConfig:
"Static Graph does not support to be started together with RL Training, and automatically switch to dynamic graph!"
)
if not current_platform.is_cuda() and not current_platform.is_maca():
if not current_platform.is_cuda() and not current_platform.is_maca() and not current_platform.is_xpu():
self.graph_opt_config.use_cudagraph = False
logger.info("CUDAGraph currently only support on GPU!")
logger.info(
"Current Platform can not support CUDAGraph, CUDAGraph currently only support on GPU/XPU/Metax GPU !"
)
# adjust speculative config
if self.speculative_config is not None and self.speculative_config.method == "mtp":
+34 -1
View File
@@ -15,7 +15,7 @@
"""
import logging
from dataclasses import dataclass
from dataclasses import dataclass, fields
from enum import IntEnum, auto
from typing import TYPE_CHECKING, Dict, Optional
@@ -260,6 +260,39 @@ class XPUForwardMeta(ForwardMeta):
# for pd_disaggregation
kv_signal_sender: Optional[paddle.Tensor] = None
def copy_from(self, other: "XPUForwardMeta", skip_keys: Optional[list] = None):
"""
Synchronize attributes from another XPUForwardMeta object
"""
if skip_keys is None:
skip_keys = []
# Use fields(self) to ensure all fields of the current class are obtained
for field in fields(self):
name = field.name
if name in skip_keys:
continue
if not hasattr(other, name):
continue
src_val = getattr(other, name)
dst_val = getattr(self, name)
# Synchronization logic
if isinstance(src_val, paddle.Tensor):
if isinstance(dst_val, paddle.Tensor):
# Only perform in-place copy_ when the destination is also a Tensor and already exists
dst_val.copy_(src_val, False)
else:
# Directly assign the reference if the destination is None (in-place copy to None is not feasible)
setattr(self, name, src_val)
else:
# Handle non-Tensor attributes (str, int, bool, etc.)
setattr(self, name, src_val)
return self
@dataclass
class DCUForwardMeta(ForwardMeta):
@@ -198,7 +198,7 @@ class CudaGraphPiecewiseBackend:
if output is not None:
output_buffer = paddle.zeros_like(output)
output._share_buffer_to(output_buffer)
output._clear
output._clear()
entry.output_buffers.append(output_buffer)
else:
entry.output_buffers.append(None)
@@ -105,6 +105,7 @@ def xpu_pre_process(
seq_lens_decoder: Optional[paddle.Tensor] = None,
is_profiling: bool = False,
forward_meta=None,
use_cudagraph=False,
) -> XPUForwardMeta:
""" """
max_len = input_ids.shape[1]
@@ -152,7 +153,6 @@ def xpu_pre_process(
cu_seqlens_k,
) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time)
share_inputs["ids_remove_padding"] = None # set this after adjust batch
share_inputs["cum_offsets"] = cum_offsets
share_inputs["batch_id_per_token"] = batch_id_per_token
share_inputs["cu_seqlens_q"] = cu_seqlens_q
@@ -221,11 +221,18 @@ def xpu_pre_process(
adjusted_input = adjusted_input.squeeze(1)
share_inputs["ids_remove_padding"] = adjusted_input
share_inputs["ids_remove_padding"].copy_(adjusted_input, False)
xpu_forward_meta.ids_remove_padding = adjusted_input
# Set forward_meta.is_profiling to True to skip init_kv_signal_per_query for attention backends
xpu_forward_meta.is_profiling = is_profiling
return xpu_forward_meta
if use_cudagraph:
if forward_meta is None:
return xpu_forward_meta
else:
forward_meta.copy_from(xpu_forward_meta)
return forward_meta
else:
return xpu_forward_meta
def xpu_process_output(
+90 -6
View File
@@ -152,10 +152,12 @@ class XPUModelRunner(ModelRunnerBase):
# Lazy initialize kv cache after model loading
# self.kv_caches: list[paddle.Tensor] = []
# Cuda Graph
self.graph_opt_level = self.graph_opt_config.graph_opt_level
self.use_cudagraph = False
# CUDA Graph
self.use_cudagraph = self.graph_opt_config.use_cudagraph
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes
self.cudagraph_only_prefill = self.graph_opt_config.cudagraph_only_prefill
self.input_ids = paddle.zeros(self.scheduler_config.max_num_seqs, dtype="int32")
# Initialize share inputs
@@ -290,6 +292,22 @@ class XPUModelRunner(ModelRunnerBase):
else:
return 0
def only_prefill(self):
"""
check whether prefill only
"""
if_only_prefill = True
decode_exists = None
if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
only_prefill_batch_list = []
decode_exists = self.exist_decode()
paddle.distributed.all_gather_object(only_prefill_batch_list, not decode_exists)
if_only_prefill = all(only_prefill_batch_list)
if_only_prefill = if_only_prefill and not (decode_exists if decode_exists is not None else self.exist_decode())
return if_only_prefill
def only_decode(self):
"""
Update Batch type for if_only_decode.
@@ -1094,7 +1112,30 @@ class XPUModelRunner(ModelRunnerBase):
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
is_profiling=is_dummy_run,
forward_meta=self.forward_meta,
)
if self.use_cudagraph:
# Update Batch type for cuda graph for only_decode_batch
if_only_decode = self.only_decode()
only_decode_use_cudagraph = self.use_cudagraph and if_only_decode
# Update config about moe for better performance
# TODO(wanglongzhi):Modifying the config at runtime is not appropriate; it needs to be moved to forward_meta. It will be used in MoEMethodBase.apply()
if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
self.fd_config.model_config.moe_phase.phase = "decode" if if_only_decode else "prefill"
if self.speculative_decoding:
self.proposer.fd_config.parallel_config.moe_phase.phase = "decode" if if_only_decode else "prefill"
# Update Batch type for cuda graph for only_prefill_batch
only_prefill_use_cudagraph = self.use_cudagraph and self.cudagraph_only_prefill and self.only_prefill()
self.forward_meta.step_use_cudagraph = (
only_prefill_use_cudagraph
if self.cudagraph_only_prefill
else only_decode_use_cudagraph and self.forward_meta.ids_remove_padding.shape[0] > 0
)
# Update bad tokens len
max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"])
@@ -1429,8 +1470,32 @@ class XPUModelRunner(ModelRunnerBase):
"""
Trigger CUDA Graph capture for all shapes in 'CudaGraphConfig.cudagraph_capture_sizes'
"""
logger.warn("XPU not support cuda graph currently")
pass
time_before_capture = time.perf_counter()
expected_decode_len = 1
capture_sizes = self.cudagraph_capture_sizes.copy()
try:
for batch_size in sorted(capture_sizes, reverse=True):
self._dummy_run(
num_tokens=self.scheduler_config.max_num_batched_tokens,
batch_size=batch_size,
expected_decode_len=expected_decode_len,
in_capturing=True,
)
logger.info(f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}")
except RuntimeError as e:
if "out of memory" in str(e):
raise RuntimeError(
"CUDA out of memory occurred when warming up CUDAGraph "
f"with the capture sizes {capture_sizes}. Please try "
"lowering `max_num_seqs` or `gpu_memory_utilization` when "
"initializing the engine."
) from e
else:
raise e
time_after_capture = time.perf_counter()
logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds")
@sot_warmup_guard(True)
def sot_warmup(self) -> None:
@@ -1467,6 +1532,11 @@ class XPUModelRunner(ModelRunnerBase):
# 1. Prepare inputs of model and decoder.
self._prepare_inputs(is_dummy_run=is_dummy_run)
if is_dummy_run:
self.forward_meta.step_use_cudagraph = in_capturing and self.forward_meta.step_use_cudagraph
# 2. Padding inputs for cuda grph
self.padding_cudagraph_inputs()
# NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state.
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
# when there is data on other runner, the current runner is required to execute part of the model.
@@ -1486,7 +1556,8 @@ class XPUModelRunner(ModelRunnerBase):
ids_remove_padding=self.share_inputs["ids_remove_padding"],
forward_meta=self.forward_meta,
)
if self.use_cudagraph:
model_output = model_output[: self.real_token_num]
hidden_states = xpu_process_output(
model_output, self.share_inputs["cum_offsets"], self.forward_meta, self.share_inputs
)
@@ -1689,6 +1760,19 @@ class XPUModelRunner(ModelRunnerBase):
"""Stop decoding if the tensor meets the termination condition"""
return self.share_inputs["not_need_stop"][0]
def padding_cudagraph_inputs(self) -> None:
"""
Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch.
In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch.
"""
# In init_attention_metadata, the decode buffer has already been cleared
# To adapt to CUDA Graph, keep the forward pass at the maximum batch size.
if self.use_cudagraph:
self.forward_meta.seq_lens_this_time = self.share_inputs["seq_lens_this_time"]
self.real_token_num = self.forward_meta.ids_remove_padding.shape[0]
return
def clear_cache(self):
"""Clear cached data from shared inputs and forward metadata"""
self.share_inputs.pop("caches", None)
+3 -2
View File
@@ -184,8 +184,9 @@ class XpuWorker(WorkerBase):
"""
Perform the warm-up and the graph optimization
"""
if self.model_runner.graph_opt_level >= 1:
self.model_runner.sot_warmup()
# Trigger cuda graph capture
if self.model_runner.use_cudagraph:
self.model_runner.capture_model()
def check_health(self) -> bool:
""" """