[Feature] Support mtp overlap schedule (#7001)

This commit is contained in:
sunxin
2026-04-01 14:24:26 +08:00
committed by GitHub
parent c6f0c5c3a6
commit c29e86fc9d
23 changed files with 215 additions and 138 deletions
@@ -123,7 +123,7 @@ def gather_logprobs(
indices = token_ids
top_logprobs = token_logprobs
return LogprobsTensors(indices, top_logprobs, token_ranks)
return LogprobsTensors(indices.cpu(), top_logprobs.cpu(), token_ranks.cpu())
def build_output_logprobs(
@@ -1041,7 +1041,7 @@ class SpeculativeSampler(nn.Layer):
)
sampler_output.logprobs_tensors = logprobs_tensors
if cu_batch_token_offset is not None:
sampler_output.cu_batch_token_offset = cu_batch_token_offset
sampler_output.cu_batch_token_offset = cu_batch_token_offset.cpu()
return sampler_output
def forward_xpu(