Co-authored-by: gongweibao <gognweibao@baidu.com>
This commit is contained in:
gongweibao
2026-03-04 21:55:31 +08:00
committed by GitHub
parent 5c8f5184d9
commit ddb06ff83f
306 changed files with 40627 additions and 34418 deletions
+18 -16
View File
@@ -36,22 +36,24 @@ std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
auto place = scores_with_bias.place();
auto group_scores = paddle::empty({num_tokens, n_group}, input_type, place);
auto topk_values = paddle::empty({num_tokens, topk}, input_type, place);
auto topk_indices = paddle::empty({num_tokens, topk}, paddle::DataType::INT64, place);
auto topk_indices =
paddle::empty({num_tokens, topk}, paddle::DataType::INT64, place);
auto stream = scores_with_bias.stream();
invokeNoAuxTc<float, int64_t>(reinterpret_cast<float*>(scores.data<float>()),
reinterpret_cast<float*>(group_scores.data<float>()),
reinterpret_cast<float*>(topk_values.data<float>()),
reinterpret_cast<int64_t*>(topk_indices.data<int64_t>()),
reinterpret_cast<float*>(scores_with_bias.data<float>()),
num_tokens,
num_experts,
n_group,
topk_group,
topk,
renormalize,
routed_scaling_factor,
stream);
invokeNoAuxTc<float, int64_t>(
reinterpret_cast<float*>(scores.data<float>()),
reinterpret_cast<float*>(group_scores.data<float>()),
reinterpret_cast<float*>(topk_values.data<float>()),
reinterpret_cast<int64_t*>(topk_indices.data<int64_t>()),
reinterpret_cast<float*>(scores_with_bias.data<float>()),
num_tokens,
num_experts,
n_group,
topk_group,
topk,
renormalize,
routed_scaling_factor,
stream);
return {scores, topk_values, topk_indices};
}
@@ -64,9 +66,9 @@ std::vector<paddle::DataType> NoauxTcInferDtype(
std::vector<std::vector<int64_t>> NoauxTcInferShape(
const std::vector<int64_t>& scores_shape,
const std::vector<int64_t>& ,
const std::vector<int64_t>&,
const int topk) {
auto num_tokens = scores_shape[0];
auto num_tokens = scores_shape[0];
auto topk_values_shape = std::vector<int64_t>{num_tokens, topk};
auto topk_indices_shape = std::vector<int64_t>{num_tokens, topk};
return {scores_shape, topk_values_shape, topk_indices_shape};