[Graph Optimization] Add max_capture_shape_prefill && cudagraph_capture_sizes_prefill (#6148)

* Add max_capture_shape_dy2st parameter to YAML config

* split cudagraph capture size between decode and prefill

* rm if

* add default value
This commit is contained in:
Ryan
2026-01-22 21:37:18 +08:00
committed by GitHub
parent 8d27a523e7
commit 31c219d483
+40 -7
View File
@@ -905,6 +905,7 @@ class GraphOptimizationConfig:
- None (default): capture sizes are inferred from llm config.
- list[int]: capture sizes are specified as given."""
self.cudagraph_capture_sizes: Optional[list[int]] = None
self.cudagraph_capture_sizes_prefill: list[int] = [1, 2, 4, 8]
""" Number of warmup runs for cudagraph. """
self.cudagraph_num_of_warmups: int = 2
"""Whether to copy input tensors for cudagraph.
@@ -942,7 +943,7 @@ class GraphOptimizationConfig:
""" Maximum CUDA Graph capture size for static graph mode.
Recommend 512 for small models (e.g., ERNIE45T 0.3B) and 128 for massive models (e.g., 300B).
"""
self.max_capture_shape_dy2st: int = 512
self.max_capture_shape_prefill: int = 512
# CINN Config ...
if args is not None:
@@ -952,13 +953,16 @@ class GraphOptimizationConfig:
self.check_legality_parameters()
def init_with_cudagrpah_size(self, max_capture_size: int = 0) -> None:
def init_with_cudagrpah_size(self, max_capture_size: int = 0, max_capture_shape_prefill: int = 0) -> None:
"""
Initialize cuda graph capture sizes and
pre-compute the mapping from batch size to padded graph size
"""
# Regular capture sizes
self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size <= max_capture_size]
self.cudagraph_capture_sizes_prefill = [
size for size in self.cudagraph_capture_sizes_prefill if size <= max_capture_shape_prefill
]
dedup_sizes = list(set(self.cudagraph_capture_sizes))
if len(dedup_sizes) < len(self.cudagraph_capture_sizes):
logger.info(
@@ -970,7 +974,11 @@ class GraphOptimizationConfig:
# Sort to make sure cudagraph capture sizes are in descending order
self.cudagraph_capture_sizes.sort(reverse=True)
self.cudagraph_capture_sizes_prefill.sort(reverse=True)
self.max_capture_size = self.cudagraph_capture_sizes[0] if self.cudagraph_capture_sizes else 0
self.max_capture_size_prefill = (
self.cudagraph_capture_sizes_prefill[0] if self.cudagraph_capture_sizes_prefill else 0
)
# Pre-compute the mapping from shape to padded graph size
self.real_shape_to_captured_size = {}
@@ -982,7 +990,21 @@ class GraphOptimizationConfig:
self.real_shape_to_captured_size[bs] = end
self.real_shape_to_captured_size[self.max_capture_size] = self.max_capture_size
def _set_cudagraph_sizes(self, max_capture_size: int = 0, dec_token_per_query_per_step: int = 1):
self.real_shape_to_captured_size_prefill = {}
for end, start in zip(self.cudagraph_capture_sizes_prefill, self.cudagraph_capture_sizes_prefill[1:] + [0]):
for bs in range(start, end):
if bs == start:
self.real_shape_to_captured_size_prefill[bs] = start
else:
self.real_shape_to_captured_size_prefill[bs] = end
self.real_shape_to_captured_size_prefill[self.max_capture_size_prefill] = self.max_capture_size_prefill
def _set_cudagraph_sizes(
self,
max_capture_size: int = 0,
max_capture_shape_prefill: int = 0,
dec_token_per_query_per_step: int = 1,
):
"""
Calculate a series of candidate capture sizes,
and then extract a portion of them as the capture list for the CUDA graph based on user input.
@@ -996,14 +1018,21 @@ class GraphOptimizationConfig:
# Shape [256, 288, ... 992, 1024] * dec_token_per_query_per_step
draft_capture_sizes += [32 * i * dec_token_per_query_per_step for i in range(9, 33)]
draft_capture_sizes_prefill = draft_capture_sizes.copy()
draft_capture_sizes.append(max_capture_size)
self.cudagraph_capture_sizes = sorted(draft_capture_sizes)
draft_capture_sizes_prefill.append(max_capture_shape_prefill)
self.cudagraph_capture_sizes_prefill = sorted(draft_capture_sizes_prefill)
def filter_capture_size(self, tp_size: int = 1):
"""When TSP is used, capture size must be divisible by tp size."""
self.cudagraph_capture_sizes = [
draft_size for draft_size in self.cudagraph_capture_sizes if (draft_size % tp_size == 0)
]
self.cudagraph_capture_sizes_prefill = [
draft_size for draft_size in self.cudagraph_capture_sizes_prefill if (draft_size % tp_size == 0)
]
def to_json_string(self):
"""
@@ -1672,8 +1701,7 @@ class FDConfig:
else:
max_capture_shape = min(512, max_capture_shape)
if self.graph_opt_config.graph_opt_level > 0:
max_capture_shape = graph_opt_config.max_capture_shape_dy2st
max_capture_shape_prefill = graph_opt_config.max_capture_shape_prefill
if self.graph_opt_config.cudagraph_capture_sizes is None:
dec_token_per_query_per_step = (
@@ -1682,9 +1710,14 @@ class FDConfig:
else 1
)
self.graph_opt_config._set_cudagraph_sizes(
max_capture_size=max_capture_shape, dec_token_per_query_per_step=dec_token_per_query_per_step
max_capture_size=max_capture_shape,
max_capture_shape_prefill=max_capture_shape_prefill,
dec_token_per_query_per_step=dec_token_per_query_per_step,
)
self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=max_capture_shape)
self.graph_opt_config.init_with_cudagrpah_size(
max_capture_size=max_capture_shape,
max_capture_shape_prefill=max_capture_shape_prefill,
)
self.tokenizer = tokenizer
self.ips = ips