mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[BugFix]fix handle 4 return values from noaux_tc_redundant op (#6384)
* fix: handle 4 return values from noaux_tc_redundant op The noaux_tc_redundant CUDA op is defined with 4 outputs in PD_BUILD_STATIC_OP: - output_tensor (scores) - topk_values - topk_indices - tokens_per_expert_stats_list_out (inplace updated) The Python code was only unpacking 3 values, causing: ValueError: too many values to unpack (expected 3) This fix correctly unpacks all 4 return values, ignoring the inplace updated tensor which is the same as the input tokens_per_expert_stats_list. Co-Authored-By: Claude (Claude Opus 4.5) <noreply@anthropic.com> * fix: make noaux_tc_redundant return 4 values to match OP definition The PD_BUILD_STATIC_OP defines 4 outputs but the function only returned 3, causing inconsistent behavior across different Paddle framework versions. This fix explicitly returns 4 values: - scores (inplace modified) - topk_values - topk_indices - tokens_per_expert_stats_list (inplace modified via atomicAdd) Co-Authored-By: Claude (Claude Opus 4.5) <noreply@anthropic.com> --------- Co-authored-by: Claude (Claude Opus 4.5) <noreply@anthropic.com>
This commit is contained in:
@@ -61,7 +61,8 @@ std::vector<paddle::Tensor> NoauxTcRedundant(
|
||||
redundant_ep_rank_num_plus_one,
|
||||
stream);
|
||||
|
||||
return {scores, topk_values, topk_indices};
|
||||
// Return 4 values to match PD_BUILD_STATIC_OP Outputs definition
|
||||
return {scores, topk_values, topk_indices, tokens_per_expert_stats_list};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> NoauxTcRedundantInferDtype(
|
||||
|
||||
Reference in New Issue
Block a user