#pragma once #ifndef MARLIN_NAMESPACE_NAME #define MARLIN_NAMESPACE_NAME marlin_moe_wna16 #endif #include "paddle/phi/api/include/api.h" #include "paddle/phi/core/enforce.h" #include "moe/moe_wna16_marlin_utils/kernel.h" #include "moe/moe_wna16_marlin_utils/types.h" std::vector MoeWna16MarlinGemmApi( const paddle::Tensor& a, const paddle::optional& c_or_none, const paddle::Tensor& b_q_weight, const paddle::Tensor& b_scales, const paddle::optional& global_scale_or_none, const paddle::optional& b_zeros_or_none, const paddle::optional& g_idx_or_none, const paddle::optional& perm_or_none, const paddle::Tensor& workspace, const paddle::Tensor& sorted_token_ids, const paddle::Tensor& expert_ids, const paddle::Tensor& num_tokens_post_padded, const paddle::Tensor& topk_weights, int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep, const std::string& b_q_type_str, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float);