mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[XPU][Graph Optimization] XPU Support CUDAGraph (#6152)
* support cuda graph
This commit is contained in:
@@ -900,7 +900,7 @@ class GraphOptimizationConfig:
|
||||
"""
|
||||
self.sot_warmup_sizes: list[int] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 32, 64, 128]
|
||||
""" Number of warmup runs for SOT warmup. """
|
||||
self.use_cudagraph: bool = True
|
||||
self.use_cudagraph: bool = False if paddle.is_compiled_with_xpu() else True
|
||||
"""Sizes to capture cudagraph.
|
||||
- None (default): capture sizes are inferred from llm config.
|
||||
- list[int]: capture sizes are specified as given."""
|
||||
@@ -1850,9 +1850,11 @@ class FDConfig:
|
||||
"Static Graph does not support to be started together with RL Training, and automatically switch to dynamic graph!"
|
||||
)
|
||||
|
||||
if not current_platform.is_cuda() and not current_platform.is_maca():
|
||||
if not current_platform.is_cuda() and not current_platform.is_maca() and not current_platform.is_xpu():
|
||||
self.graph_opt_config.use_cudagraph = False
|
||||
logger.info("CUDAGraph currently only support on GPU!")
|
||||
logger.info(
|
||||
"Current Platform can not support CUDAGraph, CUDAGraph currently only support on GPU/XPU/Metax GPU !"
|
||||
)
|
||||
|
||||
# adjust speculative config
|
||||
if self.speculative_config is not None and self.speculative_config.method == "mtp":
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, fields
|
||||
from enum import IntEnum, auto
|
||||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
|
||||
@@ -260,6 +260,39 @@ class XPUForwardMeta(ForwardMeta):
|
||||
# for pd_disaggregation
|
||||
kv_signal_sender: Optional[paddle.Tensor] = None
|
||||
|
||||
def copy_from(self, other: "XPUForwardMeta", skip_keys: Optional[list] = None):
|
||||
"""
|
||||
Synchronize attributes from another XPUForwardMeta object
|
||||
"""
|
||||
if skip_keys is None:
|
||||
skip_keys = []
|
||||
|
||||
# Use fields(self) to ensure all fields of the current class are obtained
|
||||
for field in fields(self):
|
||||
name = field.name
|
||||
|
||||
if name in skip_keys:
|
||||
continue
|
||||
|
||||
if not hasattr(other, name):
|
||||
continue
|
||||
|
||||
src_val = getattr(other, name)
|
||||
dst_val = getattr(self, name)
|
||||
|
||||
# Synchronization logic
|
||||
if isinstance(src_val, paddle.Tensor):
|
||||
if isinstance(dst_val, paddle.Tensor):
|
||||
# Only perform in-place copy_ when the destination is also a Tensor and already exists
|
||||
dst_val.copy_(src_val, False)
|
||||
else:
|
||||
# Directly assign the reference if the destination is None (in-place copy to None is not feasible)
|
||||
setattr(self, name, src_val)
|
||||
else:
|
||||
# Handle non-Tensor attributes (str, int, bool, etc.)
|
||||
setattr(self, name, src_val)
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class DCUForwardMeta(ForwardMeta):
|
||||
|
||||
@@ -198,7 +198,7 @@ class CudaGraphPiecewiseBackend:
|
||||
if output is not None:
|
||||
output_buffer = paddle.zeros_like(output)
|
||||
output._share_buffer_to(output_buffer)
|
||||
output._clear
|
||||
output._clear()
|
||||
entry.output_buffers.append(output_buffer)
|
||||
else:
|
||||
entry.output_buffers.append(None)
|
||||
|
||||
@@ -105,6 +105,7 @@ def xpu_pre_process(
|
||||
seq_lens_decoder: Optional[paddle.Tensor] = None,
|
||||
is_profiling: bool = False,
|
||||
forward_meta=None,
|
||||
use_cudagraph=False,
|
||||
) -> XPUForwardMeta:
|
||||
""" """
|
||||
max_len = input_ids.shape[1]
|
||||
@@ -152,7 +153,6 @@ def xpu_pre_process(
|
||||
cu_seqlens_k,
|
||||
) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time)
|
||||
|
||||
share_inputs["ids_remove_padding"] = None # set this after adjust batch
|
||||
share_inputs["cum_offsets"] = cum_offsets
|
||||
share_inputs["batch_id_per_token"] = batch_id_per_token
|
||||
share_inputs["cu_seqlens_q"] = cu_seqlens_q
|
||||
@@ -221,11 +221,18 @@ def xpu_pre_process(
|
||||
|
||||
adjusted_input = adjusted_input.squeeze(1)
|
||||
|
||||
share_inputs["ids_remove_padding"] = adjusted_input
|
||||
share_inputs["ids_remove_padding"].copy_(adjusted_input, False)
|
||||
xpu_forward_meta.ids_remove_padding = adjusted_input
|
||||
# Set forward_meta.is_profiling to True to skip init_kv_signal_per_query for attention backends
|
||||
xpu_forward_meta.is_profiling = is_profiling
|
||||
return xpu_forward_meta
|
||||
if use_cudagraph:
|
||||
if forward_meta is None:
|
||||
return xpu_forward_meta
|
||||
else:
|
||||
forward_meta.copy_from(xpu_forward_meta)
|
||||
return forward_meta
|
||||
else:
|
||||
return xpu_forward_meta
|
||||
|
||||
|
||||
def xpu_process_output(
|
||||
|
||||
@@ -152,10 +152,12 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
# Lazy initialize kv cache after model loading
|
||||
# self.kv_caches: list[paddle.Tensor] = []
|
||||
|
||||
# Cuda Graph
|
||||
self.graph_opt_level = self.graph_opt_config.graph_opt_level
|
||||
self.use_cudagraph = False
|
||||
# CUDA Graph
|
||||
self.use_cudagraph = self.graph_opt_config.use_cudagraph
|
||||
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
|
||||
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes
|
||||
self.cudagraph_only_prefill = self.graph_opt_config.cudagraph_only_prefill
|
||||
|
||||
self.input_ids = paddle.zeros(self.scheduler_config.max_num_seqs, dtype="int32")
|
||||
|
||||
# Initialize share inputs
|
||||
@@ -290,6 +292,22 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
else:
|
||||
return 0
|
||||
|
||||
def only_prefill(self):
|
||||
"""
|
||||
check whether prefill only
|
||||
"""
|
||||
if_only_prefill = True
|
||||
decode_exists = None
|
||||
if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
|
||||
only_prefill_batch_list = []
|
||||
decode_exists = self.exist_decode()
|
||||
paddle.distributed.all_gather_object(only_prefill_batch_list, not decode_exists)
|
||||
if_only_prefill = all(only_prefill_batch_list)
|
||||
|
||||
if_only_prefill = if_only_prefill and not (decode_exists if decode_exists is not None else self.exist_decode())
|
||||
|
||||
return if_only_prefill
|
||||
|
||||
def only_decode(self):
|
||||
"""
|
||||
Update Batch type for if_only_decode.
|
||||
@@ -1094,7 +1112,30 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
|
||||
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
|
||||
is_profiling=is_dummy_run,
|
||||
forward_meta=self.forward_meta,
|
||||
)
|
||||
|
||||
if self.use_cudagraph:
|
||||
# Update Batch type for cuda graph for only_decode_batch
|
||||
if_only_decode = self.only_decode()
|
||||
|
||||
only_decode_use_cudagraph = self.use_cudagraph and if_only_decode
|
||||
# Update config about moe for better performance
|
||||
# TODO(wanglongzhi):Modifying the config at runtime is not appropriate; it needs to be moved to forward_meta. It will be used in MoEMethodBase.apply()
|
||||
if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
|
||||
self.fd_config.model_config.moe_phase.phase = "decode" if if_only_decode else "prefill"
|
||||
if self.speculative_decoding:
|
||||
self.proposer.fd_config.parallel_config.moe_phase.phase = "decode" if if_only_decode else "prefill"
|
||||
|
||||
# Update Batch type for cuda graph for only_prefill_batch
|
||||
only_prefill_use_cudagraph = self.use_cudagraph and self.cudagraph_only_prefill and self.only_prefill()
|
||||
|
||||
self.forward_meta.step_use_cudagraph = (
|
||||
only_prefill_use_cudagraph
|
||||
if self.cudagraph_only_prefill
|
||||
else only_decode_use_cudagraph and self.forward_meta.ids_remove_padding.shape[0] > 0
|
||||
)
|
||||
|
||||
# Update bad tokens len
|
||||
max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"])
|
||||
|
||||
@@ -1429,8 +1470,32 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
"""
|
||||
Trigger CUDA Graph capture for all shapes in 'CudaGraphConfig.cudagraph_capture_sizes'
|
||||
"""
|
||||
logger.warn("XPU not support cuda graph currently")
|
||||
pass
|
||||
time_before_capture = time.perf_counter()
|
||||
expected_decode_len = 1
|
||||
capture_sizes = self.cudagraph_capture_sizes.copy()
|
||||
|
||||
try:
|
||||
for batch_size in sorted(capture_sizes, reverse=True):
|
||||
self._dummy_run(
|
||||
num_tokens=self.scheduler_config.max_num_batched_tokens,
|
||||
batch_size=batch_size,
|
||||
expected_decode_len=expected_decode_len,
|
||||
in_capturing=True,
|
||||
)
|
||||
logger.info(f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}")
|
||||
except RuntimeError as e:
|
||||
if "out of memory" in str(e):
|
||||
raise RuntimeError(
|
||||
"CUDA out of memory occurred when warming up CUDAGraph "
|
||||
f"with the capture sizes {capture_sizes}. Please try "
|
||||
"lowering `max_num_seqs` or `gpu_memory_utilization` when "
|
||||
"initializing the engine."
|
||||
) from e
|
||||
else:
|
||||
raise e
|
||||
|
||||
time_after_capture = time.perf_counter()
|
||||
logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds")
|
||||
|
||||
@sot_warmup_guard(True)
|
||||
def sot_warmup(self) -> None:
|
||||
@@ -1467,6 +1532,11 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
# 1. Prepare inputs of model and decoder.
|
||||
self._prepare_inputs(is_dummy_run=is_dummy_run)
|
||||
|
||||
if is_dummy_run:
|
||||
self.forward_meta.step_use_cudagraph = in_capturing and self.forward_meta.step_use_cudagraph
|
||||
# 2. Padding inputs for cuda grph
|
||||
self.padding_cudagraph_inputs()
|
||||
|
||||
# NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state.
|
||||
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
|
||||
# when there is data on other runner, the current runner is required to execute part of the model.
|
||||
@@ -1486,7 +1556,8 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
ids_remove_padding=self.share_inputs["ids_remove_padding"],
|
||||
forward_meta=self.forward_meta,
|
||||
)
|
||||
|
||||
if self.use_cudagraph:
|
||||
model_output = model_output[: self.real_token_num]
|
||||
hidden_states = xpu_process_output(
|
||||
model_output, self.share_inputs["cum_offsets"], self.forward_meta, self.share_inputs
|
||||
)
|
||||
@@ -1689,6 +1760,19 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
"""Stop decoding if the tensor meets the termination condition"""
|
||||
return self.share_inputs["not_need_stop"][0]
|
||||
|
||||
def padding_cudagraph_inputs(self) -> None:
|
||||
"""
|
||||
Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch.
|
||||
In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch.
|
||||
"""
|
||||
# In init_attention_metadata, the decode buffer has already been cleared
|
||||
|
||||
# To adapt to CUDA Graph, keep the forward pass at the maximum batch size.
|
||||
if self.use_cudagraph:
|
||||
self.forward_meta.seq_lens_this_time = self.share_inputs["seq_lens_this_time"]
|
||||
self.real_token_num = self.forward_meta.ids_remove_padding.shape[0]
|
||||
return
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear cached data from shared inputs and forward metadata"""
|
||||
self.share_inputs.pop("caches", None)
|
||||
|
||||
@@ -184,8 +184,9 @@ class XpuWorker(WorkerBase):
|
||||
"""
|
||||
Perform the warm-up and the graph optimization
|
||||
"""
|
||||
if self.model_runner.graph_opt_level >= 1:
|
||||
self.model_runner.sot_warmup()
|
||||
# Trigger cuda graph capture
|
||||
if self.model_runner.use_cudagraph:
|
||||
self.model_runner.capture_model()
|
||||
|
||||
def check_health(self) -> bool:
|
||||
""" """
|
||||
|
||||
Reference in New Issue
Block a user