[Model Runner] Prepare token count and move FA3 initialization into the graph (#6170)

* prepare for token num and put FA3 init in graph
This commit is contained in:
sunxin
2026-01-26 12:16:57 +08:00
committed by GitHub
parent 0966df78dc
commit adc69c15d0
10 changed files with 64 additions and 42 deletions
+2
View File
@@ -935,6 +935,7 @@ class MTPProposer(Proposer):
if self.model_inputs["not_need_stop"]:
self.model_inputs["substep"] = substep
# Remove padding
token_num_cpu = self.model_inputs["seq_lens_this_time"].numpy().sum().item()
(
ids_remove_padding,
batch_id_per_token,
@@ -943,6 +944,7 @@ class MTPProposer(Proposer):
output_cum_offsets,
output_padding_offset,
) = pre_process(
token_num_cpu,
self.model_inputs["input_ids"],
self.model_inputs["seq_lens_this_time"],
True,