From 5c6105f4a27b1b3dd8f452b26682b6fe0029eeb1 Mon Sep 17 00:00:00 2001 From: Lucas Date: Mon, 27 Oct 2025 10:50:08 +0800 Subject: [PATCH] [XPU] bind some OPs for VL model with pybind (#4522) --- custom_ops/gpu_ops/cpp_extensions.cc | 1533 ++++++++++------- custom_ops/gpu_ops/update_inputs.cu | 158 +- custom_ops/gpu_ops/update_inputs_beam.cu | 58 +- custom_ops/gpu_ops/update_inputs_v1.cu | 262 +-- custom_ops/xpu_ops/src/ops/pybind/pybind.cc | 772 +++++++-- custom_ops/xpu_ops/src/ops/update_inputs.cc | 34 +- .../xpu_ops/src/ops/update_inputs_v1.cc | 56 +- .../layers/backends/xpu/moe/fused_moe.py | 3 +- 8 files changed, 1789 insertions(+), 1087 deletions(-) diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 45c1882057..a040fdecec 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -18,14 +18,14 @@ namespace py = pybind11; // 自定义异常类,用于处理CUDA错误 class CudaError : public std::exception { -public: + public: explicit CudaError(cudaError_t error) : error_(error) {} - const char *what() const noexcept override { + const char* what() const noexcept override { return cudaGetErrorString(error_); } -private: + private: cudaError_t error_; }; @@ -39,145 +39,180 @@ void check_cuda_error(cudaError_t error) { // 封装cudaHostAlloc的Python函数 uintptr_t cuda_host_alloc(size_t size, unsigned int flags = cudaHostAllocDefault) { - void *ptr = nullptr; + void* ptr = nullptr; check_cuda_error(cudaHostAlloc(&ptr, size, flags)); return reinterpret_cast(ptr); } // 封装cudaFreeHost的Python函数 void cuda_host_free(uintptr_t ptr) { - check_cuda_error(cudaFreeHost(reinterpret_cast(ptr))); + check_cuda_error(cudaFreeHost(reinterpret_cast(ptr))); } std::vector AppendAttention( - const paddle::Tensor &qkv, const paddle::Tensor &key_cache, - const paddle::Tensor &value_cache, const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q, - const paddle::Tensor &block_tables, const paddle::Tensor &encoder_batch_ids, - const paddle::Tensor &encoder_tile_ids_per_batch, - const paddle::Tensor &encoder_num_blocks, - const paddle::Tensor &kv_batch_ids, - const paddle::Tensor &kv_tile_ids_per_batch, - const paddle::Tensor &kv_num_blocks, - const paddle::Tensor &decoder_batch_ids, - const paddle::Tensor &decoder_tile_ids_per_batch, - const paddle::Tensor &decoder_num_blocks_cpu, - const paddle::Tensor &set_max_lengths, - const paddle::optional &rotary_embs, - const paddle::optional &attn_mask, - const paddle::optional &qkv_bias, - const paddle::optional &qkv_out_scales, - const paddle::optional &cache_k_quant_scales, - const paddle::optional &cache_v_quant_scales, - const paddle::optional &cache_k_dequant_scales, - const paddle::optional &cache_v_dequant_scales, - const paddle::optional &cache_k_zp, - const paddle::optional &cache_v_zp, - const paddle::optional &out_linear_shifts, - const paddle::optional &out_linear_smooths, - const paddle::optional &mask_offset, - const paddle::optional &kv_signal_data, + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::Tensor& encoder_batch_ids, + const paddle::Tensor& encoder_tile_ids_per_batch, + const paddle::Tensor& encoder_num_blocks, + const paddle::Tensor& kv_batch_ids, + const paddle::Tensor& kv_tile_ids_per_batch, + const paddle::Tensor& kv_num_blocks, + const paddle::Tensor& decoder_batch_ids, + const paddle::Tensor& decoder_tile_ids_per_batch, + const paddle::Tensor& decoder_num_blocks_cpu, + const paddle::Tensor& set_max_lengths, + const paddle::optional& rotary_embs, + const paddle::optional& attn_mask, + const paddle::optional& qkv_bias, + const paddle::optional& qkv_out_scales, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& out_linear_shifts, + const paddle::optional& out_linear_smooths, + const paddle::optional& mask_offset, + const paddle::optional& kv_signal_data, const paddle::optional& q_norm_weight, const paddle::optional& k_norm_weight, const paddle::optional& sinks, const float rms_norm_eps, - const std::string &compute_dtype, const std::string &cache_quant_type_str, - const bool use_neox_rotary_style, const bool rope_3d, - const int max_input_length, const float quant_max_bound, - const float quant_min_bound, const float out_linear_in_scale, - const int encoder_block_shape_q, const int decoder_block_shape_q, - const int max_partition_size, const int encoder_max_partition_size, - const int speculate_max_draft_token_num, const bool causal, + const std::string& compute_dtype, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const float out_linear_in_scale, + const int encoder_block_shape_q, + const int decoder_block_shape_q, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, const bool speculate_decoder, const int sliding_window); std::vector AppendAttentionWithOutput( - const paddle::Tensor &qkv, const paddle::Tensor &key_cache, - const paddle::Tensor &value_cache, const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q, - const paddle::Tensor &block_tables, const paddle::Tensor &encoder_batch_ids, - const paddle::Tensor &encoder_tile_ids_per_batch, - const paddle::Tensor &encoder_num_blocks, - const paddle::Tensor &kv_batch_ids, - const paddle::Tensor &kv_tile_ids_per_batch, - const paddle::Tensor &kv_num_blocks, - const paddle::Tensor &decoder_batch_ids, - const paddle::Tensor &decoder_tile_ids_per_batch, - const paddle::Tensor &decoder_num_blocks_cpu, - const paddle::Tensor &set_max_lengths, - paddle::Tensor &fmha_out, - const paddle::optional &rotary_embs, - const paddle::optional &attn_mask, - const paddle::optional &qkv_bias, - const paddle::optional &qkv_out_scales, - const paddle::optional &cache_k_quant_scales, - const paddle::optional &cache_v_quant_scales, - const paddle::optional &cache_k_dequant_scales, - const paddle::optional &cache_v_dequant_scales, - const paddle::optional &cache_k_zp, - const paddle::optional &cache_v_zp, - const paddle::optional &out_linear_shifts, - const paddle::optional &out_linear_smooths, - const paddle::optional &mask_offset, - const paddle::optional &kv_signal_data, + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::Tensor& encoder_batch_ids, + const paddle::Tensor& encoder_tile_ids_per_batch, + const paddle::Tensor& encoder_num_blocks, + const paddle::Tensor& kv_batch_ids, + const paddle::Tensor& kv_tile_ids_per_batch, + const paddle::Tensor& kv_num_blocks, + const paddle::Tensor& decoder_batch_ids, + const paddle::Tensor& decoder_tile_ids_per_batch, + const paddle::Tensor& decoder_num_blocks_cpu, + const paddle::Tensor& set_max_lengths, + paddle::Tensor& fmha_out, + const paddle::optional& rotary_embs, + const paddle::optional& attn_mask, + const paddle::optional& qkv_bias, + const paddle::optional& qkv_out_scales, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& out_linear_shifts, + const paddle::optional& out_linear_smooths, + const paddle::optional& mask_offset, + const paddle::optional& kv_signal_data, const paddle::optional& q_norm_weight, const paddle::optional& k_norm_weight, const paddle::optional& sinks, const float rms_norm_eps, - const std::string &compute_dtype, const std::string &cache_quant_type_str, - const bool use_neox_rotary_style, const bool rope_3d, - const int max_input_length, const float quant_max_bound, - const float quant_min_bound, const float out_linear_in_scale, - const int encoder_block_shape_q, const int decoder_block_shape_q, - const int max_partition_size, const int encoder_max_partition_size, - const int speculate_max_draft_token_num, const bool causal, + const std::string& compute_dtype, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const float out_linear_in_scale, + const int encoder_block_shape_q, + const int decoder_block_shape_q, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, const bool speculate_decoder, const int sliding_window); std::vector GQARopeWriteCacheKernel( - const paddle::Tensor &qkv, const paddle::Tensor &key_cache, - const paddle::Tensor &value_cache, const paddle::Tensor &cu_seqlens_q, - const paddle::Tensor &cu_seqlens_k, const paddle::Tensor &rotary_embs, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &batch_id_per_token, - const paddle::Tensor &block_tables, const paddle::Tensor &kv_batch_ids, - const paddle::Tensor &kv_tile_ids, const paddle::Tensor &kv_num_blocks, - const paddle::Tensor &cache_batch_ids, const paddle::Tensor &cache_tile_ids, - const paddle::Tensor &cache_num_blocks, - const paddle::optional &cache_k_quant_scales, - const paddle::optional &cache_v_quant_scales, - const paddle::optional &cache_k_dequant_scales, - const paddle::optional &cache_v_dequant_scales, - const paddle::optional &cache_k_zp, - const paddle::optional &cache_v_zp, - const paddle::optional &kv_signal_data, - const int kv_token_num, const int max_seq_len, - const std::string &cache_quant_type, + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& cu_seqlens_k, + const paddle::Tensor& rotary_embs, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& block_tables, + const paddle::Tensor& kv_batch_ids, + const paddle::Tensor& kv_tile_ids, + const paddle::Tensor& kv_num_blocks, + const paddle::Tensor& cache_batch_ids, + const paddle::Tensor& cache_tile_ids, + const paddle::Tensor& cache_num_blocks, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& kv_signal_data, + const int kv_token_num, + const int max_seq_len, + const std::string& cache_quant_type, const bool rope_3d); -std::vector -PreCacheLenConcat(const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &seq_lens_this_time, - const int max_dec_len, const int block_size); +std::vector PreCacheLenConcat( + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const int max_dec_len, + const int block_size); paddle::Tensor FusedExpertMoeFunc( - const paddle::Tensor &input, const paddle::Tensor &gate_weight, - const paddle::Tensor &up_gate_proj_weight, const paddle::Tensor &down_proj_weight, - const paddle::optional &up_gate_proj_bias, - const paddle::optional &up_gate_proj_scale, - const paddle::optional &down_proj_bias, - const paddle::optional &down_proj_scale, - const std::string &quant_method, const int moe_topk, - const bool norm_topk_prob, const bool group_moe); + const paddle::Tensor& input, + const paddle::Tensor& gate_weight, + const paddle::Tensor& up_gate_proj_weight, + const paddle::Tensor& down_proj_weight, + const paddle::optional& up_gate_proj_bias, + const paddle::optional& up_gate_proj_scale, + const paddle::optional& down_proj_bias, + const paddle::optional& down_proj_scale, + const std::string& quant_method, + const int moe_topk, + const bool norm_topk_prob, + const bool group_moe); std::vector MacheteMMKernel( - paddle::Tensor const& A, paddle::Tensor const& B, + paddle::Tensor const& A, + paddle::Tensor const& B, paddle::optional const& maybe_group_scales, paddle::optional const& maybe_group_zeros, paddle::optional const& maybe_channel_scales, @@ -188,65 +223,80 @@ std::vector MacheteMMKernel( std::string const& maybe_schedule); std::vector MachetePrepackBKernel( - paddle::Tensor const& B, std::string const& a_type_str, std::string const& b_type_str, + paddle::Tensor const& B, + std::string const& a_type_str, + std::string const& b_type_str, std::string const& maybe_group_scales_type_str); std::vector MacheteSupportedSchedules( std::string const& a_type_str, std::string const& b_type_str); std::vector MoeExpertDispatch( - const paddle::Tensor &input, const paddle::Tensor &gating_output, - const paddle::optional &gating_correction_bias, - const paddle::optional &w4a8_in_scale, const int moe_topk, - const bool group_moe, const std::string &moe_quant_type, const bool topk_only_mode); + const paddle::Tensor& input, + const paddle::Tensor& gating_output, + const paddle::optional& gating_correction_bias, + const paddle::optional& w4a8_in_scale, + const int moe_topk, + const bool group_moe, + const std::string& moe_quant_type, + const bool topk_only_mode); -std::vector -MoETopKSelectKernel(const paddle::Tensor &gating_logits, - const paddle::optional &bias, - const int moe_topk, const bool apply_norm_weight, - const bool enable_softmax_top_k_fused); +std::vector MoETopKSelectKernel( + const paddle::Tensor& gating_logits, + const paddle::optional& bias, + const int moe_topk, + const bool apply_norm_weight, + const bool enable_softmax_top_k_fused); -std::vector -MoERedundantTopKSelectKernel(const paddle::Tensor &gating_logits, - const paddle::Tensor &expert_id_to_ep_rank_array, - const paddle::Tensor &expert_in_rank_num_list, - paddle::Tensor &tokens_per_expert_stats_list, - const paddle::optional &bias, - const int moe_topk, const bool apply_norm_weight, - const bool enable_softmax_top_k_fused, - const int redundant_ep_rank_num_plus_one); +std::vector MoERedundantTopKSelectKernel( + const paddle::Tensor& gating_logits, + const paddle::Tensor& expert_id_to_ep_rank_array, + const paddle::Tensor& expert_in_rank_num_list, + paddle::Tensor& tokens_per_expert_stats_list, + const paddle::optional& bias, + const int moe_topk, + const bool apply_norm_weight, + const bool enable_softmax_top_k_fused, + const int redundant_ep_rank_num_plus_one); -std::vector -EPMoeExpertDispatch(const paddle::Tensor &input, const paddle::Tensor &topk_ids, - const paddle::Tensor &topk_weights, - const paddle::optional &up_gate_proj_in_scale, - const std::vector &token_nums_per_expert, - const int token_nums_this_rank, - const std::string &moe_quant_type); +std::vector EPMoeExpertDispatch( + const paddle::Tensor& input, + const paddle::Tensor& topk_ids, + const paddle::Tensor& topk_weights, + const paddle::optional& up_gate_proj_in_scale, + const std::vector& token_nums_per_expert, + const int token_nums_this_rank, + const std::string& moe_quant_type); std::vector EPMoeExpertDispatchFP8( - const paddle::Tensor &input, const paddle::Tensor &scale, - const paddle::Tensor &topk_ids, const paddle::Tensor &topk_weights, - const paddle::Tensor &token_nums_per_expert, - const paddle::Tensor &token_nums_per_expert_padded, - const bool use_in_ep, const int token_nums_this_rank_padded); + const paddle::Tensor& input, + const paddle::Tensor& scale, + const paddle::Tensor& topk_ids, + const paddle::Tensor& topk_weights, + const paddle::Tensor& token_nums_per_expert, + const paddle::Tensor& token_nums_per_expert_padded, + const bool use_in_ep, + const int token_nums_this_rank_padded); -std::vector PerTokenQuant(paddle::Tensor &input, +std::vector PerTokenQuant(paddle::Tensor& input, const int block_size); -std::vector PerTokenQuantPadding(paddle::Tensor &input, +std::vector PerTokenQuantPadding(paddle::Tensor& input, const int block_size); -std::vector -MaskedPerTokenQuant(paddle::Tensor &input, paddle::Tensor &recv_expert_count, - const int block_size); +std::vector MaskedPerTokenQuant( + paddle::Tensor& input, + paddle::Tensor& recv_expert_count, + const int block_size); std::vector EPMoeExpertCombine( - const paddle::Tensor &ffn_out, const paddle::Tensor &expert_scales_float, - const paddle::Tensor &permute_indices_per_token, - const paddle::Tensor &top_k_indices, - const paddle::optional &down_proj_bias, - const bool norm_topk_prob, const float routed_scaling_factor); + const paddle::Tensor& ffn_out, + const paddle::Tensor& expert_scales_float, + const paddle::Tensor& permute_indices_per_token, + const paddle::Tensor& top_k_indices, + const paddle::optional& down_proj_bias, + const bool norm_topk_prob, + const float routed_scaling_factor); -std::vector> GetExpertTokenNum(const paddle::Tensor &topk_ids, +std::vector> GetExpertTokenNum(const paddle::Tensor& topk_ids, const int num_experts); paddle::Tensor MoeExpertFFNFunc( @@ -282,158 +332,168 @@ paddle::Tensor MoeExpertFFNWint2Func( const bool used_in_ep_low_latency); paddle::Tensor MoeExpertReduceFunc( - const paddle::Tensor &ffn_out, const paddle::Tensor &top_k_weight, - const paddle::Tensor &permute_indices_per_token, - const paddle::Tensor &top_k_indices, - const paddle::optional &down_proj_bias, - const bool norm_topk_prob, const float routed_scaling_factor); + const paddle::Tensor& ffn_out, + const paddle::Tensor& top_k_weight, + const paddle::Tensor& permute_indices_per_token, + const paddle::Tensor& top_k_indices, + const paddle::optional& down_proj_bias, + const bool norm_topk_prob, + const float routed_scaling_factor); -void InitKVSignalPerQuery(const paddle::Tensor &seq_lens_encoder_tensor, - const paddle::Tensor &seq_lens_this_time_tensor, - const paddle::Tensor &seq_lens_decoder_tensor, - const int rank, const int num_layers); +void InitKVSignalPerQuery(const paddle::Tensor& seq_lens_encoder_tensor, + const paddle::Tensor& seq_lens_this_time_tensor, + const paddle::Tensor& seq_lens_decoder_tensor, + const int rank, + const int num_layers); -void GetOutputKVSignal(const paddle::Tensor &x, int64_t rank_id, +void GetOutputKVSignal(const paddle::Tensor& x, + int64_t rank_id, bool wait_flag); -paddle::Tensor DequantInt8Func(const paddle::Tensor &input, - const paddle::Tensor &out_scale, +paddle::Tensor DequantInt8Func(const paddle::Tensor& input, + const paddle::Tensor& out_scale, std::string dtype); -paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank, const int device_id, +paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank, + const int device_id, const bool keep_pd_step_flag); -paddle::Tensor InitSignalLayerwiseFunc(const paddle::Tensor &kv_signal_metadata, +paddle::Tensor InitSignalLayerwiseFunc(const paddle::Tensor& kv_signal_metadata, const int layer_id); void GetBlockShapeAndSplitKVBlock( - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &seq_lens_this_time, - paddle::Tensor &decoder_batch_ids, // Inplace - paddle::Tensor &decoder_tile_ids_per_batch, // Inplace - paddle::Tensor &decoder_num_blocks_cpu, // Inplace, Pinned Memory - paddle::Tensor &decoder_num_blocks_device, // Inplace - paddle::Tensor &decoder_chunk_size_device, // Inplace - paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory - paddle::Tensor &encoder_batch_ids, // Inplace - paddle::Tensor &encoder_tile_ids_per_batch, // Inplace - paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, Pinned Memory - paddle::Tensor &kv_batch_ids, // Inplace - paddle::Tensor &kv_tile_ids_per_batch, // Inplace - paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, Pinned Memory + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + paddle::Tensor& decoder_batch_ids, // Inplace + paddle::Tensor& decoder_tile_ids_per_batch, // Inplace + paddle::Tensor& decoder_num_blocks_cpu, // Inplace, Pinned Memory + paddle::Tensor& decoder_num_blocks_device, // Inplace + paddle::Tensor& decoder_chunk_size_device, // Inplace + paddle::Tensor& max_len_tensor_cpu, // Inplace, Pinned Memory + paddle::Tensor& encoder_batch_ids, // Inplace + paddle::Tensor& encoder_tile_ids_per_batch, // Inplace + paddle::Tensor& encoder_num_blocks_x_cpu, // Inplace, Pinned Memory + paddle::Tensor& kv_batch_ids, // Inplace + paddle::Tensor& kv_tile_ids_per_batch, // Inplace + paddle::Tensor& kv_num_blocks_x_cpu, // Inplace, Pinned Memory const int encoder_block_shape_q, const int decoder_block_shape_q, const int group_size, const int block_size, const int decoder_step_token_num); -std::vector GetPaddingOffset(const paddle::Tensor &input_ids, - const paddle::Tensor &cum_offsets, - const paddle::Tensor &token_num, - const paddle::Tensor &seq_len); +std::vector GetPaddingOffset(const paddle::Tensor& input_ids, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& token_num, + const paddle::Tensor& seq_len); -void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all, - const paddle::Tensor &input_ids, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &step_idx, - const paddle::Tensor &stop_flags); +void SetValueByFlagsAndIdx(const paddle::Tensor& pre_ids_all, + const paddle::Tensor& input_ids, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& step_idx, + const paddle::Tensor& stop_flags); paddle::Tensor RebuildPaddingFunc( - const paddle::Tensor &tmp_out, // [token_num, dim_embed] - const paddle::Tensor &cum_offsets, // [bsz, 1] - const paddle::Tensor &seq_len_this_time, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &seq_lens_encoder, - const paddle::optional &output_padding_offset, - const paddle::optional &first_token_out, + const paddle::Tensor& tmp_out, // [token_num, dim_embed] + const paddle::Tensor& cum_offsets, // [bsz, 1] + const paddle::Tensor& seq_len_this_time, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_encoder, + const paddle::optional& output_padding_offset, + const paddle::optional& first_token_out, int max_input_length, bool enable_logprob); -void GetStopFlagsMulti(const paddle::Tensor &topk_ids, - const paddle::Tensor &stop_flags, - const paddle::Tensor &seq_lens, - const paddle::Tensor &end_ids, - const paddle::Tensor &next_tokens, - const paddle::Tensor &pre_ids, - const paddle::Tensor &step_idx, - const paddle::Tensor &stop_seqs, - const paddle::Tensor &stop_seqs_len, +void GetStopFlagsMulti(const paddle::Tensor& topk_ids, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens, + const paddle::Tensor& end_ids, + const paddle::Tensor& next_tokens, + const paddle::Tensor& pre_ids, + const paddle::Tensor& step_idx, + const paddle::Tensor& stop_seqs, + const paddle::Tensor& stop_seqs_len, const bool beam_search); +void UpdateInputs(const paddle::Tensor& stop_flags, + const paddle::Tensor& not_need_stop, // only on cpu + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& input_ids, + const paddle::Tensor& stop_nums, + const paddle::Tensor& next_tokens, + const paddle::Tensor& is_block_step); -void UpdateInputes(const paddle::Tensor &stop_flags, - const paddle::Tensor ¬_need_stop, // only on cpu - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &input_ids, - const paddle::Tensor &stop_nums, - const paddle::Tensor &next_tokens, - const paddle::Tensor &is_block_step); +void UpdateInputsV1(const paddle::Tensor& stop_flags, + const paddle::Tensor& not_need_stop, // only on cpu + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& step_seq_lens_decoder, + const paddle::Tensor& prompt_lens, + const paddle::Tensor& topk_ids, + const paddle::Tensor& input_ids, + const paddle::Tensor& block_tables, + const paddle::Tensor& stop_nums, + const paddle::Tensor& next_tokens, + const paddle::Tensor& is_block_step, + const int block_size); -void UpdateInputesV1(const paddle::Tensor &stop_flags, - const paddle::Tensor ¬_need_stop, // only on cpu - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &step_seq_lens_decoder, - const paddle::Tensor &prompt_lens, - const paddle::Tensor &topk_ids, - const paddle::Tensor &input_ids, - const paddle::Tensor &block_tables, - const paddle::Tensor &stop_nums, - const paddle::Tensor &next_tokens, - const paddle::Tensor &is_block_step, - const int block_size); +void RecoverDecodeTask( + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& step_seq_lens_decoder, + const paddle::Tensor& block_tables, + const paddle::Tensor& is_block_step, + const paddle::optional& draft_tokens, + const paddle::optional& step_draft_tokens, + const paddle::optional& step_seq_lens_this_time, + const int block_size, + const int max_draft_tokens); -void RecoverDecodeTask(const paddle::Tensor &stop_flags, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &step_seq_lens_decoder, - const paddle::Tensor &block_tables, - const paddle::Tensor &is_block_step, - const paddle::optional &draft_tokens, - const paddle::optional &step_draft_tokens, - const paddle::optional &step_seq_lens_this_time, - const int block_size, - const int max_draft_tokens); - -paddle::Tensor -GroupSwigluWithMasked(const paddle::Tensor &fc1_out_tensor, - const paddle::Tensor &token_nums_per_expert); +paddle::Tensor GroupSwigluWithMasked( + const paddle::Tensor& fc1_out_tensor, + const paddle::Tensor& token_nums_per_expert); std::vector ExtractTextTokenOutput( - const paddle::Tensor &max_seq_len, const paddle::Tensor &max_seq_len_index, - const paddle::Tensor &mm_token_num_len, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &hidden_states); + const paddle::Tensor& max_seq_len, + const paddle::Tensor& max_seq_len_index, + const paddle::Tensor& mm_token_num_len, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& hidden_states); -std::vector MoEDeepGEMMPermute(const paddle::Tensor &x, - const paddle::Tensor &topk_idx, +std::vector MoEDeepGEMMPermute(const paddle::Tensor& x, + const paddle::Tensor& topk_idx, const int num_experts, const int max_tokens_per_expert); std::vector MoEDeepGEMMDePermute( - const paddle::Tensor - &ffn_out, // [num_experts, max_tokens_per_expert, hidden] - const paddle::Tensor &permute_indices_per_token, // [token_num, topk}] - const paddle::Tensor &topk_idx, const paddle::Tensor &topk_weights); + const paddle::Tensor& + ffn_out, // [num_experts, max_tokens_per_expert, hidden] + const paddle::Tensor& permute_indices_per_token, // [token_num, topk}] + const paddle::Tensor& topk_idx, + const paddle::Tensor& topk_weights); -void TextImageIndexOut(const paddle::Tensor &token_type_ids, - paddle::Tensor &text_input, - paddle::Tensor &image_input); +void TextImageIndexOut(const paddle::Tensor& token_type_ids, + paddle::Tensor& text_input, + paddle::Tensor& image_input); -void TextImageGatherScatter(paddle::Tensor &input, paddle::Tensor &text_input, - paddle::Tensor &image_input, - paddle::Tensor &token_type_ids, - paddle::Tensor &text_index, - paddle::Tensor &image_index, const bool is_scatter); +void TextImageGatherScatter(paddle::Tensor& input, + paddle::Tensor& text_input, + paddle::Tensor& image_input, + paddle::Tensor& token_type_ids, + paddle::Tensor& text_index, + paddle::Tensor& image_index, + const bool is_scatter); -paddle::Tensor count_tokens_per_expert_func(const paddle::Tensor &topk_ids, +paddle::Tensor count_tokens_per_expert_func(const paddle::Tensor& topk_ids, int64_t num_experts); void GetPositionIdsAndMaskEncoderBatch( const paddle::Tensor& seq_lens_encoder, @@ -455,7 +515,7 @@ std::vector DecodeMLAWriteCacheKernel( const int max_seq_len, const bool speculate_decoder); - std::vector PrefillMLAWriteCacheKernel( +std::vector PrefillMLAWriteCacheKernel( const paddle::Tensor& kv_nope, const paddle::Tensor& kv_pe, const paddle::Tensor& kv_cache, @@ -467,7 +527,6 @@ std::vector DecodeMLAWriteCacheKernel( const std::string& cache_quant_type_str, const int max_seq_len); - void FusedRotaryPositionEncoding( paddle::Tensor& query, // [num_tokens, num_heads, head_size] or // [num_tokens, num_heads * head_size] @@ -520,9 +579,10 @@ std::vector MultiHeadLatentAttention( const bool causal, const bool speculate_decoder); - -std::vector tritonmoe_preprocess_kernel(const paddle::Tensor& topk_ids, int64_t num_experts, int64_t GEMM_BLOCK_SIZE_M); - +std::vector tritonmoe_preprocess_kernel( + const paddle::Tensor& topk_ids, + int64_t num_experts, + int64_t GEMM_BLOCK_SIZE_M); std::vector MoeWna16MarlinGemmApi( const paddle::Tensor& a, @@ -550,37 +610,42 @@ std::vector MoeWna16MarlinGemmApi( bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float); -void CutlassScaledMm(paddle::Tensor &c, paddle::Tensor const &a, - paddle::Tensor const &b, paddle::Tensor const &a_scales, - paddle::Tensor const &b_scales, - paddle::optional const &bias); +void CutlassScaledMm(paddle::Tensor& c, + paddle::Tensor const& a, + paddle::Tensor const& b, + paddle::Tensor const& a_scales, + paddle::Tensor const& b_scales, + paddle::optional const& bias); -void CutlassScaledMmAzp(paddle::Tensor& c, paddle::Tensor const& a, - paddle::Tensor const& b, - paddle::Tensor const& a_scales, - paddle::Tensor const& b_scales, - paddle::Tensor const& azp_adj, - paddle::optional const& azp, - paddle::optional const& bias); +void CutlassScaledMmAzp(paddle::Tensor& c, + paddle::Tensor const& a, + paddle::Tensor const& b, + paddle::Tensor const& a_scales, + paddle::Tensor const& b_scales, + paddle::Tensor const& azp_adj, + paddle::optional const& azp, + paddle::optional const& bias); -void StaticScaledFp8Quant(paddle::Tensor &out, paddle::Tensor const &input, - paddle::Tensor const &scale); +void StaticScaledFp8Quant(paddle::Tensor& out, + paddle::Tensor const& input, + paddle::Tensor const& scale); -void DynamicScaledFp8Quant(paddle::Tensor &out, paddle::Tensor const &input, - paddle::Tensor &scale); +void DynamicScaledFp8Quant(paddle::Tensor& out, + paddle::Tensor const& input, + paddle::Tensor& scale); -void DynamicPerTokenScaledFp8Quant(paddle::Tensor &out, - paddle::Tensor const &input, - paddle::Tensor &scales, float scale_ub); +void DynamicPerTokenScaledFp8Quant(paddle::Tensor& out, + paddle::Tensor const& input, + paddle::Tensor& scales, + float scale_ub); -std::vector NoauxTc( - paddle::Tensor& scores, - paddle::Tensor& scores_with_bias, - int n_group, - int topk_group, - int topk, - bool renormalize, - float routed_scaling_factor); +std::vector NoauxTc(paddle::Tensor& scores, + paddle::Tensor& scores_with_bias, + int n_group, + int topk_group, + int topk, + bool renormalize, + float routed_scaling_factor); #ifdef ENABLE_FP8 paddle::Tensor cutlass_fp8_fp8_half_gemm_func( @@ -593,24 +658,27 @@ paddle::Tensor cutlass_fp8_fp8_half_gemm_func( std::string output_dtype, std::string activation_type); -paddle::Tensor MoeFusedHadamardQuantFp8Func( - const paddle::Tensor &input, - const paddle::Tensor &scale, - const paddle::Tensor &topk_ids, - const int top_k, - const int intermediate_size, - const bool tiled); +paddle::Tensor MoeFusedHadamardQuantFp8Func(const paddle::Tensor& input, + const paddle::Tensor& scale, + const paddle::Tensor& topk_ids, + const int top_k, + const int intermediate_size, + const bool tiled); -paddle::Tensor FusedHadamardQuantFp8Func( - const paddle::Tensor &input, - const float scale); +paddle::Tensor FusedHadamardQuantFp8Func(const paddle::Tensor& input, + const float scale); #endif int64_t init_custom_all_reduce(const std::vector& fake_ipc_ptrs, - paddle::Tensor& rank_data, int64_t rank, bool full_nvlink); + paddle::Tensor& rank_data, + int64_t rank, + bool full_nvlink); -void all_reduce(paddle::Tensor& inp, paddle::Tensor& out, int64_t _fa, - int64_t reg_buffer, int64_t reg_buffer_sz_bytes); +void all_reduce(paddle::Tensor& inp, + paddle::Tensor& out, + int64_t _fa, + int64_t reg_buffer, + int64_t reg_buffer_sz_bytes); void dispose(int64_t _fa); @@ -618,7 +686,8 @@ int64_t meta_size(); void register_buffer(int64_t _fa, const std::vector& fake_ipc_ptrs); -std::tuple, std::vector> get_graph_buffer_ipc_meta(int64_t _fa); +std::tuple, std::vector> +get_graph_buffer_ipc_meta(int64_t _fa); void register_graph_buffers(int64_t _fa, const std::vector>& handles, @@ -653,67 +722,74 @@ std::vector SpeculateGetOutputPaddingOffset( const paddle::Tensor& seq_lens_output, const int max_seq_len); +void SpecTokenPenaltyMultiScores(const paddle::Tensor& pre_ids, + const paddle::Tensor& logits, + const paddle::Tensor& penalty_scores, + const paddle::Tensor& frequency_scores, + const paddle::Tensor& presence_scores, + const paddle::Tensor& temperatures, + const paddle::Tensor& bad_tokens, + const paddle::Tensor& cur_len, + const paddle::Tensor& min_len, + const paddle::Tensor& eos_token_id, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& output_padding_offset, + const paddle::Tensor& output_cum_offsets, + const int max_seq_len); -void SpecTokenPenaltyMultiScores(const paddle::Tensor &pre_ids, - const paddle::Tensor &logits, - const paddle::Tensor &penalty_scores, - const paddle::Tensor &frequency_scores, - const paddle::Tensor &presence_scores, - const paddle::Tensor &temperatures, - const paddle::Tensor &bad_tokens, - const paddle::Tensor &cur_len, - const paddle::Tensor &min_len, - const paddle::Tensor &eos_token_id, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &output_padding_offset, - const paddle::Tensor &output_cum_offsets, - const int max_seq_len); +void SpecGetStopFlagsMultiSeqs(const paddle::Tensor& accept_tokens, + const paddle::Tensor& accept_num, + const paddle::Tensor& pre_ids, + const paddle::Tensor& step_idx, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens, + const paddle::Tensor& stop_seqs, + const paddle::Tensor& stop_seqs_len, + const paddle::Tensor& end_ids); -void SpecGetStopFlagsMultiSeqs(const paddle::Tensor &accept_tokens, - const paddle::Tensor &accept_num, - const paddle::Tensor &pre_ids, - const paddle::Tensor &step_idx, - const paddle::Tensor &stop_flags, - const paddle::Tensor &seq_lens, - const paddle::Tensor &stop_seqs, - const paddle::Tensor &stop_seqs_len, - const paddle::Tensor &end_ids); +void SpeculateVerify(const paddle::Tensor& accept_tokens, + const paddle::Tensor& accept_num, + const paddle::Tensor& step_idx, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& draft_tokens, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& verify_tokens, + const paddle::Tensor& verify_scores, + const paddle::Tensor& max_dec_len, + const paddle::Tensor& end_tokens, + const paddle::Tensor& is_block_step, + const paddle::Tensor& output_cum_offsets, + const paddle::Tensor& actual_candidate_len, + const paddle::Tensor& actual_draft_token_nums, + const paddle::Tensor& topp, + int max_seq_len, + int verify_window, + bool enable_topp, + bool benchmark_mode, + bool accept_all_drafts); +void SpeculateUpdate(const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& not_need_stop, + const paddle::Tensor& draft_tokens, + const paddle::Tensor& actual_draft_token_nums, + const paddle::Tensor& accept_tokens, + const paddle::Tensor& accept_num, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& is_block_step, + const paddle::Tensor& stop_nums); -void SpeculateVerify( - const paddle::Tensor &accept_tokens, const paddle::Tensor &accept_num, - const paddle::Tensor &step_idx, const paddle::Tensor &stop_flags, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &draft_tokens, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &verify_tokens, const paddle::Tensor &verify_scores, - const paddle::Tensor &max_dec_len, const paddle::Tensor &end_tokens, - const paddle::Tensor &is_block_step, - const paddle::Tensor &output_cum_offsets, - const paddle::Tensor &actual_candidate_len, - const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp, - int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode, bool accept_all_drafts); - -void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor ¬_need_stop, - const paddle::Tensor &draft_tokens, - const paddle::Tensor &actual_draft_token_nums, - const paddle::Tensor &accept_tokens, - const paddle::Tensor &accept_num, - const paddle::Tensor &stop_flags, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &is_block_step, - const paddle::Tensor &stop_nums); - -void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all, - const paddle::Tensor &accept_tokens, - const paddle::Tensor &accept_num, - const paddle::Tensor &stop_flags, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &step_idx); +void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor& pre_ids_all, + const paddle::Tensor& accept_tokens, + const paddle::Tensor& accept_num, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& step_idx); void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens, const paddle::Tensor& accept_num, @@ -724,55 +800,52 @@ void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens, bool save_each_rank, bool skip_prefill); - void SpeculateClearAcceptNums(const paddle::Tensor& accept_num, const paddle::Tensor& seq_lens_decoder); -void SpeculateScheduleCache(const paddle::Tensor &draft_tokens, - const paddle::Tensor &block_tables, - const paddle::Tensor &stop_flags, - const paddle::Tensor &prompt_lens, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &step_seq_lens_decoder, - const paddle::Tensor &step_draft_tokens, - const paddle::Tensor &step_seq_lens_this_time, - const paddle::Tensor &accept_num, - const paddle::Tensor &accept_tokens, - const paddle::Tensor &is_block_step, - const paddle::Tensor ¬_need_stop, - const paddle::Tensor &stop_nums, +void SpeculateScheduleCache(const paddle::Tensor& draft_tokens, + const paddle::Tensor& block_tables, + const paddle::Tensor& stop_flags, + const paddle::Tensor& prompt_lens, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& step_seq_lens_decoder, + const paddle::Tensor& step_draft_tokens, + const paddle::Tensor& step_seq_lens_this_time, + const paddle::Tensor& accept_num, + const paddle::Tensor& accept_tokens, + const paddle::Tensor& is_block_step, + const paddle::Tensor& not_need_stop, + const paddle::Tensor& stop_nums, const int block_size, const int max_draft_tokens); -void NgramMatch(const paddle::Tensor &input_ids, - const paddle::Tensor &input_ids_len, - const paddle::Tensor &pre_ids, - const paddle::Tensor &step_idx, - const paddle::Tensor &draft_token_num, - const paddle::Tensor &draft_tokens, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &max_dec_len, - const int max_ngram_size, - const int max_draft_tokens); - - -void HybridMtpNgram(const paddle::Tensor &input_ids, - const paddle::Tensor &input_ids_len, - const paddle::Tensor &pre_ids, - const paddle::Tensor &step_idx, - const paddle::Tensor &draft_token_num, - const paddle::Tensor &draft_tokens, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &max_dec_len, - const int max_ngram_size, - const int min_ngram_size, - const int max_draft_tokens); +void NgramMatch(const paddle::Tensor& input_ids, + const paddle::Tensor& input_ids_len, + const paddle::Tensor& pre_ids, + const paddle::Tensor& step_idx, + const paddle::Tensor& draft_token_num, + const paddle::Tensor& draft_tokens, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& max_dec_len, + const int max_ngram_size, + const int max_draft_tokens); +void HybridMtpNgram(const paddle::Tensor& input_ids, + const paddle::Tensor& input_ids_len, + const paddle::Tensor& pre_ids, + const paddle::Tensor& step_idx, + const paddle::Tensor& draft_token_num, + const paddle::Tensor& draft_tokens, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& max_dec_len, + const int max_ngram_size, + const int min_ngram_size, + const int max_draft_tokens); // MTP void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens, @@ -780,7 +853,6 @@ void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens, const paddle::Tensor& base_model_seq_lens_encoder, const paddle::Tensor& base_model_stop_flags); - void DraftModelPreprocess(const paddle::Tensor& draft_tokens, const paddle::Tensor& input_ids, const paddle::Tensor& stop_flags, @@ -806,7 +878,6 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens, const bool splitwise_prefill, const bool kvcache_scheduler_v1); - void DraftModelUpdate(const paddle::Tensor& inter_next_tokens, const paddle::Tensor& draft_tokens, const paddle::Tensor& pre_ids, @@ -823,105 +894,103 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens, const int max_seq_len, const int substep); - - std::vector EagleGetHiddenStates( - const paddle::Tensor& input, - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& stop_flags, - const paddle::Tensor& accept_nums, - const paddle::Tensor& base_model_seq_lens_this_time, - const paddle::Tensor& base_model_seq_lens_encoder, - const int actual_draft_token_num); + const paddle::Tensor& input, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& stop_flags, + const paddle::Tensor& accept_nums, + const paddle::Tensor& base_model_seq_lens_this_time, + const paddle::Tensor& base_model_seq_lens_encoder, + const int actual_draft_token_num); std::vector EagleGetSelfHiddenStates( - const paddle::Tensor& input, - const paddle::Tensor& last_seq_lens_this_time, - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& step_idx); + const paddle::Tensor& input, + const paddle::Tensor& last_seq_lens_this_time, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& step_idx); void MTPStepPaddle( - const paddle::Tensor &base_model_stop_flags, - const paddle::Tensor &stop_flags, - const paddle::Tensor &batch_drop, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &block_tables, // [bsz, block_num_per_seq] - const paddle::Tensor &encoder_block_lens, - const paddle::Tensor &used_list_len, - const paddle::Tensor &free_list, - const paddle::Tensor &free_list_len, + const paddle::Tensor& base_model_stop_flags, + const paddle::Tensor& stop_flags, + const paddle::Tensor& batch_drop, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& block_tables, // [bsz, block_num_per_seq] + const paddle::Tensor& encoder_block_lens, + const paddle::Tensor& used_list_len, + const paddle::Tensor& free_list, + const paddle::Tensor& free_list_len, const int block_size, const int max_draft_tokens); void SpeculateStepPaddle( - const paddle::Tensor &stop_flags, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &ori_seq_lens_encoder, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &block_tables, // [bsz, block_num_per_seq] - const paddle::Tensor &encoder_block_lens, - const paddle::Tensor &is_block_step, - const paddle::Tensor &step_block_list, - const paddle::Tensor &step_lens, - const paddle::Tensor &recover_block_list, - const paddle::Tensor &recover_lens, - const paddle::Tensor &need_block_list, - const paddle::Tensor &need_block_len, - const paddle::Tensor &used_list_len, - const paddle::Tensor &free_list, - const paddle::Tensor &free_list_len, - const paddle::Tensor &input_ids, - const paddle::Tensor &pre_ids, - const paddle::Tensor &step_idx, - const paddle::Tensor &next_tokens, - const paddle::Tensor &first_token_ids, - const paddle::Tensor &accept_num, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& ori_seq_lens_encoder, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& block_tables, // [bsz, block_num_per_seq] + const paddle::Tensor& encoder_block_lens, + const paddle::Tensor& is_block_step, + const paddle::Tensor& step_block_list, + const paddle::Tensor& step_lens, + const paddle::Tensor& recover_block_list, + const paddle::Tensor& recover_lens, + const paddle::Tensor& need_block_list, + const paddle::Tensor& need_block_len, + const paddle::Tensor& used_list_len, + const paddle::Tensor& free_list, + const paddle::Tensor& free_list_len, + const paddle::Tensor& input_ids, + const paddle::Tensor& pre_ids, + const paddle::Tensor& step_idx, + const paddle::Tensor& next_tokens, + const paddle::Tensor& first_token_ids, + const paddle::Tensor& accept_num, const int block_size, const int encoder_decoder_block_num, const int max_draft_tokens); -void MergePrefillDecodeOutput( - const paddle::Tensor &encoder_res, - const paddle::Tensor &decoder_res, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &cu_seq_q, - const int head_num, - const int head_dim, - const int max_token); +void MergePrefillDecodeOutput(const paddle::Tensor& encoder_res, + const paddle::Tensor& decoder_res, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& cu_seq_q, + const int head_num, + const int head_dim, + const int max_token); -std::vector TopPSamplingReject(const paddle::Tensor &probs, - const paddle::Tensor &top_p, - const paddle::optional &top_k, - int64_t seed); +std::vector TopPSamplingReject( + const paddle::Tensor& probs, + const paddle::Tensor& top_p, + const paddle::optional& top_k, + int64_t seed); -std::vector TopKRenorm(const paddle::Tensor &probs, - const paddle::Tensor &top_k); +std::vector TopKRenorm(const paddle::Tensor& probs, + const paddle::Tensor& top_k); -std::vector MinPSamplingFromProbs(const paddle::Tensor &probs, - const paddle::Tensor &min_p); +std::vector MinPSamplingFromProbs(const paddle::Tensor& probs, + const paddle::Tensor& min_p); void SaveOutMmsgStatic(const paddle::Tensor& x, const paddle::Tensor& not_need_stop, int64_t rank_id, bool save_each_rank); -void LimitThinkingContentLengthV1(const paddle::Tensor &next_tokens, - const paddle::Tensor &max_think_lens, - const paddle::Tensor &step_idx, - const paddle::Tensor &limit_think_status, - const int64_t think_end_id); +void LimitThinkingContentLengthV1(const paddle::Tensor& next_tokens, + const paddle::Tensor& max_think_lens, + const paddle::Tensor& step_idx, + const paddle::Tensor& limit_think_status, + const int64_t think_end_id); -void LimitThinkingContentLengthV2(const paddle::Tensor &next_tokens, - const paddle::Tensor &max_think_lens, - const paddle::Tensor &step_idx, - const paddle::Tensor &limit_think_status, +void LimitThinkingContentLengthV2(const paddle::Tensor& next_tokens, + const paddle::Tensor& max_think_lens, + const paddle::Tensor& step_idx, + const paddle::Tensor& limit_think_status, const int64_t think_end_id, const int64_t line_break_id); @@ -944,46 +1013,52 @@ void SpeculateLimitThinkingContentLengthV2( const int64_t think_end_id, const int64_t line_break_id); -void SpeculateGetLogits(const paddle::Tensor &draft_logits, - const paddle::Tensor &next_token_num, - const paddle::Tensor &batch_token_num, - const paddle::Tensor &cu_next_token_offset, - const paddle::Tensor &cu_batch_token_offset, - const paddle::Tensor &logits, - const paddle::Tensor &first_token_logits, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_encoder); +void SpeculateGetLogits(const paddle::Tensor& draft_logits, + const paddle::Tensor& next_token_num, + const paddle::Tensor& batch_token_num, + const paddle::Tensor& cu_next_token_offset, + const paddle::Tensor& cu_batch_token_offset, + const paddle::Tensor& logits, + const paddle::Tensor& first_token_logits, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder); -void SpeculateInsertFirstToken(const paddle::Tensor &token_ids, - const paddle::Tensor &accept_tokens, - const paddle::Tensor &next_tokens, - const paddle::Tensor &cu_next_token_offset, - const paddle::Tensor &cu_batch_token_offset, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_encoder); +void SpeculateInsertFirstToken(const paddle::Tensor& token_ids, + const paddle::Tensor& accept_tokens, + const paddle::Tensor& next_tokens, + const paddle::Tensor& cu_next_token_offset, + const paddle::Tensor& cu_batch_token_offset, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder); -void SpeculateGetTargetLogits(const paddle::Tensor &target_logits, - const paddle::Tensor &logits, - const paddle::Tensor &cu_batch_token_offset, - const paddle::Tensor &ori_cu_batch_token_offset, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &accept_num); +void SpeculateGetTargetLogits(const paddle::Tensor& target_logits, + const paddle::Tensor& logits, + const paddle::Tensor& cu_batch_token_offset, + const paddle::Tensor& ori_cu_batch_token_offset, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& accept_num); PYBIND11_MODULE(fastdeploy_ops, m) { - - m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"), - py::arg("num_experts"), "get expert token num"); + m.def("get_expert_token_num", + &GetExpertTokenNum, + py::arg("topk_ids"), + py::arg("num_experts"), + "get expert token num"); /** * moe/fused_moe/moe_redundant_topk_select.cu * moe_redundant_topk_select */ - m.def("moe_redundant_topk_select", &MoERedundantTopKSelectKernel, - py::arg("gating_logits"), py::arg("expert_id_to_ep_rank_array"), + m.def("moe_redundant_topk_select", + &MoERedundantTopKSelectKernel, + py::arg("gating_logits"), + py::arg("expert_id_to_ep_rank_array"), py::arg("expert_in_rank_num_list"), - py::arg("tokens_per_expert_stats_list"), py::arg("bias"), - py::arg("moe_topk"), py::arg("apply_norm_weight"), + py::arg("tokens_per_expert_stats_list"), + py::arg("bias"), + py::arg("moe_topk"), + py::arg("apply_norm_weight"), py::arg("enable_softmax_top_k_fused"), py::arg("redundant_ep_rank_num_plus_one"), "moe export RedundantTopKSelect function"); @@ -992,49 +1067,62 @@ PYBIND11_MODULE(fastdeploy_ops, m) { * open_shm_and_get_meta_signal.cc * InitKVSignalPerQuery */ - m.def("init_kv_signal_per_query", &InitKVSignalPerQuery, + m.def("init_kv_signal_per_query", + &InitKVSignalPerQuery, py::arg("seq_lens_encoder_tensor"), py::arg("seq_lens_this_time_tensor"), - py::arg("seq_lens_decoder_tensor"), py::arg("rank"), - py::arg("num_layers"), "init_kv_signal_per_query function"); + py::arg("seq_lens_decoder_tensor"), + py::arg("rank"), + py::arg("num_layers"), + "init_kv_signal_per_query function"); /** * GetOutputKVSignal */ - m.def("get_output_kv_signal", &GetOutputKVSignal, py::arg("x"), - py::arg("rank_id"), py::arg("wait_flag"), + m.def("get_output_kv_signal", + &GetOutputKVSignal, + py::arg("x"), + py::arg("rank_id"), + py::arg("wait_flag"), "get_output_kv_signal function"); m.def("moe_deepgemm_permute", &MoEDeepGEMMPermute, "MoEDeepGEMMPermute"); - m.def("moe_deepgemm_depermute", &MoEDeepGEMMDePermute, - "MoEDeepGEMMDePermute"); + m.def( + "moe_deepgemm_depermute", &MoEDeepGEMMDePermute, "MoEDeepGEMMDePermute"); /** * alloc_cache_pinned.cc * cuda_host_alloc * cuda_host_free */ - m.def("cuda_host_alloc", &cuda_host_alloc, "Allocate pinned memory", - py::arg("size"), py::arg("flags") = cudaHostAllocDefault); - m.def("cuda_host_free", &cuda_host_free, "Free pinned memory", - py::arg("ptr")); + m.def("cuda_host_alloc", + &cuda_host_alloc, + "Allocate pinned memory", + py::arg("size"), + py::arg("flags") = cudaHostAllocDefault); + m.def( + "cuda_host_free", &cuda_host_free, "Free pinned memory", py::arg("ptr")); py::register_exception(m, "CudaError"); /** * append_attention.cu * append_attention */ m.def("append_attention", &AppendAttention, "append attention function"); - m.def("append_attention_with_output", &AppendAttentionWithOutput, "append attention with output function"); + m.def("append_attention_with_output", + &AppendAttentionWithOutput, + "append attention with output function"); /** * gqa_rope_write_cache.cu * gqa_rope_write_cache */ - m.def("gqa_rope_write_cache", &GQARopeWriteCacheKernel, + m.def("gqa_rope_write_cache", + &GQARopeWriteCacheKernel, "gqa rope write cache function"); /** * pre_cache_len_concat.cu * pre_cache_len_concat */ - m.def("pre_cache_len_concat", &PreCacheLenConcat, + m.def("pre_cache_len_concat", + &PreCacheLenConcat, "pre_cache len concat function"); /** * moe/fused_moe/fused_moe.cu @@ -1052,66 +1140,108 @@ PYBIND11_MODULE(fastdeploy_ops, m) { * moe/fused_moe/moe_dispatch.cu * moe_expert_dispatch */ - m.def("moe_expert_dispatch", &MoeExpertDispatch, py::arg("input"), - py::arg("gating_output"), py::arg("gating_correction_bias"), - py::arg("w4a8_in_scale"), py::arg("moe_topk"), py::arg("group_moe"), - py::arg("moe_quant_type"), py::arg("topk_only_mode"), "moe export dispatch function"); + m.def("moe_expert_dispatch", + &MoeExpertDispatch, + py::arg("input"), + py::arg("gating_output"), + py::arg("gating_correction_bias"), + py::arg("w4a8_in_scale"), + py::arg("moe_topk"), + py::arg("group_moe"), + py::arg("moe_quant_type"), + py::arg("topk_only_mode"), + "moe export dispatch function"); /** * moe/fused_moe/ep_moe_prefill_func.cu * ep_moe_dispatch */ - m.def("ep_moe_expert_dispatch", &EPMoeExpertDispatch, py::arg("input"), - py::arg("topk_ids"), py::arg("topk_weights"), py::arg("up_gate_proj_in_scale"), - py::arg("token_nums_per_expert"), py::arg("token_nums_this_rank"), - py::arg("moe_quant_type"), "ep moe export dispatch function"); + m.def("ep_moe_expert_dispatch", + &EPMoeExpertDispatch, + py::arg("input"), + py::arg("topk_ids"), + py::arg("topk_weights"), + py::arg("up_gate_proj_in_scale"), + py::arg("token_nums_per_expert"), + py::arg("token_nums_this_rank"), + py::arg("moe_quant_type"), + "ep moe export dispatch function"); m.def("ep_moe_expert_dispatch_fp8", &EPMoeExpertDispatchFP8); - m.def("ep_moe_expert_combine", &EPMoeExpertCombine, py::arg("ffn_out"), - py::arg("expert_scales_float"), py::arg("permute_indices_per_token"), - py::arg("top_k_indices"), py::arg("down_proj_bias"), - py::arg("norm_topk_prob"), py::arg("routed_scaling_factor"), + m.def("ep_moe_expert_combine", + &EPMoeExpertCombine, + py::arg("ffn_out"), + py::arg("expert_scales_float"), + py::arg("permute_indices_per_token"), + py::arg("top_k_indices"), + py::arg("down_proj_bias"), + py::arg("norm_topk_prob"), + py::arg("routed_scaling_factor"), "ep moe export combine function"); - m.def("per_token_quant", &PerTokenQuant, py::arg("input"), - py::arg("block_size"), "per token per block quant"); + m.def("per_token_quant", + &PerTokenQuant, + py::arg("input"), + py::arg("block_size"), + "per token per block quant"); - m.def("per_token_quant_padding", &PerTokenQuantPadding, py::arg("input"), + m.def("per_token_quant_padding", + &PerTokenQuantPadding, + py::arg("input"), py::arg("block_size"), "per token per block quant and padding transpose scale"); - m.def("masked_per_token_quant", &MaskedPerTokenQuant, py::arg("input"), - py::arg("recv_expert_count"), py::arg("block_size"), + m.def("masked_per_token_quant", + &MaskedPerTokenQuant, + py::arg("input"), + py::arg("recv_expert_count"), + py::arg("block_size"), "per token per block quant"); #ifdef ENABLE_MACHETE /*machete/machete_mm.cu * machete_mm */ - m.def("machete_mm", &MacheteMMKernel, py::arg("A"), py::arg("B"), py::arg("maybe_group_scale"), - py::arg("maybe_group_zeros"), py::arg("maybe_channel_scales"), py::arg("maybe_token_scales"), - py::arg("b_type_str"), py::arg("maybe_out_type_str"), py::arg("maybe_group_size"), + m.def("machete_mm", + &MacheteMMKernel, + py::arg("A"), + py::arg("B"), + py::arg("maybe_group_scale"), + py::arg("maybe_group_zeros"), + py::arg("maybe_channel_scales"), + py::arg("maybe_token_scales"), + py::arg("b_type_str"), + py::arg("maybe_out_type_str"), + py::arg("maybe_group_size"), py::arg("maybe_schedule"), "machete mm function"); /*machete/machete_prepack_B.cu * machete_prepack_B */ - m.def("machete_prepack_B", &MachetePrepackBKernel, "machete prepacked B function"); + m.def("machete_prepack_B", + &MachetePrepackBKernel, + "machete prepacked B function"); /*machete/machete_supported_schedules.cu * machete_supported_schedules */ - m.def("machete_supported_schedules", &MacheteSupportedSchedules, "machete supported schedules function"); + m.def("machete_supported_schedules", + &MacheteSupportedSchedules, + "machete supported schedules function"); #endif /** * moe/fused_moe/moe_topk_select.cu * moe_topk_select */ - m.def("moe_topk_select", &MoETopKSelectKernel, py::arg("gating_logits"), - py::arg("bias"), py::arg("moe_topk"), py::arg("apply_norm_weight"), + m.def("moe_topk_select", + &MoETopKSelectKernel, + py::arg("gating_logits"), + py::arg("bias"), + py::arg("moe_topk"), + py::arg("apply_norm_weight"), py::arg("enable_softmax_top_k_fused"), "moe export TopKSelect function"); @@ -1125,16 +1255,23 @@ PYBIND11_MODULE(fastdeploy_ops, m) { * moe/fused_moe/moe_expert_ffn_wint2.cu * moe_expert_ffn_wint2 */ - m.def("moe_expert_ffn_wint2", &MoeExpertFFNWint2Func, "moe export ffn wint2 function"); + m.def("moe_expert_ffn_wint2", + &MoeExpertFFNWint2Func, + "moe export ffn wint2 function"); /** * moe/fused_moe/moe_expert_reduce.cu * moe_expert_reduce */ - m.def("moe_expert_reduce", &MoeExpertReduceFunc, py::arg("ffn_out"), - py::arg("top_k_weight"), py::arg("permute_indices_per_token"), - py::arg("top_k_indices"), py::arg("down_proj_bias"), - py::arg("norm_topk_prob"), py::arg("routed_scaling_factor"), + m.def("moe_expert_reduce", + &MoeExpertReduceFunc, + py::arg("ffn_out"), + py::arg("top_k_weight"), + py::arg("permute_indices_per_token"), + py::arg("top_k_indices"), + py::arg("down_proj_bias"), + py::arg("norm_topk_prob"), + py::arg("routed_scaling_factor"), "moe export reduce function"); /** @@ -1147,14 +1284,16 @@ PYBIND11_MODULE(fastdeploy_ops, m) { * init_signal_layerwise.cc * init_signal_layerwise */ - m.def("init_signal_layerwise", &InitSignalLayerwiseFunc, + m.def("init_signal_layerwise", + &InitSignalLayerwiseFunc, "init_signal_layerwise function"); /** * open_shm_and_get_meta_signal.cc * open_shm_and_get_meta_signal */ - m.def("open_shm_and_get_meta_signal", &OpenShmAndGetMetaSignalFunc, + m.def("open_shm_and_get_meta_signal", + &OpenShmAndGetMetaSignalFunc, "open_shm_and_get_meta_signal function"); /** @@ -1162,7 +1301,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) { * get_block_shape_and_split_kv_block */ m.def("get_block_shape_and_split_kv_block", - &GetBlockShapeAndSplitKVBlock, "get_block_shape_and_split_kv_block function"); + &GetBlockShapeAndSplitKVBlock, + "get_block_shape_and_split_kv_block function"); /** * get_padding_offset.cu @@ -1174,7 +1314,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) { * get_padding_offset.cu * get_padding_offset */ - m.def("set_value_by_flags_and_idx", &SetValueByFlagsAndIdx, + m.def("set_value_by_flags_and_idx", + &SetValueByFlagsAndIdx, "SetValueByFlagsAndIdx"); /** @@ -1187,50 +1328,77 @@ PYBIND11_MODULE(fastdeploy_ops, m) { * stop_generation_multi_ends.cu * set_stop_value_multi_ends */ - m.def("set_stop_value_multi_ends", &GetStopFlagsMulti, + m.def("set_stop_value_multi_ends", + &GetStopFlagsMulti, "update_inputs function"); - /** * update_inputs.cu * update_inputs */ - m.def("update_inputs", &UpdateInputes, "update_inputs function"); + m.def("update_inputs", &UpdateInputs, "update_inputs function"); - /** + /** * update_inputs_v1.cu * update_inputs_v1 */ - m.def("update_inputs_v1", &UpdateInputesV1, "update inputs for scheduler v1 function"); + m.def("update_inputs_v1", + &UpdateInputsV1, + "update inputs for scheduler v1 function"); - /** + /** * recover_decode_task.cu * recover_decode_task */ - m.def("recover_decode_task", &RecoverDecodeTask, "recover decode task for scheduler v1 function"); + m.def("recover_decode_task", + &RecoverDecodeTask, + "recover decode task for scheduler v1 function"); - m.def("group_swiglu_with_masked", &GroupSwigluWithMasked, + m.def("group_swiglu_with_masked", + &GroupSwigluWithMasked, "group_swiglu_with_masked function"); - m.def("text_image_index_out", &TextImageIndexOut, + m.def("text_image_index_out", + &TextImageIndexOut, "text_image_index_out function"); - m.def("text_image_gather_scatter", &TextImageGatherScatter, + m.def("text_image_gather_scatter", + &TextImageGatherScatter, "text_image_gather_scatter function"); m.def("count_tokens_per_expert_func", &count_tokens_per_expert_func); m.def("tritonmoe_preprocess_func", &tritonmoe_preprocess_kernel); - m.def("MoeWna16MarlinGemmApi", &MoeWna16MarlinGemmApi, - py::arg("a"), py::arg("c_or_none"), py::arg("b_q_weight"), - py::arg("b_scales"), py::arg("global_scale_or_none"), py::arg("b_zeros_or_none"), - py::arg("g_idx_or_none"), py::arg("perm_or_none"), py::arg("workspace"), py::arg("sorted_token_ids"), - py::arg("expert_ids"), py::arg("num_tokens_post_padded"), py::arg("topk_weights"), py::arg("moe_block_size"), - py::arg("top_k"), py::arg("mul_topk_weights"), py::arg("is_ep"), py::arg("b_q_type_str"), - py::arg("size_m"), py::arg("size_n"), py::arg("size_k"), py::arg("is_k_full"), py::arg("use_atomic_add"), - py::arg("use_fp32_reduce"), py::arg("is_zp_float")); + m.def("MoeWna16MarlinGemmApi", + &MoeWna16MarlinGemmApi, + py::arg("a"), + py::arg("c_or_none"), + py::arg("b_q_weight"), + py::arg("b_scales"), + py::arg("global_scale_or_none"), + py::arg("b_zeros_or_none"), + py::arg("g_idx_or_none"), + py::arg("perm_or_none"), + py::arg("workspace"), + py::arg("sorted_token_ids"), + py::arg("expert_ids"), + py::arg("num_tokens_post_padded"), + py::arg("topk_weights"), + py::arg("moe_block_size"), + py::arg("top_k"), + py::arg("mul_topk_weights"), + py::arg("is_ep"), + py::arg("b_q_type_str"), + py::arg("size_m"), + py::arg("size_n"), + py::arg("size_k"), + py::arg("is_k_full"), + py::arg("use_atomic_add"), + py::arg("use_fp32_reduce"), + py::arg("is_zp_float")); - m.def("get_position_ids_and_mask_encoder_batch", &GetPositionIdsAndMaskEncoderBatch, + m.def("get_position_ids_and_mask_encoder_batch", + &GetPositionIdsAndMaskEncoderBatch, "get_position_ids_and_mask_encoder_batch function"); /** @@ -1239,7 +1407,9 @@ PYBIND11_MODULE(fastdeploy_ops, m) { * cutlass_scaled_mm_azp */ m.def("cutlass_scaled_mm", &CutlassScaledMm, "cutlass_scaled_mm function"); - m.def("cutlass_scaled_mm_azp", &CutlassScaledMmAzp, "cutlass_scaled_mm_azp function"); + m.def("cutlass_scaled_mm_azp", + &CutlassScaledMmAzp, + "cutlass_scaled_mm_azp function"); /** * quantization/common.cu @@ -1247,39 +1417,76 @@ PYBIND11_MODULE(fastdeploy_ops, m) { * dynamic_scaled_fp8_quant * dynamic_per_token_scaled_fp8_quant */ - m.def("static_scaled_fp8_quant", &StaticScaledFp8Quant, "static_scaled_fp8_quant function", - py::arg("out"), py::arg("input"), py::arg("scale")); + m.def("static_scaled_fp8_quant", + &StaticScaledFp8Quant, + "static_scaled_fp8_quant function", + py::arg("out"), + py::arg("input"), + py::arg("scale")); - m.def("dynamic_scaled_fp8_quant", &DynamicScaledFp8Quant, + m.def("dynamic_scaled_fp8_quant", + &DynamicScaledFp8Quant, "dynamic_scaled_fp8_quant function", - py::arg("out"), py::arg("input"), py::arg("scale")); + py::arg("out"), + py::arg("input"), + py::arg("scale")); - m.def("dynamic_per_token_scaled_fp8_quant", &DynamicPerTokenScaledFp8Quant, + m.def("dynamic_per_token_scaled_fp8_quant", + &DynamicPerTokenScaledFp8Quant, "dynamic_per_token_scaled_fp8_quant function", - py::arg("out"), py::arg("input"), py::arg("scales"), py::arg("scale_ub")); - m.def("decode_mla_write_cache", &DecodeMLAWriteCacheKernel, "decode_mla_write_cache function"); + py::arg("out"), + py::arg("input"), + py::arg("scales"), + py::arg("scale_ub")); + m.def("decode_mla_write_cache", + &DecodeMLAWriteCacheKernel, + "decode_mla_write_cache function"); - m.def("prefill_mla_write_cache", &PrefillMLAWriteCacheKernel, "prefill_mla_write_cache function"); + m.def("prefill_mla_write_cache", + &PrefillMLAWriteCacheKernel, + "prefill_mla_write_cache function"); - m.def("fused_rotary_position_encoding", &FusedRotaryPositionEncoding, "fused_rotary_position_encoding function"); + m.def("fused_rotary_position_encoding", + &FusedRotaryPositionEncoding, + "fused_rotary_position_encoding function"); - m.def("multi_head_latent_attention", &MultiHeadLatentAttention, "multi_head_latent_attention function"); + m.def("multi_head_latent_attention", + &MultiHeadLatentAttention, + "multi_head_latent_attention function"); - m.def("noaux_tc",&NoauxTc, "noaux_tc for Deepseekv3 MoE compute"); + m.def("noaux_tc", &NoauxTc, "noaux_tc for Deepseekv3 MoE compute"); #ifdef ENABLE_FP8 - m.def("cutlass_fp8_fp8_half_gemm_fused", &cutlass_fp8_fp8_half_gemm_func, - py::arg("x"), py::arg("y"), py::arg("bias"), py::arg("transpose_x"), - py::arg("transpose_y"), py::arg("scale"), py::arg("output_dtype"), - py::arg("activation_type"), "cutlass_fp8_fp8_half_gemm_fused function"); - m.def("moe_fused_hadamard_quant_fp8", &MoeFusedHadamardQuantFp8Func, - py::arg("input"), py::arg("scale"), py::arg("topk_ids"), - py::arg("top_k"), py::arg("intermediate_size"), py::arg("tiled"), "moe_fused_hadamard_quant_fp8 function"); - m.def("fused_hadamard_quant_fp8", &FusedHadamardQuantFp8Func, - py::arg("input"), py::arg("scale"), "fused_hadamard_quant_fp8 function"); + m.def("cutlass_fp8_fp8_half_gemm_fused", + &cutlass_fp8_fp8_half_gemm_func, + py::arg("x"), + py::arg("y"), + py::arg("bias"), + py::arg("transpose_x"), + py::arg("transpose_y"), + py::arg("scale"), + py::arg("output_dtype"), + py::arg("activation_type"), + "cutlass_fp8_fp8_half_gemm_fused function"); + m.def("moe_fused_hadamard_quant_fp8", + &MoeFusedHadamardQuantFp8Func, + py::arg("input"), + py::arg("scale"), + py::arg("topk_ids"), + py::arg("top_k"), + py::arg("intermediate_size"), + py::arg("tiled"), + "moe_fused_hadamard_quant_fp8 function"); + m.def("fused_hadamard_quant_fp8", + &FusedHadamardQuantFp8Func, + py::arg("input"), + py::arg("scale"), + "fused_hadamard_quant_fp8 function"); #endif - m.def("init_custom_all_reduce", &init_custom_all_reduce, "init all reduce class function"); + m.def("init_custom_all_reduce", + &init_custom_all_reduce, + "init all reduce class function"); m.def("all_reduce", &all_reduce, "all reduce function"); @@ -1289,9 +1496,13 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("register_buffer", ®ister_buffer, "register ipc buffer"); - m.def("register_graph_buffers", ®ister_graph_buffers, "register_graph_buffers"); + m.def("register_graph_buffers", + ®ister_graph_buffers, + "register_graph_buffers"); - m.def("allocate_shared_buffer_and_handle", &allocate_shared_buffer_and_handle, "allocate_shared_buffer_and_handle"); + m.def("allocate_shared_buffer_and_handle", + &allocate_shared_buffer_and_handle, + "allocate_shared_buffer_and_handle"); m.def("free_shared_buffer", &free_shared_buffer, "free_shared_buffer"); @@ -1299,52 +1510,86 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("open_mem_handle", &open_mem_handle, "open_mem_handle"); - m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "get_graph_buffer_ipc_meta"); + m.def("get_graph_buffer_ipc_meta", + &get_graph_buffer_ipc_meta, + "get_graph_buffer_ipc_meta"); // speculative decoding Kernel - m.def("speculate_get_padding_offset", &SpeculateGetPaddingOffset, "speculate_get_padding_offset function"); + m.def("speculate_get_padding_offset", + &SpeculateGetPaddingOffset, + "speculate_get_padding_offset function"); - m.def("speculate_get_seq_lens_output", &SpeculateGetSeqLensOutput, "speculate_get_seq_lens_output function"); + m.def("speculate_get_seq_lens_output", + &SpeculateGetSeqLensOutput, + "speculate_get_seq_lens_output function"); - m.def("speculate_get_output_padding_offset",&SpeculateGetOutputPaddingOffset, "speculate_get_output_padding_offset function"); + m.def("speculate_get_output_padding_offset", + &SpeculateGetOutputPaddingOffset, + "speculate_get_output_padding_offset function"); - m.def("speculate_get_token_penalty_multi_scores",&SpecTokenPenaltyMultiScores, "speculate_get_token_penalty_multi_scores function"); + m.def("speculate_get_token_penalty_multi_scores", + &SpecTokenPenaltyMultiScores, + "speculate_get_token_penalty_multi_scores function"); - m.def("speculate_set_stop_value_multi_seqs",&SpecGetStopFlagsMultiSeqs, "speculate_set_stop_value_multi_seqs function"); + m.def("speculate_set_stop_value_multi_seqs", + &SpecGetStopFlagsMultiSeqs, + "speculate_set_stop_value_multi_seqs function"); - m.def("speculate_verify",&SpeculateVerify, "speculate_verify function"); + m.def("speculate_verify", &SpeculateVerify, "speculate_verify function"); - m.def("speculate_update",&SpeculateUpdate, "Speculate Update Kernel"); + m.def("speculate_update", &SpeculateUpdate, "Speculate Update Kernel"); - m.def("speculate_set_value_by_flags_and_idx",&SpeculateSetValueByFlagsAndIdx, "speculate_set_value_by_flags_and_idx function"); + m.def("speculate_set_value_by_flags_and_idx", + &SpeculateSetValueByFlagsAndIdx, + "speculate_set_value_by_flags_and_idx function"); - m.def("speculate_save_output", &SpeculateSaveWithOutputMsgStatic, "speculate_save_output function"); + m.def("speculate_save_output", + &SpeculateSaveWithOutputMsgStatic, + "speculate_save_output function"); - m.def("speculate_clear_accept_nums",&SpeculateClearAcceptNums, "speculate_clear_accept_nums function"); + m.def("speculate_clear_accept_nums", + &SpeculateClearAcceptNums, + "speculate_clear_accept_nums function"); - m.def("speculate_schedule_cache",&SpeculateScheduleCache, "SpeculateScheduleCache function"); + m.def("speculate_schedule_cache", + &SpeculateScheduleCache, + "SpeculateScheduleCache function"); m.def("ngram_match", &NgramMatch, "ngram_match function"); m.def("hybird_mtp_ngram", &HybridMtpNgram, "ngram_match_mixed function"); - m.def("draft_model_postprocess",&DraftModelPostprocess, "draft_model_postprocess function"); + m.def("draft_model_postprocess", + &DraftModelPostprocess, + "draft_model_postprocess function"); - m.def("draft_model_preprocess",&DraftModelPreprocess, "draft_model_preprocess function"); + m.def("draft_model_preprocess", + &DraftModelPreprocess, + "draft_model_preprocess function"); - m.def("draft_model_update",&DraftModelUpdate, "draft_model_update function"); + m.def("draft_model_update", &DraftModelUpdate, "draft_model_update function"); - m.def("eagle_get_hidden_states",&EagleGetHiddenStates, "eagle_get_hidden_states function"); + m.def("eagle_get_hidden_states", + &EagleGetHiddenStates, + "eagle_get_hidden_states function"); - m.def("eagle_get_self_hidden_states", &EagleGetSelfHiddenStates, "eagle_get_self_hidden_states function"); + m.def("eagle_get_self_hidden_states", + &EagleGetSelfHiddenStates, + "eagle_get_self_hidden_states function"); - m.def("mtp_step_paddle",&MTPStepPaddle, "mtp_step_paddle function"); + m.def("mtp_step_paddle", &MTPStepPaddle, "mtp_step_paddle function"); - m.def("speculate_step_paddle",&SpeculateStepPaddle, "speculate_step_paddle function"); + m.def("speculate_step_paddle", + &SpeculateStepPaddle, + "speculate_step_paddle function"); - m.def("merge_prefill_decode_output", &MergePrefillDecodeOutput, "merge_prefill_decode_output function"); + m.def("merge_prefill_decode_output", + &MergePrefillDecodeOutput, + "merge_prefill_decode_output function"); - m.def("rejection_top_p_sampling", &TopPSamplingReject, "rejection_top_p_sampling function"); + m.def("rejection_top_p_sampling", + &TopPSamplingReject, + "rejection_top_p_sampling function"); m.def("top_k_renorm_probs", &TopKRenorm, "top_k_renorm_probs function"); @@ -1352,17 +1597,31 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("save_output", &SaveOutMmsgStatic, "save_output function"); - m.def("limit_thinking_content_length_v1", &LimitThinkingContentLengthV1, "limit_thinking_content_length_v1 function"); + m.def("limit_thinking_content_length_v1", + &LimitThinkingContentLengthV1, + "limit_thinking_content_length_v1 function"); - m.def("limit_thinking_content_length_v2", &LimitThinkingContentLengthV2, "limit_thinking_content_length_v2 function"); + m.def("limit_thinking_content_length_v2", + &LimitThinkingContentLengthV2, + "limit_thinking_content_length_v2 function"); - m.def("speculate_limit_thinking_content_length_v1", &SpeculateLimitThinkingContentLengthV1, "speculate limit thinking content length function"); + m.def("speculate_limit_thinking_content_length_v1", + &SpeculateLimitThinkingContentLengthV1, + "speculate limit thinking content length function"); - m.def("speculate_limit_thinking_content_length_v2", &SpeculateLimitThinkingContentLengthV2, "speculate limit thinking content length function"); + m.def("speculate_limit_thinking_content_length_v2", + &SpeculateLimitThinkingContentLengthV2, + "speculate limit thinking content length function"); - m.def("speculate_get_logits", &SpeculateGetLogits, "speculate_get_logits function"); + m.def("speculate_get_logits", + &SpeculateGetLogits, + "speculate_get_logits function"); - m.def("speculate_insert_first_token", &SpeculateInsertFirstToken, "speculate_insert_first_token function"); + m.def("speculate_insert_first_token", + &SpeculateInsertFirstToken, + "speculate_insert_first_token function"); - m.def("speculate_get_target_logits", &SpeculateGetTargetLogits, "speculate_get_target_logits function"); + m.def("speculate_get_target_logits", + &SpeculateGetTargetLogits, + "speculate_get_target_logits function"); } diff --git a/custom_ops/gpu_ops/update_inputs.cu b/custom_ops/gpu_ops/update_inputs.cu index c58aeb39c0..7bd303d1dd 100644 --- a/custom_ops/gpu_ops/update_inputs.cu +++ b/custom_ops/gpu_ops/update_inputs.cu @@ -15,93 +15,97 @@ #include "helper.h" template -__global__ void update_inputs_kernel(bool *not_need_stop, - int *seq_lens_this_time, - int *seq_lens_encoder, - int *seq_lens_decoder, - int64_t *input_ids, - const int64_t *stop_nums, - const bool *stop_flags, - const bool *is_block_step, - const int64_t *next_tokens, +__global__ void update_inputs_kernel(bool* not_need_stop, + int* seq_lens_this_time, + int* seq_lens_encoder, + int* seq_lens_decoder, + int64_t* input_ids, + const int64_t* stop_nums, + const bool* stop_flags, + const bool* is_block_step, + const int64_t* next_tokens, const int bsz, const int max_bsz, const int input_ids_stride) { - int thread_idx = threadIdx.x; - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; + int thread_idx = threadIdx.x; + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; - bool stop_flag_now = false; - int64_t stop_flag_now_int = 0; - if (thread_idx < max_bsz) { - if (thread_idx < bsz) { - stop_flag_now = stop_flags[thread_idx]; - if (is_block_step[thread_idx]) { - stop_flag_now_int = 0; - } else { - stop_flag_now_int = static_cast(stop_flag_now); - } - } else { - stop_flag_now_int = 1; - } - } + bool stop_flag_now = false; + int64_t stop_flag_now_int = 0; + if (thread_idx < max_bsz) { if (thread_idx < bsz) { - const int seq_len_this_time = seq_lens_this_time[thread_idx]; - const int seq_len_encoder = seq_lens_encoder[thread_idx]; - const int seq_len_decoder = seq_lens_decoder[thread_idx]; - - seq_lens_decoder[thread_idx] = stop_flag_now ? - 0 : (seq_len_encoder > 0 ? - (seq_len_encoder + seq_len_decoder) : seq_len_decoder + 1); - - seq_lens_this_time[thread_idx] = stop_flag_now ? 0 : 1; - seq_lens_encoder[thread_idx] = 0; - int64_t *input_ids_now = input_ids + thread_idx * input_ids_stride; - input_ids_now[0] = next_tokens[thread_idx]; - } - __syncthreads(); - int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int); - if (thread_idx == 0) { - not_need_stop[0] = stop_sum < stop_nums[0]; + stop_flag_now = stop_flags[thread_idx]; + if (is_block_step[thread_idx]) { + stop_flag_now_int = 0; + } else { + stop_flag_now_int = static_cast(stop_flag_now); + } + } else { + stop_flag_now_int = 1; } + } + if (thread_idx < bsz) { + const int seq_len_this_time = seq_lens_this_time[thread_idx]; + const int seq_len_encoder = seq_lens_encoder[thread_idx]; + const int seq_len_decoder = seq_lens_decoder[thread_idx]; + + seq_lens_decoder[thread_idx] = + stop_flag_now + ? 0 + : (seq_len_encoder > 0 ? (seq_len_encoder + seq_len_decoder) + : seq_len_decoder + 1); + + seq_lens_this_time[thread_idx] = stop_flag_now ? 0 : 1; + seq_lens_encoder[thread_idx] = 0; + int64_t* input_ids_now = input_ids + thread_idx * input_ids_stride; + input_ids_now[0] = next_tokens[thread_idx]; + } + __syncthreads(); + int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int); + if (thread_idx == 0) { + not_need_stop[0] = stop_sum < stop_nums[0]; + } } -void UpdateInputes(const paddle::Tensor &stop_flags, - const paddle::Tensor ¬_need_stop, // only on cpu - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &input_ids, - const paddle::Tensor &stop_nums, - const paddle::Tensor &next_tokens, - const paddle::Tensor &is_block_step) { +void UpdateInputs(const paddle::Tensor& stop_flags, + const paddle::Tensor& not_need_stop, // only on cpu + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& input_ids, + const paddle::Tensor& stop_nums, + const paddle::Tensor& next_tokens, + const paddle::Tensor& is_block_step) { #ifdef PADDLE_WITH_CUSTOM_DEVICE - auto dev_ctx = static_cast(paddle::experimental::DeviceContextPool::Instance().Get(input_ids.place())); - auto cu_stream = dev_ctx->stream(); + auto dev_ctx = static_cast( + paddle::experimental::DeviceContextPool::Instance().Get( + input_ids.place())); + auto cu_stream = dev_ctx->stream(); #else - auto cu_stream = input_ids.stream(); + auto cu_stream = input_ids.stream(); #endif - const int max_bsz = stop_flags.shape()[0]; - const int now_bsz = seq_lens_this_time.shape()[0]; - const int input_ids_stride = input_ids.shape()[1]; - auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false); - update_inputs_kernel<1024><<<1, 1024, 0, cu_stream>>>( - const_cast(not_need_stop_gpu.data()), - const_cast(seq_lens_this_time.data()), - const_cast(seq_lens_encoder.data()), - const_cast(seq_lens_decoder.data()), - const_cast(input_ids.data()), - stop_nums.data(), - stop_flags.data(), - is_block_step.data(), - next_tokens.data(), - now_bsz, - max_bsz, - input_ids_stride); - auto not_need_stop_cpu = - not_need_stop_gpu.copy_to(not_need_stop.place(), false); - bool *not_need_stop_data = const_cast(not_need_stop.data()); - not_need_stop_data[0] = not_need_stop_cpu.data()[0]; + const int max_bsz = stop_flags.shape()[0]; + const int now_bsz = seq_lens_this_time.shape()[0]; + const int input_ids_stride = input_ids.shape()[1]; + auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false); + update_inputs_kernel<1024><<<1, 1024, 0, cu_stream>>>( + const_cast(not_need_stop_gpu.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(input_ids.data()), + stop_nums.data(), + stop_flags.data(), + is_block_step.data(), + next_tokens.data(), + now_bsz, + max_bsz, + input_ids_stride); + auto not_need_stop_cpu = + not_need_stop_gpu.copy_to(not_need_stop.place(), false); + bool* not_need_stop_data = const_cast(not_need_stop.data()); + not_need_stop_data[0] = not_need_stop_cpu.data()[0]; } PD_BUILD_STATIC_OP(update_inputs) @@ -124,4 +128,4 @@ PD_BUILD_STATIC_OP(update_inputs) {"seq_lens_encoder", "seq_lens_encoder_out"}, {"seq_lens_decoder", "seq_lens_decoder_out"}, {"input_ids", "input_ids_out"}}) - .SetKernelFn(PD_KERNEL(UpdateInputes)); + .SetKernelFn(PD_KERNEL(UpdateInputs)); diff --git a/custom_ops/gpu_ops/update_inputs_beam.cu b/custom_ops/gpu_ops/update_inputs_beam.cu index aea374661d..afa5420826 100644 --- a/custom_ops/gpu_ops/update_inputs_beam.cu +++ b/custom_ops/gpu_ops/update_inputs_beam.cu @@ -15,15 +15,14 @@ #include "helper.h" template -__global__ void update_inputs_beam_kernel( - int *seq_lens_this_time, - int *seq_lens_encoder, - int64_t *input_ids, - float *logits, - const int bsz, - const int seq_len, - const int hidden_size, - const int beam_width) { +__global__ void update_inputs_beam_kernel(int* seq_lens_this_time, + int* seq_lens_encoder, + int64_t* input_ids, + float* logits, + const int bsz, + const int seq_len, + const int hidden_size, + const int beam_width) { int thread_idx = threadIdx.x; int block_idx = blockIdx.x; @@ -35,23 +34,22 @@ __global__ void update_inputs_beam_kernel( seq_lens_encoder[thread_idx] = seq_lens_encoder[bsz_index]; } if (block_idx < seq_len) { - input_ids[thread_idx * seq_len + block_idx] = input_ids[bsz_index * seq_len + block_idx]; + input_ids[thread_idx * seq_len + block_idx] = + input_ids[bsz_index * seq_len + block_idx]; } - logits[thread_idx * hidden_size + block_idx] = logits[bsz_index * hidden_size + block_idx]; - + logits[thread_idx * hidden_size + block_idx] = + logits[bsz_index * hidden_size + block_idx]; } - } __syncthreads(); } -void UpdateInputesBeam( - const paddle::Tensor& beam_width, - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& input_ids, - const paddle::Tensor& logits) { +void UpdateInputsBeam(const paddle::Tensor& beam_width, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& input_ids, + const paddle::Tensor& logits) { int beam_width_scalar = beam_width.data()[0]; if (beam_width_scalar > 1) { @@ -59,16 +57,16 @@ void UpdateInputesBeam( const int seq_len = input_ids.shape()[1]; const int hidden_size = logits.shape()[1]; - update_inputs_beam_kernel<1024><<>>( - const_cast(seq_lens_this_time.data()), - const_cast(seq_lens_encoder.data()), - const_cast(input_ids.data()), - const_cast(logits.data()), - bsz, - seq_len, - hidden_size, - beam_width_scalar - ); + update_inputs_beam_kernel<1024> + <<>>( + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_encoder.data()), + const_cast(input_ids.data()), + const_cast(logits.data()), + bsz, + seq_len, + hidden_size, + beam_width_scalar); } } @@ -86,4 +84,4 @@ PD_BUILD_STATIC_OP(update_inputs_beam) {"seq_lens_encoder", "seq_lens_encoder_out"}, {"input_ids", "input_ids_out"}, {"logits", "logits_out"}}) - .SetKernelFn(PD_KERNEL(UpdateInputesBeam)); + .SetKernelFn(PD_KERNEL(UpdateInputsBeam)); diff --git a/custom_ops/gpu_ops/update_inputs_v1.cu b/custom_ops/gpu_ops/update_inputs_v1.cu index 33076b073c..64230ae256 100644 --- a/custom_ops/gpu_ops/update_inputs_v1.cu +++ b/custom_ops/gpu_ops/update_inputs_v1.cu @@ -15,146 +15,150 @@ #include "helper.h" template -__global__ void update_inputs_kernel_v1(bool *not_need_stop, - int *seq_lens_this_time, - int *seq_lens_encoder, - int *seq_lens_decoder, - int *step_seq_lens_decoder, - int64_t *prompt_lens, - int64_t *topk_ids, - int64_t *input_ids, - int *block_tables, - const int64_t *stop_nums, - bool *stop_flags, - bool *is_block_step, - const int64_t *next_tokens, - const int bsz, - const int max_bsz, - const int input_ids_stride, - const int block_num_per_seq, - const int block_size, - bool prefill_one_step_stop) { - int thread_idx = threadIdx.x; - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; +__global__ void update_inputs_kernel_v1(bool* not_need_stop, + int* seq_lens_this_time, + int* seq_lens_encoder, + int* seq_lens_decoder, + int* step_seq_lens_decoder, + int64_t* prompt_lens, + int64_t* topk_ids, + int64_t* input_ids, + int* block_tables, + const int64_t* stop_nums, + bool* stop_flags, + bool* is_block_step, + const int64_t* next_tokens, + const int bsz, + const int max_bsz, + const int input_ids_stride, + const int block_num_per_seq, + const int block_size, + bool prefill_one_step_stop) { + int thread_idx = threadIdx.x; + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; - bool stop_flag_now = false; - int64_t stop_flag_now_int = 0; - if (thread_idx < max_bsz) { - if (thread_idx < bsz) { - stop_flag_now = stop_flags[thread_idx]; - stop_flag_now_int = static_cast(stop_flag_now); - } else { - stop_flag_now_int = 1; - } - } + bool stop_flag_now = false; + int64_t stop_flag_now_int = 0; + if (thread_idx < max_bsz) { if (thread_idx < bsz) { - if(stop_flag_now) { - seq_lens_this_time[thread_idx] = 0; // stop at next step - seq_lens_decoder[thread_idx] = 0; - seq_lens_encoder[thread_idx] = 0; + stop_flag_now = stop_flags[thread_idx]; + stop_flag_now_int = static_cast(stop_flag_now); + } else { + stop_flag_now_int = 1; + } + } + if (thread_idx < bsz) { + if (stop_flag_now) { + seq_lens_this_time[thread_idx] = 0; // stop at next step + seq_lens_decoder[thread_idx] = 0; + seq_lens_encoder[thread_idx] = 0; + } else { + if (seq_lens_this_time[thread_idx] + seq_lens_decoder[thread_idx] >= + prompt_lens[thread_idx]) { + if (prefill_one_step_stop) { + // prefill done, stop + stop_flags[thread_idx] = true; + seq_lens_this_time[thread_idx] = 0; + seq_lens_decoder[thread_idx] = 0; + seq_lens_encoder[thread_idx] = 0; + stop_flag_now_int = 1; } else { - if (seq_lens_this_time[thread_idx] + seq_lens_decoder[thread_idx] >= prompt_lens[thread_idx]) { - if (prefill_one_step_stop) { - // prefill done, stop - stop_flags[thread_idx] = true; - seq_lens_this_time[thread_idx] = 0; - seq_lens_decoder[thread_idx] = 0; - seq_lens_encoder[thread_idx] = 0; - stop_flag_now_int = 1; - } else{ - // decoding - seq_lens_decoder[thread_idx] += seq_lens_this_time[thread_idx]; - seq_lens_this_time[thread_idx] = 1; - seq_lens_encoder[thread_idx] = 0; - int64_t *input_ids_now = input_ids + thread_idx * input_ids_stride; - input_ids_now[0] = next_tokens[thread_idx]; + // decoding + seq_lens_decoder[thread_idx] += seq_lens_this_time[thread_idx]; + seq_lens_this_time[thread_idx] = 1; + seq_lens_encoder[thread_idx] = 0; + int64_t* input_ids_now = input_ids + thread_idx * input_ids_stride; + input_ids_now[0] = next_tokens[thread_idx]; - // to judge whether block is not enough - int *block_table_now = block_tables + thread_idx * block_num_per_seq; - if (seq_lens_this_time[thread_idx] != 0 && block_table_now[seq_lens_decoder[thread_idx] / block_size] == -1) { - // should be scheduled by server - is_block_step[thread_idx] = true; - seq_lens_this_time[thread_idx]= 0; - stop_flags[thread_idx] = true; - step_seq_lens_decoder[thread_idx] = seq_lens_decoder[thread_idx]; - seq_lens_decoder[thread_idx] = 0; - stop_flag_now_int = 1; - } - } - } else - { - stop_flags[thread_idx] = true; - seq_lens_this_time[thread_idx] = 0; - seq_lens_decoder[thread_idx] = 0; - seq_lens_encoder[thread_idx] = 0; - topk_ids[thread_idx] = -1; - stop_flag_now_int = 1; - } + // to judge whether block is not enough + int* block_table_now = block_tables + thread_idx * block_num_per_seq; + if (seq_lens_this_time[thread_idx] != 0 && + block_table_now[seq_lens_decoder[thread_idx] / block_size] == + -1) { + // should be scheduled by server + is_block_step[thread_idx] = true; + seq_lens_this_time[thread_idx] = 0; + stop_flags[thread_idx] = true; + step_seq_lens_decoder[thread_idx] = seq_lens_decoder[thread_idx]; + seq_lens_decoder[thread_idx] = 0; + stop_flag_now_int = 1; + } } + } else { + stop_flags[thread_idx] = true; + seq_lens_this_time[thread_idx] = 0; + seq_lens_decoder[thread_idx] = 0; + seq_lens_encoder[thread_idx] = 0; + topk_ids[thread_idx] = -1; + stop_flag_now_int = 1; + } } - __syncthreads(); - int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int); - if (thread_idx == 0) { - not_need_stop[0] = stop_sum < stop_nums[0]; - } + } + __syncthreads(); + int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int); + if (thread_idx == 0) { + not_need_stop[0] = stop_sum < stop_nums[0]; + } } -void UpdateInputesV1(const paddle::Tensor &stop_flags, - const paddle::Tensor ¬_need_stop, // only on cpu - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &step_seq_lens_decoder, - const paddle::Tensor &prompt_lens, - const paddle::Tensor &topk_ids, - const paddle::Tensor &input_ids, - const paddle::Tensor &block_tables, - const paddle::Tensor &stop_nums, - const paddle::Tensor &next_tokens, - const paddle::Tensor &is_block_step, - const int block_size) { +void UpdateInputsV1(const paddle::Tensor& stop_flags, + const paddle::Tensor& not_need_stop, // only on cpu + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& step_seq_lens_decoder, + const paddle::Tensor& prompt_lens, + const paddle::Tensor& topk_ids, + const paddle::Tensor& input_ids, + const paddle::Tensor& block_tables, + const paddle::Tensor& stop_nums, + const paddle::Tensor& next_tokens, + const paddle::Tensor& is_block_step, + const int block_size) { #ifdef PADDLE_WITH_CUSTOM_DEVICE - auto dev_ctx = static_cast(paddle::experimental::DeviceContextPool::Instance().Get(input_ids.place())); - auto cu_stream = dev_ctx->stream(); + auto dev_ctx = static_cast( + paddle::experimental::DeviceContextPool::Instance().Get( + input_ids.place())); + auto cu_stream = dev_ctx->stream(); #else - auto cu_stream = input_ids.stream(); + auto cu_stream = input_ids.stream(); #endif - bool prefill_one_step_stop = false; - if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP_V1")) { - if (env_p[0] == '1') { - prefill_one_step_stop = true; - } + bool prefill_one_step_stop = false; + if (const char* env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP_V1")) { + if (env_p[0] == '1') { + prefill_one_step_stop = true; } - const int max_bsz = stop_flags.shape()[0]; - const int now_bsz = seq_lens_this_time.shape()[0]; - const int input_ids_stride = input_ids.shape()[1]; - const int block_num_per_seq = block_tables.shape()[1]; - auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false); - update_inputs_kernel_v1<1024><<<1, 1024, 0, cu_stream>>>( - const_cast(not_need_stop_gpu.data()), - const_cast(seq_lens_this_time.data()), - const_cast(seq_lens_encoder.data()), - const_cast(seq_lens_decoder.data()), - const_cast(step_seq_lens_decoder.data()), - const_cast(prompt_lens.data()), - const_cast(topk_ids.data()), - const_cast(input_ids.data()), - const_cast(block_tables.data()), - stop_nums.data(), - const_cast(stop_flags.data()), - const_cast(is_block_step.data()), - next_tokens.data(), - now_bsz, - max_bsz, - input_ids_stride, - block_num_per_seq, - block_size, - prefill_one_step_stop); - auto not_need_stop_cpu = - not_need_stop_gpu.copy_to(not_need_stop.place(), false); - bool *not_need_stop_data = const_cast(not_need_stop.data()); - not_need_stop_data[0] = not_need_stop_cpu.data()[0]; + } + const int max_bsz = stop_flags.shape()[0]; + const int now_bsz = seq_lens_this_time.shape()[0]; + const int input_ids_stride = input_ids.shape()[1]; + const int block_num_per_seq = block_tables.shape()[1]; + auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false); + update_inputs_kernel_v1<1024><<<1, 1024, 0, cu_stream>>>( + const_cast(not_need_stop_gpu.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(step_seq_lens_decoder.data()), + const_cast(prompt_lens.data()), + const_cast(topk_ids.data()), + const_cast(input_ids.data()), + const_cast(block_tables.data()), + stop_nums.data(), + const_cast(stop_flags.data()), + const_cast(is_block_step.data()), + next_tokens.data(), + now_bsz, + max_bsz, + input_ids_stride, + block_num_per_seq, + block_size, + prefill_one_step_stop); + auto not_need_stop_cpu = + not_need_stop_gpu.copy_to(not_need_stop.place(), false); + bool* not_need_stop_data = const_cast(not_need_stop.data()); + not_need_stop_data[0] = not_need_stop_cpu.data()[0]; } PD_BUILD_STATIC_OP(update_inputs_v1) @@ -190,4 +194,4 @@ PD_BUILD_STATIC_OP(update_inputs_v1) {"stop_flags", "stop_flags_out"}, {"step_seq_lens_decoder", "step_seq_lens_decoder_out"}, {"is_block_step", "is_block_step_out"}}) - .SetKernelFn(PD_KERNEL(UpdateInputesV1)); + .SetKernelFn(PD_KERNEL(UpdateInputsV1)); diff --git a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc index 288f432261..833cd04ecf 100644 --- a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc +++ b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc @@ -33,6 +33,20 @@ void prof_start(); void prof_stop(); +std::vector AdjustBatch( + const paddle::Tensor& x, // [token_num, dim_embed] + const paddle::Tensor& cum_offsets, // [bsz, 1] + const paddle::Tensor& encoder_seq_lod, + const paddle::Tensor& encoder_batch_idx, + const paddle::Tensor& decoder_batch_idx, + const paddle::Tensor& encoder_seq_lod_cpu, + const paddle::Tensor& encoder_batch_idx_cpu, + const paddle::Tensor& decoder_batch_idx_cpu, + const paddle::Tensor& enc_batch_tensor, + const paddle::Tensor& dec_batch_tensor, + const paddle::optional& output_padding_offset, + int max_input_length); + void InitKVSignalPerQuery(const paddle::Tensor& seq_lens_encoder_tensor, const paddle::Tensor& seq_lens_this_time_tensor, const paddle::Tensor& seq_lens_decoder_tensor, @@ -73,6 +87,21 @@ std::vector BlockAttn( const std::string& pos_emb_type = "NORMAL", bool rope_3d = false); +std::vector MoeLayer( + const paddle::Tensor& x, + const paddle::Tensor& gate_weight, + const paddle::optional& gate_correction_bias, + const paddle::Tensor& up_gate_proj_weight, + const paddle::Tensor& down_proj_weight, + const paddle::optional& up_gate_proj_bias, + const paddle::optional& down_proj_bias, + const paddle::optional& up_gate_proj_weight_scale, + const paddle::optional& down_proj_weight_scale, + const paddle::optional& down_proj_in_scale, + const std::string& quant_method, + const int moe_top_k, + const bool moe_group); + std::vector MoERedundantTopKSelect( const paddle::Tensor& gating_logits, const paddle::Tensor& expert_id_to_ep_rank_array, @@ -294,6 +323,65 @@ std::vector EagleGetSelfHiddenStates( const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& step_idx); +std::vector GatherNextToken( + const paddle::Tensor& tmp_out, // [token_num, dim_embed] + const paddle::Tensor& cum_offsets, // [bsz, 1] + const paddle::Tensor& encoder_seq_lod, + const paddle::Tensor& encoder_batch_map, + const paddle::Tensor& decoder_batch_map, + const paddle::Tensor& encoder_seq_lod_cpu, + const paddle::Tensor& encoder_batch_map_cpu, + const paddle::Tensor& decoder_batch_map_cpu, + const paddle::Tensor& enc_batch_tensor, + const paddle::Tensor& dec_batch_tensor, + const paddle::optional& output_padding_offset, + int max_input_length); + +std::vector GetImgBoundaries( + const paddle::Tensor& task_input_ids, + const paddle::Tensor& grid_thw, + const int64_t image_patch_id); + +std::vector GetInferParam( + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& block_tables, + int block_size); + +void GetOutputStatic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag); + +void GetOutputDynamic(const paddle::Tensor& x, + int64_t rank_id, + bool wait_flag, + int msg_queue_id); + +std::vector GetPaddingOffset(const paddle::Tensor& input_ids, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& token_num, + const paddle::Tensor& seq_len); + +void GetStopFlagsMulti(const paddle::Tensor& topk_ids, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens, + const paddle::Tensor& end_ids, + const paddle::Tensor& next_tokens, + const bool beam_search); + +void RecoverDecodeTask(const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& step_seq_lens_decoder, + const paddle::Tensor& block_tables, + const paddle::Tensor& is_block_step, + const int block_size); + +std::vector ShareExternalData(const paddle::Tensor& input, + const std::string shm_name, + const std::vector& shape, + bool use_ipc); + std::vector SpeculateGetOutputPaddingOffset( const paddle::Tensor& output_cum_offsets_tmp, const paddle::Tensor& out_token_num, @@ -308,6 +396,31 @@ std::vector SpeculateGetPaddingOffset( const paddle::Tensor& seq_len, const paddle::Tensor& seq_lens_encoder); +void StepPaddle(const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& ori_seq_lens_encoder, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& block_tables, // [bsz, block_num_per_seq] + const paddle::Tensor& encoder_block_lens, + const paddle::Tensor& is_block_step, + const paddle::Tensor& step_block_list, + const paddle::Tensor& step_lens, + const paddle::Tensor& recover_block_list, + const paddle::Tensor& recover_lens, + const paddle::Tensor& need_block_list, + const paddle::Tensor& need_block_len, + const paddle::Tensor& used_list_len, + const paddle::Tensor& free_list, + const paddle::Tensor& free_list_len, + const paddle::Tensor& input_ids, + const paddle::Tensor& pre_ids, + const paddle::Tensor& step_idx, + const paddle::Tensor& next_tokens, + const paddle::Tensor& first_token_ids, + const int block_size, + const int encoder_decoder_block_num); + void MTPStepPaddle( const paddle::Tensor& base_model_stop_flags, const paddle::Tensor& stop_flags, @@ -323,6 +436,17 @@ void MTPStepPaddle( const int block_size, const int max_draft_tokens); +void SaveOutMmsgStatic(const paddle::Tensor& x, + const paddle::Tensor& not_need_stop, + int64_t rank_id, + bool save_each_rank); + +void SaveOutMmsgDynamic(const paddle::Tensor& x, + const paddle::Tensor& not_need_stop, + int64_t rank_id, + int msg_queue_id, + bool save_each_rank); + void SpeculateStepSchedule( const paddle::Tensor& stop_flags, const paddle::Tensor& seq_lens_this_time, @@ -356,7 +480,78 @@ std::vector SpeculateGetSeqLensOutput( const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder); +void SetDataIpc(const paddle::Tensor& tmp_input, const std::string& shm_name); + +void TextImageGatherScatter(paddle::Tensor& input, + paddle::Tensor& text_input, + paddle::Tensor& image_input, + paddle::Tensor& token_type_ids, + paddle::Tensor& text_index, + paddle::Tensor& image_index, + const bool is_scatter); + +void TextImageIndexOut(const paddle::Tensor& token_type_ids, + const paddle::Tensor& text_index, + const paddle::Tensor& image_index); + +void TokenPenaltyMultiScores(const paddle::Tensor& pre_ids, + const paddle::Tensor& logits, + const paddle::Tensor& penalty_scores, + const paddle::Tensor& frequency_scores, + const paddle::Tensor& presence_scores, + const paddle::Tensor& temperatures, + const paddle::Tensor& bad_tokens, + const paddle::Tensor& cur_len, + const paddle::Tensor& min_len, + const paddle::Tensor& eos_token_id); + +void UpdateInputs(const paddle::Tensor& stop_flags, + const paddle::Tensor& not_need_stop, // cpu + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& input_ids, + const paddle::Tensor& stop_nums, + const paddle::Tensor& next_tokens, + const paddle::Tensor& is_block_step); + +void UpdateInputsV1(const paddle::Tensor& stop_flags, + const paddle::Tensor& not_need_stop, // only on cpu + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& step_seq_lens_decoder, + const paddle::Tensor& prompt_lens, + const paddle::Tensor& topk_ids, + const paddle::Tensor& input_ids, + const paddle::Tensor& block_tables, + const paddle::Tensor& stop_nums, + const paddle::Tensor& next_tokens, + const paddle::Tensor& is_block_step, + const int block_size); + +std::vector WeightQuantize(const paddle::Tensor& x, + const std::string& algo, + const int32_t arch, + const int32_t group_size); + PYBIND11_MODULE(fastdeploy_ops, m) { + m.def("adjust_batch", + &AdjustBatch, + py::arg("x"), + py::arg("cum_offsets"), + py::arg("encoder_seq_lod"), + py::arg("encoder_batch_idx"), + py::arg("decoder_batch_idx"), + py::arg("encoder_seq_lod_cpu"), + py::arg("encoder_batch_idx_cpu"), + py::arg("decoder_batch_idx_cpu"), + py::arg("enc_batch_tensor"), + py::arg("dec_batch_tensor"), + py::arg("output_padding_offset"), + py::arg("max_input_length"), + "adjust batch in XPU"); + m.def("block_attn", &BlockAttn, py::arg("qkv"), @@ -388,96 +583,107 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("pos_emb_type") = "NORMAL", py::arg("rope_3d") = false, "block attention in XPU"); + + m.def("create_kv_signal_sender", + &create_cachekv_signal_thread, + "init write cache kv signal thread"); + m.def("cuda_host_alloc", &custom_xpu_host_alloc, "Allocate pinned memory", py::arg("size"), py::arg("flags") = 0x00); + m.def("cuda_host_free", &custom_xpu_host_free, "Free pinned memory", py::arg("ptr")); - m.def("get_peer_mem_addr", - &xpu_get_peer_mem_addr, - "Get Host memory address of device pointer", - py::arg("ptr")); + m.def("cuda_host_register", &xpu_cuda_host_register, "Register pinned memory", py::arg("ptr"), py::arg("size"), py::arg("flags") = cudaHostRegisterDefault); - m.def("create_kv_signal_sender", - &create_cachekv_signal_thread, - "init write cache kv signal thread"); + m.def("destroy_kv_signal_sender", &destroy_cachekv_signal_thread, "write cache kv signal thread exit"); - m.def("prof_start", &prof_start, "prof_start"); - m.def("prof_stop", &prof_stop, "prof_stop"); - m.def("moe_redundant_topk_select", - &MoERedundantTopKSelect, - py::arg("gating_logits"), - py::arg("expert_id_to_ep_rank_array"), - py::arg("expert_in_rank_num_list"), - py::arg("tokens_per_expert_stats_list"), - py::arg("bias"), - py::arg("moe_topk"), - py::arg("apply_norm_weight"), - py::arg("enable_softmax_top_k_fused"), - py::arg("redundant_ep_rank_num_plus_one"), - "moe export RedundantTopKSelect function"); - m.def("set_ncluster", &set_ncluster, "set ncluster"); - /** - * open_shm_and_get_meta_signal.cc - * InitKVSingnalPerQuery - */ - m.def("init_kv_signal_per_query", - &InitKVSignalPerQuery, - py::arg("seq_lens_encoder_tensor"), - py::arg("seq_lens_this_time_tensor"), - py::arg("seq_lens_decoder_tensor"), - py::arg("rank"), - py::arg("num_layers"), - "init_kv_signal_per_query function"); + m.def("draft_model_preprocess", + &DraftModelPreprocess, + py::arg("draft_tokens"), + py::arg("input_ids"), + py::arg("stop_flags"), + py::arg("seq_lens_this_time"), + py::arg("seq_lens_encoder"), + py::arg("seq_lens_decoder"), + py::arg("step_idx"), + py::arg("seq_lens_encoder_record"), + py::arg("seq_lens_decoder_record"), + py::arg("not_need_stop"), + py::arg("batch_drop"), + py::arg("accept_tokens"), + py::arg("accept_num"), + py::arg("base_model_seq_lens_encoder"), + py::arg("base_model_seq_lens_decoder"), + py::arg("base_model_step_idx"), + py::arg("base_model_stop_flags"), + py::arg("base_model_is_block_step"), + py::arg("base_model_draft_tokens"), + py::arg("max_draft_token"), + py::arg("truncate_first_token"), + py::arg("splitwise_prefill"), + "Preprocess data for draft model in speculative decoding"); - /** - * GetOutputKVSignal - */ - m.def("get_output_kv_signal", - &GetOutputKVSignal, - py::arg("x"), - py::arg("rank_id"), - py::arg("wait_flag"), - "get_output_kv_signal function"); + m.def("draft_model_postprocess", + &DraftModelPostprocess, + py::arg("base_model_draft_tokens"), + py::arg("base_model_seq_lens_this_time"), + py::arg("base_model_seq_lens_encoder"), + py::arg("base_model_stop_flags"), + "Postprocess data for draft model in speculative decoding"); - m.def("fused_rms_norm_xpu", - &RmsNorm, - "Fused RMS normalization for XPU", - py::arg("x"), // 输入张量 - py::arg("bias"), // 偏置(可选) - py::arg("residual"), // 残差连接(可选) - py::arg("norm_weight"), // 归一化权重 - py::arg("norm_bias"), // 归一化偏置(可选) - py::arg("epsilon"), // 数值稳定项 - py::arg("begin_norm_axis"), // 归一化起始维度 - py::arg("quant_scale"), // 量化缩放因子 - py::arg("quant_round_type"), // 量化舍入类型 - py::arg("quant_max_bound"), // 量化最大值边界 - py::arg("quant_min_bound") // 量化最小值边界 + m.def("draft_model_update", + &DraftModelUpdate, + "Update draft model states during speculative decoding", + py::arg("inter_next_tokens"), // 中间next tokens张量 + py::arg("draft_tokens"), // 草稿token张量 + py::arg("pre_ids"), // 前置ID张量 + py::arg("seq_lens_this_time"), // 当前步骤序列长度张量 + py::arg("seq_lens_encoder"), // 编码器序列长度张量 + py::arg("seq_lens_decoder"), // 解码器序列长度张量 + py::arg("step_idx"), // 步骤索引张量 + py::arg("output_cum_offsets"), // 输出累积偏移量张量 + py::arg("stop_flags"), // 停止标志张量 + py::arg("not_need_stop"), // 无需停止标志张量 + py::arg("max_dec_len"), // 最大解码长度张量 + py::arg("end_ids"), // 结束ID张量 + py::arg("base_model_draft_tokens"), // 基础模型草稿token张量 + py::arg("max_seq_len"), // 最大序列长度(int) + py::arg("substep") // 子步骤编号(int) ); - m.def("weight_only_linear_xpu", - &WeightOnlyLinear, - "Weight-only quantized linear layer", - py::arg("x"), - py::arg("weight"), - py::arg("weight_scale"), - py::arg("bias"), - py::arg("weight_dtype"), - py::arg("arch"), - py::arg("group_size") = -1); + m.def("eagle_get_hidden_states", + &EagleGetHiddenStates, + py::arg("input"), + py::arg("seq_lens_this_time"), + py::arg("seq_lens_encoder"), + py::arg("seq_lens_decoder"), + py::arg("stop_flags"), + py::arg("accept_nums"), + py::arg("base_model_seq_lens_this_time"), + py::arg("base_model_seq_lens_encoder"), + py::arg("actual_draft_token_num"), + "Get draft model hidden states"); + + m.def("eagle_get_self_hidden_states", + &EagleGetSelfHiddenStates, + py::arg("input"), + py::arg("last_seq_lens_this_time"), + py::arg("seq_lens_this_time"), + py::arg("step_idx"), + "Rebuild draft model hidden states"); m.def("ep_moe_expert_combine", &MoeEPCombine, @@ -502,6 +708,157 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("token_nums_this_rank"), py::arg("quant_method")); + m.def("fused_rms_norm_xpu", + &RmsNorm, + "Fused RMS normalization for XPU", + py::arg("x"), // 输入张量 + py::arg("bias"), // 偏置(可选) + py::arg("residual"), // 残差连接(可选) + py::arg("norm_weight"), // 归一化权重 + py::arg("norm_bias"), // 归一化偏置(可选) + py::arg("epsilon"), // 数值稳定项 + py::arg("begin_norm_axis"), // 归一化起始维度 + py::arg("quant_scale"), // 量化缩放因子 + py::arg("quant_round_type"), // 量化舍入类型 + py::arg("quant_max_bound"), // 量化最大值边界 + py::arg("quant_min_bound") // 量化最小值边界 + ); + + m.def("gather_next_token", + &GatherNextToken, + py::arg("tmp_out"), + py::arg("cum_offsets"), + py::arg("encoder_seq_lod"), + py::arg("encoder_batch_map"), + py::arg("decoder_batch_map"), + py::arg("encoder_seq_lod_cpu"), + py::arg("encoder_batch_map_cpu"), + py::arg("decoder_batch_map_cpu"), + py::arg("enc_batch_tensor"), + py::arg("dec_batch_tensor"), + py::arg("output_padding_offset"), + py::arg("max_input_length"), + "Gather next token for XPU"); + + m.def("get_img_boundaries", + &GetImgBoundaries, + py::arg("task_input_ids"), + py::arg("grid_thw"), + py::arg("image_patch_id"), + "Get image boundaries in VL model"); + + m.def("get_infer_param", + &GetInferParam, + py::arg("seq_lens_encoder"), + py::arg("seq_lens_decoder"), + py::arg("seq_lens_this_time"), + py::arg("block_tables"), + py::arg("block_size"), + "Get infer parameters for block attention in XPU"); + + m.def("get_peer_mem_addr", + &xpu_get_peer_mem_addr, + "Get Host memory address of device pointer", + py::arg("ptr")); + + m.def("get_token_penalty_multi_scores", + &TokenPenaltyMultiScores, + py::arg("pre_ids"), + py::arg("logits"), + py::arg("penalty_scores"), + py::arg("frequency_scores"), + py::arg("presence_scores"), + py::arg("temperatures"), + py::arg("bad_tokens"), + py::arg("cur_len"), + py::arg("min_len"), + py::arg("eos_token_id"), + "get token_penalty_multi_scores function"); + + m.def("get_output", + &GetOutputStatic, + py::arg("x"), + py::arg("rank_id"), + py::arg("wait_flag"), + "get_output function"); + + m.def("get_output_ep", + &GetOutputStatic, + py::arg("x"), + py::arg("rank_id"), + py::arg("wait_flag"), + "get_output_ep function"); + + m.def("get_output_dynamic", + &GetOutputDynamic, + py::arg("x"), + py::arg("rank_id"), + py::arg("wait_flag"), + py::arg("msg_queue_id"), + "get_output_dynamic function"); + + m.def("get_output_ep_dynamic", + &GetOutputDynamic, + py::arg("x"), + py::arg("rank_id"), + py::arg("wait_flag"), + py::arg("msg_queue_id"), + "get_output_ep_dynamic function"); + + m.def("get_output_kv_signal", + &GetOutputKVSignal, + py::arg("x"), + py::arg("rank_id"), + py::arg("wait_flag"), + "get_output_kv_signal function"); + + m.def("get_padding_offset", + &GetPaddingOffset, + py::arg("input_ids"), + py::arg("cum_offsets"), + py::arg("token_num"), + py::arg("seq_len"), + "get padding offset function"); + + m.def("init_kv_signal_per_query", + &InitKVSignalPerQuery, + py::arg("seq_lens_encoder_tensor"), + py::arg("seq_lens_this_time_tensor"), + py::arg("seq_lens_decoder_tensor"), + py::arg("rank"), + py::arg("num_layers"), + "init_kv_signal_per_query function"); + + m.def("moe_redundant_topk_select", + &MoERedundantTopKSelect, + py::arg("gating_logits"), + py::arg("expert_id_to_ep_rank_array"), + py::arg("expert_in_rank_num_list"), + py::arg("tokens_per_expert_stats_list"), + py::arg("bias"), + py::arg("moe_topk"), + py::arg("apply_norm_weight"), + py::arg("enable_softmax_top_k_fused"), + py::arg("redundant_ep_rank_num_plus_one"), + "moe export RedundantTopKSelect function"); + + m.def("mtp_step_paddle", + &MTPStepPaddle, + py::arg("base_model_stop_flags"), + py::arg("stop_flags"), + py::arg("batch_drop"), + py::arg("seq_lens_this_time"), + py::arg("seq_lens_encoder"), + py::arg("seq_lens_decoder"), + py::arg("block_tables"), // [bsz, block_num_per_seq] + py::arg("encoder_block_lens"), + py::arg("used_list_len"), + py::arg("free_list"), + py::arg("free_list_len"), + py::arg("block_size"), + py::arg("max_draft_tokens"), + "MTP step paddle"); + m.def("moe_expert_ffn", &MoeExpertFFN, "MoE expert feed-forward network with quantization support", @@ -529,25 +886,46 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("moe_topk"), py::arg("apply_norm_weight")); - m.def("draft_model_update", - &DraftModelUpdate, - "Update draft model states during speculative decoding", - py::arg("inter_next_tokens"), // 中间next tokens张量 - py::arg("draft_tokens"), // 草稿token张量 - py::arg("pre_ids"), // 前置ID张量 - py::arg("seq_lens_this_time"), // 当前步骤序列长度张量 - py::arg("seq_lens_encoder"), // 编码器序列长度张量 - py::arg("seq_lens_decoder"), // 解码器序列长度张量 - py::arg("step_idx"), // 步骤索引张量 - py::arg("output_cum_offsets"), // 输出累积偏移量张量 - py::arg("stop_flags"), // 停止标志张量 - py::arg("not_need_stop"), // 无需停止标志张量 - py::arg("max_dec_len"), // 最大解码长度张量 - py::arg("end_ids"), // 结束ID张量 - py::arg("base_model_draft_tokens"), // 基础模型草稿token张量 - py::arg("max_seq_len"), // 最大序列长度(int) - py::arg("substep") // 子步骤编号(int) - ); + m.def("prof_start", &prof_start, "prof_start"); + + m.def("prof_stop", &prof_stop, "prof_stop"); + + m.def("recover_decode_task", + &RecoverDecodeTask, + py::arg("stop_flags"), + py::arg("seq_lens_this_time"), + py::arg("seq_lens_encoder"), + py::arg("seq_lens_decoder"), + py::arg("step_seq_lens_decoder"), + py::arg("block_tables"), + py::arg("is_block_step"), + py::arg("block_size"), + "Recover decode task function"); + + m.def("save_output", + &SaveOutMmsgStatic, + py::arg("x"), + py::arg("not_need_stop"), + py::arg("rank_id"), + py::arg("save_each_rank"), + "Save output function"); + + m.def("save_output_dynamic", + &SaveOutMmsgDynamic, + py::arg("x"), + py::arg("not_need_stop"), + py::arg("rank_id"), + py::arg("msg_queue_id"), + py::arg("save_each_rank"), + "Save output dynamic function"); + + m.def("share_external_data", + &ShareExternalData, + py::arg("input"), + py::arg("shm_name"), + py::arg("shape"), + py::arg("use_ipc"), + "Share external data function"); m.def("speculate_get_token_penalty_multi_scores", &SpeculateTokenPenaltyMultiScores, @@ -582,15 +960,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("stop_nums"), "Update speculative decoding states (V3)"); - m.def("top_p_candidates", - &TopPCandidates, - py::arg("probs"), - py::arg("top_p"), - py::arg("output_padding_offset"), - py::arg("candidates_len"), - py::arg("max_seq_len"), - "Generate top-p candidates based on probability distributions"); - m.def("speculate_verify", &SpeculateVerify, py::arg("accept_tokens"), @@ -633,61 +1002,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("step_idx"), "Set values based on flags and indices in speculative decoding"); - m.def("draft_model_preprocess", - &DraftModelPreprocess, - py::arg("draft_tokens"), - py::arg("input_ids"), - py::arg("stop_flags"), - py::arg("seq_lens_this_time"), - py::arg("seq_lens_encoder"), - py::arg("seq_lens_decoder"), - py::arg("step_idx"), - py::arg("seq_lens_encoder_record"), - py::arg("seq_lens_decoder_record"), - py::arg("not_need_stop"), - py::arg("batch_drop"), - py::arg("accept_tokens"), - py::arg("accept_num"), - py::arg("base_model_seq_lens_encoder"), - py::arg("base_model_seq_lens_decoder"), - py::arg("base_model_step_idx"), - py::arg("base_model_stop_flags"), - py::arg("base_model_is_block_step"), - py::arg("base_model_draft_tokens"), - py::arg("max_draft_token"), - py::arg("truncate_first_token"), - py::arg("splitwise_prefill"), - "Preprocess data for draft model in speculative decoding"); - - m.def("draft_model_postprocess", - &DraftModelPostprocess, - py::arg("base_model_draft_tokens"), - py::arg("base_model_seq_lens_this_time"), - py::arg("base_model_seq_lens_encoder"), - py::arg("base_model_stop_flags"), - "Postprocess data for draft model in speculative decoding"); - - m.def("eagle_get_hidden_states", - &EagleGetHiddenStates, - py::arg("input"), - py::arg("seq_lens_this_time"), - py::arg("seq_lens_encoder"), - py::arg("seq_lens_decoder"), - py::arg("stop_flags"), - py::arg("accept_nums"), - py::arg("base_model_seq_lens_this_time"), - py::arg("base_model_seq_lens_encoder"), - py::arg("actual_draft_token_num"), - "Get draft model hidden states"); - - m.def("eagle_get_self_hidden_states", - &EagleGetSelfHiddenStates, - py::arg("input"), - py::arg("last_seq_lens_this_time"), - py::arg("seq_lens_this_time"), - py::arg("step_idx"), - "Rebuild draft model hidden states"); - m.def("speculate_get_output_padding_offset", &SpeculateGetOutputPaddingOffset, py::arg("output_cum_offsets_tmp"), @@ -706,23 +1020,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("seq_lens_encoder"), "Get padding offset"); - m.def("mtp_step_paddle", - &MTPStepPaddle, - py::arg("base_model_stop_flags"), - py::arg("stop_flags"), - py::arg("batch_drop"), - py::arg("seq_lens_this_time"), - py::arg("seq_lens_encoder"), - py::arg("seq_lens_decoder"), - py::arg("block_tables"), // [bsz, block_num_per_seq] - py::arg("encoder_block_lens"), - py::arg("used_list_len"), - py::arg("free_list"), - py::arg("free_list_len"), - py::arg("block_size"), - py::arg("max_draft_tokens"), - "MTP step paddle"); - m.def("speculate_step_reschedule", &SpeculateStepSchedule, py::arg("stop_flags"), @@ -760,6 +1057,147 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("seq_lens_decoder"), "Get sequence lengths output"); + m.def("set_data_ipc", + &SetDataIpc, + py::arg("tmp_input"), + py::arg("shm_name"), + "Set data IPC function"); + + m.def("set_ncluster", &set_ncluster, "set ncluster"); + + m.def("set_stop_value_multi_ends", + &GetStopFlagsMulti, + py::arg("topk_ids"), + py::arg("stop_flags"), + py::arg("seq_lens"), + py::arg("end_ids"), + py::arg("next_tokens"), + py::arg("beam_search"), + "Set stop value multi ends function"); + + m.def("step_paddle", + &StepPaddle, + py::arg("stop_flags"), + py::arg("seq_lens_this_time"), + py::arg("ori_seq_lens_encoder"), + py::arg("seq_lens_encoder"), + py::arg("seq_lens_decoder"), + py::arg("block_tables"), + py::arg("encoder_block_lens"), + py::arg("is_block_step"), + py::arg("step_block_list"), + py::arg("step_lens"), + py::arg("recover_block_list"), + py::arg("recover_lens"), + py::arg("need_block_list"), + py::arg("need_block_len"), + py::arg("used_list_len"), + py::arg("free_list"), + py::arg("free_list_len"), + py::arg("input_ids"), + py::arg("pre_ids"), + py::arg("step_idx"), + py::arg("next_tokens"), + py::arg("first_token_ids"), + py::arg("block_size"), + py::arg("encoder_decoder_block_num"), + "Step paddle function"); + + m.def("text_image_gather_scatter", + &TextImageGatherScatter, + py::arg("input"), + py::arg("text_input"), + py::arg("image_input"), + py::arg("token_type_ids"), + py::arg("text_index"), + py::arg("image_index"), + py::arg("is_scatter"), + "Scatter image and text from hidden states, or gather them to hidden " + "states"); + + m.def("text_image_index_out", + &TextImageIndexOut, + py::arg("token_type_ids"), + py::arg("text_index"), + py::arg("image_index"), + "Generate index for text and image"); + + m.def("top_p_candidates", + &TopPCandidates, + py::arg("probs"), + py::arg("top_p"), + py::arg("output_padding_offset"), + py::arg("candidates_len"), + py::arg("max_seq_len"), + "Generate top-p candidates based on probability distributions"); + + m.def("update_inputs", + &UpdateInputs, + py::arg("stop_flags"), + py::arg("not_need_stop"), + py::arg("seq_lens_this_time"), + py::arg("seq_lens_encoder"), + py::arg("seq_lens_decoder"), + py::arg("input_ids"), + py::arg("stop_nums"), + py::arg("next_tokens"), + py::arg("is_block_step"), + "Update inputs function"); + + m.def("update_inputs_v1", + &UpdateInputsV1, + py::arg("stop_flags"), + py::arg("not_need_stop"), + py::arg("seq_lens_this_time"), + py::arg("seq_lens_encoder"), + py::arg("seq_lens_decoder"), + py::arg("step_seq_lens_decoder"), + py::arg("prompt_lens"), + py::arg("topk_ids"), + py::arg("input_ids"), + py::arg("block_tables"), + py::arg("stop_nums"), + py::arg("next_tokens"), + py::arg("is_block_step"), + py::arg("block_size"), + "Update inputs v1 function"); + + m.def("weight_quantize_xpu", + &WeightQuantize, + py::arg("x"), + py::arg("algo"), + py::arg("arch"), + py::arg("group_size"), + "Quantize weights on XPU"); + + m.def("weight_only_linear_xpu", + &WeightOnlyLinear, + "Weight-only quantized linear layer", + py::arg("x"), + py::arg("weight"), + py::arg("weight_scale"), + py::arg("bias"), + py::arg("weight_dtype"), + py::arg("arch"), + py::arg("group_size") = -1); + + m.def("xpu_moe_layer", + &MoeLayer, + py::arg("x"), + py::arg("gate_weight"), + py::arg("gate_correction_bias"), + py::arg("up_gate_proj_weight"), + py::arg("down_proj_weight"), + py::arg("up_gate_proj_bias"), + py::arg("down_proj_bias"), + py::arg("up_gate_proj_weight_scale"), + py::arg("down_proj_weight_scale"), + py::arg("down_proj_in_scale"), + py::arg("quant_method"), + py::arg("moe_top_k"), + py::arg("moe_group"), + "fused moe op(topk + dispatch + ffn + combine) in XPU"); + // 添加XPU错误信息的异常处理类 py::register_exception(m, "XPUError"); } diff --git a/custom_ops/xpu_ops/src/ops/update_inputs.cc b/custom_ops/xpu_ops/src/ops/update_inputs.cc index 53b057e302..ce2fb3f3ca 100644 --- a/custom_ops/xpu_ops/src/ops/update_inputs.cc +++ b/custom_ops/xpu_ops/src/ops/update_inputs.cc @@ -17,18 +17,18 @@ #include "paddle/phi/core/enforce.h" #include "xpu/plugin.h" -void UpdateInputes(const paddle::Tensor &stop_flags, - const paddle::Tensor ¬_need_stop, // cpu - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &input_ids, - const paddle::Tensor &stop_nums, - const paddle::Tensor &next_tokens, - const paddle::Tensor &is_block_step) { +void UpdateInputs(const paddle::Tensor& stop_flags, + const paddle::Tensor& not_need_stop, // cpu + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& input_ids, + const paddle::Tensor& stop_nums, + const paddle::Tensor& next_tokens, + const paddle::Tensor& is_block_step) { phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); - auto xpu_ctx = static_cast(dev_ctx); + auto xpu_ctx = static_cast(dev_ctx); const int max_bsz = stop_flags.shape()[0]; PADDLE_ENFORCE_LE( @@ -42,11 +42,11 @@ void UpdateInputes(const paddle::Tensor &stop_flags, int r = baidu::xpu::api::plugin::update_inputs( xpu_ctx->x_context(), - const_cast(not_need_stop_xpu.data()), - const_cast(seq_lens_this_time.data()), - const_cast(seq_lens_encoder.data()), - const_cast(seq_lens_decoder.data()), - const_cast(input_ids.data()), + const_cast(not_need_stop_xpu.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(input_ids.data()), stop_nums.data(), stop_flags.data(), is_block_step.data(), @@ -57,7 +57,7 @@ void UpdateInputes(const paddle::Tensor &stop_flags, PD_CHECK(r == 0, "baidu::xpu::api::plugin::update_inputs failed."); auto not_need_stop_cpu = not_need_stop_xpu.copy_to(not_need_stop.place(), false); - bool *not_need_stop_data = const_cast(not_need_stop.data()); + bool* not_need_stop_data = const_cast(not_need_stop.data()); not_need_stop_data[0] = not_need_stop_cpu.data()[0]; } @@ -81,4 +81,4 @@ PD_BUILD_OP(update_inputs) {"seq_lens_encoder", "seq_lens_encoder_out"}, {"seq_lens_decoder", "seq_lens_decoder_out"}, {"input_ids", "input_ids_out"}}) - .SetKernelFn(PD_KERNEL(UpdateInputes)); + .SetKernelFn(PD_KERNEL(UpdateInputs)); diff --git a/custom_ops/xpu_ops/src/ops/update_inputs_v1.cc b/custom_ops/xpu_ops/src/ops/update_inputs_v1.cc index 9e77e636f2..32e1619c14 100644 --- a/custom_ops/xpu_ops/src/ops/update_inputs_v1.cc +++ b/custom_ops/xpu_ops/src/ops/update_inputs_v1.cc @@ -17,23 +17,23 @@ #include "paddle/phi/core/enforce.h" #include "xpu/plugin.h" -void UpdateInputesV1(const paddle::Tensor &stop_flags, - const paddle::Tensor ¬_need_stop, // only on cpu - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &step_seq_lens_decoder, - const paddle::Tensor &prompt_lens, - const paddle::Tensor &topk_ids, - const paddle::Tensor &input_ids, - const paddle::Tensor &block_tables, - const paddle::Tensor &stop_nums, - const paddle::Tensor &next_tokens, - const paddle::Tensor &is_block_step, - const int block_size) { +void UpdateInputsV1(const paddle::Tensor& stop_flags, + const paddle::Tensor& not_need_stop, // only on cpu + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& step_seq_lens_decoder, + const paddle::Tensor& prompt_lens, + const paddle::Tensor& topk_ids, + const paddle::Tensor& input_ids, + const paddle::Tensor& block_tables, + const paddle::Tensor& stop_nums, + const paddle::Tensor& next_tokens, + const paddle::Tensor& is_block_step, + const int block_size) { phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); - auto xpu_ctx = static_cast(dev_ctx); + auto xpu_ctx = static_cast(dev_ctx); const int max_bsz = stop_flags.shape()[0]; const int now_bsz = seq_lens_this_time.shape()[0]; @@ -43,18 +43,18 @@ void UpdateInputesV1(const paddle::Tensor &stop_flags, auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false); int r = baidu::xpu::api::plugin::update_inputs_v1( xpu_ctx->x_context(), - const_cast(not_need_stop_gpu.data()), - const_cast(seq_lens_this_time.data()), - const_cast(seq_lens_encoder.data()), - const_cast(seq_lens_decoder.data()), - const_cast(step_seq_lens_decoder.data()), - const_cast(prompt_lens.data()), - const_cast(topk_ids.data()), - const_cast(input_ids.data()), - const_cast(block_tables.data()), + const_cast(not_need_stop_gpu.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(step_seq_lens_decoder.data()), + const_cast(prompt_lens.data()), + const_cast(topk_ids.data()), + const_cast(input_ids.data()), + const_cast(block_tables.data()), stop_nums.data(), - const_cast(stop_flags.data()), - const_cast(is_block_step.data()), + const_cast(stop_flags.data()), + const_cast(is_block_step.data()), next_tokens.data(), now_bsz, max_bsz, @@ -64,7 +64,7 @@ void UpdateInputesV1(const paddle::Tensor &stop_flags, PD_CHECK(r == 0, "baidu::xpu::api::plugin::update_inputs_kernel_v1 failed."); auto not_need_stop_cpu = not_need_stop_gpu.copy_to(not_need_stop.place(), false); - bool *not_need_stop_data = const_cast(not_need_stop.data()); + bool* not_need_stop_data = const_cast(not_need_stop.data()); not_need_stop_data[0] = not_need_stop_cpu.data()[0]; } @@ -101,4 +101,4 @@ PD_BUILD_OP(update_inputs_v1) {"stop_flags", "stop_flags_out"}, {"step_seq_lens_decoder", "step_seq_lens_decoder_out"}, {"is_block_step", "is_block_step_out"}}) - .SetKernelFn(PD_KERNEL(UpdateInputesV1)); + .SetKernelFn(PD_KERNEL(UpdateInputsV1)); diff --git a/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py b/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py index 9407849f58..540988266d 100644 --- a/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py +++ b/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py @@ -27,6 +27,7 @@ from fastdeploy.model_executor.ops.xpu import ( moe_expert_ffn, moe_topk_select, weight_quantize_xpu, + xpu_moe_layer, ) @@ -153,8 +154,6 @@ class XPUMoEMethod(MoEMethodBase): """ Apply TP Fused Op. """ - from fastdeploy.model_executor.ops.xpu import xpu_moe_layer - fused_moe_out = xpu_moe_layer( x, gate.weight.transpose([1, 0]),