[Metax][CI] e2e ci tests enable cuda graph (#6401)

This commit is contained in:
MingkunZhang
2026-02-09 16:25:23 +08:00
committed by GitHub
parent fd56d85346
commit 268276e287
4 changed files with 11 additions and 12 deletions
@@ -911,6 +911,7 @@ def rebuild_padding(
seq_lens_decoder,
seq_lens_encoder,
batch_id_per_token_output,
cu_seqlens_q_output,
first_token_out,
enable_logprob,
)
+8 -10
View File
@@ -1264,8 +1264,8 @@ class MetaxModelRunner(ModelRunnerBase):
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
output_cum_offsets,
output_padding_offset,
cu_seqlens_q_output,
batch_id_per_token_output,
) = pre_process(
token_num_cpu,
self.share_inputs["input_ids"],
@@ -1284,8 +1284,8 @@ class MetaxModelRunner(ModelRunnerBase):
# For speculative decoding
if self.speculative_decoding:
self.share_inputs["output_cum_offsets"].copy_(output_cum_offsets, False)
self.share_inputs["output_padding_offset"].copy_(output_padding_offset, False)
self.share_inputs["cu_seqlens_q_output"].copy_(cu_seqlens_q_output, False)
self.share_inputs["batch_id_per_token_output"].copy_(batch_id_per_token_output, False)
# Initialize forward meta data
self.initialize_forward_meta(is_dummy_or_profile_run=is_dummy_or_profile_run)
@@ -1896,10 +1896,8 @@ class MetaxModelRunner(ModelRunnerBase):
self.share_inputs["seq_lens_this_time"],
self.share_inputs["seq_lens_decoder"],
self.share_inputs["seq_lens_encoder"],
(
self.share_inputs["output_padding_offset"] if self.speculative_decoding else None
), # speculative decoding requires
self.model_config.max_model_len,
(self.share_inputs["batch_id_per_token_output"] if self.speculative_decoding else None),
(self.share_inputs["cu_seqlens_q_output"] if self.speculative_decoding else None),
)
self._dummy_sampler_run(hidden_states, model_output, accept_all_drafts, reject_all_drafts)
@@ -2337,8 +2335,8 @@ class MetaxModelRunner(ModelRunnerBase):
self.share_inputs["seq_lens_this_time"],
self.share_inputs["seq_lens_decoder"],
self.share_inputs["seq_lens_encoder"],
(self.share_inputs["output_padding_offset"] if self.speculative_decoding else None),
self.model_config.max_model_len,
(self.share_inputs["batch_id_per_token_output"] if self.speculative_decoding else None),
(self.share_inputs["cu_seqlens_q_output"] if self.speculative_decoding else None),
)
# 4. Compute logits, Sample
+1 -1
View File
@@ -31,7 +31,7 @@ class TestErnie21B(unittest.TestCase):
load_choices="default_v1",
# enable_prefix_caching=False,
disable_custom_all_reduce=True,
graph_optimization_config={"use_cudagraph": False, "graph_opt_level": 0},
# graph_optimization_config={"use_cudagraph": False, "graph_opt_level": 0},
)
cls.sampling_params = fastdeploy.SamplingParams(top_p=0.95, max_tokens=256, temperature=0.6)
+1 -1
View File
@@ -37,7 +37,7 @@ class TestErnie28BVL(unittest.TestCase):
quantization="wint8",
disable_custom_all_reduce=True,
# enable_prefix_caching=False,
graph_optimization_config={"use_cudagraph": False, "graph_opt_level": 0},
# graph_optimization_config={"use_cudagraph": False, "graph_opt_level": 0},
limit_mm_per_prompt={"image": 100},
reasoning_parser="ernie-45-vl",
load_choices="default_v1",