mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Metax][CI] e2e ci tests enable cuda graph (#6401)
This commit is contained in:
@@ -911,6 +911,7 @@ def rebuild_padding(
|
|||||||
seq_lens_decoder,
|
seq_lens_decoder,
|
||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
batch_id_per_token_output,
|
batch_id_per_token_output,
|
||||||
|
cu_seqlens_q_output,
|
||||||
first_token_out,
|
first_token_out,
|
||||||
enable_logprob,
|
enable_logprob,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1264,8 +1264,8 @@ class MetaxModelRunner(ModelRunnerBase):
|
|||||||
batch_id_per_token,
|
batch_id_per_token,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
cu_seqlens_k,
|
cu_seqlens_k,
|
||||||
output_cum_offsets,
|
cu_seqlens_q_output,
|
||||||
output_padding_offset,
|
batch_id_per_token_output,
|
||||||
) = pre_process(
|
) = pre_process(
|
||||||
token_num_cpu,
|
token_num_cpu,
|
||||||
self.share_inputs["input_ids"],
|
self.share_inputs["input_ids"],
|
||||||
@@ -1284,8 +1284,8 @@ class MetaxModelRunner(ModelRunnerBase):
|
|||||||
|
|
||||||
# For speculative decoding
|
# For speculative decoding
|
||||||
if self.speculative_decoding:
|
if self.speculative_decoding:
|
||||||
self.share_inputs["output_cum_offsets"].copy_(output_cum_offsets, False)
|
self.share_inputs["cu_seqlens_q_output"].copy_(cu_seqlens_q_output, False)
|
||||||
self.share_inputs["output_padding_offset"].copy_(output_padding_offset, False)
|
self.share_inputs["batch_id_per_token_output"].copy_(batch_id_per_token_output, False)
|
||||||
|
|
||||||
# Initialize forward meta data
|
# Initialize forward meta data
|
||||||
self.initialize_forward_meta(is_dummy_or_profile_run=is_dummy_or_profile_run)
|
self.initialize_forward_meta(is_dummy_or_profile_run=is_dummy_or_profile_run)
|
||||||
@@ -1896,10 +1896,8 @@ class MetaxModelRunner(ModelRunnerBase):
|
|||||||
self.share_inputs["seq_lens_this_time"],
|
self.share_inputs["seq_lens_this_time"],
|
||||||
self.share_inputs["seq_lens_decoder"],
|
self.share_inputs["seq_lens_decoder"],
|
||||||
self.share_inputs["seq_lens_encoder"],
|
self.share_inputs["seq_lens_encoder"],
|
||||||
(
|
(self.share_inputs["batch_id_per_token_output"] if self.speculative_decoding else None),
|
||||||
self.share_inputs["output_padding_offset"] if self.speculative_decoding else None
|
(self.share_inputs["cu_seqlens_q_output"] if self.speculative_decoding else None),
|
||||||
), # speculative decoding requires
|
|
||||||
self.model_config.max_model_len,
|
|
||||||
)
|
)
|
||||||
self._dummy_sampler_run(hidden_states, model_output, accept_all_drafts, reject_all_drafts)
|
self._dummy_sampler_run(hidden_states, model_output, accept_all_drafts, reject_all_drafts)
|
||||||
|
|
||||||
@@ -2337,8 +2335,8 @@ class MetaxModelRunner(ModelRunnerBase):
|
|||||||
self.share_inputs["seq_lens_this_time"],
|
self.share_inputs["seq_lens_this_time"],
|
||||||
self.share_inputs["seq_lens_decoder"],
|
self.share_inputs["seq_lens_decoder"],
|
||||||
self.share_inputs["seq_lens_encoder"],
|
self.share_inputs["seq_lens_encoder"],
|
||||||
(self.share_inputs["output_padding_offset"] if self.speculative_decoding else None),
|
(self.share_inputs["batch_id_per_token_output"] if self.speculative_decoding else None),
|
||||||
self.model_config.max_model_len,
|
(self.share_inputs["cu_seqlens_q_output"] if self.speculative_decoding else None),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. Compute logits, Sample
|
# 4. Compute logits, Sample
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ class TestErnie21B(unittest.TestCase):
|
|||||||
load_choices="default_v1",
|
load_choices="default_v1",
|
||||||
# enable_prefix_caching=False,
|
# enable_prefix_caching=False,
|
||||||
disable_custom_all_reduce=True,
|
disable_custom_all_reduce=True,
|
||||||
graph_optimization_config={"use_cudagraph": False, "graph_opt_level": 0},
|
# graph_optimization_config={"use_cudagraph": False, "graph_opt_level": 0},
|
||||||
)
|
)
|
||||||
|
|
||||||
cls.sampling_params = fastdeploy.SamplingParams(top_p=0.95, max_tokens=256, temperature=0.6)
|
cls.sampling_params = fastdeploy.SamplingParams(top_p=0.95, max_tokens=256, temperature=0.6)
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ class TestErnie28BVL(unittest.TestCase):
|
|||||||
quantization="wint8",
|
quantization="wint8",
|
||||||
disable_custom_all_reduce=True,
|
disable_custom_all_reduce=True,
|
||||||
# enable_prefix_caching=False,
|
# enable_prefix_caching=False,
|
||||||
graph_optimization_config={"use_cudagraph": False, "graph_opt_level": 0},
|
# graph_optimization_config={"use_cudagraph": False, "graph_opt_level": 0},
|
||||||
limit_mm_per_prompt={"image": 100},
|
limit_mm_per_prompt={"image": 100},
|
||||||
reasoning_parser="ernie-45-vl",
|
reasoning_parser="ernie-45-vl",
|
||||||
load_choices="default_v1",
|
load_choices="default_v1",
|
||||||
|
|||||||
Reference in New Issue
Block a user