[Model Runner] Prepare token count and move FA3 initialization into the graph (#6170)

* prepare for token num and put FA3 init in graph
This commit is contained in:
sunxin
2026-01-26 12:16:57 +08:00
committed by GitHub
parent 0966df78dc
commit adc69c15d0
10 changed files with 64 additions and 42 deletions
+2
View File
@@ -1352,6 +1352,7 @@ class MetaxModelRunner(ModelRunnerBase):
self.max_logprobs = None
# Remove padding
token_num_cpu = self.share_inputs["seq_lens_this_time"].numpy().sum().item()
(
ids_remove_padding,
batch_id_per_token,
@@ -1360,6 +1361,7 @@ class MetaxModelRunner(ModelRunnerBase):
output_cum_offsets,
output_padding_offset,
) = pre_process(
token_num_cpu,
self.share_inputs["input_ids"],
self.share_inputs["seq_lens_this_time"],
self.speculative_decoding,