mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 01:29:57 +08:00
[Feature] Support mtp overlap schedule (#7001)
This commit is contained in:
@@ -123,7 +123,7 @@ def gather_logprobs(
|
||||
indices = token_ids
|
||||
top_logprobs = token_logprobs
|
||||
|
||||
return LogprobsTensors(indices, top_logprobs, token_ranks)
|
||||
return LogprobsTensors(indices.cpu(), top_logprobs.cpu(), token_ranks.cpu())
|
||||
|
||||
|
||||
def build_output_logprobs(
|
||||
|
||||
@@ -1041,7 +1041,7 @@ class SpeculativeSampler(nn.Layer):
|
||||
)
|
||||
sampler_output.logprobs_tensors = logprobs_tensors
|
||||
if cu_batch_token_offset is not None:
|
||||
sampler_output.cu_batch_token_offset = cu_batch_token_offset
|
||||
sampler_output.cu_batch_token_offset = cu_batch_token_offset.cpu()
|
||||
return sampler_output
|
||||
|
||||
def forward_xpu(
|
||||
|
||||
Reference in New Issue
Block a user