diff --git a/custom_ops/gpu_ops/append_attn/decoder_mla_attention_kernel.cu b/custom_ops/gpu_ops/append_attn/decoder_mla_attention_kernel.cu new file mode 100644 index 0000000000..6e2d9eb2ba --- /dev/null +++ b/custom_ops/gpu_ops/append_attn/decoder_mla_attention_kernel.cu @@ -0,0 +1,142 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "helper.h" +#include "multiquery_decoder_attention_kernel.h" +#include "utils.cuh" + +template +void DecodeMLAAttentionKernel( + const AppendAttnMetaData &meta_data, + const paddle::Tensor &q, // [token_num, num_heads, head_dim] + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::optional &attn_mask, + const paddle::optional &shift_bias, + const paddle::optional &smooth_weight, + const paddle::Tensor &seq_lens_q, // q_seq_len is 1 + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &block_table, + int max_seq_len, + int max_dec_len, + float softmax_scale, + float in_scale, + bool causal, + cudaStream_t &stream, + paddle::Tensor *out) { + const auto token_num = meta_data.token_nums; + const auto block_size = meta_data.block_size; + const auto bsz = meta_data.batch_size; + const auto num_heads = meta_data.q_num_heads; + const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads; + const auto head_dim_qk = meta_data.head_dims; + const auto head_dim_v = meta_data.head_dims_v; + const float rope_scale = 0.0; + const float rope_theta = 0.0; + const uint32_t deal_each_time = get_cascade_attention_deal_each_time(); + const uint32_t num_stage = get_cascade_attention_num_stages(); + const uint32_t num_threads = get_cascade_attention_num_threads(); + + DISPATCH_CAUSAL( + causal, + CAUSAL, + {DISPATCH_MLA_GROUP_SIZE( + group_size, + GROUP_SIZE, + {DISPATCH_MLA_HEAD_DIM( + head_dim_qk, + HEAD_DIM_QK, + {DISPATCH_MLA_HEAD_DIM( + head_dim_v, + HEAD_DIM_V, + {DISPATCH_BLOCK_SIZE( + block_size, + BLOCK_SIZE, + {DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME, { + MultiQueryDecoderAttention( + meta_data, + stream, + q, + cache_k, + cache_v, + attn_mask, + shift_bias, + smooth_weight, + seq_lens_q, + seq_lens_kv, + batch_id_per_token, + cu_seqlens_q, + block_table, + max_seq_len, + max_dec_len, + rope_scale, + rope_theta, + softmax_scale, + in_scale, + out); + })})})})})}); +} + +template void DecodeMLAAttentionKernel( + const AppendAttnMetaData &meta_data, + const paddle::Tensor &q, // [token_num, num_heads, head_dim] + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::optional &attn_mask, + const paddle::optional &shift_bias, + const paddle::optional &smooth_weight, + const paddle::Tensor &seq_lens_q, // q_seq_len is 1 + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &block_table, + int max_seq_len, + int max_dec_len, + float softmax_scale, + float in_scale, + bool causal, + cudaStream_t &stream, + paddle::Tensor *out); + +template void DecodeMLAAttentionKernel( + const AppendAttnMetaData &meta_data, + const paddle::Tensor &q, // [token_num, num_heads, head_dim] + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::optional &attn_mask, + const paddle::optional &shift_bias, + const paddle::optional &smooth_weight, + const paddle::Tensor &seq_lens_q, // q_seq_len is 1 + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &block_table, + int max_seq_len, + int max_dec_len, + float softmax_scale, + float in_scale, + bool causal, + cudaStream_t &stream, + paddle::Tensor *out); diff --git a/custom_ops/gpu_ops/append_attn/decoder_mla_attention_kernel.h b/custom_ops/gpu_ops/append_attn/decoder_mla_attention_kernel.h new file mode 100644 index 0000000000..1546f37685 --- /dev/null +++ b/custom_ops/gpu_ops/append_attn/decoder_mla_attention_kernel.h @@ -0,0 +1,39 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "helper.h" +#include "utils.cuh" + +template +void DecodeMLAAttentionKernel( + const AppendAttnMetaData &meta_data, + const paddle::Tensor &q, // [token_num, num_heads, head_dim] + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::optional &attn_mask, + const paddle::optional &shift_bias, + const paddle::optional &smooth_weight, + const paddle::Tensor &seq_lens_q, // q_seq_len is 1 + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &block_table, + int max_seq_len, + int max_dec_len, + float softmax_scale, + float in_scale, + bool causal, + cudaStream_t &stream, + paddle::Tensor *out); diff --git a/custom_ops/gpu_ops/append_attn/multi_head_latent_attention_kernel.h b/custom_ops/gpu_ops/append_attn/multi_head_latent_attention_kernel.h deleted file mode 100644 index 54e4fd6de9..0000000000 --- a/custom_ops/gpu_ops/append_attn/multi_head_latent_attention_kernel.h +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once - -#include "helper.h" -#include "utils.cuh" -#include "multiquery_decoder_attention_impl.cuh" - -template -void DecodeMLAAttentionKernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor &q, // [token_num, num_heads, head_dim] - const paddle::Tensor &cache_k, - const paddle::Tensor &cache_v, - const paddle::optional& attn_mask, - const paddle::optional& shift_bias, - const paddle::optional& smooth_weight, - const paddle::Tensor &seq_lens_q, // q_seq_len is 1 - const paddle::Tensor &seq_lens_kv, - const paddle::Tensor &batch_id_per_token, - const paddle::Tensor &cu_seqlens_q, - const paddle::Tensor &block_table, - int max_seq_len, - int max_dec_len, - float softmax_scale, - float in_scale, - bool causal, - cudaStream_t &stream, - paddle::Tensor *out) { - const auto token_num = meta_data.token_nums; - const auto block_size = meta_data.block_size; - const auto bsz = meta_data.batch_size; - const auto num_heads = meta_data.q_num_heads; - const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads; - const auto head_dim_qk = meta_data.head_dims; - const auto head_dim_v = meta_data.head_dims_v; - const float rope_scale = 0.0; - const float rope_theta = 0.0; - const uint32_t deal_each_time = get_cascade_attention_deal_each_time(); - const uint32_t num_stage = get_cascade_attention_num_stages(); - const uint32_t num_threads = get_cascade_attention_num_threads(); - - DISPATCH_CAUSAL(causal, CAUSAL, - {DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, - {DISPATCH_MLA_HEAD_DIM(head_dim_qk, HEAD_DIM_QK, - {DISPATCH_MLA_HEAD_DIM(head_dim_v, HEAD_DIM_V, - {DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, - {DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME, - {MultiQueryDecoderAttention( - meta_data, stream, q, cache_k, cache_v, attn_mask, shift_bias, smooth_weight, seq_lens_q, seq_lens_kv, batch_id_per_token, cu_seqlens_q, - block_table, max_seq_len, max_dec_len, rope_scale, rope_theta, softmax_scale, in_scale, out);})})})})})}); -} - -template void DecodeMLAAttentionKernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor &q, // [token_num, num_heads, head_dim] - const paddle::Tensor &cache_k, - const paddle::Tensor &cache_v, - const paddle::optional& attn_mask, - const paddle::optional& shift_bias, - const paddle::optional& smooth_weight, - const paddle::Tensor &seq_lens_q, // q_seq_len is 1 - const paddle::Tensor &seq_lens_kv, - const paddle::Tensor &batch_id_per_token, - const paddle::Tensor &cu_seqlens_q, - const paddle::Tensor &block_table, - int max_seq_len, - int max_dec_len, - float softmax_scale, - float in_scale, - bool causal, - cudaStream_t &stream, - paddle::Tensor *out); - -template void DecodeMLAAttentionKernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor &q, // [token_num, num_heads, head_dim] - const paddle::Tensor &cache_k, - const paddle::Tensor &cache_v, - const paddle::optional& attn_mask, - const paddle::optional& shift_bias, - const paddle::optional& smooth_weight, - const paddle::Tensor &seq_lens_q, // q_seq_len is 1 - const paddle::Tensor &seq_lens_kv, - const paddle::Tensor &batch_id_per_token, - const paddle::Tensor &cu_seqlens_q, - const paddle::Tensor &block_table, - int max_seq_len, - int max_dec_len, - float softmax_scale, - float in_scale, - bool causal, - cudaStream_t &stream, - paddle::Tensor *out); diff --git a/custom_ops/gpu_ops/moe/fast_hardamard_kernel.cu b/custom_ops/gpu_ops/moe/fast_hardamard_kernel.cu deleted file mode 100644 index 1323cb4839..0000000000 --- a/custom_ops/gpu_ops/moe/fast_hardamard_kernel.cu +++ /dev/null @@ -1,1052 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "fast_hardamard_kernel.h" - -#define FULL_MASK 0xffffffff - -struct uint8 { - uint4 u; - uint4 v; -}; - -template struct BytesToType {}; - -template<> -struct BytesToType<32> { - using Type = uint8; - static_assert(sizeof(Type) == 32); -}; - -template<> struct BytesToType<16> { - using Type = uint4; - static_assert(sizeof(Type) == 16); -}; - -template<> struct BytesToType<8> { - using Type = uint64_t; - static_assert(sizeof(Type) == 8); -}; - -template<> struct BytesToType<4> { - using Type = uint32_t; - static_assert(sizeof(Type) == 4); -}; - -template<> struct BytesToType<2> { - using Type = uint16_t; - static_assert(sizeof(Type) == 2); -}; - -template<> struct BytesToType<1> { - using Type = uint8_t; - static_assert(sizeof(Type) == 1); -}; - -template -struct nv_type_traits { - using type = T; -}; - -template <> -struct nv_type_traits { - using type = half; -}; - -template <> -struct nv_type_traits { - using type = __nv_bfloat16; -}; - -template <> -struct nv_type_traits { - using type = int8_t; -}; - -#define DISPATCH_SP_logN(logN, kLogN, ...) \ - if (logN == 10) { \ - constexpr int kLogN = 10; \ - __VA_ARGS__ \ - } else if (logN == 9) { \ - constexpr int kLogN = 9; \ - __VA_ARGS__ \ - } else if (logN == 8) { \ - constexpr int kLogN = 8; \ - __VA_ARGS__ \ - } else if (logN == 7) { \ - constexpr int kLogN = 7; \ - __VA_ARGS__ \ - } else { \ - PADDLE_THROW(phi::errors::Unimplemented("logN = %d is unsupported!", logN)); \ - } - -#define DISPATCH_SP_VS(vec_size, VEC_SIZE, ...) \ - if (vec_size == 16) { \ - constexpr int VEC_SIZE = 16; \ - __VA_ARGS__ \ - } else if (vec_size == 8) { \ - constexpr int VEC_SIZE = 8; \ - __VA_ARGS__ \ - } else if (vec_size == 4) { \ - constexpr int VEC_SIZE = 4; \ - __VA_ARGS__ \ - } else if (vec_size == 2) { \ - constexpr int VEC_SIZE = 2; \ - __VA_ARGS__ \ - } else if (vec_size == 1) { \ - constexpr int VEC_SIZE = 1; \ - __VA_ARGS__ \ - } else { \ - PADDLE_THROW(phi::errors::Unimplemented("vec_size = %d is unsupported!", vec_size)); \ - } - -#define DISPATCH_logN(logN, kLogN, ...) \ - if (logN == 11) { \ - constexpr int kLogN = 11; \ - __VA_ARGS__ \ - } else if (logN == 12) { \ - constexpr int kLogN = 12; \ - __VA_ARGS__ \ - } else if (logN == 13) { \ - constexpr int kLogN = 13; \ - __VA_ARGS__ \ - } else if (logN == 14) { \ - constexpr int kLogN = 14; \ - __VA_ARGS__ \ - } else { \ - PADDLE_THROW(phi::errors::Unimplemented("unsupported logN")); \ - } - -template -__device__ __forceinline__ void hadamard_mult_thread_28_transpose(T x[28][VecSize]) { // 35 - T out[28]; -#pragma unroll - for (int vi = 0; vi < VecSize; vi++) { - out[0] = + x[0][vi] + x[1][vi] + x[2][vi] + x[3][vi] + x[4][vi] + x[5][vi] + x[6][vi] + x[7][vi] + x[8][vi] + x[9][vi] + x[10][vi] + x[11][vi] + x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] + x[16][vi] + x[17][vi] + x[18][vi] + x[19][vi] + x[20][vi] + x[21][vi] + x[22][vi] + x[23][vi] + x[24][vi] + x[25][vi] + x[26][vi] + x[27][vi]; - out[1] = + x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] - x[6][vi] - x[7][vi] - x[8][vi] - x[9][vi] + x[10][vi] + x[11][vi] - x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] + x[16][vi] - x[17][vi] + x[18][vi] + x[19][vi] - x[20][vi] - x[21][vi] - x[22][vi] - x[23][vi] + x[24][vi] + x[25][vi] - x[26][vi] + x[27][vi]; - out[2] = + x[0][vi] + x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] + x[6][vi] - x[7][vi] - x[8][vi] - x[9][vi] - x[10][vi] + x[11][vi] + x[12][vi] - x[13][vi] + x[14][vi] + x[15][vi] - x[16][vi] + x[17][vi] - x[18][vi] + x[19][vi] + x[20][vi] - x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] + x[25][vi] + x[26][vi] - x[27][vi]; - out[3] = + x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] + x[6][vi] + x[7][vi] - x[8][vi] - x[9][vi] - x[10][vi] - x[11][vi] + x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] + x[16][vi] - x[17][vi] + x[18][vi] - x[19][vi] + x[20][vi] + x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] + x[26][vi] + x[27][vi]; - out[4] = + x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] + x[5][vi] - x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] - x[10][vi] - x[11][vi] - x[12][vi] + x[13][vi] + x[14][vi] + x[15][vi] - x[16][vi] + x[17][vi] - x[18][vi] + x[19][vi] - x[20][vi] + x[21][vi] + x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] - x[26][vi] + x[27][vi]; - out[5] = + x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] + x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] - x[11][vi] - x[12][vi] - x[13][vi] + x[14][vi] + x[15][vi] + x[16][vi] - x[17][vi] + x[18][vi] - x[19][vi] + x[20][vi] - x[21][vi] + x[22][vi] + x[23][vi] - x[24][vi] - x[25][vi] - x[26][vi] - x[27][vi]; - out[6] = + x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] + x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] - x[11][vi] - x[12][vi] - x[13][vi] + x[14][vi] - x[15][vi] + x[16][vi] + x[17][vi] - x[18][vi] + x[19][vi] - x[20][vi] + x[21][vi] - x[22][vi] + x[23][vi] + x[24][vi] - x[25][vi] - x[26][vi] - x[27][vi]; - out[7] = + x[0][vi] - x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] + x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] + x[11][vi] - x[12][vi] - x[13][vi] + x[14][vi] - x[15][vi] - x[16][vi] + x[17][vi] + x[18][vi] - x[19][vi] + x[20][vi] - x[21][vi] + x[22][vi] - x[23][vi] + x[24][vi] + x[25][vi] - x[26][vi] - x[27][vi]; - out[8] = + x[0][vi] - x[1][vi] - x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] - x[6][vi] + x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] + x[11][vi] + x[12][vi] - x[13][vi] + x[14][vi] - x[15][vi] - x[16][vi] - x[17][vi] + x[18][vi] + x[19][vi] - x[20][vi] + x[21][vi] - x[22][vi] + x[23][vi] - x[24][vi] + x[25][vi] + x[26][vi] - x[27][vi]; - out[9] = + x[0][vi] - x[1][vi] - x[2][vi] - x[3][vi] - x[4][vi] + x[5][vi] + x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] + x[10][vi] - x[11][vi] + x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] - x[16][vi] - x[17][vi] - x[18][vi] + x[19][vi] + x[20][vi] - x[21][vi] + x[22][vi] - x[23][vi] + x[24][vi] - x[25][vi] + x[26][vi] + x[27][vi]; - out[10] = + x[0][vi] + x[1][vi] - x[2][vi] - x[3][vi] - x[4][vi] - x[5][vi] + x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] + x[11][vi] - x[12][vi] + x[13][vi] + x[14][vi] + x[15][vi] - x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] + x[20][vi] + x[21][vi] - x[22][vi] + x[23][vi] - x[24][vi] + x[25][vi] - x[26][vi] + x[27][vi]; - out[11] = + x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] - x[4][vi] - x[5][vi] - x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] + x[11][vi] + x[12][vi] - x[13][vi] + x[14][vi] + x[15][vi] + x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] + x[21][vi] + x[22][vi] - x[23][vi] + x[24][vi] - x[25][vi] + x[26][vi] - x[27][vi]; - out[12] = + x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] - x[5][vi] - x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] + x[11][vi] + x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] + x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] - x[21][vi] + x[22][vi] + x[23][vi] - x[24][vi] + x[25][vi] - x[26][vi] + x[27][vi]; - out[13] = + x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] - x[6][vi] - x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] - x[11][vi] + x[12][vi] + x[13][vi] + x[14][vi] + x[15][vi] - x[16][vi] + x[17][vi] + x[18][vi] - x[19][vi] - x[20][vi] - x[21][vi] - x[22][vi] + x[23][vi] + x[24][vi] - x[25][vi] + x[26][vi] - x[27][vi]; - out[14] = - x[0][vi] + x[1][vi] + x[2][vi] + x[3][vi] + x[4][vi] + x[5][vi] + x[6][vi] + x[7][vi] + x[8][vi] + x[9][vi] + x[10][vi] + x[11][vi] + x[12][vi] + x[13][vi] - x[14][vi] - x[15][vi] - x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] - x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] - x[26][vi] - x[27][vi]; - out[15] = + x[0][vi] - x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] - x[6][vi] - x[7][vi] - x[8][vi] - x[9][vi] + x[10][vi] + x[11][vi] - x[12][vi] + x[13][vi] - x[14][vi] - x[15][vi] - x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] + x[20][vi] + x[21][vi] + x[22][vi] + x[23][vi] - x[24][vi] - x[25][vi] + x[26][vi] - x[27][vi]; - out[16] = + x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] + x[6][vi] - x[7][vi] - x[8][vi] - x[9][vi] - x[10][vi] + x[11][vi] + x[12][vi] - x[13][vi] - x[14][vi] - x[15][vi] - x[16][vi] - x[17][vi] + x[18][vi] - x[19][vi] - x[20][vi] + x[21][vi] + x[22][vi] + x[23][vi] + x[24][vi] - x[25][vi] - x[26][vi] + x[27][vi]; - out[17] = + x[0][vi] - x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] - x[5][vi] + x[6][vi] + x[7][vi] - x[8][vi] - x[9][vi] - x[10][vi] - x[11][vi] + x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] - x[16][vi] - x[17][vi] - x[18][vi] + x[19][vi] - x[20][vi] - x[21][vi] + x[22][vi] + x[23][vi] + x[24][vi] + x[25][vi] - x[26][vi] - x[27][vi]; - out[18] = + x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] - x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] - x[10][vi] - x[11][vi] - x[12][vi] + x[13][vi] - x[14][vi] - x[15][vi] + x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] + x[20][vi] - x[21][vi] - x[22][vi] + x[23][vi] + x[24][vi] + x[25][vi] + x[26][vi] - x[27][vi]; - out[19] = + x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] - x[5][vi] + x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] - x[11][vi] - x[12][vi] - x[13][vi] - x[14][vi] - x[15][vi] - x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] + x[21][vi] - x[22][vi] - x[23][vi] + x[24][vi] + x[25][vi] + x[26][vi] + x[27][vi]; - out[20] = + x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] - x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] - x[11][vi] - x[12][vi] - x[13][vi] - x[14][vi] + x[15][vi] - x[16][vi] - x[17][vi] + x[18][vi] - x[19][vi] - x[20][vi] - x[21][vi] + x[22][vi] - x[23][vi] - x[24][vi] + x[25][vi] + x[26][vi] + x[27][vi]; - out[21] = + x[0][vi] - x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] + x[6][vi] - x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] + x[11][vi] - x[12][vi] - x[13][vi] - x[14][vi] + x[15][vi] + x[16][vi] - x[17][vi] - x[18][vi] + x[19][vi] - x[20][vi] - x[21][vi] - x[22][vi] + x[23][vi] - x[24][vi] - x[25][vi] + x[26][vi] + x[27][vi]; - out[22] = + x[0][vi] - x[1][vi] - x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] - x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] - x[10][vi] + x[11][vi] + x[12][vi] - x[13][vi] - x[14][vi] + x[15][vi] + x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] + x[20][vi] - x[21][vi] - x[22][vi] - x[23][vi] + x[24][vi] - x[25][vi] - x[26][vi] + x[27][vi]; - out[23] = + x[0][vi] - x[1][vi] - x[2][vi] - x[3][vi] - x[4][vi] + x[5][vi] + x[6][vi] - x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] - x[11][vi] + x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] + x[16][vi] + x[17][vi] + x[18][vi] - x[19][vi] - x[20][vi] + x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] + x[25][vi] - x[26][vi] - x[27][vi]; - out[24] = + x[0][vi] + x[1][vi] - x[2][vi] - x[3][vi] - x[4][vi] - x[5][vi] + x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] - x[10][vi] + x[11][vi] - x[12][vi] + x[13][vi] - x[14][vi] - x[15][vi] + x[16][vi] + x[17][vi] + x[18][vi] + x[19][vi] - x[20][vi] - x[21][vi] + x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] + x[26][vi] - x[27][vi]; - out[25] = + x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] - x[4][vi] - x[5][vi] - x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] - x[11][vi] + x[12][vi] - x[13][vi] - x[14][vi] - x[15][vi] - x[16][vi] + x[17][vi] + x[18][vi] + x[19][vi] + x[20][vi] - x[21][vi] - x[22][vi] + x[23][vi] - x[24][vi] - x[25][vi] - x[26][vi] + x[27][vi]; - out[26] = + x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] - x[5][vi] - x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] + x[11][vi] - x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] - x[16][vi] - x[17][vi] + x[18][vi] + x[19][vi] + x[20][vi] + x[21][vi] - x[22][vi] - x[23][vi] + x[24][vi] - x[25][vi] - x[26][vi] - x[27][vi]; - out[27] = + x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] - x[6][vi] - x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] - x[11][vi] + x[12][vi] - x[13][vi] - x[14][vi] - x[15][vi] + x[16][vi] - x[17][vi] - x[18][vi] + x[19][vi] + x[20][vi] + x[21][vi] + x[22][vi] - x[23][vi] - x[24][vi] + x[25][vi] - x[26][vi] - x[27][vi]; - #pragma unroll - for (int i = 0; i < 28; i++) { x[i][vi] = out[i]; } - } -} - -template -__device__ __forceinline__ void hadamard_mult_thread_36_transpose(T x[36][VecSize]) { // 4t - T out[36]; -#pragma unroll - for (int vi = 0; vi < VecSize; vi++) { - out[0] = + x[0][vi] + x[1][vi] + x[2][vi] + x[3][vi] + x[4][vi] + x[5][vi] + x[6][vi] + x[7][vi] + x[8][vi] + x[9][vi] + x[10][vi] + x[11][vi] + x[12][vi] + x[13][vi] + x[14][vi] + x[15][vi] + x[16][vi] + x[17][vi] - x[18][vi] + x[19][vi] + x[20][vi] + x[21][vi] + x[22][vi] + x[23][vi] + x[24][vi] + x[25][vi] + x[26][vi] + x[27][vi] + x[28][vi] + x[29][vi] + x[30][vi] + x[31][vi] + x[32][vi] + x[33][vi] + x[34][vi] + x[35][vi]; - out[1] = + x[0][vi] + x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] - x[6][vi] - x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] - x[11][vi] - x[12][vi] - x[13][vi] + x[14][vi] - x[15][vi] + x[16][vi] + x[17][vi] + x[18][vi] - x[19][vi] + x[20][vi] + x[21][vi] - x[22][vi] + x[23][vi] - x[24][vi] - x[25][vi] - x[26][vi] + x[27][vi] + x[28][vi] - x[29][vi] - x[30][vi] - x[31][vi] + x[32][vi] - x[33][vi] + x[34][vi] + x[35][vi]; - out[2] = + x[0][vi] + x[1][vi] + x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] + x[6][vi] - x[7][vi] - x[8][vi] - x[9][vi] + x[10][vi] + x[11][vi] - x[12][vi] - x[13][vi] - x[14][vi] + x[15][vi] - x[16][vi] + x[17][vi] + x[18][vi] + x[19][vi] - x[20][vi] + x[21][vi] + x[22][vi] - x[23][vi] + x[24][vi] - x[25][vi] - x[26][vi] - x[27][vi] + x[28][vi] + x[29][vi] - x[30][vi] - x[31][vi] - x[32][vi] + x[33][vi] - x[34][vi] + x[35][vi]; - out[3] = + x[0][vi] + x[1][vi] + x[2][vi] + x[3][vi] + x[4][vi] + x[5][vi] - x[6][vi] + x[7][vi] - x[8][vi] - x[9][vi] - x[10][vi] + x[11][vi] + x[12][vi] - x[13][vi] - x[14][vi] - x[15][vi] + x[16][vi] - x[17][vi] + x[18][vi] + x[19][vi] + x[20][vi] - x[21][vi] + x[22][vi] + x[23][vi] - x[24][vi] + x[25][vi] - x[26][vi] - x[27][vi] - x[28][vi] + x[29][vi] + x[30][vi] - x[31][vi] - x[32][vi] - x[33][vi] + x[34][vi] - x[35][vi]; - out[4] = + x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] + x[4][vi] + x[5][vi] + x[6][vi] - x[7][vi] + x[8][vi] - x[9][vi] - x[10][vi] - x[11][vi] + x[12][vi] + x[13][vi] - x[14][vi] - x[15][vi] - x[16][vi] + x[17][vi] + x[18][vi] - x[19][vi] + x[20][vi] + x[21][vi] - x[22][vi] + x[23][vi] + x[24][vi] - x[25][vi] + x[26][vi] - x[27][vi] - x[28][vi] - x[29][vi] + x[30][vi] + x[31][vi] - x[32][vi] - x[33][vi] - x[34][vi] + x[35][vi]; - out[5] = + x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] + x[5][vi] + x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] - x[10][vi] - x[11][vi] - x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] - x[16][vi] - x[17][vi] + x[18][vi] + x[19][vi] - x[20][vi] + x[21][vi] + x[22][vi] - x[23][vi] + x[24][vi] + x[25][vi] - x[26][vi] + x[27][vi] - x[28][vi] - x[29][vi] - x[30][vi] + x[31][vi] + x[32][vi] - x[33][vi] - x[34][vi] - x[35][vi]; - out[6] = + x[0][vi] - x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] + x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] - x[11][vi] - x[12][vi] - x[13][vi] + x[14][vi] + x[15][vi] - x[16][vi] - x[17][vi] + x[18][vi] - x[19][vi] + x[20][vi] - x[21][vi] + x[22][vi] + x[23][vi] - x[24][vi] + x[25][vi] + x[26][vi] - x[27][vi] + x[28][vi] - x[29][vi] - x[30][vi] - x[31][vi] + x[32][vi] + x[33][vi] - x[34][vi] - x[35][vi]; - out[7] = + x[0][vi] - x[1][vi] - x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] + x[6][vi] + x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] + x[11][vi] - x[12][vi] - x[13][vi] - x[14][vi] + x[15][vi] + x[16][vi] - x[17][vi] + x[18][vi] - x[19][vi] - x[20][vi] + x[21][vi] - x[22][vi] + x[23][vi] + x[24][vi] - x[25][vi] + x[26][vi] + x[27][vi] - x[28][vi] + x[29][vi] - x[30][vi] - x[31][vi] - x[32][vi] + x[33][vi] + x[34][vi] - x[35][vi]; - out[8] = + x[0][vi] - x[1][vi] - x[2][vi] - x[3][vi] + x[4][vi] - x[5][vi] + x[6][vi] + x[7][vi] + x[8][vi] + x[9][vi] + x[10][vi] - x[11][vi] + x[12][vi] - x[13][vi] - x[14][vi] - x[15][vi] + x[16][vi] + x[17][vi] + x[18][vi] - x[19][vi] - x[20][vi] - x[21][vi] + x[22][vi] - x[23][vi] + x[24][vi] + x[25][vi] - x[26][vi] + x[27][vi] + x[28][vi] - x[29][vi] + x[30][vi] - x[31][vi] - x[32][vi] - x[33][vi] + x[34][vi] + x[35][vi]; - out[9] = + x[0][vi] + x[1][vi] - x[2][vi] - x[3][vi] - x[4][vi] + x[5][vi] - x[6][vi] + x[7][vi] + x[8][vi] + x[9][vi] + x[10][vi] + x[11][vi] - x[12][vi] + x[13][vi] - x[14][vi] - x[15][vi] - x[16][vi] + x[17][vi] + x[18][vi] + x[19][vi] - x[20][vi] - x[21][vi] - x[22][vi] + x[23][vi] - x[24][vi] + x[25][vi] + x[26][vi] - x[27][vi] + x[28][vi] + x[29][vi] - x[30][vi] + x[31][vi] - x[32][vi] - x[33][vi] - x[34][vi] + x[35][vi]; - out[10] = + x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] - x[4][vi] - x[5][vi] + x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] + x[10][vi] + x[11][vi] + x[12][vi] - x[13][vi] + x[14][vi] - x[15][vi] - x[16][vi] - x[17][vi] + x[18][vi] + x[19][vi] + x[20][vi] - x[21][vi] - x[22][vi] - x[23][vi] + x[24][vi] - x[25][vi] + x[26][vi] + x[27][vi] - x[28][vi] + x[29][vi] + x[30][vi] - x[31][vi] + x[32][vi] - x[33][vi] - x[34][vi] - x[35][vi]; - out[11] = + x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] - x[5][vi] - x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] + x[11][vi] + x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] - x[16][vi] - x[17][vi] + x[18][vi] - x[19][vi] + x[20][vi] + x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] + x[25][vi] - x[26][vi] + x[27][vi] + x[28][vi] - x[29][vi] + x[30][vi] + x[31][vi] - x[32][vi] + x[33][vi] - x[34][vi] - x[35][vi]; - out[12] = + x[0][vi] - x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] - x[6][vi] - x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] + x[11][vi] + x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] + x[16][vi] - x[17][vi] + x[18][vi] - x[19][vi] - x[20][vi] + x[21][vi] + x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] + x[26][vi] - x[27][vi] + x[28][vi] + x[29][vi] - x[30][vi] + x[31][vi] + x[32][vi] - x[33][vi] + x[34][vi] - x[35][vi]; - out[13] = + x[0][vi] - x[1][vi] - x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] - x[6][vi] - x[7][vi] - x[8][vi] + x[9][vi] - x[10][vi] + x[11][vi] + x[12][vi] + x[13][vi] + x[14][vi] + x[15][vi] - x[16][vi] + x[17][vi] + x[18][vi] - x[19][vi] - x[20][vi] - x[21][vi] + x[22][vi] + x[23][vi] - x[24][vi] - x[25][vi] - x[26][vi] + x[27][vi] - x[28][vi] + x[29][vi] + x[30][vi] - x[31][vi] + x[32][vi] + x[33][vi] - x[34][vi] + x[35][vi]; - out[14] = + x[0][vi] + x[1][vi] - x[2][vi] - x[3][vi] - x[4][vi] + x[5][vi] + x[6][vi] - x[7][vi] - x[8][vi] - x[9][vi] + x[10][vi] - x[11][vi] + x[12][vi] + x[13][vi] + x[14][vi] + x[15][vi] + x[16][vi] - x[17][vi] + x[18][vi] + x[19][vi] - x[20][vi] - x[21][vi] - x[22][vi] + x[23][vi] + x[24][vi] - x[25][vi] - x[26][vi] - x[27][vi] + x[28][vi] - x[29][vi] + x[30][vi] + x[31][vi] - x[32][vi] + x[33][vi] + x[34][vi] - x[35][vi]; - out[15] = + x[0][vi] - x[1][vi] + x[2][vi] - x[3][vi] - x[4][vi] - x[5][vi] + x[6][vi] + x[7][vi] - x[8][vi] - x[9][vi] - x[10][vi] + x[11][vi] - x[12][vi] + x[13][vi] + x[14][vi] + x[15][vi] + x[16][vi] + x[17][vi] + x[18][vi] - x[19][vi] + x[20][vi] - x[21][vi] - x[22][vi] - x[23][vi] + x[24][vi] + x[25][vi] - x[26][vi] - x[27][vi] - x[28][vi] + x[29][vi] - x[30][vi] + x[31][vi] + x[32][vi] - x[33][vi] + x[34][vi] + x[35][vi]; - out[16] = + x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] - x[4][vi] - x[5][vi] - x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] - x[10][vi] - x[11][vi] + x[12][vi] - x[13][vi] + x[14][vi] + x[15][vi] + x[16][vi] + x[17][vi] + x[18][vi] + x[19][vi] - x[20][vi] + x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] + x[25][vi] + x[26][vi] - x[27][vi] - x[28][vi] - x[29][vi] + x[30][vi] - x[31][vi] + x[32][vi] + x[33][vi] - x[34][vi] + x[35][vi]; - out[17] = + x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] - x[5][vi] - x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] - x[11][vi] - x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] + x[16][vi] + x[17][vi] + x[18][vi] + x[19][vi] + x[20][vi] - x[21][vi] + x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] + x[26][vi] + x[27][vi] - x[28][vi] - x[29][vi] - x[30][vi] + x[31][vi] - x[32][vi] + x[33][vi] + x[34][vi] - x[35][vi]; - out[18] = - x[0][vi] + x[1][vi] + x[2][vi] + x[3][vi] + x[4][vi] + x[5][vi] + x[6][vi] + x[7][vi] + x[8][vi] + x[9][vi] + x[10][vi] + x[11][vi] + x[12][vi] + x[13][vi] + x[14][vi] + x[15][vi] + x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] - x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] - x[26][vi] - x[27][vi] - x[28][vi] - x[29][vi] - x[30][vi] - x[31][vi] - x[32][vi] - x[33][vi] - x[34][vi] - x[35][vi]; - out[19] = + x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] - x[6][vi] - x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] - x[11][vi] - x[12][vi] - x[13][vi] + x[14][vi] - x[15][vi] + x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] - x[21][vi] + x[22][vi] - x[23][vi] + x[24][vi] + x[25][vi] + x[26][vi] - x[27][vi] - x[28][vi] + x[29][vi] + x[30][vi] + x[31][vi] - x[32][vi] + x[33][vi] - x[34][vi] - x[35][vi]; - out[20] = + x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] + x[6][vi] - x[7][vi] - x[8][vi] - x[9][vi] + x[10][vi] + x[11][vi] - x[12][vi] - x[13][vi] - x[14][vi] + x[15][vi] - x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] - x[21][vi] - x[22][vi] + x[23][vi] - x[24][vi] + x[25][vi] + x[26][vi] + x[27][vi] - x[28][vi] - x[29][vi] + x[30][vi] + x[31][vi] + x[32][vi] - x[33][vi] + x[34][vi] - x[35][vi]; - out[21] = + x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] - x[6][vi] + x[7][vi] - x[8][vi] - x[9][vi] - x[10][vi] + x[11][vi] + x[12][vi] - x[13][vi] - x[14][vi] - x[15][vi] + x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] - x[21][vi] - x[22][vi] - x[23][vi] + x[24][vi] - x[25][vi] + x[26][vi] + x[27][vi] + x[28][vi] - x[29][vi] - x[30][vi] + x[31][vi] + x[32][vi] + x[33][vi] - x[34][vi] + x[35][vi]; - out[22] = + x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] + x[6][vi] - x[7][vi] + x[8][vi] - x[9][vi] - x[10][vi] - x[11][vi] + x[12][vi] + x[13][vi] - x[14][vi] - x[15][vi] - x[16][vi] + x[17][vi] - x[18][vi] + x[19][vi] - x[20][vi] - x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] + x[25][vi] - x[26][vi] + x[27][vi] + x[28][vi] + x[29][vi] - x[30][vi] - x[31][vi] + x[32][vi] + x[33][vi] + x[34][vi] - x[35][vi]; - out[23] = + x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] + x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] - x[10][vi] - x[11][vi] - x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] - x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] + x[20][vi] - x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] + x[26][vi] - x[27][vi] + x[28][vi] + x[29][vi] + x[30][vi] - x[31][vi] - x[32][vi] + x[33][vi] + x[34][vi] + x[35][vi]; - out[24] = + x[0][vi] - x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] - x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] - x[11][vi] - x[12][vi] - x[13][vi] + x[14][vi] + x[15][vi] - x[16][vi] - x[17][vi] - x[18][vi] + x[19][vi] - x[20][vi] + x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] - x[26][vi] + x[27][vi] - x[28][vi] + x[29][vi] + x[30][vi] + x[31][vi] - x[32][vi] - x[33][vi] + x[34][vi] + x[35][vi]; - out[25] = + x[0][vi] - x[1][vi] - x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] + x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] + x[11][vi] - x[12][vi] - x[13][vi] - x[14][vi] + x[15][vi] + x[16][vi] - x[17][vi] - x[18][vi] + x[19][vi] + x[20][vi] - x[21][vi] + x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] - x[26][vi] - x[27][vi] + x[28][vi] - x[29][vi] + x[30][vi] + x[31][vi] + x[32][vi] - x[33][vi] - x[34][vi] + x[35][vi]; - out[26] = + x[0][vi] - x[1][vi] - x[2][vi] - x[3][vi] + x[4][vi] - x[5][vi] + x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] - x[11][vi] + x[12][vi] - x[13][vi] - x[14][vi] - x[15][vi] + x[16][vi] + x[17][vi] - x[18][vi] + x[19][vi] + x[20][vi] + x[21][vi] - x[22][vi] + x[23][vi] - x[24][vi] - x[25][vi] - x[26][vi] - x[27][vi] - x[28][vi] + x[29][vi] - x[30][vi] + x[31][vi] + x[32][vi] + x[33][vi] - x[34][vi] - x[35][vi]; - out[27] = + x[0][vi] + x[1][vi] - x[2][vi] - x[3][vi] - x[4][vi] + x[5][vi] - x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] + x[11][vi] - x[12][vi] + x[13][vi] - x[14][vi] - x[15][vi] - x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] + x[20][vi] + x[21][vi] + x[22][vi] - x[23][vi] + x[24][vi] - x[25][vi] - x[26][vi] - x[27][vi] - x[28][vi] - x[29][vi] + x[30][vi] - x[31][vi] + x[32][vi] + x[33][vi] + x[34][vi] - x[35][vi]; - out[28] = + x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] - x[4][vi] - x[5][vi] + x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] + x[11][vi] + x[12][vi] - x[13][vi] + x[14][vi] - x[15][vi] - x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] + x[21][vi] + x[22][vi] + x[23][vi] - x[24][vi] + x[25][vi] - x[26][vi] - x[27][vi] - x[28][vi] - x[29][vi] - x[30][vi] + x[31][vi] - x[32][vi] + x[33][vi] + x[34][vi] + x[35][vi]; - out[29] = + x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] - x[5][vi] - x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] - x[11][vi] + x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] - x[16][vi] - x[17][vi] - x[18][vi] + x[19][vi] - x[20][vi] - x[21][vi] + x[22][vi] + x[23][vi] + x[24][vi] - x[25][vi] + x[26][vi] - x[27][vi] - x[28][vi] - x[29][vi] - x[30][vi] - x[31][vi] + x[32][vi] - x[33][vi] + x[34][vi] + x[35][vi]; - out[30] = + x[0][vi] - x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] - x[6][vi] - x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] + x[11][vi] - x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] + x[16][vi] - x[17][vi] - x[18][vi] + x[19][vi] + x[20][vi] - x[21][vi] - x[22][vi] + x[23][vi] + x[24][vi] + x[25][vi] - x[26][vi] + x[27][vi] - x[28][vi] - x[29][vi] - x[30][vi] - x[31][vi] - x[32][vi] + x[33][vi] - x[34][vi] + x[35][vi]; - out[31] = + x[0][vi] - x[1][vi] - x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] - x[6][vi] - x[7][vi] - x[8][vi] + x[9][vi] - x[10][vi] + x[11][vi] + x[12][vi] - x[13][vi] + x[14][vi] + x[15][vi] - x[16][vi] + x[17][vi] - x[18][vi] + x[19][vi] + x[20][vi] + x[21][vi] - x[22][vi] - x[23][vi] + x[24][vi] + x[25][vi] + x[26][vi] - x[27][vi] + x[28][vi] - x[29][vi] - x[30][vi] - x[31][vi] - x[32][vi] - x[33][vi] + x[34][vi] - x[35][vi]; - out[32] = + x[0][vi] + x[1][vi] - x[2][vi] - x[3][vi] - x[4][vi] + x[5][vi] + x[6][vi] - x[7][vi] - x[8][vi] - x[9][vi] + x[10][vi] - x[11][vi] + x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] + x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] + x[20][vi] + x[21][vi] + x[22][vi] - x[23][vi] - x[24][vi] + x[25][vi] + x[26][vi] + x[27][vi] - x[28][vi] + x[29][vi] - x[30][vi] - x[31][vi] - x[32][vi] - x[33][vi] - x[34][vi] + x[35][vi]; - out[33] = + x[0][vi] - x[1][vi] + x[2][vi] - x[3][vi] - x[4][vi] - x[5][vi] + x[6][vi] + x[7][vi] - x[8][vi] - x[9][vi] - x[10][vi] + x[11][vi] - x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] + x[16][vi] + x[17][vi] - x[18][vi] + x[19][vi] - x[20][vi] + x[21][vi] + x[22][vi] + x[23][vi] - x[24][vi] - x[25][vi] + x[26][vi] + x[27][vi] + x[28][vi] - x[29][vi] + x[30][vi] - x[31][vi] - x[32][vi] - x[33][vi] - x[34][vi] - x[35][vi]; - out[34] = + x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] - x[4][vi] - x[5][vi] - x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] - x[10][vi] - x[11][vi] + x[12][vi] - x[13][vi] + x[14][vi] + x[15][vi] - x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] + x[20][vi] - x[21][vi] + x[22][vi] + x[23][vi] + x[24][vi] - x[25][vi] - x[26][vi] + x[27][vi] + x[28][vi] + x[29][vi] - x[30][vi] + x[31][vi] - x[32][vi] - x[33][vi] - x[34][vi] - x[35][vi]; - out[35] = + x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] - x[5][vi] - x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] - x[11][vi] - x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] + x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] + x[21][vi] - x[22][vi] + x[23][vi] + x[24][vi] + x[25][vi] - x[26][vi] - x[27][vi] + x[28][vi] + x[29][vi] + x[30][vi] - x[31][vi] + x[32][vi] - x[33][vi] - x[34][vi] - x[35][vi]; -#pragma unroll - for (int i = 0; i < 36; i++) { x[i][vi] = out[i]; } - } -} - -template -__device__ __forceinline__ void hadamard_mult_thread_28(T x[28]) { // 35 - T out[28]; - out[0] = + x[0] + x[1] + x[2] + x[3] + x[4] + x[5] + x[6] + x[7] + x[8] + x[9] + x[10] + x[11] + x[12] + x[13] - x[14] + x[15] + x[16] + x[17] + x[18] + x[19] + x[20] + x[21] + x[22] + x[23] + x[24] + x[25] + x[26] + x[27]; - out[1] = + x[0] + x[1] + x[2] - x[3] + x[4] + x[5] - x[6] - x[7] - x[8] - x[9] + x[10] + x[11] - x[12] + x[13] + x[14] - x[15] + x[16] - x[17] + x[18] + x[19] - x[20] - x[21] - x[22] - x[23] + x[24] + x[25] - x[26] + x[27]; - out[2] = + x[0] + x[1] + x[2] + x[3] - x[4] + x[5] + x[6] - x[7] - x[8] - x[9] - x[10] + x[11] + x[12] - x[13] + x[14] + x[15] - x[16] + x[17] - x[18] + x[19] + x[20] - x[21] - x[22] - x[23] - x[24] + x[25] + x[26] - x[27]; - out[3] = + x[0] - x[1] + x[2] + x[3] + x[4] - x[5] + x[6] + x[7] - x[8] - x[9] - x[10] - x[11] + x[12] + x[13] + x[14] - x[15] + x[16] - x[17] + x[18] - x[19] + x[20] + x[21] - x[22] - x[23] - x[24] - x[25] + x[26] + x[27]; - out[4] = + x[0] + x[1] - x[2] + x[3] + x[4] + x[5] - x[6] + x[7] + x[8] - x[9] - x[10] - x[11] - x[12] + x[13] + x[14] + x[15] - x[16] + x[17] - x[18] + x[19] - x[20] + x[21] + x[22] - x[23] - x[24] - x[25] - x[26] + x[27]; - out[5] = + x[0] + x[1] + x[2] - x[3] + x[4] + x[5] + x[6] - x[7] + x[8] + x[9] - x[10] - x[11] - x[12] - x[13] + x[14] + x[15] + x[16] - x[17] + x[18] - x[19] + x[20] - x[21] + x[22] + x[23] - x[24] - x[25] - x[26] - x[27]; - out[6] = + x[0] - x[1] + x[2] + x[3] - x[4] + x[5] + x[6] + x[7] - x[8] + x[9] + x[10] - x[11] - x[12] - x[13] + x[14] - x[15] + x[16] + x[17] - x[18] + x[19] - x[20] + x[21] - x[22] + x[23] + x[24] - x[25] - x[26] - x[27]; - out[7] = + x[0] - x[1] - x[2] + x[3] + x[4] - x[5] + x[6] + x[7] + x[8] - x[9] + x[10] + x[11] - x[12] - x[13] + x[14] - x[15] - x[16] + x[17] + x[18] - x[19] + x[20] - x[21] + x[22] - x[23] + x[24] + x[25] - x[26] - x[27]; - out[8] = + x[0] - x[1] - x[2] - x[3] + x[4] + x[5] - x[6] + x[7] + x[8] + x[9] - x[10] + x[11] + x[12] - x[13] + x[14] - x[15] - x[16] - x[17] + x[18] + x[19] - x[20] + x[21] - x[22] + x[23] - x[24] + x[25] + x[26] - x[27]; - out[9] = + x[0] - x[1] - x[2] - x[3] - x[4] + x[5] + x[6] - x[7] + x[8] + x[9] + x[10] - x[11] + x[12] + x[13] + x[14] - x[15] - x[16] - x[17] - x[18] + x[19] + x[20] - x[21] + x[22] - x[23] + x[24] - x[25] + x[26] + x[27]; - out[10] = + x[0] + x[1] - x[2] - x[3] - x[4] - x[5] + x[6] + x[7] - x[8] + x[9] + x[10] + x[11] - x[12] + x[13] + x[14] + x[15] - x[16] - x[17] - x[18] - x[19] + x[20] + x[21] - x[22] + x[23] - x[24] + x[25] - x[26] + x[27]; - out[11] = + x[0] + x[1] + x[2] - x[3] - x[4] - x[5] - x[6] + x[7] + x[8] - x[9] + x[10] + x[11] + x[12] - x[13] + x[14] + x[15] + x[16] - x[17] - x[18] - x[19] - x[20] + x[21] + x[22] - x[23] + x[24] - x[25] + x[26] - x[27]; - out[12] = + x[0] - x[1] + x[2] + x[3] - x[4] - x[5] - x[6] - x[7] + x[8] + x[9] - x[10] + x[11] + x[12] + x[13] + x[14] - x[15] + x[16] + x[17] - x[18] - x[19] - x[20] - x[21] + x[22] + x[23] - x[24] + x[25] - x[26] + x[27]; - out[13] = + x[0] + x[1] - x[2] + x[3] + x[4] - x[5] - x[6] - x[7] - x[8] + x[9] + x[10] - x[11] + x[12] + x[13] + x[14] + x[15] - x[16] + x[17] + x[18] - x[19] - x[20] - x[21] - x[22] + x[23] + x[24] - x[25] + x[26] - x[27]; - out[14] = - x[0] + x[1] + x[2] + x[3] + x[4] + x[5] + x[6] + x[7] + x[8] + x[9] + x[10] + x[11] + x[12] + x[13] - x[14] - x[15] - x[16] - x[17] - x[18] - x[19] - x[20] - x[21] - x[22] - x[23] - x[24] - x[25] - x[26] - x[27]; - out[15] = + x[0] - x[1] + x[2] - x[3] + x[4] + x[5] - x[6] - x[7] - x[8] - x[9] + x[10] + x[11] - x[12] + x[13] - x[14] - x[15] - x[16] + x[17] - x[18] - x[19] + x[20] + x[21] + x[22] + x[23] - x[24] - x[25] + x[26] - x[27]; - out[16] = + x[0] + x[1] - x[2] + x[3] - x[4] + x[5] + x[6] - x[7] - x[8] - x[9] - x[10] + x[11] + x[12] - x[13] - x[14] - x[15] - x[16] - x[17] + x[18] - x[19] - x[20] + x[21] + x[22] + x[23] + x[24] - x[25] - x[26] + x[27]; - out[17] = + x[0] - x[1] + x[2] - x[3] + x[4] - x[5] + x[6] + x[7] - x[8] - x[9] - x[10] - x[11] + x[12] + x[13] - x[14] + x[15] - x[16] - x[17] - x[18] + x[19] - x[20] - x[21] + x[22] + x[23] + x[24] + x[25] - x[26] - x[27]; - out[18] = + x[0] + x[1] - x[2] + x[3] - x[4] + x[5] - x[6] + x[7] + x[8] - x[9] - x[10] - x[11] - x[12] + x[13] - x[14] - x[15] + x[16] - x[17] - x[18] - x[19] + x[20] - x[21] - x[22] + x[23] + x[24] + x[25] + x[26] - x[27]; - out[19] = + x[0] + x[1] + x[2] - x[3] + x[4] - x[5] + x[6] - x[7] + x[8] + x[9] - x[10] - x[11] - x[12] - x[13] - x[14] - x[15] - x[16] + x[17] - x[18] - x[19] - x[20] + x[21] - x[22] - x[23] + x[24] + x[25] + x[26] + x[27]; - out[20] = + x[0] - x[1] + x[2] + x[3] - x[4] + x[5] - x[6] + x[7] - x[8] + x[9] + x[10] - x[11] - x[12] - x[13] - x[14] + x[15] - x[16] - x[17] + x[18] - x[19] - x[20] - x[21] + x[22] - x[23] - x[24] + x[25] + x[26] + x[27]; - out[21] = + x[0] - x[1] - x[2] + x[3] + x[4] - x[5] + x[6] - x[7] + x[8] - x[9] + x[10] + x[11] - x[12] - x[13] - x[14] + x[15] + x[16] - x[17] - x[18] + x[19] - x[20] - x[21] - x[22] + x[23] - x[24] - x[25] + x[26] + x[27]; - out[22] = + x[0] - x[1] - x[2] - x[3] + x[4] + x[5] - x[6] + x[7] - x[8] + x[9] - x[10] + x[11] + x[12] - x[13] - x[14] + x[15] + x[16] + x[17] - x[18] - x[19] + x[20] - x[21] - x[22] - x[23] + x[24] - x[25] - x[26] + x[27]; - out[23] = + x[0] - x[1] - x[2] - x[3] - x[4] + x[5] + x[6] - x[7] + x[8] - x[9] + x[10] - x[11] + x[12] + x[13] - x[14] + x[15] + x[16] + x[17] + x[18] - x[19] - x[20] + x[21] - x[22] - x[23] - x[24] + x[25] - x[26] - x[27]; - out[24] = + x[0] + x[1] - x[2] - x[3] - x[4] - x[5] + x[6] + x[7] - x[8] + x[9] - x[10] + x[11] - x[12] + x[13] - x[14] - x[15] + x[16] + x[17] + x[18] + x[19] - x[20] - x[21] + x[22] - x[23] - x[24] - x[25] + x[26] - x[27]; - out[25] = + x[0] + x[1] + x[2] - x[3] - x[4] - x[5] - x[6] + x[7] + x[8] - x[9] + x[10] - x[11] + x[12] - x[13] - x[14] - x[15] - x[16] + x[17] + x[18] + x[19] + x[20] - x[21] - x[22] + x[23] - x[24] - x[25] - x[26] + x[27]; - out[26] = + x[0] - x[1] + x[2] + x[3] - x[4] - x[5] - x[6] - x[7] + x[8] + x[9] - x[10] + x[11] - x[12] + x[13] - x[14] + x[15] - x[16] - x[17] + x[18] + x[19] + x[20] + x[21] - x[22] - x[23] + x[24] - x[25] - x[26] - x[27]; - out[27] = + x[0] + x[1] - x[2] + x[3] + x[4] - x[5] - x[6] - x[7] - x[8] + x[9] + x[10] - x[11] + x[12] - x[13] - x[14] - x[15] + x[16] - x[17] - x[18] + x[19] + x[20] + x[21] + x[22] - x[23] - x[24] + x[25] - x[26] - x[27]; -#pragma unroll - for (int i = 0; i < 28; i++) { x[i] = out[i]; } -} - -template -__device__ __forceinline__ void hadamard_mult_thread_36(T x[36]) { // 4t - T out[36]; - out[0] = + x[0] + x[1] + x[2] + x[3] + x[4] + x[5] + x[6] + x[7] + x[8] + x[9] + x[10] + x[11] + x[12] + x[13] + x[14] + x[15] + x[16] + x[17] - x[18] + x[19] + x[20] + x[21] + x[22] + x[23] + x[24] + x[25] + x[26] + x[27] + x[28] + x[29] + x[30] + x[31] + x[32] + x[33] + x[34] + x[35]; - out[1] = + x[0] + x[1] + x[2] + x[3] - x[4] + x[5] - x[6] - x[7] - x[8] + x[9] + x[10] - x[11] - x[12] - x[13] + x[14] - x[15] + x[16] + x[17] + x[18] - x[19] + x[20] + x[21] - x[22] + x[23] - x[24] - x[25] - x[26] + x[27] + x[28] - x[29] - x[30] - x[31] + x[32] - x[33] + x[34] + x[35]; - out[2] = + x[0] + x[1] + x[2] + x[3] + x[4] - x[5] + x[6] - x[7] - x[8] - x[9] + x[10] + x[11] - x[12] - x[13] - x[14] + x[15] - x[16] + x[17] + x[18] + x[19] - x[20] + x[21] + x[22] - x[23] + x[24] - x[25] - x[26] - x[27] + x[28] + x[29] - x[30] - x[31] - x[32] + x[33] - x[34] + x[35]; - out[3] = + x[0] + x[1] + x[2] + x[3] + x[4] + x[5] - x[6] + x[7] - x[8] - x[9] - x[10] + x[11] + x[12] - x[13] - x[14] - x[15] + x[16] - x[17] + x[18] + x[19] + x[20] - x[21] + x[22] + x[23] - x[24] + x[25] - x[26] - x[27] - x[28] + x[29] + x[30] - x[31] - x[32] - x[33] + x[34] - x[35]; - out[4] = + x[0] - x[1] + x[2] + x[3] + x[4] + x[5] + x[6] - x[7] + x[8] - x[9] - x[10] - x[11] + x[12] + x[13] - x[14] - x[15] - x[16] + x[17] + x[18] - x[19] + x[20] + x[21] - x[22] + x[23] + x[24] - x[25] + x[26] - x[27] - x[28] - x[29] + x[30] + x[31] - x[32] - x[33] - x[34] + x[35]; - out[5] = + x[0] + x[1] - x[2] + x[3] + x[4] + x[5] + x[6] + x[7] - x[8] + x[9] - x[10] - x[11] - x[12] + x[13] + x[14] - x[15] - x[16] - x[17] + x[18] + x[19] - x[20] + x[21] + x[22] - x[23] + x[24] + x[25] - x[26] + x[27] - x[28] - x[29] - x[30] + x[31] + x[32] - x[33] - x[34] - x[35]; - out[6] = + x[0] - x[1] + x[2] - x[3] + x[4] + x[5] + x[6] + x[7] + x[8] - x[9] + x[10] - x[11] - x[12] - x[13] + x[14] + x[15] - x[16] - x[17] + x[18] - x[19] + x[20] - x[21] + x[22] + x[23] - x[24] + x[25] + x[26] - x[27] + x[28] - x[29] - x[30] - x[31] + x[32] + x[33] - x[34] - x[35]; - out[7] = + x[0] - x[1] - x[2] + x[3] - x[4] + x[5] + x[6] + x[7] + x[8] + x[9] - x[10] + x[11] - x[12] - x[13] - x[14] + x[15] + x[16] - x[17] + x[18] - x[19] - x[20] + x[21] - x[22] + x[23] + x[24] - x[25] + x[26] + x[27] - x[28] + x[29] - x[30] - x[31] - x[32] + x[33] + x[34] - x[35]; - out[8] = + x[0] - x[1] - x[2] - x[3] + x[4] - x[5] + x[6] + x[7] + x[8] + x[9] + x[10] - x[11] + x[12] - x[13] - x[14] - x[15] + x[16] + x[17] + x[18] - x[19] - x[20] - x[21] + x[22] - x[23] + x[24] + x[25] - x[26] + x[27] + x[28] - x[29] + x[30] - x[31] - x[32] - x[33] + x[34] + x[35]; - out[9] = + x[0] + x[1] - x[2] - x[3] - x[4] + x[5] - x[6] + x[7] + x[8] + x[9] + x[10] + x[11] - x[12] + x[13] - x[14] - x[15] - x[16] + x[17] + x[18] + x[19] - x[20] - x[21] - x[22] + x[23] - x[24] + x[25] + x[26] - x[27] + x[28] + x[29] - x[30] + x[31] - x[32] - x[33] - x[34] + x[35]; - out[10] = + x[0] + x[1] + x[2] - x[3] - x[4] - x[5] + x[6] - x[7] + x[8] + x[9] + x[10] + x[11] + x[12] - x[13] + x[14] - x[15] - x[16] - x[17] + x[18] + x[19] + x[20] - x[21] - x[22] - x[23] + x[24] - x[25] + x[26] + x[27] - x[28] + x[29] + x[30] - x[31] + x[32] - x[33] - x[34] - x[35]; - out[11] = + x[0] - x[1] + x[2] + x[3] - x[4] - x[5] - x[6] + x[7] - x[8] + x[9] + x[10] + x[11] + x[12] + x[13] - x[14] + x[15] - x[16] - x[17] + x[18] - x[19] + x[20] + x[21] - x[22] - x[23] - x[24] + x[25] - x[26] + x[27] + x[28] - x[29] + x[30] + x[31] - x[32] + x[33] - x[34] - x[35]; - out[12] = + x[0] - x[1] - x[2] + x[3] + x[4] - x[5] - x[6] - x[7] + x[8] - x[9] + x[10] + x[11] + x[12] + x[13] + x[14] - x[15] + x[16] - x[17] + x[18] - x[19] - x[20] + x[21] + x[22] - x[23] - x[24] - x[25] + x[26] - x[27] + x[28] + x[29] - x[30] + x[31] + x[32] - x[33] + x[34] - x[35]; - out[13] = + x[0] - x[1] - x[2] - x[3] + x[4] + x[5] - x[6] - x[7] - x[8] + x[9] - x[10] + x[11] + x[12] + x[13] + x[14] + x[15] - x[16] + x[17] + x[18] - x[19] - x[20] - x[21] + x[22] + x[23] - x[24] - x[25] - x[26] + x[27] - x[28] + x[29] + x[30] - x[31] + x[32] + x[33] - x[34] + x[35]; - out[14] = + x[0] + x[1] - x[2] - x[3] - x[4] + x[5] + x[6] - x[7] - x[8] - x[9] + x[10] - x[11] + x[12] + x[13] + x[14] + x[15] + x[16] - x[17] + x[18] + x[19] - x[20] - x[21] - x[22] + x[23] + x[24] - x[25] - x[26] - x[27] + x[28] - x[29] + x[30] + x[31] - x[32] + x[33] + x[34] - x[35]; - out[15] = + x[0] - x[1] + x[2] - x[3] - x[4] - x[5] + x[6] + x[7] - x[8] - x[9] - x[10] + x[11] - x[12] + x[13] + x[14] + x[15] + x[16] + x[17] + x[18] - x[19] + x[20] - x[21] - x[22] - x[23] + x[24] + x[25] - x[26] - x[27] - x[28] + x[29] - x[30] + x[31] + x[32] - x[33] + x[34] + x[35]; - out[16] = + x[0] + x[1] - x[2] + x[3] - x[4] - x[5] - x[6] + x[7] + x[8] - x[9] - x[10] - x[11] + x[12] - x[13] + x[14] + x[15] + x[16] + x[17] + x[18] + x[19] - x[20] + x[21] - x[22] - x[23] - x[24] + x[25] + x[26] - x[27] - x[28] - x[29] + x[30] - x[31] + x[32] + x[33] - x[34] + x[35]; - out[17] = + x[0] + x[1] + x[2] - x[3] + x[4] - x[5] - x[6] - x[7] + x[8] + x[9] - x[10] - x[11] - x[12] + x[13] - x[14] + x[15] + x[16] + x[17] + x[18] + x[19] + x[20] - x[21] + x[22] - x[23] - x[24] - x[25] + x[26] + x[27] - x[28] - x[29] - x[30] + x[31] - x[32] + x[33] + x[34] - x[35]; - out[18] = - x[0] + x[1] + x[2] + x[3] + x[4] + x[5] + x[6] + x[7] + x[8] + x[9] + x[10] + x[11] + x[12] + x[13] + x[14] + x[15] + x[16] + x[17] - x[18] - x[19] - x[20] - x[21] - x[22] - x[23] - x[24] - x[25] - x[26] - x[27] - x[28] - x[29] - x[30] - x[31] - x[32] - x[33] - x[34] - x[35]; - out[19] = + x[0] - x[1] + x[2] + x[3] - x[4] + x[5] - x[6] - x[7] - x[8] + x[9] + x[10] - x[11] - x[12] - x[13] + x[14] - x[15] + x[16] + x[17] - x[18] - x[19] - x[20] - x[21] + x[22] - x[23] + x[24] + x[25] + x[26] - x[27] - x[28] + x[29] + x[30] + x[31] - x[32] + x[33] - x[34] - x[35]; - out[20] = + x[0] + x[1] - x[2] + x[3] + x[4] - x[5] + x[6] - x[7] - x[8] - x[9] + x[10] + x[11] - x[12] - x[13] - x[14] + x[15] - x[16] + x[17] - x[18] - x[19] - x[20] - x[21] - x[22] + x[23] - x[24] + x[25] + x[26] + x[27] - x[28] - x[29] + x[30] + x[31] + x[32] - x[33] + x[34] - x[35]; - out[21] = + x[0] + x[1] + x[2] - x[3] + x[4] + x[5] - x[6] + x[7] - x[8] - x[9] - x[10] + x[11] + x[12] - x[13] - x[14] - x[15] + x[16] - x[17] - x[18] - x[19] - x[20] - x[21] - x[22] - x[23] + x[24] - x[25] + x[26] + x[27] + x[28] - x[29] - x[30] + x[31] + x[32] + x[33] - x[34] + x[35]; - out[22] = + x[0] - x[1] + x[2] + x[3] - x[4] + x[5] + x[6] - x[7] + x[8] - x[9] - x[10] - x[11] + x[12] + x[13] - x[14] - x[15] - x[16] + x[17] - x[18] + x[19] - x[20] - x[21] - x[22] - x[23] - x[24] + x[25] - x[26] + x[27] + x[28] + x[29] - x[30] - x[31] + x[32] + x[33] + x[34] - x[35]; - out[23] = + x[0] + x[1] - x[2] + x[3] + x[4] - x[5] + x[6] + x[7] - x[8] + x[9] - x[10] - x[11] - x[12] + x[13] + x[14] - x[15] - x[16] - x[17] - x[18] - x[19] + x[20] - x[21] - x[22] - x[23] - x[24] - x[25] + x[26] - x[27] + x[28] + x[29] + x[30] - x[31] - x[32] + x[33] + x[34] + x[35]; - out[24] = + x[0] - x[1] + x[2] - x[3] + x[4] + x[5] - x[6] + x[7] + x[8] - x[9] + x[10] - x[11] - x[12] - x[13] + x[14] + x[15] - x[16] - x[17] - x[18] + x[19] - x[20] + x[21] - x[22] - x[23] - x[24] - x[25] - x[26] + x[27] - x[28] + x[29] + x[30] + x[31] - x[32] - x[33] + x[34] + x[35]; - out[25] = + x[0] - x[1] - x[2] + x[3] - x[4] + x[5] + x[6] - x[7] + x[8] + x[9] - x[10] + x[11] - x[12] - x[13] - x[14] + x[15] + x[16] - x[17] - x[18] + x[19] + x[20] - x[21] + x[22] - x[23] - x[24] - x[25] - x[26] - x[27] + x[28] - x[29] + x[30] + x[31] + x[32] - x[33] - x[34] + x[35]; - out[26] = + x[0] - x[1] - x[2] - x[3] + x[4] - x[5] + x[6] + x[7] - x[8] + x[9] + x[10] - x[11] + x[12] - x[13] - x[14] - x[15] + x[16] + x[17] - x[18] + x[19] + x[20] + x[21] - x[22] + x[23] - x[24] - x[25] - x[26] - x[27] - x[28] + x[29] - x[30] + x[31] + x[32] + x[33] - x[34] - x[35]; - out[27] = + x[0] + x[1] - x[2] - x[3] - x[4] + x[5] - x[6] + x[7] + x[8] - x[9] + x[10] + x[11] - x[12] + x[13] - x[14] - x[15] - x[16] + x[17] - x[18] - x[19] + x[20] + x[21] + x[22] - x[23] + x[24] - x[25] - x[26] - x[27] - x[28] - x[29] + x[30] - x[31] + x[32] + x[33] + x[34] - x[35]; - out[28] = + x[0] + x[1] + x[2] - x[3] - x[4] - x[5] + x[6] - x[7] + x[8] + x[9] - x[10] + x[11] + x[12] - x[13] + x[14] - x[15] - x[16] - x[17] - x[18] - x[19] - x[20] + x[21] + x[22] + x[23] - x[24] + x[25] - x[26] - x[27] - x[28] - x[29] - x[30] + x[31] - x[32] + x[33] + x[34] + x[35]; - out[29] = + x[0] - x[1] + x[2] + x[3] - x[4] - x[5] - x[6] + x[7] - x[8] + x[9] + x[10] - x[11] + x[12] + x[13] - x[14] + x[15] - x[16] - x[17] - x[18] + x[19] - x[20] - x[21] + x[22] + x[23] + x[24] - x[25] + x[26] - x[27] - x[28] - x[29] - x[30] - x[31] + x[32] - x[33] + x[34] + x[35]; - out[30] = + x[0] - x[1] - x[2] + x[3] + x[4] - x[5] - x[6] - x[7] + x[8] - x[9] + x[10] + x[11] - x[12] + x[13] + x[14] - x[15] + x[16] - x[17] - x[18] + x[19] + x[20] - x[21] - x[22] + x[23] + x[24] + x[25] - x[26] + x[27] - x[28] - x[29] - x[30] - x[31] - x[32] + x[33] - x[34] + x[35]; - out[31] = + x[0] - x[1] - x[2] - x[3] + x[4] + x[5] - x[6] - x[7] - x[8] + x[9] - x[10] + x[11] + x[12] - x[13] + x[14] + x[15] - x[16] + x[17] - x[18] + x[19] + x[20] + x[21] - x[22] - x[23] + x[24] + x[25] + x[26] - x[27] + x[28] - x[29] - x[30] - x[31] - x[32] - x[33] + x[34] - x[35]; - out[32] = + x[0] + x[1] - x[2] - x[3] - x[4] + x[5] + x[6] - x[7] - x[8] - x[9] + x[10] - x[11] + x[12] + x[13] - x[14] + x[15] + x[16] - x[17] - x[18] - x[19] + x[20] + x[21] + x[22] - x[23] - x[24] + x[25] + x[26] + x[27] - x[28] + x[29] - x[30] - x[31] - x[32] - x[33] - x[34] + x[35]; - out[33] = + x[0] - x[1] + x[2] - x[3] - x[4] - x[5] + x[6] + x[7] - x[8] - x[9] - x[10] + x[11] - x[12] + x[13] + x[14] - x[15] + x[16] + x[17] - x[18] + x[19] - x[20] + x[21] + x[22] + x[23] - x[24] - x[25] + x[26] + x[27] + x[28] - x[29] + x[30] - x[31] - x[32] - x[33] - x[34] - x[35]; - out[34] = + x[0] + x[1] - x[2] + x[3] - x[4] - x[5] - x[6] + x[7] + x[8] - x[9] - x[10] - x[11] + x[12] - x[13] + x[14] + x[15] - x[16] + x[17] - x[18] - x[19] + x[20] - x[21] + x[22] + x[23] + x[24] - x[25] - x[26] + x[27] + x[28] + x[29] - x[30] + x[31] - x[32] - x[33] - x[34] - x[35]; - out[35] = + x[0] + x[1] + x[2] - x[3] + x[4] - x[5] - x[6] - x[7] + x[8] + x[9] - x[10] - x[11] - x[12] + x[13] - x[14] + x[15] + x[16] - x[17] - x[18] - x[19] - x[20] + x[21] - x[22] + x[23] + x[24] + x[25] - x[26] - x[27] + x[28] + x[29] + x[30] - x[31] + x[32] - x[33] - x[34] - x[35]; -#pragma unroll - for (int i = 0; i < 36; i++) { x[i] = out[i]; } -} - -template -__device__ __forceinline__ void hadamard_mult_thread_chunk_28(T x[kNChunks][28]) { -#pragma unroll - for (int c = 0; c < kNChunks; ++c) { hadamard_mult_thread_28(x[c]); } -} - -template -__device__ __forceinline__ void hadamard_mult_thread_chunk_36(T x[kNChunks][36]) { -#pragma unroll - for (int c = 0; c < kNChunks; ++c) { hadamard_mult_thread_36(x[c]); } -} - -template -inline __device__ void load_input(const T *x, T x_vals[kNChunks][VecSize], int dim) { - using vec_t = typename BytesToType::Type; -#pragma unroll - for (int c = 0; c < kNChunks; ++c) { - int offset; - if constexpr (UseDiagonalBlockMatrix) { - static_assert(kNChunks == 1); - offset = blockIdx.y * blockDim.x + threadIdx.x; - } else { - offset = c * blockDim.x + threadIdx.x; - } - if (offset * VecSize < dim) { - reinterpret_cast(x_vals)[c] = reinterpret_cast(x)[offset]; - } - } -} - -template -__forceinline__ __device__ OutType QuantHelperFunc(const InType input, - const float scale, - const int round_type, - const float max_bound, - const float min_bound) { - float quant_value = max_bound * scale * static_cast(input); - - if (round_type == 0) { - quant_value = static_cast(rint(quant_value)); - } else { - quant_value = static_cast(round(quant_value)); - } - return static_cast(ClipFunc(quant_value, min_bound, max_bound)); -} - -template -inline __device__ void smooth_quant_store_output( - OutT *out, - const T *shift, - const T *smooth, - T out_vals[kNChunks][VecSize], - const float quant_scale, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - const int dim) { - using DstVec = AlignedVector; - using Vec = AlignedVector; - DstVec dst_vec; - Vec shift_vec; - Vec smooth_vec; -#pragma unroll - for (int c = 0; c < kNChunks; ++c) { - int base_idx; - if constexpr (UseDiagonalBlockMatrix) { - base_idx = blockIdx.y * blockDim.x + threadIdx.x; - } else { - base_idx = c * blockDim.x + threadIdx.x; - } - const int idx = base_idx * VecSize; - if (idx < dim) { - Load(shift + idx, &shift_vec); - Load(smooth + idx, &smooth_vec); -#pragma unroll - for (int vi = 0; vi < VecSize; ++vi) { - out_vals[c][vi] = (out_vals[c][vi] + shift_vec[vi]) * smooth_vec[vi]; - dst_vec[vi] = QuantHelperFunc( - static_cast(out_vals[c][vi]), - quant_scale, - quant_round_type, - quant_max_bound, - quant_min_bound); - } - Store(dst_vec, out + idx); - } - } -} - -template -inline __device__ void quant_store_output( - OutT *out, - T out_vals[kNChunks][VecSize], - const float quant_scale, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - const int dim) { - using DstVec = AlignedVector; - using Vec = AlignedVector; - DstVec dst_vec; -#pragma unroll - for (int c = 0; c < kNChunks; ++c) { - int base_idx; - if constexpr (UseDiagonalBlockMatrix) { - base_idx = blockIdx.y * blockDim.x + threadIdx.x; - } else { - base_idx = c * blockDim.x + threadIdx.x; - } - const int idx = base_idx * VecSize; - if (idx < dim) { -#pragma unroll - for (int vi = 0; vi < VecSize; ++vi) { - // out_vals[c][vi] = (out_vals[c][vi] + shift_vec[vi]) * smooth_vec[vi]; - dst_vec[vi] = QuantHelperFunc( - static_cast(out_vals[c][vi]), - quant_scale, - quant_round_type, - quant_max_bound, - quant_min_bound); - } - Store(dst_vec, out + idx); - } - } -} - -template -inline __device__ void store_output(OutT *out, T out_vals[kNChunks][VecSize], int dim) { - using vec_t = typename BytesToType::Type; -#pragma unroll - for (int c = 0; c < kNChunks; ++c) { - int offset; - if constexpr (UseDiagonalBlockMatrix) { - offset = blockIdx.y * blockDim.x + threadIdx.x; - } else { - offset = c * blockDim.x + threadIdx.x; - } - if (offset * VecSize < dim) { - reinterpret_cast(out)[offset] = reinterpret_cast(out_vals)[c]; - } - } -} - -template -__device__ __forceinline__ void hadamard_mult_thread_transpose(T x[1 << kLogN][kNChunks]) { - constexpr int N = 1 << kLogN; -#pragma unroll - for (int i = 0; i < kLogN; ++i) { - const int stride = 1 << i; -#pragma unroll - for (int j = 0; j < N / 2; ++j) { - const int lo = j & (stride - 1); - const int idx = (j - lo) * 2 + lo; -#pragma unroll - for (int c = 0; c < kNChunks; ++c) { - const T a = x[idx][c]; - const T b = x[idx + stride][c]; - x[idx][c] = a + b; - x[idx + stride][c] = a - b; - } - } - } -} - -template -__device__ __forceinline__ void hadamard_mult_thread(T x[kNChunks][1 << kLogN]) { - constexpr int N = 1 << kLogN; -#pragma unroll - for (int i = 0; i < kLogN; ++i) { - const int stride = 1 << i; -#pragma unroll - for (int j = 0; j < N / 2; ++j) { - const int lo = j & (stride - 1); - const int idx = (j - lo) * 2 + lo; -#pragma unroll - for (int c = 0; c < kNChunks; ++c) { - const T a = x[c][idx]; - const T b = x[c][idx + stride]; - x[c][idx] = a + b; - x[c][idx + stride] = a - b; - } - } - } -} - -template -__device__ __forceinline__ void hadamard_mult_warp(T x[kNChunks][kNItems]) { - constexpr int N = 1 << kLogWarpSize; - int lane_id = threadIdx.x % N; -#pragma unroll - for (int step = kStepStart; step < kLogWarpSize; ++step) { - const int lane_mask = 1 << step; - const T sign = (lane_id & lane_mask) ? -1.f : 1.f; -#pragma unroll - for (int c = 0; c < kNChunks; ++c) { -#pragma unroll - for (int i = 0; i < kNItems; ++i) { - T x_val_other = __shfl_xor_sync(FULL_MASK, x[c][i], lane_mask); - x[c][i] = sign * x[c][i] + x_val_other; - } - } - } -} - -template -inline __device__ void exchange_smem_pre(T x_vals[kNChunks][kNElts], vec_t *smem) { - // kNChunks表示整体需要多少次循环才能处理完 - // kChunksPerExchange表示每次循环可以处理多少个chunk - // kNExchanges表示多少次循环才能处理完所有数据 - constexpr int kNThreads = kWarpSize * kNWarps; - const int warp_id = threadIdx.x / kWarpSize; - const int lane_id = threadIdx.x % kWarpSize; - const int row_t = threadIdx.x % kNWarps; - const int col_t = threadIdx.x / kNWarps; -#pragma unroll - for (int c0 = 0; c0 < kNChunks / kChunksPerExchange; ++c0) { - // 搬多少次chunk算完所有数据 - __syncthreads(); -#pragma unroll - for (int c1 = 0; c1 < kChunksPerExchange; ++c1) { - // 每次循环搬多少数据把smem塞满 - // smem[c1 * kNThreads + (Pre ? warp_id * kWarpSize + lane_id ^ warp_id : row_t * kWarpSize + col_t ^ row_t)] = *reinterpret_cast(x_vals[c0 * kChunksPerExchange + c1]); - smem[c1 * kNThreads + (Pre ? warp_id * kWarpSize + lane_id : row_t * kWarpSize + col_t)] = *reinterpret_cast(x_vals[c0 * kChunksPerExchange + c1]); - } - __syncthreads(); -#pragma unroll - for (int c1 = 0; c1 < kChunksPerExchange; ++c1) { - // *reinterpret_cast(x_vals[c0 * kChunksPerExchange + c1]) = smem[c1 * kNThreads + (Pre ? row_t * kWarpSize + col_t ^ row_t : warp_id * kWarpSize + lane_id ^ warp_id)]; - *reinterpret_cast(x_vals[c0 * kChunksPerExchange + c1]) = smem[c1 * kNThreads + (Pre ? row_t * kWarpSize + col_t : warp_id * kWarpSize + lane_id)]; - } - } -} - -constexpr int cilog2(int val) { return val > 0 ? 1 + cilog2(val >> 1) : -1; } - -template -__global__ __launch_bounds__(kThreads) -void moe_fast_hardamard_kernel(const T *x, - const int64_t *expert_idx_per_token, - const T *shift, - const T *smooth, - const float* quant_scales, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - const int64_t token_num, - const int64_t dim, - OutT *out) { - using vec_t = typename BytesToType::Type; - constexpr int kLogVecSize = cilog2(VecSize); - constexpr int kLogWarpSize = cilog2(32); - constexpr int kWarpSize = 32; - constexpr int kNWarps = kThreads / kWarpSize; - constexpr int kLogNWarps = cilog2(kNWarps); - constexpr int kLogNChunks = cilog2(kNChunks); - - extern __shared__ char smem_[]; - vec_t *smem_exchange = reinterpret_cast(smem_); - - for (int token_id = blockIdx.x; token_id < token_num; token_id += gridDim.x) { - const T *x_now = x + token_id * dim; - OutT *out_now = out + token_id * dim; - T init_value = static_cast(0.f); - T x_vals[kNChunks][VecSize] = {init_value}; - - load_input(x_now, x_vals, dim); -#ifdef DEBUG_HARDAMARD - if (blockIdx.x == 0 && threadIdx.x == 0) { - for (int i = 0; i < 1; ++i) { - printf("chunk_id0: %d\n", i); - for (int j = 0; j < VecSize; ++j) { - printf("%f ", (float)x_vals[i][j]); - } - printf("\n"); - } - } - __syncthreads(); -#endif - - hadamard_mult_thread(x_vals); -#ifdef DEBUG_HARDAMARD - if (blockIdx.x == 0 && threadIdx.x == 0) { - for (int i = 0; i < 1; ++i) { - printf("chunk_id1: %d, kLogVecSize: %d\n", i, kLogVecSize); - for (int j = 0; j < VecSize; ++j) { - printf("%f ", (float)x_vals[i][j]); - } - printf("\n"); - } - } - __syncthreads(); -#endif - hadamard_mult_warp(x_vals); -#ifdef DEBUG_HARDAMARD - if (blockIdx.x == 0 && threadIdx.x == 0) { - for (int i = 0; i < 1; ++i) { - printf("chunk_id2: %d\n", i); - for (int j = 0; j < VecSize; ++j) { - printf("%f ", (float)x_vals[i][j]); - } - printf("\n"); - } - } - __syncthreads(); -#endif - if constexpr (kNWarps > 1) { - // 先让连续的NWARPS个线程拿到其余warps上的数据 - exchange_smem_pre(x_vals, smem_exchange); - // 交叉计算 - hadamard_mult_warp(x_vals); - // 再换回来 - exchange_smem_pre(x_vals, smem_exchange); - } - if constexpr (kNChunks > 1) { - if constexpr (kNChunks == 28) { - hadamard_mult_thread_28_transpose(x_vals); - } else if constexpr (kNChunks == 36) { - hadamard_mult_thread_36_transpose(x_vals); - } else { - constexpr int kLogNChunks = cilog2(kNChunks); - static_assert(1 << kLogNChunks == kNChunks, "kNChunks must be a power of 2"); - hadamard_mult_thread_transpose(x_vals); - } - } - if (quant_scales) { - int64_t expert_id = expert_idx_per_token[token_id]; - float quant_scale = quant_scales[expert_id]; - if (shift) { - smooth_quant_store_output( - out_now, - shift, - smooth, - x_vals, - quant_scale, - quant_round_type, - quant_max_bound, - quant_min_bound, - dim); - } else { - quant_store_output( - out_now, - x_vals, - quant_scale, - quant_round_type, - quant_max_bound, - quant_min_bound, - dim); - } - } else { - store_output(out_now, x_vals, dim); - } - } -} - -template -__global__ __launch_bounds__(kThreads) -void masked_moe_fast_hardamard_kernel(const T *x, - const int64_t *recv_expert_count, - const T *shift, - const T *smooth, - const float* quant_scales, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - const int64_t token_num, - const int64_t dim, - const int num_max_tokens_per_expert, - OutT *out) { - using vec_t = typename BytesToType::Type; - constexpr int kLogVecSize = cilog2(VecSize); - constexpr int kLogWarpSize = cilog2(32); - constexpr int kWarpSize = 32; - constexpr int kNWarps = kThreads / kWarpSize; - constexpr int kLogNWarps = cilog2(kNWarps); - constexpr int kLogNChunks = cilog2(kNChunks); - - extern __shared__ char smem_[]; - vec_t *smem_exchange = reinterpret_cast(smem_); - - for (int token_id = blockIdx.x; token_id < token_num; token_id += gridDim.x) { - const auto token_idx_in_expert = token_id % num_max_tokens_per_expert; - const auto expert_id = token_id / num_max_tokens_per_expert; - if (token_idx_in_expert >= recv_expert_count[expert_id]) { - auto next_expert_start_idx = (expert_id + 1) * num_max_tokens_per_expert; - auto num_iters_to_next_expert = (next_expert_start_idx - token_id - 1) / gridDim.x; - token_id += num_iters_to_next_expert * gridDim.x; - continue; - } - const T *x_now = x + token_id * dim; - OutT *out_now = out + token_id * dim; - T init_value = static_cast(0.f); - T x_vals[kNChunks][VecSize] = {init_value}; - - load_input(x_now, x_vals, dim); -#ifdef DEBUG_HARDAMARD - if (blockIdx.x == 0 && threadIdx.x == 0) { - for (int i = 0; i < 1; ++i) { - printf("chunk_id0: %d\n", i); - for (int j = 0; j < VecSize; ++j) { - printf("%f ", (float)x_vals[i][j]); - } - printf("\n"); - } - } - __syncthreads(); -#endif - - hadamard_mult_thread(x_vals); -#ifdef DEBUG_HARDAMARD - if (blockIdx.x == 0 && threadIdx.x == 0) { - for (int i = 0; i < 1; ++i) { - printf("chunk_id1: %d, kLogVecSize: %d\n", i, kLogVecSize); - for (int j = 0; j < VecSize; ++j) { - printf("%f ", (float)x_vals[i][j]); - } - printf("\n"); - } - } - __syncthreads(); -#endif - hadamard_mult_warp(x_vals); -#ifdef DEBUG_HARDAMARD - if (blockIdx.x == 0 && threadIdx.x == 0) { - for (int i = 0; i < 1; ++i) { - printf("chunk_id2: %d\n", i); - for (int j = 0; j < VecSize; ++j) { - printf("%f ", (float)x_vals[i][j]); - } - printf("\n"); - } - } - __syncthreads(); -#endif - if constexpr (kNWarps > 1) { - // 先让连续的NWARPS个线程拿到其余warps上的数据 - exchange_smem_pre(x_vals, smem_exchange); - // 交叉计算 - hadamard_mult_warp(x_vals); - // 再换回来 - exchange_smem_pre(x_vals, smem_exchange); - } - if constexpr (kNChunks > 1) { - if constexpr (kNChunks == 28) { - hadamard_mult_thread_28_transpose(x_vals); - } else if constexpr (kNChunks == 36) { - hadamard_mult_thread_36_transpose(x_vals); - } else { - constexpr int kLogNChunks = cilog2(kNChunks); - static_assert(1 << kLogNChunks == kNChunks, "kNChunks must be a power of 2"); - hadamard_mult_thread_transpose(x_vals); - } - } - if (quant_scales) { - float quant_scale = quant_scales[expert_id]; - if (shift) { - smooth_quant_store_output( - out_now, - shift, - smooth, - x_vals, - quant_scale, - quant_round_type, - quant_max_bound, - quant_min_bound, - dim); - } else { - quant_store_output( - out_now, - x_vals, - quant_scale, - quant_round_type, - quant_max_bound, - quant_min_bound, - dim); - } - } else { - store_output(out_now, x_vals, dim); - } - } -} - - -template -void MoeFastHardamardImplWrapper(const T *x, - const int64_t *expert_idx_per_token, - const int64_t *recv_expert_count, - const T *shift, - const T *smooth, - const float* quant_scales, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - const int64_t token_num, - const int64_t dim, - const int num_max_tokens_per_expert, - bool used_in_ep_low_latency, - OutT* out, - cudaStream_t stream) { - using nv_type = typename nv_type_traits::type; - using out_type = typename nv_type_traits::type; - constexpr int kNBytes = sizeof(T); - constexpr int N = 1 << kLogN; // pad - constexpr int kSmemSize = std::min(N * kNBytes, 32 * 1024); - constexpr int kRounds = N * kNBytes / kSmemSize; - constexpr int kChunksPerSmemSize = kSmemSize / (kThreads * VecSize * kNBytes); - VLOG(1) << "real_dim: " << dim << ", N: " << N; - VLOG(1) << "kNChunks: " << kNChunks; - VLOG(1) << "kNBytes: " << kNBytes; - VLOG(1) << "kSmemSize: " << kSmemSize; - VLOG(1) << "kRounds: " << kRounds; - VLOG(1) << "kChunksPerSmemSize: " << kChunksPerSmemSize; - const int dev_id = 0; - int sm_count; - int act_blocks_per_sm; - cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); - - if (used_in_ep_low_latency) { - auto masked_kernel = masked_moe_fast_hardamard_kernel; - cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &act_blocks_per_sm, masked_kernel, kThreads, kSmemSize); - const int num_blocks_per_wave = sm_count * act_blocks_per_sm; - dim3 grid; - grid.x = min(static_cast(num_blocks_per_wave), token_num); - if constexpr (UseDiagonalBlockMatrix) { - grid.y = ceil(dim / (kThreads * VecSize)); - } - masked_kernel<<>>( - reinterpret_cast(x), - recv_expert_count, - reinterpret_cast(shift), - reinterpret_cast(smooth), - quant_scales, - quant_round_type, - quant_max_bound, - quant_min_bound, - token_num, - dim, - num_max_tokens_per_expert, - reinterpret_cast(out) - ); - } else { - auto kernel = moe_fast_hardamard_kernel; - cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &act_blocks_per_sm, kernel, kThreads, kSmemSize); - const int num_blocks_per_wave = sm_count * act_blocks_per_sm; - dim3 grid; - grid.x = min(static_cast(num_blocks_per_wave), token_num); - if constexpr (UseDiagonalBlockMatrix) { - grid.y = ceil(dim / (kThreads * VecSize)); - } - kernel<<>>( - reinterpret_cast(x), - expert_idx_per_token, - reinterpret_cast(shift), - reinterpret_cast(smooth), - quant_scales, - quant_round_type, - quant_max_bound, - quant_min_bound, - token_num, - dim, - reinterpret_cast(out) - ); - } -} - -template -void MoeFastHardamardWrapper(const T *x_data, - const int64_t *expert_idx_per_token, - const int64_t *recv_expert_count, - const T *shift, - const T *smooth, - const float* quant_scales, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - const int64_t token_num, - const int64_t dim, - const int num_max_tokens_per_expert, - bool used_in_ep_low_latency, - const int hadamard_block_size, - OutT* out, - cudaStream_t &stream) { - bool FLAGS_hardamard_use_diagonal_block_matrix = true; - - constexpr int kThreads = 128; - if (FLAGS_hardamard_use_diagonal_block_matrix) { - const int VecSize = hadamard_block_size / kThreads; - const int logN = int(ceil(std::log2(kThreads * VecSize))); - constexpr int kNChunks = 1; - DISPATCH_SP_VS(VecSize, VEC_SIZE, { - DISPATCH_SP_logN(logN, kLogN, { - MoeFastHardamardImplWrapper( - x_data, - expert_idx_per_token, - recv_expert_count, - shift, - smooth, - quant_scales, - quant_round_type, - quant_max_bound, - quant_min_bound, - token_num, - dim, - num_max_tokens_per_expert, - used_in_ep_low_latency, - out, - stream); - })}); - } else { - if (!((dim / 28) & (dim / 28 - 1))) { - VLOG(1) << "28 * 2^n"; - const int logN = int(ceil(std::log2(dim / 28))); - constexpr int kNChunks = 28; - DISPATCH_SP_logN(logN, kLogN, { - constexpr int VecSize = (1 << kLogN) / kThreads; - MoeFastHardamardImplWrapper( - x_data, - expert_idx_per_token, - recv_expert_count, - shift, - smooth, - quant_scales, - quant_round_type, - quant_max_bound, - quant_min_bound, - token_num, - dim, - num_max_tokens_per_expert, - used_in_ep_low_latency, - out, - stream); - }); - } else if (!((dim / 36) & (dim / 36 - 1))) { - VLOG(1) << "36 * 2^n"; - const int logN = int(ceil(std::log2(dim / 36))); - constexpr int kNChunks = 36; - DISPATCH_SP_logN(logN, kLogN, { - constexpr int VecSize = (1 << kLogN) / kThreads; - MoeFastHardamardImplWrapper( - x_data, - expert_idx_per_token, - recv_expert_count, - shift, - smooth, - quant_scales, - quant_round_type, - quant_max_bound, - quant_min_bound, - token_num, - dim, - num_max_tokens_per_expert, - used_in_ep_low_latency, - out, - stream); - }); - } else { - VLOG(1) << "2^n"; - const int logN = int(ceil(std::log2(dim))); - constexpr int VecSize = 16 / sizeof(T); - DISPATCH_logN(logN, kLogN, { - constexpr int kNChunks = (1 << kLogN) / (kThreads * VecSize); - MoeFastHardamardImplWrapper( - x_data, - expert_idx_per_token, - recv_expert_count, - shift, - smooth, - quant_scales, - quant_round_type, - quant_max_bound, - quant_min_bound, - token_num, - dim, - num_max_tokens_per_expert, - used_in_ep_low_latency, - out, - stream); - }); - } - } -} - -template void MoeFastHardamardWrapper( - const phi::dtype::float16 *x_data, - const int64_t *expert_idx_per_token, - const int64_t *recv_expert_count, - const phi::dtype::float16 *shift, - const phi::dtype::float16 *smooth, - const float* quant_scales, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - const int64_t token_num, - const int64_t dim, - const int num_max_tokens_per_expert, - bool used_in_ep_low_latency, - const int hadamard_block_size, - phi::dtype::float16 *out, - cudaStream_t &stream -); - -template void MoeFastHardamardWrapper( - const phi::dtype::float16 *x_data, - const int64_t *expert_idx_per_token, - const int64_t *recv_expert_count, - const phi::dtype::float16 *shift, - const phi::dtype::float16 *smooth, - const float* quant_scales, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - const int64_t token_num, - const int64_t dim, - const int num_max_tokens_per_expert, - bool used_in_ep_low_latency, - const int hadamard_block_size, - int8_t *out, - cudaStream_t &stream -); - -template void MoeFastHardamardWrapper( - const phi::dtype::bfloat16 *x_data, - const int64_t *expert_idx_per_token, - const int64_t *recv_expert_count, - const phi::dtype::bfloat16 *shift, - const phi::dtype::bfloat16 *smooth, - const float* quant_scales, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - const int64_t token_num, - const int64_t dim, - const int num_max_tokens_per_expert, - bool used_in_ep_low_latency, - const int hadamard_block_size, - phi::dtype::bfloat16 *out, - cudaStream_t &stream -); - -template void MoeFastHardamardWrapper( - const phi::dtype::bfloat16 *x_data, - const int64_t *expert_idx_per_token, - const int64_t *recv_expert_count, - const phi::dtype::bfloat16 *shift, - const phi::dtype::bfloat16 *smooth, - const float* quant_scales, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - const int64_t token_num, - const int64_t dim, - const int num_max_tokens_per_expert, - bool used_in_ep_low_latency, - const int hadamard_block_size, - int8_t *out, - cudaStream_t &stream -); diff --git a/custom_ops/gpu_ops/moe/fast_hardamard_kernel.h b/custom_ops/gpu_ops/moe/fast_hardamard_kernel.h deleted file mode 100644 index ccb624e5c2..0000000000 --- a/custom_ops/gpu_ops/moe/fast_hardamard_kernel.h +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include "helper.h" - -template -void MoeFastHardamardWrapper(const T *x_data, - const int64_t *expert_idx_per_token, - const int64_t *recv_expert_count, - const T *shift, - const T *smooth, - const float* quant_scales, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - const int64_t token_num, - const int64_t dim, - const int num_max_tokens_per_expert, - bool used_in_ep_low_latency, - const int hadamard_block_size, - OutT* out, - cudaStream_t &stream); diff --git a/custom_ops/gpu_ops/moe/moe_expert_ffn_wint2.cu b/custom_ops/gpu_ops/moe/moe_expert_ffn_wint2.cu index f3e51bfcfa..288bbdf882 100644 --- a/custom_ops/gpu_ops/moe/moe_expert_ffn_wint2.cu +++ b/custom_ops/gpu_ops/moe/moe_expert_ffn_wint2.cu @@ -17,169 +17,190 @@ #include "cutlass/numeric_conversion.h" #include "group_swiglu_with_masked.h" #include "helper.h" -#include "moe/fast_hardamard_kernel.h" #include "moe/fused_moe_helper.h" -template +template void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input, - const paddle::Tensor& tokens_expert_prefix_sum, - const paddle::Tensor& up_gate_proj_weight, - const paddle::Tensor& down_proj_weight, - const paddle::Tensor* up_gate_proj_bias, - const paddle::Tensor* up_gate_proj_super_scale, - const paddle::Tensor* down_proj_super_scale, - const paddle::Tensor* up_gate_proj_local_scale, - const paddle::Tensor* up_gate_proj_code_scale, - const paddle::Tensor* up_gate_proj_code_zp, - const paddle::Tensor* down_proj_local_scale, - const paddle::Tensor* down_proj_code_scale, - const paddle::Tensor* down_proj_code_zp, - paddle::Tensor fc1_out, - paddle::Tensor ffn_out, - const int64_t total_rows_in_ll_else_minus1, - const int64_t actual_total_rows, - const int64_t inter_size, - const int64_t hidden_size, - const int num_experts, - bool used_in_ep_low_latency) { - using namespace phi; - using WeightOnlyTraits = cutlass::WintQuantTraits; - using WeightType = typename WeightOnlyTraits::WeightType; + const paddle::Tensor& tokens_expert_prefix_sum, + const paddle::Tensor& up_gate_proj_weight, + const paddle::Tensor& down_proj_weight, + const paddle::Tensor* up_gate_proj_bias, + const paddle::Tensor* up_gate_proj_super_scale, + const paddle::Tensor* down_proj_super_scale, + const paddle::Tensor* up_gate_proj_local_scale, + const paddle::Tensor* up_gate_proj_code_scale, + const paddle::Tensor* up_gate_proj_code_zp, + const paddle::Tensor* down_proj_local_scale, + const paddle::Tensor* down_proj_code_scale, + const paddle::Tensor* down_proj_code_zp, + paddle::Tensor fc1_out, + paddle::Tensor ffn_out, + const int64_t total_rows_in_ll_else_minus1, + const int64_t actual_total_rows, + const int64_t inter_size, + const int64_t hidden_size, + const int num_experts, + bool used_in_ep_low_latency) { + using namespace phi; + using WeightOnlyTraits = cutlass::WintQuantTraits; + using WeightType = typename WeightOnlyTraits::WeightType; - typename WeightOnlyTraits::Arguments up_gate_proj_quant_args; - typename WeightOnlyTraits::Arguments down_proj_quant_args; - if constexpr (QuantMethod == cutlass::WintQuantMethod::kWeightOnlyInt2) { - up_gate_proj_quant_args.local_scale_ptr = const_cast(up_gate_proj_local_scale->data()); - up_gate_proj_quant_args.code_scale_ptr = const_cast(up_gate_proj_code_scale->data()); - up_gate_proj_quant_args.code_zp_ptr = const_cast(up_gate_proj_code_zp->data()); + typename WeightOnlyTraits::Arguments up_gate_proj_quant_args; + typename WeightOnlyTraits::Arguments down_proj_quant_args; + if constexpr (QuantMethod == cutlass::WintQuantMethod::kWeightOnlyInt2) { + up_gate_proj_quant_args.local_scale_ptr = + const_cast(up_gate_proj_local_scale->data()); + up_gate_proj_quant_args.code_scale_ptr = + const_cast(up_gate_proj_code_scale->data()); + up_gate_proj_quant_args.code_zp_ptr = + const_cast(up_gate_proj_code_zp->data()); - down_proj_quant_args.local_scale_ptr = const_cast(down_proj_local_scale->data()); - down_proj_quant_args.code_scale_ptr = const_cast(down_proj_code_scale->data()); - down_proj_quant_args.code_zp_ptr = const_cast(down_proj_code_zp->data()); - } + down_proj_quant_args.local_scale_ptr = + const_cast(down_proj_local_scale->data()); + down_proj_quant_args.code_scale_ptr = + const_cast(down_proj_code_scale->data()); + down_proj_quant_args.code_zp_ptr = + const_cast(down_proj_code_zp->data()); + } - auto moe_gemm_runner = MoeGemmRunner(); - auto stream = permute_input.stream(); + auto moe_gemm_runner = MoeGemmRunner(); + auto stream = permute_input.stream(); - moe_gemm_runner.moe_gemm_bias_act( - reinterpret_cast(permute_input.data()), - reinterpret_cast(up_gate_proj_weight.data()), - reinterpret_cast(up_gate_proj_super_scale ? up_gate_proj_super_scale->data() : nullptr), - reinterpret_cast(up_gate_proj_bias ? up_gate_proj_bias->data() : nullptr), - reinterpret_cast(fc1_out.data()), - const_cast(tokens_expert_prefix_sum.data()), - total_rows_in_ll_else_minus1, - actual_total_rows, - inter_size, - hidden_size, - num_experts, - up_gate_proj_quant_args, - "none", - stream); + moe_gemm_runner.moe_gemm_bias_act( + reinterpret_cast(permute_input.data()), + reinterpret_cast( + up_gate_proj_weight.data()), + reinterpret_cast( + up_gate_proj_super_scale ? up_gate_proj_super_scale->data() + : nullptr), + reinterpret_cast( + up_gate_proj_bias ? up_gate_proj_bias->data() : nullptr), + reinterpret_cast(fc1_out.data()), + const_cast(tokens_expert_prefix_sum.data()), + total_rows_in_ll_else_minus1, + actual_total_rows, + inter_size, + hidden_size, + num_experts, + up_gate_proj_quant_args, + "none", + stream); - paddle::Tensor act_out; - if (used_in_ep_low_latency) { - act_out = GroupSwigluWithMasked(fc1_out, tokens_expert_prefix_sum); - } else { - act_out = paddle::experimental::swiglu(fc1_out, nullptr); - } + paddle::Tensor act_out; + if (used_in_ep_low_latency) { + act_out = GroupSwigluWithMasked(fc1_out, tokens_expert_prefix_sum); + } else { + act_out = paddle::experimental::swiglu(fc1_out, nullptr); + } - moe_gemm_runner.moe_gemm( - reinterpret_cast(act_out.data()), - reinterpret_cast(down_proj_weight.data()), - reinterpret_cast(down_proj_super_scale ? down_proj_super_scale->data() : nullptr), - reinterpret_cast(ffn_out.data()), - const_cast(tokens_expert_prefix_sum.data()), - total_rows_in_ll_else_minus1, - actual_total_rows, - hidden_size, - inter_size / 2, - num_experts, - down_proj_quant_args, - stream); + moe_gemm_runner.moe_gemm( + reinterpret_cast(act_out.data()), + reinterpret_cast( + down_proj_weight.data()), + reinterpret_cast(down_proj_super_scale + ? down_proj_super_scale->data() + : nullptr), + reinterpret_cast(ffn_out.data()), + const_cast(tokens_expert_prefix_sum.data()), + total_rows_in_ll_else_minus1, + actual_total_rows, + hidden_size, + inter_size / 2, + num_experts, + down_proj_quant_args, + stream); } template -void MoeFFNWint2Kernel(const paddle::Tensor& permute_input, - const paddle::Tensor& tokens_expert_prefix_sum, - const paddle::Tensor& up_gate_proj_weight, - const paddle::Tensor& down_proj_weight, - const paddle::optional& up_gate_proj_bias, - const paddle::optional& up_gate_proj_scale, - const paddle::optional& down_proj_scale, - const paddle::optional& up_gate_proj_local_scale, - const paddle::optional& up_gate_proj_code_scale, - const paddle::optional& up_gate_proj_code_zp, - const paddle::optional& down_proj_local_scale, - const paddle::optional& down_proj_code_scale, - const paddle::optional& down_proj_code_zp, - paddle::Tensor ffn_out, - bool used_in_ep_low_latency) { - using namespace phi; - using data_t = typename PDTraits::data_t; - using NvType = typename PDTraits::DataType; +void MoeFFNWint2Kernel( + const paddle::Tensor& permute_input, + const paddle::Tensor& tokens_expert_prefix_sum, + const paddle::Tensor& up_gate_proj_weight, + const paddle::Tensor& down_proj_weight, + const paddle::optional& up_gate_proj_bias, + const paddle::optional& up_gate_proj_scale, + const paddle::optional& down_proj_scale, + const paddle::optional& up_gate_proj_local_scale, + const paddle::optional& up_gate_proj_code_scale, + const paddle::optional& up_gate_proj_code_zp, + const paddle::optional& down_proj_local_scale, + const paddle::optional& down_proj_code_scale, + const paddle::optional& down_proj_code_zp, + paddle::Tensor ffn_out, + bool used_in_ep_low_latency) { + using namespace phi; + using data_t = typename PDTraits::data_t; + using NvType = typename PDTraits::DataType; - auto place = permute_input.place(); + auto place = permute_input.place(); - assert(permute_input.dims().size() == 3 || permute_input.dims().size() == 2); - assert(up_gate_proj_weight.dims().size() == 3); + assert(permute_input.dims().size() == 3 || permute_input.dims().size() == 2); + assert(up_gate_proj_weight.dims().size() == 3); - const int num_experts = up_gate_proj_weight.dims()[0]; - const int hidden_size = permute_input.dims()[permute_input.dims().size() - 1]; + const int num_experts = up_gate_proj_weight.dims()[0]; + const int hidden_size = permute_input.dims()[permute_input.dims().size() - 1]; - int inter_dim = up_gate_proj_weight.dims()[1] * up_gate_proj_weight.dims()[2] / hidden_size; + int inter_dim = up_gate_proj_weight.dims()[1] * + up_gate_proj_weight.dims()[2] / hidden_size; - const int64_t inter_size = inter_dim * 4; + const int64_t inter_size = inter_dim * 4; - int num_experts_ = num_experts; - int num_max_tokens_per_expert = 0; - int expanded_active_expert_rows = 0; + int num_experts_ = num_experts; + int num_max_tokens_per_expert = 0; + int expanded_active_expert_rows = 0; - paddle::Tensor fc1_out_tensor; - if (permute_input.dims().size() == 3) { - num_experts_ = permute_input.dims()[0]; - assert(num_experts == num_experts_); + paddle::Tensor fc1_out_tensor; + if (permute_input.dims().size() == 3) { + num_experts_ = permute_input.dims()[0]; + assert(num_experts == num_experts_); - num_max_tokens_per_expert = permute_input.dims()[1]; - expanded_active_expert_rows = num_experts_ * num_max_tokens_per_expert; - fc1_out_tensor = GetEmptyTensor( - {num_experts_, num_max_tokens_per_expert, inter_size}, T, place); - } else { - expanded_active_expert_rows = permute_input.dims()[0]; - fc1_out_tensor = GetEmptyTensor( - {expanded_active_expert_rows, inter_size}, T, place); - } + num_max_tokens_per_expert = permute_input.dims()[1]; + expanded_active_expert_rows = num_experts_ * num_max_tokens_per_expert; + fc1_out_tensor = GetEmptyTensor( + {num_experts_, num_max_tokens_per_expert, inter_size}, T, place); + } else { + expanded_active_expert_rows = permute_input.dims()[0]; + fc1_out_tensor = + GetEmptyTensor({expanded_active_expert_rows, inter_size}, T, place); + } - // This is a trick. - // expanded_active_expert_rows is not needed in variable group gemm. - // but is needed in accommodating deepep low latency mode - const int64_t total_rows_in_ll_else_minus1 = used_in_ep_low_latency ? expanded_active_expert_rows : -1; + // This is a trick. + // expanded_active_expert_rows is not needed in variable group gemm. + // but is needed in accommodating deepep low latency mode + const int64_t total_rows_in_ll_else_minus1 = + used_in_ep_low_latency ? expanded_active_expert_rows : -1; - // When we tune the optimal configuration, we need the actual total_rows. - const int64_t actual_total_rows = expanded_active_expert_rows; + // When we tune the optimal configuration, we need the actual total_rows. + const int64_t actual_total_rows = expanded_active_expert_rows; - WeightOnlyMoeFFNKernel( - permute_input, - tokens_expert_prefix_sum, - up_gate_proj_weight, - down_proj_weight, - const_cast(up_gate_proj_bias.get_ptr()), - const_cast(up_gate_proj_scale.get_ptr()), - const_cast(down_proj_scale.get_ptr()), - const_cast(up_gate_proj_local_scale.get_ptr()), - const_cast(up_gate_proj_code_scale.get_ptr()), - const_cast(up_gate_proj_code_zp.get_ptr()), - const_cast(down_proj_local_scale.get_ptr()), - const_cast(down_proj_code_scale.get_ptr()), - const_cast(down_proj_code_zp.get_ptr()), - fc1_out_tensor, - ffn_out, - total_rows_in_ll_else_minus1, - actual_total_rows, - inter_size, - hidden_size, - num_experts, - used_in_ep_low_latency); + WeightOnlyMoeFFNKernel( + permute_input, + tokens_expert_prefix_sum, + up_gate_proj_weight, + down_proj_weight, + const_cast(up_gate_proj_bias.get_ptr()), + const_cast(up_gate_proj_scale.get_ptr()), + const_cast(down_proj_scale.get_ptr()), + const_cast(up_gate_proj_local_scale.get_ptr()), + const_cast(up_gate_proj_code_scale.get_ptr()), + const_cast(up_gate_proj_code_zp.get_ptr()), + const_cast(down_proj_local_scale.get_ptr()), + const_cast(down_proj_code_scale.get_ptr()), + const_cast(down_proj_code_zp.get_ptr()), + fc1_out_tensor, + ffn_out, + total_rows_in_ll_else_minus1, + actual_total_rows, + inter_size, + hidden_size, + num_experts, + used_in_ep_low_latency); } paddle::Tensor MoeExpertFFNWint2Func( @@ -197,49 +218,48 @@ paddle::Tensor MoeExpertFFNWint2Func( const paddle::optional& down_proj_code_scale, const paddle::optional& down_proj_code_zp, const bool used_in_ep_low_latency) { + const auto dtype = permute_input.dtype(); + auto ffn_out = paddle::empty_like(permute_input, dtype); - const auto dtype = permute_input.dtype(); - auto ffn_out = paddle::empty_like(permute_input, dtype); - - switch (dtype) { - case paddle::DataType::BFLOAT16: - MoeFFNWint2Kernel(permute_input, - tokens_expert_prefix_sum, - up_gate_proj_weight, - down_proj_weight, - up_gate_proj_bias, - up_gate_proj_scale, - down_proj_scale, - up_gate_proj_local_scale, - up_gate_proj_code_scale, - up_gate_proj_code_zp, - down_proj_local_scale, - down_proj_code_scale, - down_proj_code_zp, - ffn_out, - used_in_ep_low_latency); - break; - case paddle::DataType::FLOAT16: - MoeFFNWint2Kernel(permute_input, - tokens_expert_prefix_sum, - up_gate_proj_weight, - down_proj_weight, - up_gate_proj_bias, - up_gate_proj_scale, - down_proj_scale, - up_gate_proj_local_scale, - up_gate_proj_code_scale, - up_gate_proj_code_zp, - down_proj_local_scale, - down_proj_code_scale, - down_proj_code_zp, - ffn_out, - used_in_ep_low_latency); - break; - default: - PD_THROW("Unsupported data type for MoeExpertFFN"); - } - return ffn_out; + switch (dtype) { + case paddle::DataType::BFLOAT16: + MoeFFNWint2Kernel(permute_input, + tokens_expert_prefix_sum, + up_gate_proj_weight, + down_proj_weight, + up_gate_proj_bias, + up_gate_proj_scale, + down_proj_scale, + up_gate_proj_local_scale, + up_gate_proj_code_scale, + up_gate_proj_code_zp, + down_proj_local_scale, + down_proj_code_scale, + down_proj_code_zp, + ffn_out, + used_in_ep_low_latency); + break; + case paddle::DataType::FLOAT16: + MoeFFNWint2Kernel(permute_input, + tokens_expert_prefix_sum, + up_gate_proj_weight, + down_proj_weight, + up_gate_proj_bias, + up_gate_proj_scale, + down_proj_scale, + up_gate_proj_local_scale, + up_gate_proj_code_scale, + up_gate_proj_code_zp, + down_proj_local_scale, + down_proj_code_scale, + down_proj_code_zp, + ffn_out, + used_in_ep_low_latency); + break; + default: + PD_THROW("Unsupported data type for MoeExpertFFN"); + } + return ffn_out; } std::vector MoeExpertFFNWint2( @@ -257,21 +277,20 @@ std::vector MoeExpertFFNWint2( const paddle::optional& down_proj_code_scale, const paddle::optional& down_proj_code_zp, const bool used_in_ep_low_latency) { - - return {MoeExpertFFNWint2Func(permute_input, - tokens_expert_prefix_sum, - up_gate_proj_weight, - down_proj_weight, - up_gate_proj_bias, - up_gate_proj_scale, - down_proj_scale, - up_gate_proj_local_scale, - up_gate_proj_code_scale, - up_gate_proj_code_zp, - down_proj_local_scale, - down_proj_code_scale, - down_proj_code_zp, - used_in_ep_low_latency)}; + return {MoeExpertFFNWint2Func(permute_input, + tokens_expert_prefix_sum, + up_gate_proj_weight, + down_proj_weight, + up_gate_proj_bias, + up_gate_proj_scale, + down_proj_scale, + up_gate_proj_local_scale, + up_gate_proj_code_scale, + up_gate_proj_code_zp, + down_proj_local_scale, + down_proj_code_scale, + down_proj_code_zp, + used_in_ep_low_latency)}; } std::vector> MoeExpertFFNWint2InferShape( @@ -282,53 +301,53 @@ std::vector> MoeExpertFFNWint2InferShape( const paddle::optional>& up_gate_proj_bias_shape, const paddle::optional>& up_gate_proj_scale_shape, const paddle::optional>& down_proj_scale_shape, - const paddle::optional>& up_gate_proj_local_scale_shape, + const paddle::optional>& + up_gate_proj_local_scale_shape, const paddle::optional>& up_gate_proj_code_scale_shape, const paddle::optional>& up_gate_proj_code_zp_shape, const paddle::optional>& down_proj_local_scale_shape, const paddle::optional>& down_proj_code_scale_shape, const paddle::optional>& down_proj_code_zp_shape, const bool used_in_ep_low_latency) { - - return {permute_input_shape}; + return {permute_input_shape}; } std::vector MoeExpertFFNWint2InferDtype( - const paddle::DataType &permute_input_dtype, - const paddle::DataType &tokens_expert_prefix_sum_dtype, - const paddle::DataType &up_gate_proj_weight_dtype, - const paddle::DataType &down_proj_weight_dtype, - const paddle::optional &up_gate_proj_bias_dtype, - const paddle::optional &up_gate_proj_scale_dtype, - const paddle::optional &down_proj_scale_dtype, - const paddle::optional &up_gate_proj_local_scale_dtype, - const paddle::optional &up_gate_proj_code_scale_dtype, - const paddle::optional &up_gate_proj_code_zp_dtype, - const paddle::optional &down_proj_local_scale_dtype, - const paddle::optional &down_proj_code_scale_dtype, - const paddle::optional &down_proj_code_zp_dtype, + const paddle::DataType& permute_input_dtype, + const paddle::DataType& tokens_expert_prefix_sum_dtype, + const paddle::DataType& up_gate_proj_weight_dtype, + const paddle::DataType& down_proj_weight_dtype, + const paddle::optional& up_gate_proj_bias_dtype, + const paddle::optional& up_gate_proj_scale_dtype, + const paddle::optional& down_proj_scale_dtype, + const paddle::optional& up_gate_proj_local_scale_dtype, + const paddle::optional& up_gate_proj_code_scale_dtype, + const paddle::optional& up_gate_proj_code_zp_dtype, + const paddle::optional& down_proj_local_scale_dtype, + const paddle::optional& down_proj_code_scale_dtype, + const paddle::optional& down_proj_code_zp_dtype, const bool used_in_ep_low_latency) { - - return {permute_input_dtype}; + return {permute_input_dtype}; } /** - * @brief Weight-Only Quantized Mixture of Experts (MoE) Feed-Forward Network Operator + * @brief Weight-Only Quantized Mixture of Experts (MoE) Feed-Forward Network + * Operator * * This operator performs the expert computation in MoE architecture, including: * 1. First linear transformation (up_gate_proj) with optional quantization * 2. SwiGLU activation function * 3. Second linear transformation (down_proj) with optional quantization * - * Supports multiple quantization methods including weight-only int4/int8 and w4a8 quantization. + * Supports multiple quantization methods including weight-only int4/int8 and + * w4a8 quantization. * * Inputs: * - permute_input: Permuted input tensor organized by expert * Shape: [total_tokens * top_k, hidden_size] * dtype: bfloat16/float16 (or int8 for w4a8) - * - tokens_expert_prefix_sum: Prefix sum array of token counts per expert for group_gemm - * Shape: [num_experts] - * dtype: int64 + * - tokens_expert_prefix_sum: Prefix sum array of token counts per expert for + * group_gemm Shape: [num_experts] dtype: int64 * - up_gate_proj_weight: First FFN layer weights * Shape: [num_experts, inter_size * 2, hidden_size] * dtype: Same as input (unquantized) or int8 (quantized) diff --git a/custom_ops/gpu_ops/moe/moe_fast_hardamard_impl.cuh b/custom_ops/gpu_ops/moe/moe_fast_hardamard_impl.cuh new file mode 100644 index 0000000000..926a9c9047 --- /dev/null +++ b/custom_ops/gpu_ops/moe/moe_fast_hardamard_impl.cuh @@ -0,0 +1,1468 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "helper.h" +#include "moe_fast_hardamard_impl_common.h" + +template +__device__ __forceinline__ void hadamard_mult_thread_28_transpose( + T x[28][VecSize]) { // 35 + T out[28]; +#pragma unroll + for (int vi = 0; vi < VecSize; vi++) { + out[0] = +x[0][vi] + x[1][vi] + x[2][vi] + x[3][vi] + x[4][vi] + x[5][vi] + + x[6][vi] + x[7][vi] + x[8][vi] + x[9][vi] + x[10][vi] + x[11][vi] + + x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] + x[16][vi] + + x[17][vi] + x[18][vi] + x[19][vi] + x[20][vi] + x[21][vi] + + x[22][vi] + x[23][vi] + x[24][vi] + x[25][vi] + x[26][vi] + + x[27][vi]; + out[1] = +x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] - + x[6][vi] - x[7][vi] - x[8][vi] - x[9][vi] + x[10][vi] + x[11][vi] - + x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] + x[16][vi] - + x[17][vi] + x[18][vi] + x[19][vi] - x[20][vi] - x[21][vi] - + x[22][vi] - x[23][vi] + x[24][vi] + x[25][vi] - x[26][vi] + + x[27][vi]; + out[2] = +x[0][vi] + x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] + + x[6][vi] - x[7][vi] - x[8][vi] - x[9][vi] - x[10][vi] + x[11][vi] + + x[12][vi] - x[13][vi] + x[14][vi] + x[15][vi] - x[16][vi] + + x[17][vi] - x[18][vi] + x[19][vi] + x[20][vi] - x[21][vi] - + x[22][vi] - x[23][vi] - x[24][vi] + x[25][vi] + x[26][vi] - + x[27][vi]; + out[3] = +x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] + + x[6][vi] + x[7][vi] - x[8][vi] - x[9][vi] - x[10][vi] - x[11][vi] + + x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] + x[16][vi] - + x[17][vi] + x[18][vi] - x[19][vi] + x[20][vi] + x[21][vi] - + x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] + x[26][vi] + + x[27][vi]; + out[4] = +x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] + x[5][vi] - + x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] - x[10][vi] - x[11][vi] - + x[12][vi] + x[13][vi] + x[14][vi] + x[15][vi] - x[16][vi] + + x[17][vi] - x[18][vi] + x[19][vi] - x[20][vi] + x[21][vi] + + x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] - x[26][vi] + + x[27][vi]; + out[5] = +x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] + + x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] - x[11][vi] - + x[12][vi] - x[13][vi] + x[14][vi] + x[15][vi] + x[16][vi] - + x[17][vi] + x[18][vi] - x[19][vi] + x[20][vi] - x[21][vi] + + x[22][vi] + x[23][vi] - x[24][vi] - x[25][vi] - x[26][vi] - + x[27][vi]; + out[6] = +x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] + + x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] - x[11][vi] - + x[12][vi] - x[13][vi] + x[14][vi] - x[15][vi] + x[16][vi] + + x[17][vi] - x[18][vi] + x[19][vi] - x[20][vi] + x[21][vi] - + x[22][vi] + x[23][vi] + x[24][vi] - x[25][vi] - x[26][vi] - + x[27][vi]; + out[7] = +x[0][vi] - x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] + + x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] + x[11][vi] - + x[12][vi] - x[13][vi] + x[14][vi] - x[15][vi] - x[16][vi] + + x[17][vi] + x[18][vi] - x[19][vi] + x[20][vi] - x[21][vi] + + x[22][vi] - x[23][vi] + x[24][vi] + x[25][vi] - x[26][vi] - + x[27][vi]; + out[8] = +x[0][vi] - x[1][vi] - x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] - + x[6][vi] + x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] + x[11][vi] + + x[12][vi] - x[13][vi] + x[14][vi] - x[15][vi] - x[16][vi] - + x[17][vi] + x[18][vi] + x[19][vi] - x[20][vi] + x[21][vi] - + x[22][vi] + x[23][vi] - x[24][vi] + x[25][vi] + x[26][vi] - + x[27][vi]; + out[9] = +x[0][vi] - x[1][vi] - x[2][vi] - x[3][vi] - x[4][vi] + x[5][vi] + + x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] + x[10][vi] - x[11][vi] + + x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] - x[16][vi] - + x[17][vi] - x[18][vi] + x[19][vi] + x[20][vi] - x[21][vi] + + x[22][vi] - x[23][vi] + x[24][vi] - x[25][vi] + x[26][vi] + + x[27][vi]; + out[10] = +x[0][vi] + x[1][vi] - x[2][vi] - x[3][vi] - x[4][vi] - x[5][vi] + + x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] + + x[11][vi] - x[12][vi] + x[13][vi] + x[14][vi] + x[15][vi] - + x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] + x[20][vi] + + x[21][vi] - x[22][vi] + x[23][vi] - x[24][vi] + x[25][vi] - + x[26][vi] + x[27][vi]; + out[11] = +x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] - x[4][vi] - x[5][vi] - + x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] + + x[11][vi] + x[12][vi] - x[13][vi] + x[14][vi] + x[15][vi] + + x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] + + x[21][vi] + x[22][vi] - x[23][vi] + x[24][vi] - x[25][vi] + + x[26][vi] - x[27][vi]; + out[12] = +x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] - x[5][vi] - + x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] + + x[11][vi] + x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] + + x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] - + x[21][vi] + x[22][vi] + x[23][vi] - x[24][vi] + x[25][vi] - + x[26][vi] + x[27][vi]; + out[13] = +x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] - + x[6][vi] - x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] - + x[11][vi] + x[12][vi] + x[13][vi] + x[14][vi] + x[15][vi] - + x[16][vi] + x[17][vi] + x[18][vi] - x[19][vi] - x[20][vi] - + x[21][vi] - x[22][vi] + x[23][vi] + x[24][vi] - x[25][vi] + + x[26][vi] - x[27][vi]; + out[14] = -x[0][vi] + x[1][vi] + x[2][vi] + x[3][vi] + x[4][vi] + x[5][vi] + + x[6][vi] + x[7][vi] + x[8][vi] + x[9][vi] + x[10][vi] + + x[11][vi] + x[12][vi] + x[13][vi] - x[14][vi] - x[15][vi] - + x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] - + x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] - + x[26][vi] - x[27][vi]; + out[15] = +x[0][vi] - x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] - + x[6][vi] - x[7][vi] - x[8][vi] - x[9][vi] + x[10][vi] + + x[11][vi] - x[12][vi] + x[13][vi] - x[14][vi] - x[15][vi] - + x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] + x[20][vi] + + x[21][vi] + x[22][vi] + x[23][vi] - x[24][vi] - x[25][vi] + + x[26][vi] - x[27][vi]; + out[16] = +x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] + + x[6][vi] - x[7][vi] - x[8][vi] - x[9][vi] - x[10][vi] + + x[11][vi] + x[12][vi] - x[13][vi] - x[14][vi] - x[15][vi] - + x[16][vi] - x[17][vi] + x[18][vi] - x[19][vi] - x[20][vi] + + x[21][vi] + x[22][vi] + x[23][vi] + x[24][vi] - x[25][vi] - + x[26][vi] + x[27][vi]; + out[17] = +x[0][vi] - x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] - x[5][vi] + + x[6][vi] + x[7][vi] - x[8][vi] - x[9][vi] - x[10][vi] - + x[11][vi] + x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] - + x[16][vi] - x[17][vi] - x[18][vi] + x[19][vi] - x[20][vi] - + x[21][vi] + x[22][vi] + x[23][vi] + x[24][vi] + x[25][vi] - + x[26][vi] - x[27][vi]; + out[18] = +x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] - + x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] - x[10][vi] - + x[11][vi] - x[12][vi] + x[13][vi] - x[14][vi] - x[15][vi] + + x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] + x[20][vi] - + x[21][vi] - x[22][vi] + x[23][vi] + x[24][vi] + x[25][vi] + + x[26][vi] - x[27][vi]; + out[19] = +x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] - x[5][vi] + + x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] - + x[11][vi] - x[12][vi] - x[13][vi] - x[14][vi] - x[15][vi] - + x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] + + x[21][vi] - x[22][vi] - x[23][vi] + x[24][vi] + x[25][vi] + + x[26][vi] + x[27][vi]; + out[20] = +x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] - + x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] - + x[11][vi] - x[12][vi] - x[13][vi] - x[14][vi] + x[15][vi] - + x[16][vi] - x[17][vi] + x[18][vi] - x[19][vi] - x[20][vi] - + x[21][vi] + x[22][vi] - x[23][vi] - x[24][vi] + x[25][vi] + + x[26][vi] + x[27][vi]; + out[21] = +x[0][vi] - x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] + + x[6][vi] - x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] + + x[11][vi] - x[12][vi] - x[13][vi] - x[14][vi] + x[15][vi] + + x[16][vi] - x[17][vi] - x[18][vi] + x[19][vi] - x[20][vi] - + x[21][vi] - x[22][vi] + x[23][vi] - x[24][vi] - x[25][vi] + + x[26][vi] + x[27][vi]; + out[22] = +x[0][vi] - x[1][vi] - x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] - + x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] - x[10][vi] + + x[11][vi] + x[12][vi] - x[13][vi] - x[14][vi] + x[15][vi] + + x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] + x[20][vi] - + x[21][vi] - x[22][vi] - x[23][vi] + x[24][vi] - x[25][vi] - + x[26][vi] + x[27][vi]; + out[23] = +x[0][vi] - x[1][vi] - x[2][vi] - x[3][vi] - x[4][vi] + x[5][vi] + + x[6][vi] - x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] - + x[11][vi] + x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] + + x[16][vi] + x[17][vi] + x[18][vi] - x[19][vi] - x[20][vi] + + x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] + x[25][vi] - + x[26][vi] - x[27][vi]; + out[24] = +x[0][vi] + x[1][vi] - x[2][vi] - x[3][vi] - x[4][vi] - x[5][vi] + + x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] - x[10][vi] + + x[11][vi] - x[12][vi] + x[13][vi] - x[14][vi] - x[15][vi] + + x[16][vi] + x[17][vi] + x[18][vi] + x[19][vi] - x[20][vi] - + x[21][vi] + x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] + + x[26][vi] - x[27][vi]; + out[25] = +x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] - x[4][vi] - x[5][vi] - + x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] - + x[11][vi] + x[12][vi] - x[13][vi] - x[14][vi] - x[15][vi] - + x[16][vi] + x[17][vi] + x[18][vi] + x[19][vi] + x[20][vi] - + x[21][vi] - x[22][vi] + x[23][vi] - x[24][vi] - x[25][vi] - + x[26][vi] + x[27][vi]; + out[26] = +x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] - x[5][vi] - + x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] + + x[11][vi] - x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] - + x[16][vi] - x[17][vi] + x[18][vi] + x[19][vi] + x[20][vi] + + x[21][vi] - x[22][vi] - x[23][vi] + x[24][vi] - x[25][vi] - + x[26][vi] - x[27][vi]; + out[27] = +x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] - + x[6][vi] - x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] - + x[11][vi] + x[12][vi] - x[13][vi] - x[14][vi] - x[15][vi] + + x[16][vi] - x[17][vi] - x[18][vi] + x[19][vi] + x[20][vi] + + x[21][vi] + x[22][vi] - x[23][vi] - x[24][vi] + x[25][vi] - + x[26][vi] - x[27][vi]; +#pragma unroll + for (int i = 0; i < 28; i++) { + x[i][vi] = out[i]; + } + } +} + +template +__device__ __forceinline__ void hadamard_mult_thread_36_transpose( + T x[36][VecSize]) { // 4t + T out[36]; +#pragma unroll + for (int vi = 0; vi < VecSize; vi++) { + out[0] = +x[0][vi] + x[1][vi] + x[2][vi] + x[3][vi] + x[4][vi] + x[5][vi] + + x[6][vi] + x[7][vi] + x[8][vi] + x[9][vi] + x[10][vi] + x[11][vi] + + x[12][vi] + x[13][vi] + x[14][vi] + x[15][vi] + x[16][vi] + + x[17][vi] - x[18][vi] + x[19][vi] + x[20][vi] + x[21][vi] + + x[22][vi] + x[23][vi] + x[24][vi] + x[25][vi] + x[26][vi] + + x[27][vi] + x[28][vi] + x[29][vi] + x[30][vi] + x[31][vi] + + x[32][vi] + x[33][vi] + x[34][vi] + x[35][vi]; + out[1] = +x[0][vi] + x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] - + x[6][vi] - x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] - x[11][vi] - + x[12][vi] - x[13][vi] + x[14][vi] - x[15][vi] + x[16][vi] + + x[17][vi] + x[18][vi] - x[19][vi] + x[20][vi] + x[21][vi] - + x[22][vi] + x[23][vi] - x[24][vi] - x[25][vi] - x[26][vi] + + x[27][vi] + x[28][vi] - x[29][vi] - x[30][vi] - x[31][vi] + + x[32][vi] - x[33][vi] + x[34][vi] + x[35][vi]; + out[2] = +x[0][vi] + x[1][vi] + x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] + + x[6][vi] - x[7][vi] - x[8][vi] - x[9][vi] + x[10][vi] + x[11][vi] - + x[12][vi] - x[13][vi] - x[14][vi] + x[15][vi] - x[16][vi] + + x[17][vi] + x[18][vi] + x[19][vi] - x[20][vi] + x[21][vi] + + x[22][vi] - x[23][vi] + x[24][vi] - x[25][vi] - x[26][vi] - + x[27][vi] + x[28][vi] + x[29][vi] - x[30][vi] - x[31][vi] - + x[32][vi] + x[33][vi] - x[34][vi] + x[35][vi]; + out[3] = +x[0][vi] + x[1][vi] + x[2][vi] + x[3][vi] + x[4][vi] + x[5][vi] - + x[6][vi] + x[7][vi] - x[8][vi] - x[9][vi] - x[10][vi] + x[11][vi] + + x[12][vi] - x[13][vi] - x[14][vi] - x[15][vi] + x[16][vi] - + x[17][vi] + x[18][vi] + x[19][vi] + x[20][vi] - x[21][vi] + + x[22][vi] + x[23][vi] - x[24][vi] + x[25][vi] - x[26][vi] - + x[27][vi] - x[28][vi] + x[29][vi] + x[30][vi] - x[31][vi] - + x[32][vi] - x[33][vi] + x[34][vi] - x[35][vi]; + out[4] = +x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] + x[4][vi] + x[5][vi] + + x[6][vi] - x[7][vi] + x[8][vi] - x[9][vi] - x[10][vi] - x[11][vi] + + x[12][vi] + x[13][vi] - x[14][vi] - x[15][vi] - x[16][vi] + + x[17][vi] + x[18][vi] - x[19][vi] + x[20][vi] + x[21][vi] - + x[22][vi] + x[23][vi] + x[24][vi] - x[25][vi] + x[26][vi] - + x[27][vi] - x[28][vi] - x[29][vi] + x[30][vi] + x[31][vi] - + x[32][vi] - x[33][vi] - x[34][vi] + x[35][vi]; + out[5] = +x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] + x[5][vi] + + x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] - x[10][vi] - x[11][vi] - + x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] - x[16][vi] - + x[17][vi] + x[18][vi] + x[19][vi] - x[20][vi] + x[21][vi] + + x[22][vi] - x[23][vi] + x[24][vi] + x[25][vi] - x[26][vi] + + x[27][vi] - x[28][vi] - x[29][vi] - x[30][vi] + x[31][vi] + + x[32][vi] - x[33][vi] - x[34][vi] - x[35][vi]; + out[6] = +x[0][vi] - x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] + + x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] - x[11][vi] - + x[12][vi] - x[13][vi] + x[14][vi] + x[15][vi] - x[16][vi] - + x[17][vi] + x[18][vi] - x[19][vi] + x[20][vi] - x[21][vi] + + x[22][vi] + x[23][vi] - x[24][vi] + x[25][vi] + x[26][vi] - + x[27][vi] + x[28][vi] - x[29][vi] - x[30][vi] - x[31][vi] + + x[32][vi] + x[33][vi] - x[34][vi] - x[35][vi]; + out[7] = +x[0][vi] - x[1][vi] - x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] + + x[6][vi] + x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] + x[11][vi] - + x[12][vi] - x[13][vi] - x[14][vi] + x[15][vi] + x[16][vi] - + x[17][vi] + x[18][vi] - x[19][vi] - x[20][vi] + x[21][vi] - + x[22][vi] + x[23][vi] + x[24][vi] - x[25][vi] + x[26][vi] + + x[27][vi] - x[28][vi] + x[29][vi] - x[30][vi] - x[31][vi] - + x[32][vi] + x[33][vi] + x[34][vi] - x[35][vi]; + out[8] = +x[0][vi] - x[1][vi] - x[2][vi] - x[3][vi] + x[4][vi] - x[5][vi] + + x[6][vi] + x[7][vi] + x[8][vi] + x[9][vi] + x[10][vi] - x[11][vi] + + x[12][vi] - x[13][vi] - x[14][vi] - x[15][vi] + x[16][vi] + + x[17][vi] + x[18][vi] - x[19][vi] - x[20][vi] - x[21][vi] + + x[22][vi] - x[23][vi] + x[24][vi] + x[25][vi] - x[26][vi] + + x[27][vi] + x[28][vi] - x[29][vi] + x[30][vi] - x[31][vi] - + x[32][vi] - x[33][vi] + x[34][vi] + x[35][vi]; + out[9] = +x[0][vi] + x[1][vi] - x[2][vi] - x[3][vi] - x[4][vi] + x[5][vi] - + x[6][vi] + x[7][vi] + x[8][vi] + x[9][vi] + x[10][vi] + x[11][vi] - + x[12][vi] + x[13][vi] - x[14][vi] - x[15][vi] - x[16][vi] + + x[17][vi] + x[18][vi] + x[19][vi] - x[20][vi] - x[21][vi] - + x[22][vi] + x[23][vi] - x[24][vi] + x[25][vi] + x[26][vi] - + x[27][vi] + x[28][vi] + x[29][vi] - x[30][vi] + x[31][vi] - + x[32][vi] - x[33][vi] - x[34][vi] + x[35][vi]; + out[10] = +x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] - x[4][vi] - x[5][vi] + + x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] + x[10][vi] + + x[11][vi] + x[12][vi] - x[13][vi] + x[14][vi] - x[15][vi] - + x[16][vi] - x[17][vi] + x[18][vi] + x[19][vi] + x[20][vi] - + x[21][vi] - x[22][vi] - x[23][vi] + x[24][vi] - x[25][vi] + + x[26][vi] + x[27][vi] - x[28][vi] + x[29][vi] + x[30][vi] - + x[31][vi] + x[32][vi] - x[33][vi] - x[34][vi] - x[35][vi]; + out[11] = +x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] - x[5][vi] - + x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] + + x[11][vi] + x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] - + x[16][vi] - x[17][vi] + x[18][vi] - x[19][vi] + x[20][vi] + + x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] + x[25][vi] - + x[26][vi] + x[27][vi] + x[28][vi] - x[29][vi] + x[30][vi] + + x[31][vi] - x[32][vi] + x[33][vi] - x[34][vi] - x[35][vi]; + out[12] = +x[0][vi] - x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] - + x[6][vi] - x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] + + x[11][vi] + x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] + + x[16][vi] - x[17][vi] + x[18][vi] - x[19][vi] - x[20][vi] + + x[21][vi] + x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] + + x[26][vi] - x[27][vi] + x[28][vi] + x[29][vi] - x[30][vi] + + x[31][vi] + x[32][vi] - x[33][vi] + x[34][vi] - x[35][vi]; + out[13] = +x[0][vi] - x[1][vi] - x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] - + x[6][vi] - x[7][vi] - x[8][vi] + x[9][vi] - x[10][vi] + + x[11][vi] + x[12][vi] + x[13][vi] + x[14][vi] + x[15][vi] - + x[16][vi] + x[17][vi] + x[18][vi] - x[19][vi] - x[20][vi] - + x[21][vi] + x[22][vi] + x[23][vi] - x[24][vi] - x[25][vi] - + x[26][vi] + x[27][vi] - x[28][vi] + x[29][vi] + x[30][vi] - + x[31][vi] + x[32][vi] + x[33][vi] - x[34][vi] + x[35][vi]; + out[14] = +x[0][vi] + x[1][vi] - x[2][vi] - x[3][vi] - x[4][vi] + x[5][vi] + + x[6][vi] - x[7][vi] - x[8][vi] - x[9][vi] + x[10][vi] - + x[11][vi] + x[12][vi] + x[13][vi] + x[14][vi] + x[15][vi] + + x[16][vi] - x[17][vi] + x[18][vi] + x[19][vi] - x[20][vi] - + x[21][vi] - x[22][vi] + x[23][vi] + x[24][vi] - x[25][vi] - + x[26][vi] - x[27][vi] + x[28][vi] - x[29][vi] + x[30][vi] + + x[31][vi] - x[32][vi] + x[33][vi] + x[34][vi] - x[35][vi]; + out[15] = +x[0][vi] - x[1][vi] + x[2][vi] - x[3][vi] - x[4][vi] - x[5][vi] + + x[6][vi] + x[7][vi] - x[8][vi] - x[9][vi] - x[10][vi] + + x[11][vi] - x[12][vi] + x[13][vi] + x[14][vi] + x[15][vi] + + x[16][vi] + x[17][vi] + x[18][vi] - x[19][vi] + x[20][vi] - + x[21][vi] - x[22][vi] - x[23][vi] + x[24][vi] + x[25][vi] - + x[26][vi] - x[27][vi] - x[28][vi] + x[29][vi] - x[30][vi] + + x[31][vi] + x[32][vi] - x[33][vi] + x[34][vi] + x[35][vi]; + out[16] = +x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] - x[4][vi] - x[5][vi] - + x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] - x[10][vi] - + x[11][vi] + x[12][vi] - x[13][vi] + x[14][vi] + x[15][vi] + + x[16][vi] + x[17][vi] + x[18][vi] + x[19][vi] - x[20][vi] + + x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] + x[25][vi] + + x[26][vi] - x[27][vi] - x[28][vi] - x[29][vi] + x[30][vi] - + x[31][vi] + x[32][vi] + x[33][vi] - x[34][vi] + x[35][vi]; + out[17] = +x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] - x[5][vi] - + x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] - + x[11][vi] - x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] + + x[16][vi] + x[17][vi] + x[18][vi] + x[19][vi] + x[20][vi] - + x[21][vi] + x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] + + x[26][vi] + x[27][vi] - x[28][vi] - x[29][vi] - x[30][vi] + + x[31][vi] - x[32][vi] + x[33][vi] + x[34][vi] - x[35][vi]; + out[18] = -x[0][vi] + x[1][vi] + x[2][vi] + x[3][vi] + x[4][vi] + x[5][vi] + + x[6][vi] + x[7][vi] + x[8][vi] + x[9][vi] + x[10][vi] + + x[11][vi] + x[12][vi] + x[13][vi] + x[14][vi] + x[15][vi] + + x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] - + x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] - + x[26][vi] - x[27][vi] - x[28][vi] - x[29][vi] - x[30][vi] - + x[31][vi] - x[32][vi] - x[33][vi] - x[34][vi] - x[35][vi]; + out[19] = +x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] - + x[6][vi] - x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] - + x[11][vi] - x[12][vi] - x[13][vi] + x[14][vi] - x[15][vi] + + x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] - + x[21][vi] + x[22][vi] - x[23][vi] + x[24][vi] + x[25][vi] + + x[26][vi] - x[27][vi] - x[28][vi] + x[29][vi] + x[30][vi] + + x[31][vi] - x[32][vi] + x[33][vi] - x[34][vi] - x[35][vi]; + out[20] = +x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] + + x[6][vi] - x[7][vi] - x[8][vi] - x[9][vi] + x[10][vi] + + x[11][vi] - x[12][vi] - x[13][vi] - x[14][vi] + x[15][vi] - + x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] - + x[21][vi] - x[22][vi] + x[23][vi] - x[24][vi] + x[25][vi] + + x[26][vi] + x[27][vi] - x[28][vi] - x[29][vi] + x[30][vi] + + x[31][vi] + x[32][vi] - x[33][vi] + x[34][vi] - x[35][vi]; + out[21] = +x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] - + x[6][vi] + x[7][vi] - x[8][vi] - x[9][vi] - x[10][vi] + + x[11][vi] + x[12][vi] - x[13][vi] - x[14][vi] - x[15][vi] + + x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] - + x[21][vi] - x[22][vi] - x[23][vi] + x[24][vi] - x[25][vi] + + x[26][vi] + x[27][vi] + x[28][vi] - x[29][vi] - x[30][vi] + + x[31][vi] + x[32][vi] + x[33][vi] - x[34][vi] + x[35][vi]; + out[22] = +x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] + + x[6][vi] - x[7][vi] + x[8][vi] - x[9][vi] - x[10][vi] - + x[11][vi] + x[12][vi] + x[13][vi] - x[14][vi] - x[15][vi] - + x[16][vi] + x[17][vi] - x[18][vi] + x[19][vi] - x[20][vi] - + x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] + x[25][vi] - + x[26][vi] + x[27][vi] + x[28][vi] + x[29][vi] - x[30][vi] - + x[31][vi] + x[32][vi] + x[33][vi] + x[34][vi] - x[35][vi]; + out[23] = +x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] + + x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] - x[10][vi] - + x[11][vi] - x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] - + x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] + x[20][vi] - + x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] + + x[26][vi] - x[27][vi] + x[28][vi] + x[29][vi] + x[30][vi] - + x[31][vi] - x[32][vi] + x[33][vi] + x[34][vi] + x[35][vi]; + out[24] = +x[0][vi] - x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] - + x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] - + x[11][vi] - x[12][vi] - x[13][vi] + x[14][vi] + x[15][vi] - + x[16][vi] - x[17][vi] - x[18][vi] + x[19][vi] - x[20][vi] + + x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] - + x[26][vi] + x[27][vi] - x[28][vi] + x[29][vi] + x[30][vi] + + x[31][vi] - x[32][vi] - x[33][vi] + x[34][vi] + x[35][vi]; + out[25] = +x[0][vi] - x[1][vi] - x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] + + x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] + + x[11][vi] - x[12][vi] - x[13][vi] - x[14][vi] + x[15][vi] + + x[16][vi] - x[17][vi] - x[18][vi] + x[19][vi] + x[20][vi] - + x[21][vi] + x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] - + x[26][vi] - x[27][vi] + x[28][vi] - x[29][vi] + x[30][vi] + + x[31][vi] + x[32][vi] - x[33][vi] - x[34][vi] + x[35][vi]; + out[26] = +x[0][vi] - x[1][vi] - x[2][vi] - x[3][vi] + x[4][vi] - x[5][vi] + + x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] - + x[11][vi] + x[12][vi] - x[13][vi] - x[14][vi] - x[15][vi] + + x[16][vi] + x[17][vi] - x[18][vi] + x[19][vi] + x[20][vi] + + x[21][vi] - x[22][vi] + x[23][vi] - x[24][vi] - x[25][vi] - + x[26][vi] - x[27][vi] - x[28][vi] + x[29][vi] - x[30][vi] + + x[31][vi] + x[32][vi] + x[33][vi] - x[34][vi] - x[35][vi]; + out[27] = +x[0][vi] + x[1][vi] - x[2][vi] - x[3][vi] - x[4][vi] + x[5][vi] - + x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] + + x[11][vi] - x[12][vi] + x[13][vi] - x[14][vi] - x[15][vi] - + x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] + x[20][vi] + + x[21][vi] + x[22][vi] - x[23][vi] + x[24][vi] - x[25][vi] - + x[26][vi] - x[27][vi] - x[28][vi] - x[29][vi] + x[30][vi] - + x[31][vi] + x[32][vi] + x[33][vi] + x[34][vi] - x[35][vi]; + out[28] = +x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] - x[4][vi] - x[5][vi] + + x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] + + x[11][vi] + x[12][vi] - x[13][vi] + x[14][vi] - x[15][vi] - + x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] + + x[21][vi] + x[22][vi] + x[23][vi] - x[24][vi] + x[25][vi] - + x[26][vi] - x[27][vi] - x[28][vi] - x[29][vi] - x[30][vi] + + x[31][vi] - x[32][vi] + x[33][vi] + x[34][vi] + x[35][vi]; + out[29] = +x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] - x[5][vi] - + x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] - + x[11][vi] + x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] - + x[16][vi] - x[17][vi] - x[18][vi] + x[19][vi] - x[20][vi] - + x[21][vi] + x[22][vi] + x[23][vi] + x[24][vi] - x[25][vi] + + x[26][vi] - x[27][vi] - x[28][vi] - x[29][vi] - x[30][vi] - + x[31][vi] + x[32][vi] - x[33][vi] + x[34][vi] + x[35][vi]; + out[30] = +x[0][vi] - x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] - + x[6][vi] - x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] + + x[11][vi] - x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] + + x[16][vi] - x[17][vi] - x[18][vi] + x[19][vi] + x[20][vi] - + x[21][vi] - x[22][vi] + x[23][vi] + x[24][vi] + x[25][vi] - + x[26][vi] + x[27][vi] - x[28][vi] - x[29][vi] - x[30][vi] - + x[31][vi] - x[32][vi] + x[33][vi] - x[34][vi] + x[35][vi]; + out[31] = +x[0][vi] - x[1][vi] - x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] - + x[6][vi] - x[7][vi] - x[8][vi] + x[9][vi] - x[10][vi] + + x[11][vi] + x[12][vi] - x[13][vi] + x[14][vi] + x[15][vi] - + x[16][vi] + x[17][vi] - x[18][vi] + x[19][vi] + x[20][vi] + + x[21][vi] - x[22][vi] - x[23][vi] + x[24][vi] + x[25][vi] + + x[26][vi] - x[27][vi] + x[28][vi] - x[29][vi] - x[30][vi] - + x[31][vi] - x[32][vi] - x[33][vi] + x[34][vi] - x[35][vi]; + out[32] = +x[0][vi] + x[1][vi] - x[2][vi] - x[3][vi] - x[4][vi] + x[5][vi] + + x[6][vi] - x[7][vi] - x[8][vi] - x[9][vi] + x[10][vi] - + x[11][vi] + x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] + + x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] + x[20][vi] + + x[21][vi] + x[22][vi] - x[23][vi] - x[24][vi] + x[25][vi] + + x[26][vi] + x[27][vi] - x[28][vi] + x[29][vi] - x[30][vi] - + x[31][vi] - x[32][vi] - x[33][vi] - x[34][vi] + x[35][vi]; + out[33] = +x[0][vi] - x[1][vi] + x[2][vi] - x[3][vi] - x[4][vi] - x[5][vi] + + x[6][vi] + x[7][vi] - x[8][vi] - x[9][vi] - x[10][vi] + + x[11][vi] - x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] + + x[16][vi] + x[17][vi] - x[18][vi] + x[19][vi] - x[20][vi] + + x[21][vi] + x[22][vi] + x[23][vi] - x[24][vi] - x[25][vi] + + x[26][vi] + x[27][vi] + x[28][vi] - x[29][vi] + x[30][vi] - + x[31][vi] - x[32][vi] - x[33][vi] - x[34][vi] - x[35][vi]; + out[34] = +x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] - x[4][vi] - x[5][vi] - + x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] - x[10][vi] - + x[11][vi] + x[12][vi] - x[13][vi] + x[14][vi] + x[15][vi] - + x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] + x[20][vi] - + x[21][vi] + x[22][vi] + x[23][vi] + x[24][vi] - x[25][vi] - + x[26][vi] + x[27][vi] + x[28][vi] + x[29][vi] - x[30][vi] + + x[31][vi] - x[32][vi] - x[33][vi] - x[34][vi] - x[35][vi]; + out[35] = +x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] - x[5][vi] - + x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] - + x[11][vi] - x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] + + x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] + + x[21][vi] - x[22][vi] + x[23][vi] + x[24][vi] + x[25][vi] - + x[26][vi] - x[27][vi] + x[28][vi] + x[29][vi] + x[30][vi] - + x[31][vi] + x[32][vi] - x[33][vi] - x[34][vi] - x[35][vi]; +#pragma unroll + for (int i = 0; i < 36; i++) { + x[i][vi] = out[i]; + } + } +} + +template +__device__ __forceinline__ void hadamard_mult_thread_28(T x[28]) { // 35 + T out[28]; + out[0] = +x[0] + x[1] + x[2] + x[3] + x[4] + x[5] + x[6] + x[7] + x[8] + + x[9] + x[10] + x[11] + x[12] + x[13] - x[14] + x[15] + x[16] + + x[17] + x[18] + x[19] + x[20] + x[21] + x[22] + x[23] + x[24] + + x[25] + x[26] + x[27]; + out[1] = +x[0] + x[1] + x[2] - x[3] + x[4] + x[5] - x[6] - x[7] - x[8] - + x[9] + x[10] + x[11] - x[12] + x[13] + x[14] - x[15] + x[16] - + x[17] + x[18] + x[19] - x[20] - x[21] - x[22] - x[23] + x[24] + + x[25] - x[26] + x[27]; + out[2] = +x[0] + x[1] + x[2] + x[3] - x[4] + x[5] + x[6] - x[7] - x[8] - + x[9] - x[10] + x[11] + x[12] - x[13] + x[14] + x[15] - x[16] + + x[17] - x[18] + x[19] + x[20] - x[21] - x[22] - x[23] - x[24] + + x[25] + x[26] - x[27]; + out[3] = +x[0] - x[1] + x[2] + x[3] + x[4] - x[5] + x[6] + x[7] - x[8] - + x[9] - x[10] - x[11] + x[12] + x[13] + x[14] - x[15] + x[16] - + x[17] + x[18] - x[19] + x[20] + x[21] - x[22] - x[23] - x[24] - + x[25] + x[26] + x[27]; + out[4] = +x[0] + x[1] - x[2] + x[3] + x[4] + x[5] - x[6] + x[7] + x[8] - + x[9] - x[10] - x[11] - x[12] + x[13] + x[14] + x[15] - x[16] + + x[17] - x[18] + x[19] - x[20] + x[21] + x[22] - x[23] - x[24] - + x[25] - x[26] + x[27]; + out[5] = +x[0] + x[1] + x[2] - x[3] + x[4] + x[5] + x[6] - x[7] + x[8] + + x[9] - x[10] - x[11] - x[12] - x[13] + x[14] + x[15] + x[16] - + x[17] + x[18] - x[19] + x[20] - x[21] + x[22] + x[23] - x[24] - + x[25] - x[26] - x[27]; + out[6] = +x[0] - x[1] + x[2] + x[3] - x[4] + x[5] + x[6] + x[7] - x[8] + + x[9] + x[10] - x[11] - x[12] - x[13] + x[14] - x[15] + x[16] + + x[17] - x[18] + x[19] - x[20] + x[21] - x[22] + x[23] + x[24] - + x[25] - x[26] - x[27]; + out[7] = +x[0] - x[1] - x[2] + x[3] + x[4] - x[5] + x[6] + x[7] + x[8] - + x[9] + x[10] + x[11] - x[12] - x[13] + x[14] - x[15] - x[16] + + x[17] + x[18] - x[19] + x[20] - x[21] + x[22] - x[23] + x[24] + + x[25] - x[26] - x[27]; + out[8] = +x[0] - x[1] - x[2] - x[3] + x[4] + x[5] - x[6] + x[7] + x[8] + + x[9] - x[10] + x[11] + x[12] - x[13] + x[14] - x[15] - x[16] - + x[17] + x[18] + x[19] - x[20] + x[21] - x[22] + x[23] - x[24] + + x[25] + x[26] - x[27]; + out[9] = +x[0] - x[1] - x[2] - x[3] - x[4] + x[5] + x[6] - x[7] + x[8] + + x[9] + x[10] - x[11] + x[12] + x[13] + x[14] - x[15] - x[16] - + x[17] - x[18] + x[19] + x[20] - x[21] + x[22] - x[23] + x[24] - + x[25] + x[26] + x[27]; + out[10] = +x[0] + x[1] - x[2] - x[3] - x[4] - x[5] + x[6] + x[7] - x[8] + + x[9] + x[10] + x[11] - x[12] + x[13] + x[14] + x[15] - x[16] - + x[17] - x[18] - x[19] + x[20] + x[21] - x[22] + x[23] - x[24] + + x[25] - x[26] + x[27]; + out[11] = +x[0] + x[1] + x[2] - x[3] - x[4] - x[5] - x[6] + x[7] + x[8] - + x[9] + x[10] + x[11] + x[12] - x[13] + x[14] + x[15] + x[16] - + x[17] - x[18] - x[19] - x[20] + x[21] + x[22] - x[23] + x[24] - + x[25] + x[26] - x[27]; + out[12] = +x[0] - x[1] + x[2] + x[3] - x[4] - x[5] - x[6] - x[7] + x[8] + + x[9] - x[10] + x[11] + x[12] + x[13] + x[14] - x[15] + x[16] + + x[17] - x[18] - x[19] - x[20] - x[21] + x[22] + x[23] - x[24] + + x[25] - x[26] + x[27]; + out[13] = +x[0] + x[1] - x[2] + x[3] + x[4] - x[5] - x[6] - x[7] - x[8] + + x[9] + x[10] - x[11] + x[12] + x[13] + x[14] + x[15] - x[16] + + x[17] + x[18] - x[19] - x[20] - x[21] - x[22] + x[23] + x[24] - + x[25] + x[26] - x[27]; + out[14] = -x[0] + x[1] + x[2] + x[3] + x[4] + x[5] + x[6] + x[7] + x[8] + + x[9] + x[10] + x[11] + x[12] + x[13] - x[14] - x[15] - x[16] - + x[17] - x[18] - x[19] - x[20] - x[21] - x[22] - x[23] - x[24] - + x[25] - x[26] - x[27]; + out[15] = +x[0] - x[1] + x[2] - x[3] + x[4] + x[5] - x[6] - x[7] - x[8] - + x[9] + x[10] + x[11] - x[12] + x[13] - x[14] - x[15] - x[16] + + x[17] - x[18] - x[19] + x[20] + x[21] + x[22] + x[23] - x[24] - + x[25] + x[26] - x[27]; + out[16] = +x[0] + x[1] - x[2] + x[3] - x[4] + x[5] + x[6] - x[7] - x[8] - + x[9] - x[10] + x[11] + x[12] - x[13] - x[14] - x[15] - x[16] - + x[17] + x[18] - x[19] - x[20] + x[21] + x[22] + x[23] + x[24] - + x[25] - x[26] + x[27]; + out[17] = +x[0] - x[1] + x[2] - x[3] + x[4] - x[5] + x[6] + x[7] - x[8] - + x[9] - x[10] - x[11] + x[12] + x[13] - x[14] + x[15] - x[16] - + x[17] - x[18] + x[19] - x[20] - x[21] + x[22] + x[23] + x[24] + + x[25] - x[26] - x[27]; + out[18] = +x[0] + x[1] - x[2] + x[3] - x[4] + x[5] - x[6] + x[7] + x[8] - + x[9] - x[10] - x[11] - x[12] + x[13] - x[14] - x[15] + x[16] - + x[17] - x[18] - x[19] + x[20] - x[21] - x[22] + x[23] + x[24] + + x[25] + x[26] - x[27]; + out[19] = +x[0] + x[1] + x[2] - x[3] + x[4] - x[5] + x[6] - x[7] + x[8] + + x[9] - x[10] - x[11] - x[12] - x[13] - x[14] - x[15] - x[16] + + x[17] - x[18] - x[19] - x[20] + x[21] - x[22] - x[23] + x[24] + + x[25] + x[26] + x[27]; + out[20] = +x[0] - x[1] + x[2] + x[3] - x[4] + x[5] - x[6] + x[7] - x[8] + + x[9] + x[10] - x[11] - x[12] - x[13] - x[14] + x[15] - x[16] - + x[17] + x[18] - x[19] - x[20] - x[21] + x[22] - x[23] - x[24] + + x[25] + x[26] + x[27]; + out[21] = +x[0] - x[1] - x[2] + x[3] + x[4] - x[5] + x[6] - x[7] + x[8] - + x[9] + x[10] + x[11] - x[12] - x[13] - x[14] + x[15] + x[16] - + x[17] - x[18] + x[19] - x[20] - x[21] - x[22] + x[23] - x[24] - + x[25] + x[26] + x[27]; + out[22] = +x[0] - x[1] - x[2] - x[3] + x[4] + x[5] - x[6] + x[7] - x[8] + + x[9] - x[10] + x[11] + x[12] - x[13] - x[14] + x[15] + x[16] + + x[17] - x[18] - x[19] + x[20] - x[21] - x[22] - x[23] + x[24] - + x[25] - x[26] + x[27]; + out[23] = +x[0] - x[1] - x[2] - x[3] - x[4] + x[5] + x[6] - x[7] + x[8] - + x[9] + x[10] - x[11] + x[12] + x[13] - x[14] + x[15] + x[16] + + x[17] + x[18] - x[19] - x[20] + x[21] - x[22] - x[23] - x[24] + + x[25] - x[26] - x[27]; + out[24] = +x[0] + x[1] - x[2] - x[3] - x[4] - x[5] + x[6] + x[7] - x[8] + + x[9] - x[10] + x[11] - x[12] + x[13] - x[14] - x[15] + x[16] + + x[17] + x[18] + x[19] - x[20] - x[21] + x[22] - x[23] - x[24] - + x[25] + x[26] - x[27]; + out[25] = +x[0] + x[1] + x[2] - x[3] - x[4] - x[5] - x[6] + x[7] + x[8] - + x[9] + x[10] - x[11] + x[12] - x[13] - x[14] - x[15] - x[16] + + x[17] + x[18] + x[19] + x[20] - x[21] - x[22] + x[23] - x[24] - + x[25] - x[26] + x[27]; + out[26] = +x[0] - x[1] + x[2] + x[3] - x[4] - x[5] - x[6] - x[7] + x[8] + + x[9] - x[10] + x[11] - x[12] + x[13] - x[14] + x[15] - x[16] - + x[17] + x[18] + x[19] + x[20] + x[21] - x[22] - x[23] + x[24] - + x[25] - x[26] - x[27]; + out[27] = +x[0] + x[1] - x[2] + x[3] + x[4] - x[5] - x[6] - x[7] - x[8] + + x[9] + x[10] - x[11] + x[12] - x[13] - x[14] - x[15] + x[16] - + x[17] - x[18] + x[19] + x[20] + x[21] + x[22] - x[23] - x[24] + + x[25] - x[26] - x[27]; +#pragma unroll + for (int i = 0; i < 28; i++) { + x[i] = out[i]; + } +} + +template +__device__ __forceinline__ void hadamard_mult_thread_36(T x[36]) { // 4t + T out[36]; + out[0] = +x[0] + x[1] + x[2] + x[3] + x[4] + x[5] + x[6] + x[7] + x[8] + + x[9] + x[10] + x[11] + x[12] + x[13] + x[14] + x[15] + x[16] + + x[17] - x[18] + x[19] + x[20] + x[21] + x[22] + x[23] + x[24] + + x[25] + x[26] + x[27] + x[28] + x[29] + x[30] + x[31] + x[32] + + x[33] + x[34] + x[35]; + out[1] = +x[0] + x[1] + x[2] + x[3] - x[4] + x[5] - x[6] - x[7] - x[8] + + x[9] + x[10] - x[11] - x[12] - x[13] + x[14] - x[15] + x[16] + + x[17] + x[18] - x[19] + x[20] + x[21] - x[22] + x[23] - x[24] - + x[25] - x[26] + x[27] + x[28] - x[29] - x[30] - x[31] + x[32] - + x[33] + x[34] + x[35]; + out[2] = +x[0] + x[1] + x[2] + x[3] + x[4] - x[5] + x[6] - x[7] - x[8] - + x[9] + x[10] + x[11] - x[12] - x[13] - x[14] + x[15] - x[16] + + x[17] + x[18] + x[19] - x[20] + x[21] + x[22] - x[23] + x[24] - + x[25] - x[26] - x[27] + x[28] + x[29] - x[30] - x[31] - x[32] + + x[33] - x[34] + x[35]; + out[3] = +x[0] + x[1] + x[2] + x[3] + x[4] + x[5] - x[6] + x[7] - x[8] - + x[9] - x[10] + x[11] + x[12] - x[13] - x[14] - x[15] + x[16] - + x[17] + x[18] + x[19] + x[20] - x[21] + x[22] + x[23] - x[24] + + x[25] - x[26] - x[27] - x[28] + x[29] + x[30] - x[31] - x[32] - + x[33] + x[34] - x[35]; + out[4] = +x[0] - x[1] + x[2] + x[3] + x[4] + x[5] + x[6] - x[7] + x[8] - + x[9] - x[10] - x[11] + x[12] + x[13] - x[14] - x[15] - x[16] + + x[17] + x[18] - x[19] + x[20] + x[21] - x[22] + x[23] + x[24] - + x[25] + x[26] - x[27] - x[28] - x[29] + x[30] + x[31] - x[32] - + x[33] - x[34] + x[35]; + out[5] = +x[0] + x[1] - x[2] + x[3] + x[4] + x[5] + x[6] + x[7] - x[8] + + x[9] - x[10] - x[11] - x[12] + x[13] + x[14] - x[15] - x[16] - + x[17] + x[18] + x[19] - x[20] + x[21] + x[22] - x[23] + x[24] + + x[25] - x[26] + x[27] - x[28] - x[29] - x[30] + x[31] + x[32] - + x[33] - x[34] - x[35]; + out[6] = +x[0] - x[1] + x[2] - x[3] + x[4] + x[5] + x[6] + x[7] + x[8] - + x[9] + x[10] - x[11] - x[12] - x[13] + x[14] + x[15] - x[16] - + x[17] + x[18] - x[19] + x[20] - x[21] + x[22] + x[23] - x[24] + + x[25] + x[26] - x[27] + x[28] - x[29] - x[30] - x[31] + x[32] + + x[33] - x[34] - x[35]; + out[7] = +x[0] - x[1] - x[2] + x[3] - x[4] + x[5] + x[6] + x[7] + x[8] + + x[9] - x[10] + x[11] - x[12] - x[13] - x[14] + x[15] + x[16] - + x[17] + x[18] - x[19] - x[20] + x[21] - x[22] + x[23] + x[24] - + x[25] + x[26] + x[27] - x[28] + x[29] - x[30] - x[31] - x[32] + + x[33] + x[34] - x[35]; + out[8] = +x[0] - x[1] - x[2] - x[3] + x[4] - x[5] + x[6] + x[7] + x[8] + + x[9] + x[10] - x[11] + x[12] - x[13] - x[14] - x[15] + x[16] + + x[17] + x[18] - x[19] - x[20] - x[21] + x[22] - x[23] + x[24] + + x[25] - x[26] + x[27] + x[28] - x[29] + x[30] - x[31] - x[32] - + x[33] + x[34] + x[35]; + out[9] = +x[0] + x[1] - x[2] - x[3] - x[4] + x[5] - x[6] + x[7] + x[8] + + x[9] + x[10] + x[11] - x[12] + x[13] - x[14] - x[15] - x[16] + + x[17] + x[18] + x[19] - x[20] - x[21] - x[22] + x[23] - x[24] + + x[25] + x[26] - x[27] + x[28] + x[29] - x[30] + x[31] - x[32] - + x[33] - x[34] + x[35]; + out[10] = +x[0] + x[1] + x[2] - x[3] - x[4] - x[5] + x[6] - x[7] + x[8] + + x[9] + x[10] + x[11] + x[12] - x[13] + x[14] - x[15] - x[16] - + x[17] + x[18] + x[19] + x[20] - x[21] - x[22] - x[23] + x[24] - + x[25] + x[26] + x[27] - x[28] + x[29] + x[30] - x[31] + x[32] - + x[33] - x[34] - x[35]; + out[11] = +x[0] - x[1] + x[2] + x[3] - x[4] - x[5] - x[6] + x[7] - x[8] + + x[9] + x[10] + x[11] + x[12] + x[13] - x[14] + x[15] - x[16] - + x[17] + x[18] - x[19] + x[20] + x[21] - x[22] - x[23] - x[24] + + x[25] - x[26] + x[27] + x[28] - x[29] + x[30] + x[31] - x[32] + + x[33] - x[34] - x[35]; + out[12] = +x[0] - x[1] - x[2] + x[3] + x[4] - x[5] - x[6] - x[7] + x[8] - + x[9] + x[10] + x[11] + x[12] + x[13] + x[14] - x[15] + x[16] - + x[17] + x[18] - x[19] - x[20] + x[21] + x[22] - x[23] - x[24] - + x[25] + x[26] - x[27] + x[28] + x[29] - x[30] + x[31] + x[32] - + x[33] + x[34] - x[35]; + out[13] = +x[0] - x[1] - x[2] - x[3] + x[4] + x[5] - x[6] - x[7] - x[8] + + x[9] - x[10] + x[11] + x[12] + x[13] + x[14] + x[15] - x[16] + + x[17] + x[18] - x[19] - x[20] - x[21] + x[22] + x[23] - x[24] - + x[25] - x[26] + x[27] - x[28] + x[29] + x[30] - x[31] + x[32] + + x[33] - x[34] + x[35]; + out[14] = +x[0] + x[1] - x[2] - x[3] - x[4] + x[5] + x[6] - x[7] - x[8] - + x[9] + x[10] - x[11] + x[12] + x[13] + x[14] + x[15] + x[16] - + x[17] + x[18] + x[19] - x[20] - x[21] - x[22] + x[23] + x[24] - + x[25] - x[26] - x[27] + x[28] - x[29] + x[30] + x[31] - x[32] + + x[33] + x[34] - x[35]; + out[15] = +x[0] - x[1] + x[2] - x[3] - x[4] - x[5] + x[6] + x[7] - x[8] - + x[9] - x[10] + x[11] - x[12] + x[13] + x[14] + x[15] + x[16] + + x[17] + x[18] - x[19] + x[20] - x[21] - x[22] - x[23] + x[24] + + x[25] - x[26] - x[27] - x[28] + x[29] - x[30] + x[31] + x[32] - + x[33] + x[34] + x[35]; + out[16] = +x[0] + x[1] - x[2] + x[3] - x[4] - x[5] - x[6] + x[7] + x[8] - + x[9] - x[10] - x[11] + x[12] - x[13] + x[14] + x[15] + x[16] + + x[17] + x[18] + x[19] - x[20] + x[21] - x[22] - x[23] - x[24] + + x[25] + x[26] - x[27] - x[28] - x[29] + x[30] - x[31] + x[32] + + x[33] - x[34] + x[35]; + out[17] = +x[0] + x[1] + x[2] - x[3] + x[4] - x[5] - x[6] - x[7] + x[8] + + x[9] - x[10] - x[11] - x[12] + x[13] - x[14] + x[15] + x[16] + + x[17] + x[18] + x[19] + x[20] - x[21] + x[22] - x[23] - x[24] - + x[25] + x[26] + x[27] - x[28] - x[29] - x[30] + x[31] - x[32] + + x[33] + x[34] - x[35]; + out[18] = -x[0] + x[1] + x[2] + x[3] + x[4] + x[5] + x[6] + x[7] + x[8] + + x[9] + x[10] + x[11] + x[12] + x[13] + x[14] + x[15] + x[16] + + x[17] - x[18] - x[19] - x[20] - x[21] - x[22] - x[23] - x[24] - + x[25] - x[26] - x[27] - x[28] - x[29] - x[30] - x[31] - x[32] - + x[33] - x[34] - x[35]; + out[19] = +x[0] - x[1] + x[2] + x[3] - x[4] + x[5] - x[6] - x[7] - x[8] + + x[9] + x[10] - x[11] - x[12] - x[13] + x[14] - x[15] + x[16] + + x[17] - x[18] - x[19] - x[20] - x[21] + x[22] - x[23] + x[24] + + x[25] + x[26] - x[27] - x[28] + x[29] + x[30] + x[31] - x[32] + + x[33] - x[34] - x[35]; + out[20] = +x[0] + x[1] - x[2] + x[3] + x[4] - x[5] + x[6] - x[7] - x[8] - + x[9] + x[10] + x[11] - x[12] - x[13] - x[14] + x[15] - x[16] + + x[17] - x[18] - x[19] - x[20] - x[21] - x[22] + x[23] - x[24] + + x[25] + x[26] + x[27] - x[28] - x[29] + x[30] + x[31] + x[32] - + x[33] + x[34] - x[35]; + out[21] = +x[0] + x[1] + x[2] - x[3] + x[4] + x[5] - x[6] + x[7] - x[8] - + x[9] - x[10] + x[11] + x[12] - x[13] - x[14] - x[15] + x[16] - + x[17] - x[18] - x[19] - x[20] - x[21] - x[22] - x[23] + x[24] - + x[25] + x[26] + x[27] + x[28] - x[29] - x[30] + x[31] + x[32] + + x[33] - x[34] + x[35]; + out[22] = +x[0] - x[1] + x[2] + x[3] - x[4] + x[5] + x[6] - x[7] + x[8] - + x[9] - x[10] - x[11] + x[12] + x[13] - x[14] - x[15] - x[16] + + x[17] - x[18] + x[19] - x[20] - x[21] - x[22] - x[23] - x[24] + + x[25] - x[26] + x[27] + x[28] + x[29] - x[30] - x[31] + x[32] + + x[33] + x[34] - x[35]; + out[23] = +x[0] + x[1] - x[2] + x[3] + x[4] - x[5] + x[6] + x[7] - x[8] + + x[9] - x[10] - x[11] - x[12] + x[13] + x[14] - x[15] - x[16] - + x[17] - x[18] - x[19] + x[20] - x[21] - x[22] - x[23] - x[24] - + x[25] + x[26] - x[27] + x[28] + x[29] + x[30] - x[31] - x[32] + + x[33] + x[34] + x[35]; + out[24] = +x[0] - x[1] + x[2] - x[3] + x[4] + x[5] - x[6] + x[7] + x[8] - + x[9] + x[10] - x[11] - x[12] - x[13] + x[14] + x[15] - x[16] - + x[17] - x[18] + x[19] - x[20] + x[21] - x[22] - x[23] - x[24] - + x[25] - x[26] + x[27] - x[28] + x[29] + x[30] + x[31] - x[32] - + x[33] + x[34] + x[35]; + out[25] = +x[0] - x[1] - x[2] + x[3] - x[4] + x[5] + x[6] - x[7] + x[8] + + x[9] - x[10] + x[11] - x[12] - x[13] - x[14] + x[15] + x[16] - + x[17] - x[18] + x[19] + x[20] - x[21] + x[22] - x[23] - x[24] - + x[25] - x[26] - x[27] + x[28] - x[29] + x[30] + x[31] + x[32] - + x[33] - x[34] + x[35]; + out[26] = +x[0] - x[1] - x[2] - x[3] + x[4] - x[5] + x[6] + x[7] - x[8] + + x[9] + x[10] - x[11] + x[12] - x[13] - x[14] - x[15] + x[16] + + x[17] - x[18] + x[19] + x[20] + x[21] - x[22] + x[23] - x[24] - + x[25] - x[26] - x[27] - x[28] + x[29] - x[30] + x[31] + x[32] + + x[33] - x[34] - x[35]; + out[27] = +x[0] + x[1] - x[2] - x[3] - x[4] + x[5] - x[6] + x[7] + x[8] - + x[9] + x[10] + x[11] - x[12] + x[13] - x[14] - x[15] - x[16] + + x[17] - x[18] - x[19] + x[20] + x[21] + x[22] - x[23] + x[24] - + x[25] - x[26] - x[27] - x[28] - x[29] + x[30] - x[31] + x[32] + + x[33] + x[34] - x[35]; + out[28] = +x[0] + x[1] + x[2] - x[3] - x[4] - x[5] + x[6] - x[7] + x[8] + + x[9] - x[10] + x[11] + x[12] - x[13] + x[14] - x[15] - x[16] - + x[17] - x[18] - x[19] - x[20] + x[21] + x[22] + x[23] - x[24] + + x[25] - x[26] - x[27] - x[28] - x[29] - x[30] + x[31] - x[32] + + x[33] + x[34] + x[35]; + out[29] = +x[0] - x[1] + x[2] + x[3] - x[4] - x[5] - x[6] + x[7] - x[8] + + x[9] + x[10] - x[11] + x[12] + x[13] - x[14] + x[15] - x[16] - + x[17] - x[18] + x[19] - x[20] - x[21] + x[22] + x[23] + x[24] - + x[25] + x[26] - x[27] - x[28] - x[29] - x[30] - x[31] + x[32] - + x[33] + x[34] + x[35]; + out[30] = +x[0] - x[1] - x[2] + x[3] + x[4] - x[5] - x[6] - x[7] + x[8] - + x[9] + x[10] + x[11] - x[12] + x[13] + x[14] - x[15] + x[16] - + x[17] - x[18] + x[19] + x[20] - x[21] - x[22] + x[23] + x[24] + + x[25] - x[26] + x[27] - x[28] - x[29] - x[30] - x[31] - x[32] + + x[33] - x[34] + x[35]; + out[31] = +x[0] - x[1] - x[2] - x[3] + x[4] + x[5] - x[6] - x[7] - x[8] + + x[9] - x[10] + x[11] + x[12] - x[13] + x[14] + x[15] - x[16] + + x[17] - x[18] + x[19] + x[20] + x[21] - x[22] - x[23] + x[24] + + x[25] + x[26] - x[27] + x[28] - x[29] - x[30] - x[31] - x[32] - + x[33] + x[34] - x[35]; + out[32] = +x[0] + x[1] - x[2] - x[3] - x[4] + x[5] + x[6] - x[7] - x[8] - + x[9] + x[10] - x[11] + x[12] + x[13] - x[14] + x[15] + x[16] - + x[17] - x[18] - x[19] + x[20] + x[21] + x[22] - x[23] - x[24] + + x[25] + x[26] + x[27] - x[28] + x[29] - x[30] - x[31] - x[32] - + x[33] - x[34] + x[35]; + out[33] = +x[0] - x[1] + x[2] - x[3] - x[4] - x[5] + x[6] + x[7] - x[8] - + x[9] - x[10] + x[11] - x[12] + x[13] + x[14] - x[15] + x[16] + + x[17] - x[18] + x[19] - x[20] + x[21] + x[22] + x[23] - x[24] - + x[25] + x[26] + x[27] + x[28] - x[29] + x[30] - x[31] - x[32] - + x[33] - x[34] - x[35]; + out[34] = +x[0] + x[1] - x[2] + x[3] - x[4] - x[5] - x[6] + x[7] + x[8] - + x[9] - x[10] - x[11] + x[12] - x[13] + x[14] + x[15] - x[16] + + x[17] - x[18] - x[19] + x[20] - x[21] + x[22] + x[23] + x[24] - + x[25] - x[26] + x[27] + x[28] + x[29] - x[30] + x[31] - x[32] - + x[33] - x[34] - x[35]; + out[35] = +x[0] + x[1] + x[2] - x[3] + x[4] - x[5] - x[6] - x[7] + x[8] + + x[9] - x[10] - x[11] - x[12] + x[13] - x[14] + x[15] + x[16] - + x[17] - x[18] - x[19] - x[20] + x[21] - x[22] + x[23] + x[24] + + x[25] - x[26] - x[27] + x[28] + x[29] + x[30] - x[31] + x[32] - + x[33] - x[34] - x[35]; +#pragma unroll + for (int i = 0; i < 36; i++) { + x[i] = out[i]; + } +} + +template +__device__ __forceinline__ void hadamard_mult_thread_chunk_28( + T x[kNChunks][28]) { +#pragma unroll + for (int c = 0; c < kNChunks; ++c) { + hadamard_mult_thread_28(x[c]); + } +} + +template +__device__ __forceinline__ void hadamard_mult_thread_chunk_36( + T x[kNChunks][36]) { +#pragma unroll + for (int c = 0; c < kNChunks; ++c) { + hadamard_mult_thread_36(x[c]); + } +} + +template +inline __device__ void load_input(const T *x, + T x_vals[kNChunks][VecSize], + int dim) { + using vec_t = typename BytesToType::Type; +#pragma unroll + for (int c = 0; c < kNChunks; ++c) { + int offset; + if constexpr (UseDiagonalBlockMatrix) { + static_assert(kNChunks == 1); + offset = blockIdx.y * blockDim.x + threadIdx.x; + } else { + offset = c * blockDim.x + threadIdx.x; + } + if (offset * VecSize < dim) { + reinterpret_cast(x_vals)[c] = + reinterpret_cast(x)[offset]; + } + } +} + +template +__forceinline__ __device__ OutType QuantHelperFunc(const InType input, + const float scale, + const int round_type, + const float max_bound, + const float min_bound) { + float quant_value = max_bound * scale * static_cast(input); + + if (round_type == 0) { + quant_value = static_cast(rint(quant_value)); + } else { + quant_value = static_cast(round(quant_value)); + } + return static_cast( + ClipFunc(quant_value, min_bound, max_bound)); +} + +template +inline __device__ void smooth_quant_store_output(OutT *out, + const T *shift, + const T *smooth, + T out_vals[kNChunks][VecSize], + const float quant_scale, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + const int dim) { + using DstVec = AlignedVector; + using Vec = AlignedVector; + DstVec dst_vec; + Vec shift_vec; + Vec smooth_vec; +#pragma unroll + for (int c = 0; c < kNChunks; ++c) { + int base_idx; + if constexpr (UseDiagonalBlockMatrix) { + base_idx = blockIdx.y * blockDim.x + threadIdx.x; + } else { + base_idx = c * blockDim.x + threadIdx.x; + } + const int idx = base_idx * VecSize; + if (idx < dim) { + Load(shift + idx, &shift_vec); + Load(smooth + idx, &smooth_vec); +#pragma unroll + for (int vi = 0; vi < VecSize; ++vi) { + out_vals[c][vi] = (out_vals[c][vi] + shift_vec[vi]) * smooth_vec[vi]; + dst_vec[vi] = + QuantHelperFunc(static_cast(out_vals[c][vi]), + quant_scale, + quant_round_type, + quant_max_bound, + quant_min_bound); + } + Store(dst_vec, out + idx); + } + } +} + +template +inline __device__ void quant_store_output(OutT *out, + T out_vals[kNChunks][VecSize], + const float quant_scale, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + const int dim) { + using DstVec = AlignedVector; + using Vec = AlignedVector; + DstVec dst_vec; +#pragma unroll + for (int c = 0; c < kNChunks; ++c) { + int base_idx; + if constexpr (UseDiagonalBlockMatrix) { + base_idx = blockIdx.y * blockDim.x + threadIdx.x; + } else { + base_idx = c * blockDim.x + threadIdx.x; + } + const int idx = base_idx * VecSize; + if (idx < dim) { +#pragma unroll + for (int vi = 0; vi < VecSize; ++vi) { + // out_vals[c][vi] = (out_vals[c][vi] + shift_vec[vi]) * smooth_vec[vi]; + dst_vec[vi] = + QuantHelperFunc(static_cast(out_vals[c][vi]), + quant_scale, + quant_round_type, + quant_max_bound, + quant_min_bound); + } + Store(dst_vec, out + idx); + } + } +} + +template +inline __device__ void store_output(OutT *out, + T out_vals[kNChunks][VecSize], + int dim) { + using vec_t = typename BytesToType::Type; +#pragma unroll + for (int c = 0; c < kNChunks; ++c) { + int offset; + if constexpr (UseDiagonalBlockMatrix) { + offset = blockIdx.y * blockDim.x + threadIdx.x; + } else { + offset = c * blockDim.x + threadIdx.x; + } + if (offset * VecSize < dim) { + reinterpret_cast(out)[offset] = + reinterpret_cast(out_vals)[c]; + } + } +} + +template +__device__ __forceinline__ void hadamard_mult_thread_transpose( + T x[1 << kLogN][kNChunks]) { + constexpr int N = 1 << kLogN; +#pragma unroll + for (int i = 0; i < kLogN; ++i) { + const int stride = 1 << i; +#pragma unroll + for (int j = 0; j < N / 2; ++j) { + const int lo = j & (stride - 1); + const int idx = (j - lo) * 2 + lo; +#pragma unroll + for (int c = 0; c < kNChunks; ++c) { + const T a = x[idx][c]; + const T b = x[idx + stride][c]; + x[idx][c] = a + b; + x[idx + stride][c] = a - b; + } + } + } +} + +template +__device__ __forceinline__ void hadamard_mult_thread( + T x[kNChunks][1 << kLogN]) { + constexpr int N = 1 << kLogN; +#pragma unroll + for (int i = 0; i < kLogN; ++i) { + const int stride = 1 << i; +#pragma unroll + for (int j = 0; j < N / 2; ++j) { + const int lo = j & (stride - 1); + const int idx = (j - lo) * 2 + lo; +#pragma unroll + for (int c = 0; c < kNChunks; ++c) { + const T a = x[c][idx]; + const T b = x[c][idx + stride]; + x[c][idx] = a + b; + x[c][idx + stride] = a - b; + } + } + } +} + +template +__device__ __forceinline__ void hadamard_mult_warp(T x[kNChunks][kNItems]) { + constexpr int N = 1 << kLogWarpSize; + int lane_id = threadIdx.x % N; +#pragma unroll + for (int step = kStepStart; step < kLogWarpSize; ++step) { + const int lane_mask = 1 << step; + const T sign = (lane_id & lane_mask) ? -1.f : 1.f; +#pragma unroll + for (int c = 0; c < kNChunks; ++c) { +#pragma unroll + for (int i = 0; i < kNItems; ++i) { + T x_val_other = __shfl_xor_sync(FULL_MASK, x[c][i], lane_mask); + x[c][i] = sign * x[c][i] + x_val_other; + } + } + } +} + +template +inline __device__ void exchange_smem_pre(T x_vals[kNChunks][kNElts], + vec_t *smem) { + // kNChunks表示整体需要多少次循环才能处理完 + // kChunksPerExchange表示每次循环可以处理多少个chunk + // kNExchanges表示多少次循环才能处理完所有数据 + constexpr int kNThreads = kWarpSize * kNWarps; + const int warp_id = threadIdx.x / kWarpSize; + const int lane_id = threadIdx.x % kWarpSize; + const int row_t = threadIdx.x % kNWarps; + const int col_t = threadIdx.x / kNWarps; +#pragma unroll + for (int c0 = 0; c0 < kNChunks / kChunksPerExchange; ++c0) { + // 搬多少次chunk算完所有数据 + __syncthreads(); +#pragma unroll + for (int c1 = 0; c1 < kChunksPerExchange; ++c1) { + // 每次循环搬多少数据把smem塞满 + // smem[c1 * kNThreads + (Pre ? warp_id * kWarpSize + lane_id ^ warp_id : + // row_t * kWarpSize + col_t ^ row_t)] = + // *reinterpret_cast(x_vals[c0 * kChunksPerExchange + c1]); + smem[c1 * kNThreads + + (Pre ? warp_id * kWarpSize + lane_id : row_t * kWarpSize + col_t)] = + *reinterpret_cast(x_vals[c0 * kChunksPerExchange + c1]); + } + __syncthreads(); +#pragma unroll + for (int c1 = 0; c1 < kChunksPerExchange; ++c1) { + // *reinterpret_cast(x_vals[c0 * kChunksPerExchange + c1]) = + // smem[c1 * kNThreads + (Pre ? row_t * kWarpSize + col_t ^ row_t : + // warp_id * kWarpSize + lane_id ^ warp_id)]; + *reinterpret_cast(x_vals[c0 * kChunksPerExchange + c1]) = + smem[c1 * kNThreads + (Pre ? row_t * kWarpSize + col_t + : warp_id * kWarpSize + lane_id)]; + } + } +} + +constexpr int cilog2(int val) { return val > 0 ? 1 + cilog2(val >> 1) : -1; } + +template +__global__ __launch_bounds__(kThreads) void moe_fast_hardamard_kernel( + const T *x, + const int64_t *expert_idx_per_token, + const T *shift, + const T *smooth, + const float *quant_scales, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + const int64_t token_num, + const int64_t dim, + OutT *out) { + using vec_t = typename BytesToType::Type; + constexpr int kLogVecSize = cilog2(VecSize); + constexpr int kLogWarpSize = cilog2(32); + constexpr int kWarpSize = 32; + constexpr int kNWarps = kThreads / kWarpSize; + constexpr int kLogNWarps = cilog2(kNWarps); + constexpr int kLogNChunks = cilog2(kNChunks); + + extern __shared__ char smem_[]; + vec_t *smem_exchange = reinterpret_cast(smem_); + + for (int token_id = blockIdx.x; token_id < token_num; token_id += gridDim.x) { + const T *x_now = x + token_id * dim; + OutT *out_now = out + token_id * dim; + T init_value = static_cast(0.f); + T x_vals[kNChunks][VecSize] = {init_value}; + + load_input( + x_now, x_vals, dim); +#ifdef DEBUG_HARDAMARD + if (blockIdx.x == 0 && threadIdx.x == 0) { + for (int i = 0; i < 1; ++i) { + printf("chunk_id0: %d\n", i); + for (int j = 0; j < VecSize; ++j) { + printf("%f ", (float)x_vals[i][j]); + } + printf("\n"); + } + } + __syncthreads(); +#endif + + hadamard_mult_thread(x_vals); +#ifdef DEBUG_HARDAMARD + if (blockIdx.x == 0 && threadIdx.x == 0) { + for (int i = 0; i < 1; ++i) { + printf("chunk_id1: %d, kLogVecSize: %d\n", i, kLogVecSize); + for (int j = 0; j < VecSize; ++j) { + printf("%f ", (float)x_vals[i][j]); + } + printf("\n"); + } + } + __syncthreads(); +#endif + hadamard_mult_warp(x_vals); +#ifdef DEBUG_HARDAMARD + if (blockIdx.x == 0 && threadIdx.x == 0) { + for (int i = 0; i < 1; ++i) { + printf("chunk_id2: %d\n", i); + for (int j = 0; j < VecSize; ++j) { + printf("%f ", (float)x_vals[i][j]); + } + printf("\n"); + } + } + __syncthreads(); +#endif + if constexpr (kNWarps > 1) { + // 先让连续的NWARPS个线程拿到其余warps上的数据 + exchange_smem_pre(x_vals, smem_exchange); + // 交叉计算 + hadamard_mult_warp(x_vals); + // 再换回来 + exchange_smem_pre(x_vals, smem_exchange); + } + if constexpr (kNChunks > 1) { + if constexpr (kNChunks == 28) { + hadamard_mult_thread_28_transpose(x_vals); + } else if constexpr (kNChunks == 36) { + hadamard_mult_thread_36_transpose(x_vals); + } else { + constexpr int kLogNChunks = cilog2(kNChunks); + static_assert(1 << kLogNChunks == kNChunks, + "kNChunks must be a power of 2"); + hadamard_mult_thread_transpose(x_vals); + } + } + if (quant_scales) { + int64_t expert_id = expert_idx_per_token[token_id]; + float quant_scale = quant_scales[expert_id]; + if (shift) { + smooth_quant_store_output(out_now, + shift, + smooth, + x_vals, + quant_scale, + quant_round_type, + quant_max_bound, + quant_min_bound, + dim); + } else { + quant_store_output( + out_now, + x_vals, + quant_scale, + quant_round_type, + quant_max_bound, + quant_min_bound, + dim); + } + } else { + store_output( + out_now, x_vals, dim); + } + } +} + +template +__global__ __launch_bounds__(kThreads) void masked_moe_fast_hardamard_kernel( + const T *x, + const int64_t *recv_expert_count, + const T *shift, + const T *smooth, + const float *quant_scales, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + const int64_t token_num, + const int64_t dim, + const int num_max_tokens_per_expert, + OutT *out) { + using vec_t = typename BytesToType::Type; + constexpr int kLogVecSize = cilog2(VecSize); + constexpr int kLogWarpSize = cilog2(32); + constexpr int kWarpSize = 32; + constexpr int kNWarps = kThreads / kWarpSize; + constexpr int kLogNWarps = cilog2(kNWarps); + constexpr int kLogNChunks = cilog2(kNChunks); + + extern __shared__ char smem_[]; + vec_t *smem_exchange = reinterpret_cast(smem_); + + for (int token_id = blockIdx.x; token_id < token_num; token_id += gridDim.x) { + const auto token_idx_in_expert = token_id % num_max_tokens_per_expert; + const auto expert_id = token_id / num_max_tokens_per_expert; + if (token_idx_in_expert >= recv_expert_count[expert_id]) { + auto next_expert_start_idx = (expert_id + 1) * num_max_tokens_per_expert; + auto num_iters_to_next_expert = + (next_expert_start_idx - token_id - 1) / gridDim.x; + token_id += num_iters_to_next_expert * gridDim.x; + continue; + } + const T *x_now = x + token_id * dim; + OutT *out_now = out + token_id * dim; + T init_value = static_cast(0.f); + T x_vals[kNChunks][VecSize] = {init_value}; + + load_input( + x_now, x_vals, dim); +#ifdef DEBUG_HARDAMARD + if (blockIdx.x == 0 && threadIdx.x == 0) { + for (int i = 0; i < 1; ++i) { + printf("chunk_id0: %d\n", i); + for (int j = 0; j < VecSize; ++j) { + printf("%f ", (float)x_vals[i][j]); + } + printf("\n"); + } + } + __syncthreads(); +#endif + + hadamard_mult_thread(x_vals); +#ifdef DEBUG_HARDAMARD + if (blockIdx.x == 0 && threadIdx.x == 0) { + for (int i = 0; i < 1; ++i) { + printf("chunk_id1: %d, kLogVecSize: %d\n", i, kLogVecSize); + for (int j = 0; j < VecSize; ++j) { + printf("%f ", (float)x_vals[i][j]); + } + printf("\n"); + } + } + __syncthreads(); +#endif + hadamard_mult_warp(x_vals); +#ifdef DEBUG_HARDAMARD + if (blockIdx.x == 0 && threadIdx.x == 0) { + for (int i = 0; i < 1; ++i) { + printf("chunk_id2: %d\n", i); + for (int j = 0; j < VecSize; ++j) { + printf("%f ", (float)x_vals[i][j]); + } + printf("\n"); + } + } + __syncthreads(); +#endif + if constexpr (kNWarps > 1) { + // 先让连续的NWARPS个线程拿到其余warps上的数据 + exchange_smem_pre(x_vals, smem_exchange); + // 交叉计算 + hadamard_mult_warp(x_vals); + // 再换回来 + exchange_smem_pre(x_vals, smem_exchange); + } + if constexpr (kNChunks > 1) { + if constexpr (kNChunks == 28) { + hadamard_mult_thread_28_transpose(x_vals); + } else if constexpr (kNChunks == 36) { + hadamard_mult_thread_36_transpose(x_vals); + } else { + constexpr int kLogNChunks = cilog2(kNChunks); + static_assert(1 << kLogNChunks == kNChunks, + "kNChunks must be a power of 2"); + hadamard_mult_thread_transpose(x_vals); + } + } + if (quant_scales) { + float quant_scale = quant_scales[expert_id]; + if (shift) { + smooth_quant_store_output(out_now, + shift, + smooth, + x_vals, + quant_scale, + quant_round_type, + quant_max_bound, + quant_min_bound, + dim); + } else { + quant_store_output( + out_now, + x_vals, + quant_scale, + quant_round_type, + quant_max_bound, + quant_min_bound, + dim); + } + } else { + store_output( + out_now, x_vals, dim); + } + } +} + +template +void MoeFastHardamardImplWrapper(const T *x, + const int64_t *expert_idx_per_token, + const int64_t *recv_expert_count, + const T *shift, + const T *smooth, + const float *quant_scales, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + const int64_t token_num, + const int64_t dim, + const int num_max_tokens_per_expert, + bool used_in_ep_low_latency, + OutT *out, + cudaStream_t stream) { + using nv_type = typename nv_type_traits::type; + using out_type = typename nv_type_traits::type; + constexpr int kNBytes = sizeof(T); + constexpr int N = 1 << kLogN; // pad + constexpr int kSmemSize = std::min(N * kNBytes, 32 * 1024); + constexpr int kRounds = N * kNBytes / kSmemSize; + constexpr int kChunksPerSmemSize = kSmemSize / (kThreads * VecSize * kNBytes); + VLOG(1) << "real_dim: " << dim << ", N: " << N; + VLOG(1) << "kNChunks: " << kNChunks; + VLOG(1) << "kNBytes: " << kNBytes; + VLOG(1) << "kSmemSize: " << kSmemSize; + VLOG(1) << "kRounds: " << kRounds; + VLOG(1) << "kChunksPerSmemSize: " << kChunksPerSmemSize; + const int dev_id = 0; + int sm_count; + int act_blocks_per_sm; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + + if (used_in_ep_low_latency) { + auto masked_kernel = + masked_moe_fast_hardamard_kernel; + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &act_blocks_per_sm, masked_kernel, kThreads, kSmemSize); + const int num_blocks_per_wave = sm_count * act_blocks_per_sm; + dim3 grid; + grid.x = min(static_cast(num_blocks_per_wave), token_num); + if constexpr (UseDiagonalBlockMatrix) { + grid.y = ceil(dim / (kThreads * VecSize)); + } + masked_kernel<<>>( + reinterpret_cast(x), + recv_expert_count, + reinterpret_cast(shift), + reinterpret_cast(smooth), + quant_scales, + quant_round_type, + quant_max_bound, + quant_min_bound, + token_num, + dim, + num_max_tokens_per_expert, + reinterpret_cast(out)); + } else { + auto kernel = moe_fast_hardamard_kernel; + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &act_blocks_per_sm, kernel, kThreads, kSmemSize); + const int num_blocks_per_wave = sm_count * act_blocks_per_sm; + dim3 grid; + grid.x = min(static_cast(num_blocks_per_wave), token_num); + if constexpr (UseDiagonalBlockMatrix) { + grid.y = ceil(dim / (kThreads * VecSize)); + } + kernel<<>>( + reinterpret_cast(x), + expert_idx_per_token, + reinterpret_cast(shift), + reinterpret_cast(smooth), + quant_scales, + quant_round_type, + quant_max_bound, + quant_min_bound, + token_num, + dim, + reinterpret_cast(out)); + } +} diff --git a/custom_ops/gpu_ops/moe/moe_fast_hardamard_impl_common.h b/custom_ops/gpu_ops/moe/moe_fast_hardamard_impl_common.h new file mode 100644 index 0000000000..d937a252be --- /dev/null +++ b/custom_ops/gpu_ops/moe/moe_fast_hardamard_impl_common.h @@ -0,0 +1,164 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "helper.h" + +#define FULL_MASK 0xffffffff + +struct uint8 { + uint4 u; + uint4 v; +}; + +template +struct BytesToType {}; + +template <> +struct BytesToType<32> { + using Type = uint8; + static_assert(sizeof(Type) == 32); +}; + +template <> +struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); +}; + +template <> +struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); +}; + +template <> +struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); +}; + +template <> +struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); +}; + +template <> +struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); +}; + +template +struct nv_type_traits { + using type = T; +}; + +template <> +struct nv_type_traits { + using type = half; +}; + +template <> +struct nv_type_traits { + using type = __nv_bfloat16; +}; + +template <> +struct nv_type_traits { + using type = int8_t; +}; + +#define DISPATCH_SP_logN(logN, kLogN, ...) \ + if (logN == 10) { \ + constexpr int kLogN = 10; \ + __VA_ARGS__ \ + } else if (logN == 9) { \ + constexpr int kLogN = 9; \ + __VA_ARGS__ \ + } else if (logN == 8) { \ + constexpr int kLogN = 8; \ + __VA_ARGS__ \ + } else if (logN == 7) { \ + constexpr int kLogN = 7; \ + __VA_ARGS__ \ + } else { \ + PADDLE_THROW( \ + phi::errors::Unimplemented("logN = %d is unsupported!", logN)); \ + } + +#define DISPATCH_SP_VS(vec_size, VEC_SIZE, ...) \ + if (vec_size == 16) { \ + constexpr int VEC_SIZE = 16; \ + __VA_ARGS__ \ + } else if (vec_size == 8) { \ + constexpr int VEC_SIZE = 8; \ + __VA_ARGS__ \ + } else if (vec_size == 4) { \ + constexpr int VEC_SIZE = 4; \ + __VA_ARGS__ \ + } else if (vec_size == 2) { \ + constexpr int VEC_SIZE = 2; \ + __VA_ARGS__ \ + } else if (vec_size == 1) { \ + constexpr int VEC_SIZE = 1; \ + __VA_ARGS__ \ + } else { \ + PADDLE_THROW(phi::errors::Unimplemented("vec_size = %d is unsupported!", \ + vec_size)); \ + } + +#define DISPATCH_logN(logN, kLogN, ...) \ + if (logN == 11) { \ + constexpr int kLogN = 11; \ + __VA_ARGS__ \ + } else if (logN == 12) { \ + constexpr int kLogN = 12; \ + __VA_ARGS__ \ + } else if (logN == 13) { \ + constexpr int kLogN = 13; \ + __VA_ARGS__ \ + } else if (logN == 14) { \ + constexpr int kLogN = 14; \ + __VA_ARGS__ \ + } else { \ + PADDLE_THROW(phi::errors::Unimplemented("unsupported logN")); \ + } + +template +void MoeFastHardamardImplWrapper(const T *x, + const int64_t *expert_idx_per_token, + const int64_t *recv_expert_count, + const T *shift, + const T *smooth, + const float *quant_scales, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + const int64_t token_num, + const int64_t dim, + const int num_max_tokens_per_expert, + bool used_in_ep_low_latency, + OutT *out, + cudaStream_t stream); diff --git a/custom_ops/gpu_ops/moe/moe_fast_hardamard_kernel.cu b/custom_ops/gpu_ops/moe/moe_fast_hardamard_kernel.cu new file mode 100644 index 0000000000..02302b0a00 --- /dev/null +++ b/custom_ops/gpu_ops/moe/moe_fast_hardamard_kernel.cu @@ -0,0 +1,230 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "helper.h" +#include "moe_fast_hardamard_impl_common.h" + +template +void MoeFastHardamardWrapper(const T *x_data, + const int64_t *expert_idx_per_token, + const int64_t *recv_expert_count, + const T *shift, + const T *smooth, + const float *quant_scales, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + const int64_t token_num, + const int64_t dim, + const int num_max_tokens_per_expert, + bool used_in_ep_low_latency, + const int hadamard_block_size, + OutT *out, + cudaStream_t &stream) { + bool FLAGS_hardamard_use_diagonal_block_matrix = true; + + constexpr int kThreads = 128; + if (FLAGS_hardamard_use_diagonal_block_matrix) { + const int VecSize = hadamard_block_size / kThreads; + const int logN = int(ceil(std::log2(kThreads * VecSize))); + constexpr int kNChunks = 1; + DISPATCH_SP_VS(VecSize, VEC_SIZE, {DISPATCH_SP_logN(logN, kLogN, { + MoeFastHardamardImplWrapper( + x_data, + expert_idx_per_token, + recv_expert_count, + shift, + smooth, + quant_scales, + quant_round_type, + quant_max_bound, + quant_min_bound, + token_num, + dim, + num_max_tokens_per_expert, + used_in_ep_low_latency, + out, + stream); + })}); + } else { + if (!((dim / 28) & (dim / 28 - 1))) { + VLOG(1) << "28 * 2^n"; + const int logN = int(ceil(std::log2(dim / 28))); + constexpr int kNChunks = 28; + DISPATCH_SP_logN(logN, kLogN, { + constexpr int VecSize = (1 << kLogN) / kThreads; + MoeFastHardamardImplWrapper(x_data, + expert_idx_per_token, + recv_expert_count, + shift, + smooth, + quant_scales, + quant_round_type, + quant_max_bound, + quant_min_bound, + token_num, + dim, + num_max_tokens_per_expert, + used_in_ep_low_latency, + out, + stream); + }); + } else if (!((dim / 36) & (dim / 36 - 1))) { + VLOG(1) << "36 * 2^n"; + const int logN = int(ceil(std::log2(dim / 36))); + constexpr int kNChunks = 36; + DISPATCH_SP_logN(logN, kLogN, { + constexpr int VecSize = (1 << kLogN) / kThreads; + MoeFastHardamardImplWrapper(x_data, + expert_idx_per_token, + recv_expert_count, + shift, + smooth, + quant_scales, + quant_round_type, + quant_max_bound, + quant_min_bound, + token_num, + dim, + num_max_tokens_per_expert, + used_in_ep_low_latency, + out, + stream); + }); + } else { + VLOG(1) << "2^n"; + const int logN = int(ceil(std::log2(dim))); + constexpr int VecSize = 16 / sizeof(T); + DISPATCH_logN(logN, kLogN, { + constexpr int kNChunks = (1 << kLogN) / (kThreads * VecSize); + MoeFastHardamardImplWrapper(x_data, + expert_idx_per_token, + recv_expert_count, + shift, + smooth, + quant_scales, + quant_round_type, + quant_max_bound, + quant_min_bound, + token_num, + dim, + num_max_tokens_per_expert, + used_in_ep_low_latency, + out, + stream); + }); + } + } +} + +template void MoeFastHardamardWrapper( + const phi::dtype::float16 *x_data, + const int64_t *expert_idx_per_token, + const int64_t *recv_expert_count, + const phi::dtype::float16 *shift, + const phi::dtype::float16 *smooth, + const float *quant_scales, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + const int64_t token_num, + const int64_t dim, + const int num_max_tokens_per_expert, + bool used_in_ep_low_latency, + const int hadamard_block_size, + phi::dtype::float16 *out, + cudaStream_t &stream); + +template void MoeFastHardamardWrapper( + const phi::dtype::float16 *x_data, + const int64_t *expert_idx_per_token, + const int64_t *recv_expert_count, + const phi::dtype::float16 *shift, + const phi::dtype::float16 *smooth, + const float *quant_scales, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + const int64_t token_num, + const int64_t dim, + const int num_max_tokens_per_expert, + bool used_in_ep_low_latency, + const int hadamard_block_size, + int8_t *out, + cudaStream_t &stream); + +template void +MoeFastHardamardWrapper( + const phi::dtype::bfloat16 *x_data, + const int64_t *expert_idx_per_token, + const int64_t *recv_expert_count, + const phi::dtype::bfloat16 *shift, + const phi::dtype::bfloat16 *smooth, + const float *quant_scales, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + const int64_t token_num, + const int64_t dim, + const int num_max_tokens_per_expert, + bool used_in_ep_low_latency, + const int hadamard_block_size, + phi::dtype::bfloat16 *out, + cudaStream_t &stream); + +template void MoeFastHardamardWrapper( + const phi::dtype::bfloat16 *x_data, + const int64_t *expert_idx_per_token, + const int64_t *recv_expert_count, + const phi::dtype::bfloat16 *shift, + const phi::dtype::bfloat16 *smooth, + const float *quant_scales, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + const int64_t token_num, + const int64_t dim, + const int num_max_tokens_per_expert, + bool used_in_ep_low_latency, + const int hadamard_block_size, + int8_t *out, + cudaStream_t &stream); diff --git a/custom_ops/gpu_ops/moe/moe_fast_hardamard_kernel.h b/custom_ops/gpu_ops/moe/moe_fast_hardamard_kernel.h new file mode 100644 index 0000000000..d1fe9d8825 --- /dev/null +++ b/custom_ops/gpu_ops/moe/moe_fast_hardamard_kernel.h @@ -0,0 +1,35 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "helper.h" + +template +void MoeFastHardamardWrapper(const T *x_data, + const int64_t *expert_idx_per_token, + const int64_t *recv_expert_count, + const T *shift, + const T *smooth, + const float *quant_scales, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + const int64_t token_num, + const int64_t dim, + const int num_max_tokens_per_expert, + bool used_in_ep_low_latency, + const int hadamard_block_size, + OutT *out, + cudaStream_t &stream); diff --git a/custom_ops/gpu_ops/moe/moe_ffn.cu b/custom_ops/gpu_ops/moe/moe_ffn.cu index de0256375f..ce7a604c5c 100644 --- a/custom_ops/gpu_ops/moe/moe_ffn.cu +++ b/custom_ops/gpu_ops/moe/moe_ffn.cu @@ -18,10 +18,10 @@ #include "cutlass_kernels/w4a8_moe/w4a8_moe_gemm_kernel.h" #include "group_swiglu_with_masked.h" #include "helper.h" -#include "moe/fast_hardamard_kernel.h" #include "moe/fused_moe_helper.h" -#include "w4afp8_gemm/w4afp8_gemm.h" +#include "moe/moe_fast_hardamard_kernel.h" #include "swigluoai.h" +#include "w4afp8_gemm/w4afp8_gemm.h" template void MoeFFNKernel(const paddle::Tensor& permute_input, @@ -39,367 +39,402 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, const int estimate_total_token_nums, const int hadamard_block_size, const std::string& activation) { - using namespace phi; - typedef PDTraits traits_; - typedef typename traits_::DataType DataType_; - typedef typename traits_::data_t data_t; - auto quant_mode = cutlass::epilogue::QuantMode::PerChannelQuant; + using namespace phi; + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; + auto quant_mode = cutlass::epilogue::QuantMode::PerChannelQuant; - auto ffn_out_data = ffn_out.data(); - auto place = permute_input.place(); - auto stream = permute_input.stream(); + auto ffn_out_data = ffn_out.data(); + auto place = permute_input.place(); + auto stream = permute_input.stream(); - auto fp16_moe_gemm_runner = MoeGemmRunner>(); - auto int8_moe_gemm_runner = MoeGemmRunner>(); - auto int4_moe_gemm_runner = MoeGemmRunner>(); - auto w4a8_moe_gemm_runner = W4A8MoeGemmRunner(); + auto fp16_moe_gemm_runner = MoeGemmRunner< + DataType_, + cutlass::WintQuantTraits>(); + auto int8_moe_gemm_runner = MoeGemmRunner< + DataType_, + cutlass::WintQuantTraits>(); + auto int4_moe_gemm_runner = MoeGemmRunner< + DataType_, + cutlass::WintQuantTraits>(); + auto w4a8_moe_gemm_runner = + W4A8MoeGemmRunner(); - assert(permute_input.dims().size() == 3 || permute_input.dims().size() == 2); + assert(permute_input.dims().size() == 3 || permute_input.dims().size() == 2); - const int num_experts = up_gate_proj_weight.dims()[0]; - const int hidden_size = permute_input.dims()[permute_input.dims().size() - 1]; + const int num_experts = up_gate_proj_weight.dims()[0]; + const int hidden_size = permute_input.dims()[permute_input.dims().size() - 1]; - assert(up_gate_proj_weight.dims().size() == 3); - int inter_dim = up_gate_proj_weight.dims()[1] * up_gate_proj_weight.dims()[2] / hidden_size; + assert(up_gate_proj_weight.dims().size() == 3); + int inter_dim = up_gate_proj_weight.dims()[1] * + up_gate_proj_weight.dims()[2] / hidden_size; - constexpr size_t workspace_size = 1 * 1024 * 1024 * 1024; // for nf4 stream-k - Allocator* allocator = paddle::GetAllocator(place); - Allocator::AllocationPtr workspace; - if (quant_method == "weight_only_int4" || quant_method == "w4a8" || quant_method == "w4afp8") { - inter_dim = inter_dim * 2; - } - if (quant_method == "w4a8" || quant_method == "w4afp8") { - workspace = allocator->Allocate( - SizeOf(paddle::DataType::INT8) * workspace_size); - } + constexpr size_t workspace_size = 1 * 1024 * 1024 * 1024; // for nf4 stream-k + Allocator* allocator = paddle::GetAllocator(place); + Allocator::AllocationPtr workspace; + if (quant_method == "weight_only_int4" || quant_method == "w4a8" || + quant_method == "w4afp8") { + inter_dim = inter_dim * 2; + } + if (quant_method == "w4a8" || quant_method == "w4afp8") { + workspace = + allocator->Allocate(SizeOf(paddle::DataType::INT8) * workspace_size); + } - const int64_t inter_size = inter_dim; + const int64_t inter_size = inter_dim; + typedef PDTraits traits_fp8; + typedef typename traits_fp8::DataType DataType_fp8; + typedef typename traits_fp8::data_t data_t_fp8; + + int num_experts_ = num_experts; + int num_max_tokens_per_expert = 256; + int expanded_active_expert_rows; + + paddle::Tensor fc1_out_tensor; + + if (permute_input.dims().size() == 3) { + num_experts_ = permute_input.dims()[0]; + assert(num_experts == num_experts_); + + num_max_tokens_per_expert = permute_input.dims()[1]; + expanded_active_expert_rows = num_experts_ * num_max_tokens_per_expert; + fc1_out_tensor = GetEmptyTensor( + {num_experts_, num_max_tokens_per_expert, inter_size}, T, place); + } else { + expanded_active_expert_rows = permute_input.dims()[0]; + fc1_out_tensor = + GetEmptyTensor({expanded_active_expert_rows, inter_size}, T, place); + } + + auto fc1_out = fc1_out_tensor.data(); + + using NvType = typename traits_::DataType; + + auto fc1_expert_biases = + up_gate_proj_bias + ? const_cast(up_gate_proj_bias.get_ptr()) + ->data() + : nullptr; + + // This is a trick. + // expanded_active_expert_rows is not needed in variable group gemm. + // but is needed in accommodating deepep low latency mode + const int64_t total_rows_in_ll_else_minus1 = + used_in_ep_low_latency ? expanded_active_expert_rows : -1; + + // When we tune the optimal configuration, we need the actual total_rows. + const int64_t tune_total_rows = expanded_active_expert_rows; + + if (quant_method == "weight_only_int8") { + typename cutlass::WintQuantTraits< + DataType_, + cutlass::WintQuantMethod::kWeightOnlyInt8>::Arguments quant_args; + int8_moe_gemm_runner.moe_gemm_bias_act( + reinterpret_cast(permute_input.data()), + reinterpret_cast(up_gate_proj_weight.data()), + reinterpret_cast( + const_cast(up_gate_proj_scale.get_ptr()) + ->data()), + reinterpret_cast(fc1_expert_biases), + reinterpret_cast(fc1_out), + const_cast(tokens_expert_prefix_sum.data()), + total_rows_in_ll_else_minus1, + tune_total_rows, + inter_size, + hidden_size, + num_experts, + quant_args, + "none", + stream); + } else if (quant_method == "weight_only_int4") { + typename cutlass::WintQuantTraits< + DataType_, + cutlass::WintQuantMethod::kWeightOnlyInt4>::Arguments quant_args; + int4_moe_gemm_runner.moe_gemm_bias_act( + reinterpret_cast(permute_input.data()), + reinterpret_cast( + up_gate_proj_weight.data()), + reinterpret_cast( + const_cast(up_gate_proj_scale.get_ptr()) + ->data()), + reinterpret_cast(fc1_expert_biases), + reinterpret_cast(fc1_out), + const_cast(tokens_expert_prefix_sum.data()), + total_rows_in_ll_else_minus1, + tune_total_rows, + inter_size, + hidden_size, + num_experts, + quant_args, + "none", + stream); + } else if (quant_method == "w4a8") { + w4a8_moe_gemm_runner.moe_gemm( + reinterpret_cast(permute_input.data()), + reinterpret_cast( + up_gate_proj_weight.data()), + quant_mode, + reinterpret_cast( + const_cast(up_gate_proj_scale.get_ptr()) + ->data()), + nullptr, // up_gate_proj_scale_dyquant + nullptr, // nf4_look_up_table + reinterpret_cast(fc1_out), + const_cast(tokens_expert_prefix_sum.data()), + total_rows_in_ll_else_minus1, + used_in_ep_low_latency ? estimate_total_token_nums : tune_total_rows, + inter_size, + hidden_size, + reinterpret_cast(workspace->ptr()), + workspace_size, + num_experts, + stream); + } else if (quant_method == "w4afp8") { typedef PDTraits traits_fp8; typedef typename traits_fp8::DataType DataType_fp8; typedef typename traits_fp8::data_t data_t_fp8; - int num_experts_ = num_experts; - int num_max_tokens_per_expert = 256; - int expanded_active_expert_rows; + Allocator::AllocationPtr ffn1_input_row_sum; + ffn1_input_row_sum = + allocator->Allocate(sizeof(float) * expanded_active_expert_rows); - paddle::Tensor fc1_out_tensor; + compute_row_sum( + permute_input.data(), + expanded_active_expert_rows, + hidden_size, + reinterpret_cast(ffn1_input_row_sum->ptr()), + const_cast(tokens_expert_prefix_sum.data()), + num_max_tokens_per_expert, + used_in_ep_low_latency, + stream); - if (permute_input.dims().size() == 3) { - num_experts_ = permute_input.dims()[0]; - assert(num_experts == num_experts_); + float* row_scale = nullptr; + DisPatchW4AFp8GemmWrapper( + reinterpret_cast(permute_input.data()), + reinterpret_cast( + up_gate_proj_weight.data()), + const_cast(tokens_expert_prefix_sum.data()), + reinterpret_cast(ffn1_input_row_sum->ptr()), + row_scale, + const_cast(up_gate_proj_scale.get_ptr()) + ->data(), + reinterpret_cast(fc1_out), + used_in_ep_low_latency ? num_max_tokens_per_expert : 0, + used_in_ep_low_latency ? num_max_tokens_per_expert + : permute_input.dims()[0], + num_experts, + inter_size, + hidden_size, + stream); + } else { + typename cutlass::WintQuantTraits< + DataType_, + cutlass::WintQuantMethod::kNone>::Arguments quant_args; + fp16_moe_gemm_runner.moe_gemm_bias_act( + reinterpret_cast(permute_input.data()), + reinterpret_cast(up_gate_proj_weight.data()), + nullptr, + reinterpret_cast(fc1_expert_biases), + reinterpret_cast(fc1_out), + const_cast(tokens_expert_prefix_sum.data()), + total_rows_in_ll_else_minus1, + tune_total_rows, + inter_size, + hidden_size, + num_experts, + quant_args, + "none", + stream); + } - num_max_tokens_per_expert = permute_input.dims()[1]; - expanded_active_expert_rows = num_experts_ * num_max_tokens_per_expert; - fc1_out_tensor = GetEmptyTensor( - {num_experts_, num_max_tokens_per_expert, inter_size}, T, place); + paddle::Tensor act_out_tensor; + if (used_in_ep_low_latency) { + act_out_tensor = + GroupSwigluWithMasked(fc1_out_tensor, tokens_expert_prefix_sum); + } else { + if (activation == "swigluoai") { + act_out_tensor = SwigluOAI(fc1_out_tensor, 1.702, 7.0, "interleave"); } else { - expanded_active_expert_rows = permute_input.dims()[0]; - fc1_out_tensor = GetEmptyTensor( - {expanded_active_expert_rows, inter_size}, T, place); + act_out_tensor = paddle::experimental::swiglu(fc1_out_tensor, nullptr); } + } - auto fc1_out = fc1_out_tensor.data(); - - using NvType = typename traits_::DataType; - - auto fc1_expert_biases = - up_gate_proj_bias - ? const_cast(up_gate_proj_bias.get_ptr())->data() - : nullptr; - - // This is a trick. - // expanded_active_expert_rows is not needed in variable group gemm. - // but is needed in accommodating deepep low latency mode - const int64_t total_rows_in_ll_else_minus1 = used_in_ep_low_latency ? expanded_active_expert_rows : -1; - - // When we tune the optimal configuration, we need the actual total_rows. - const int64_t tune_total_rows = expanded_active_expert_rows; - - if (quant_method == "weight_only_int8") { - typename cutlass::WintQuantTraits::Arguments quant_args; - int8_moe_gemm_runner.moe_gemm_bias_act( - reinterpret_cast(permute_input.data()), - reinterpret_cast(up_gate_proj_weight.data()), - reinterpret_cast( - const_cast(up_gate_proj_scale.get_ptr()) - ->data()), - reinterpret_cast(fc1_expert_biases), - reinterpret_cast(fc1_out), - const_cast(tokens_expert_prefix_sum.data()), - total_rows_in_ll_else_minus1, - tune_total_rows, - inter_size, - hidden_size, - num_experts, - quant_args, - "none", - stream); - } else if (quant_method == "weight_only_int4") { - typename cutlass::WintQuantTraits::Arguments quant_args; - int4_moe_gemm_runner.moe_gemm_bias_act( - reinterpret_cast(permute_input.data()), - reinterpret_cast( - up_gate_proj_weight.data()), - reinterpret_cast( - const_cast(up_gate_proj_scale.get_ptr()) - ->data()), - reinterpret_cast(fc1_expert_biases), - reinterpret_cast(fc1_out), - const_cast(tokens_expert_prefix_sum.data()), - total_rows_in_ll_else_minus1, - tune_total_rows, - inter_size, - hidden_size, - num_experts, - quant_args, - "none", - stream); - } else if (quant_method == "w4a8") { - w4a8_moe_gemm_runner.moe_gemm( - reinterpret_cast(permute_input.data()), - reinterpret_cast( - up_gate_proj_weight.data()), - quant_mode, - reinterpret_cast( - const_cast(up_gate_proj_scale.get_ptr()) - ->data()), - nullptr, // up_gate_proj_scale_dyquant - nullptr, // nf4_look_up_table - reinterpret_cast(fc1_out), - const_cast(tokens_expert_prefix_sum.data()), - total_rows_in_ll_else_minus1, - used_in_ep_low_latency ? estimate_total_token_nums : tune_total_rows, - inter_size, - hidden_size, - reinterpret_cast(workspace->ptr()), - workspace_size, - num_experts, - stream); - } else if (quant_method == "w4afp8") { - typedef PDTraits traits_fp8; - typedef typename traits_fp8::DataType DataType_fp8; - typedef typename traits_fp8::data_t data_t_fp8; - - Allocator::AllocationPtr ffn1_input_row_sum; - ffn1_input_row_sum = allocator->Allocate( - sizeof(float) * expanded_active_expert_rows); - - compute_row_sum( - permute_input.data(), - expanded_active_expert_rows, - hidden_size, - reinterpret_cast(ffn1_input_row_sum->ptr()), - const_cast(tokens_expert_prefix_sum.data()), - num_max_tokens_per_expert, - used_in_ep_low_latency, - stream); - - - float* row_scale = nullptr; - DisPatchW4AFp8GemmWrapper( - reinterpret_cast(permute_input.data()), - reinterpret_cast(up_gate_proj_weight.data()), - const_cast(tokens_expert_prefix_sum.data()), - reinterpret_cast(ffn1_input_row_sum->ptr()), - row_scale, - const_cast(up_gate_proj_scale.get_ptr()) - ->data(), - reinterpret_cast(fc1_out), - used_in_ep_low_latency ? num_max_tokens_per_expert : 0, - used_in_ep_low_latency ? num_max_tokens_per_expert : permute_input.dims()[0], - num_experts, - inter_size, - hidden_size, - stream); - } else { - typename cutlass::WintQuantTraits::Arguments quant_args; - fp16_moe_gemm_runner.moe_gemm_bias_act( - reinterpret_cast(permute_input.data()), - reinterpret_cast(up_gate_proj_weight.data()), - nullptr, - reinterpret_cast(fc1_expert_biases), - reinterpret_cast(fc1_out), - const_cast(tokens_expert_prefix_sum.data()), - total_rows_in_ll_else_minus1, - tune_total_rows, - inter_size, - hidden_size, - num_experts, - quant_args, - "none", - stream); - } - - paddle::Tensor act_out_tensor; - if (used_in_ep_low_latency) { - act_out_tensor = GroupSwigluWithMasked(fc1_out_tensor, tokens_expert_prefix_sum); - } else { - if (activation == "swigluoai") { - act_out_tensor = SwigluOAI(fc1_out_tensor, 1.702, 7.0, "interleave"); - } else { - act_out_tensor = paddle::experimental::swiglu(fc1_out_tensor, nullptr); - } - } - - auto act_out = act_out_tensor.data(); - if (quant_method == "weight_only_int8") { - typename cutlass::WintQuantTraits::Arguments quant_args; - int8_moe_gemm_runner.moe_gemm( - reinterpret_cast(act_out), - reinterpret_cast(down_proj_weight.data()), - reinterpret_cast( - const_cast(down_proj_scale.get_ptr()) - ->data()), - reinterpret_cast(ffn_out_data), - const_cast(tokens_expert_prefix_sum.data()), - total_rows_in_ll_else_minus1, - tune_total_rows, - hidden_size, - inter_size / 2, - num_experts, - quant_args, - stream); - - } else if (quant_method == "weight_only_int4") { - typename cutlass::WintQuantTraits::Arguments quant_args; - int4_moe_gemm_runner.moe_gemm( - reinterpret_cast(act_out), - reinterpret_cast( - down_proj_weight.data()), - reinterpret_cast( - const_cast(down_proj_scale.get_ptr()) - ->data()), - reinterpret_cast(ffn_out_data), - const_cast(tokens_expert_prefix_sum.data()), - total_rows_in_ll_else_minus1, - tune_total_rows, - hidden_size, - inter_size / 2, - num_experts, - quant_args, - stream); - } else if (quant_method == "w4a8") { - data_t *down_proj_shift = nullptr; - data_t *down_proj_smooth = nullptr; - Allocator::AllocationPtr int8_act_out; - int8_act_out = allocator->Allocate( - SizeOf(paddle::DataType::INT8) * act_out_tensor.numel()); - MoeFastHardamardWrapper( - act_out_tensor.data(), - expert_idx_per_token ? expert_idx_per_token.get().data() : nullptr, - const_cast(tokens_expert_prefix_sum.data()), - down_proj_shift, // down_proj_shift->data(), - down_proj_smooth, // down_proj_smooth->data(), - down_proj_in_scale ? const_cast(down_proj_in_scale.get_ptr())->data() : nullptr, - 1, - 127.0, - -127.0, - expanded_active_expert_rows, - inter_size / 2, - num_max_tokens_per_expert, - used_in_ep_low_latency, - hadamard_block_size, - reinterpret_cast(int8_act_out->ptr()), - stream - ); - w4a8_moe_gemm_runner.moe_gemm( - reinterpret_cast(int8_act_out->ptr()), - reinterpret_cast( - down_proj_weight.data()), - quant_mode, - reinterpret_cast( - const_cast(down_proj_scale.get_ptr()) - ->data()), - nullptr, // down_proj_scale_dyquant - nullptr, // reinterpret_cast(d_nf4_look_up_table), // nf4_look_up_table - reinterpret_cast(ffn_out_data), - const_cast(tokens_expert_prefix_sum.data()), - total_rows_in_ll_else_minus1, - used_in_ep_low_latency ? estimate_total_token_nums : tune_total_rows, - hidden_size, - inter_size / 2, - reinterpret_cast(workspace->ptr()), - workspace_size, - num_experts, - stream); - } else if (quant_method == "w4afp8") { - data_t *ffn2_shift = nullptr; - data_t *ffn2_smooth = nullptr; - float* row_scale = nullptr; - Allocator::AllocationPtr fp8_act_out; - fp8_act_out = allocator->Allocate( - SizeOf(paddle::DataType::INT8) * act_out_tensor.numel()); - Allocator::AllocationPtr ffn2_input_row_sum; - ffn2_input_row_sum = allocator->Allocate( - sizeof(float) * expanded_active_expert_rows); - - // note(yuanxiaolan): optimize this - MoeFastHardamardWrapper( - act_out_tensor.data(), - expert_idx_per_token ? expert_idx_per_token.get().data() : nullptr, - const_cast(tokens_expert_prefix_sum.data()), - ffn2_shift, // ffn2_shift->data(), - ffn2_smooth, // ffn2_smooth->data(), - nullptr, - 1, - 448.0f, - -448.0f, - expanded_active_expert_rows, - inter_size / 2, - num_max_tokens_per_expert, - used_in_ep_low_latency, - hadamard_block_size, - act_out_tensor.data(), - stream - ); - - quantize_moe_input(act_out_tensor.data(), - expert_idx_per_token ? expert_idx_per_token.get().data() : nullptr, - down_proj_in_scale ? const_cast(down_proj_in_scale.get_ptr())->data() : nullptr, - 448.0f, - -448.0f, - expanded_active_expert_rows, - inter_size / 2, - reinterpret_cast(ffn2_input_row_sum->ptr()), - const_cast(tokens_expert_prefix_sum.data()), - num_max_tokens_per_expert, - used_in_ep_low_latency, - reinterpret_cast(fp8_act_out->ptr()), - stream - ); - - DisPatchW4AFp8GemmWrapper( - reinterpret_cast(fp8_act_out->ptr()), - reinterpret_cast(down_proj_weight.data()), - const_cast(tokens_expert_prefix_sum.data()), - reinterpret_cast(ffn2_input_row_sum->ptr()), - row_scale, + auto act_out = act_out_tensor.data(); + if (quant_method == "weight_only_int8") { + typename cutlass::WintQuantTraits< + DataType_, + cutlass::WintQuantMethod::kWeightOnlyInt8>::Arguments quant_args; + int8_moe_gemm_runner.moe_gemm( + reinterpret_cast(act_out), + reinterpret_cast(down_proj_weight.data()), + reinterpret_cast( const_cast(down_proj_scale.get_ptr()) - ->data(), - reinterpret_cast(ffn_out_data), - used_in_ep_low_latency ? num_max_tokens_per_expert : 0, - used_in_ep_low_latency ? num_max_tokens_per_expert : act_out_tensor.dims()[0], - num_experts, - hidden_size, - inter_size / 2, - stream); - } else { - typename cutlass::WintQuantTraits::Arguments quant_args; - fp16_moe_gemm_runner.moe_gemm( - reinterpret_cast(act_out), - reinterpret_cast(down_proj_weight.data()), - nullptr, - reinterpret_cast(ffn_out_data), - const_cast(tokens_expert_prefix_sum.data()), - total_rows_in_ll_else_minus1, - tune_total_rows, - hidden_size, - inter_size / 2, - num_experts, - quant_args, - stream); - } + ->data()), + reinterpret_cast(ffn_out_data), + const_cast(tokens_expert_prefix_sum.data()), + total_rows_in_ll_else_minus1, + tune_total_rows, + hidden_size, + inter_size / 2, + num_experts, + quant_args, + stream); + + } else if (quant_method == "weight_only_int4") { + typename cutlass::WintQuantTraits< + DataType_, + cutlass::WintQuantMethod::kWeightOnlyInt4>::Arguments quant_args; + int4_moe_gemm_runner.moe_gemm( + reinterpret_cast(act_out), + reinterpret_cast( + down_proj_weight.data()), + reinterpret_cast( + const_cast(down_proj_scale.get_ptr()) + ->data()), + reinterpret_cast(ffn_out_data), + const_cast(tokens_expert_prefix_sum.data()), + total_rows_in_ll_else_minus1, + tune_total_rows, + hidden_size, + inter_size / 2, + num_experts, + quant_args, + stream); + } else if (quant_method == "w4a8") { + data_t* down_proj_shift = nullptr; + data_t* down_proj_smooth = nullptr; + Allocator::AllocationPtr int8_act_out; + int8_act_out = allocator->Allocate(SizeOf(paddle::DataType::INT8) * + act_out_tensor.numel()); + MoeFastHardamardWrapper( + act_out_tensor.data(), + expert_idx_per_token ? expert_idx_per_token.get().data() + : nullptr, + const_cast(tokens_expert_prefix_sum.data()), + down_proj_shift, // down_proj_shift->data(), + down_proj_smooth, // down_proj_smooth->data(), + down_proj_in_scale + ? const_cast(down_proj_in_scale.get_ptr()) + ->data() + : nullptr, + 1, + 127.0, + -127.0, + expanded_active_expert_rows, + inter_size / 2, + num_max_tokens_per_expert, + used_in_ep_low_latency, + hadamard_block_size, + reinterpret_cast(int8_act_out->ptr()), + stream); + w4a8_moe_gemm_runner.moe_gemm( + reinterpret_cast(int8_act_out->ptr()), + reinterpret_cast( + down_proj_weight.data()), + quant_mode, + reinterpret_cast( + const_cast(down_proj_scale.get_ptr()) + ->data()), + nullptr, // down_proj_scale_dyquant + nullptr, // reinterpret_cast(d_nf4_look_up_table), // + // nf4_look_up_table + reinterpret_cast(ffn_out_data), + const_cast(tokens_expert_prefix_sum.data()), + total_rows_in_ll_else_minus1, + used_in_ep_low_latency ? estimate_total_token_nums : tune_total_rows, + hidden_size, + inter_size / 2, + reinterpret_cast(workspace->ptr()), + workspace_size, + num_experts, + stream); + } else if (quant_method == "w4afp8") { + data_t* ffn2_shift = nullptr; + data_t* ffn2_smooth = nullptr; + float* row_scale = nullptr; + Allocator::AllocationPtr fp8_act_out; + fp8_act_out = allocator->Allocate(SizeOf(paddle::DataType::INT8) * + act_out_tensor.numel()); + Allocator::AllocationPtr ffn2_input_row_sum; + ffn2_input_row_sum = + allocator->Allocate(sizeof(float) * expanded_active_expert_rows); + + // note(yuanxiaolan): optimize this + MoeFastHardamardWrapper( + act_out_tensor.data(), + expert_idx_per_token ? expert_idx_per_token.get().data() + : nullptr, + const_cast(tokens_expert_prefix_sum.data()), + ffn2_shift, // ffn2_shift->data(), + ffn2_smooth, // ffn2_smooth->data(), + nullptr, + 1, + 448.0f, + -448.0f, + expanded_active_expert_rows, + inter_size / 2, + num_max_tokens_per_expert, + used_in_ep_low_latency, + hadamard_block_size, + act_out_tensor.data(), + stream); + + quantize_moe_input( + act_out_tensor.data(), + expert_idx_per_token ? expert_idx_per_token.get().data() + : nullptr, + down_proj_in_scale + ? const_cast(down_proj_in_scale.get_ptr()) + ->data() + : nullptr, + 448.0f, + -448.0f, + expanded_active_expert_rows, + inter_size / 2, + reinterpret_cast(ffn2_input_row_sum->ptr()), + const_cast(tokens_expert_prefix_sum.data()), + num_max_tokens_per_expert, + used_in_ep_low_latency, + reinterpret_cast(fp8_act_out->ptr()), + stream); + + DisPatchW4AFp8GemmWrapper( + reinterpret_cast(fp8_act_out->ptr()), + reinterpret_cast(down_proj_weight.data()), + const_cast(tokens_expert_prefix_sum.data()), + reinterpret_cast(ffn2_input_row_sum->ptr()), + row_scale, + const_cast(down_proj_scale.get_ptr())->data(), + reinterpret_cast(ffn_out_data), + used_in_ep_low_latency ? num_max_tokens_per_expert : 0, + used_in_ep_low_latency ? num_max_tokens_per_expert + : act_out_tensor.dims()[0], + num_experts, + hidden_size, + inter_size / 2, + stream); + } else { + typename cutlass::WintQuantTraits< + DataType_, + cutlass::WintQuantMethod::kNone>::Arguments quant_args; + fp16_moe_gemm_runner.moe_gemm( + reinterpret_cast(act_out), + reinterpret_cast(down_proj_weight.data()), + nullptr, + reinterpret_cast(ffn_out_data), + const_cast(tokens_expert_prefix_sum.data()), + total_rows_in_ll_else_minus1, + tune_total_rows, + hidden_size, + inter_size / 2, + num_experts, + quant_args, + stream); + } } paddle::Tensor MoeExpertFFNFunc( @@ -414,55 +449,56 @@ paddle::Tensor MoeExpertFFNFunc( const paddle::optional& expert_idx_per_token, const std::string& quant_method, const bool used_in_ep_low_latency, - const int estimate_total_token_nums, const int hadamard_block_size, + const int estimate_total_token_nums, + const int hadamard_block_size, const std::string& activation) { - -const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype() : - (quant_method == "w4afp8") ? paddle::DataType::BFLOAT16 : - permute_input.dtype(); - auto ffn_out = paddle::empty_like(permute_input, t_type); - if(permute_input.numel() == 0){ - return ffn_out; - } - switch (t_type) { - case paddle::DataType::BFLOAT16: - MoeFFNKernel(permute_input, - tokens_expert_prefix_sum, - up_gate_proj_weight, - down_proj_weight, - up_gate_proj_bias, - up_gate_proj_scale, - down_proj_scale, - down_proj_in_scale, - expert_idx_per_token, - quant_method, - ffn_out, - used_in_ep_low_latency, - estimate_total_token_nums, - hadamard_block_size, - activation); - break; - case paddle::DataType::FLOAT16: - MoeFFNKernel(permute_input, - tokens_expert_prefix_sum, - up_gate_proj_weight, - down_proj_weight, - up_gate_proj_bias, - up_gate_proj_scale, - down_proj_scale, - down_proj_in_scale, - expert_idx_per_token, - quant_method, - ffn_out, - used_in_ep_low_latency, - estimate_total_token_nums, - hadamard_block_size, - activation); - break; - default: - PD_THROW("Unsupported data type for MoeExpertFFN"); - } + const auto t_type = (quant_method == "w4a8") + ? up_gate_proj_scale.get().dtype() + : (quant_method == "w4afp8") ? paddle::DataType::BFLOAT16 + : permute_input.dtype(); + auto ffn_out = paddle::empty_like(permute_input, t_type); + if (permute_input.numel() == 0) { return ffn_out; + } + switch (t_type) { + case paddle::DataType::BFLOAT16: + MoeFFNKernel(permute_input, + tokens_expert_prefix_sum, + up_gate_proj_weight, + down_proj_weight, + up_gate_proj_bias, + up_gate_proj_scale, + down_proj_scale, + down_proj_in_scale, + expert_idx_per_token, + quant_method, + ffn_out, + used_in_ep_low_latency, + estimate_total_token_nums, + hadamard_block_size, + activation); + break; + case paddle::DataType::FLOAT16: + MoeFFNKernel(permute_input, + tokens_expert_prefix_sum, + up_gate_proj_weight, + down_proj_weight, + up_gate_proj_bias, + up_gate_proj_scale, + down_proj_scale, + down_proj_in_scale, + expert_idx_per_token, + quant_method, + ffn_out, + used_in_ep_low_latency, + estimate_total_token_nums, + hadamard_block_size, + activation); + break; + default: + PD_THROW("Unsupported data type for MoeExpertFFN"); + } + return ffn_out; } std::vector MoeExpertFFN( @@ -475,24 +511,25 @@ std::vector MoeExpertFFN( const paddle::optional& down_proj_scale, const paddle::optional& down_proj_in_scale, const paddle::optional& expert_idx_per_token, - const std::string& quant_method, const bool used_in_ep_low_latency, + const std::string& quant_method, + const bool used_in_ep_low_latency, const int estimate_total_token_nums, const int hadamard_block_size, const std::string& activation) { - return {MoeExpertFFNFunc(permute_input, - tokens_expert_prefix_sum, - up_gate_proj_weight, - down_proj_weight, - up_gate_proj_bias, - up_gate_proj_scale, - down_proj_scale, - down_proj_in_scale, - expert_idx_per_token, - quant_method, - used_in_ep_low_latency, - estimate_total_token_nums, - hadamard_block_size, - activation)}; + return {MoeExpertFFNFunc(permute_input, + tokens_expert_prefix_sum, + up_gate_proj_weight, + down_proj_weight, + up_gate_proj_bias, + up_gate_proj_scale, + down_proj_scale, + down_proj_in_scale, + expert_idx_per_token, + quant_method, + used_in_ep_low_latency, + estimate_total_token_nums, + hadamard_block_size, + activation)}; } std::vector> MoeExpertFFNInferShape( @@ -510,21 +547,23 @@ std::vector> MoeExpertFFNInferShape( const int estimate_total_token_nums, const int hadamard_block_size, const std::string& activation) { - return {permute_input_shape}; + return {permute_input_shape}; } std::vector MoeExpertFFNInferDtype( - const paddle::DataType &permute_input_dtype, - const paddle::DataType &tokens_expert_prefix_sum_dtype, - const paddle::DataType &up_gate_proj_weight_dtype, - const paddle::DataType &down_proj_weight_dtype, - const paddle::optional &up_gate_proj_bias_dtype, - const paddle::optional &up_gate_proj_scale_dtype, - const paddle::optional &down_proj_scale_dtype, - const paddle::optional &down_proj_in_scale_dtype, - const std::string &quant_method, const bool used_in_ep_low_latency, - const int estimate_total_token_nums, const int hadamard_block_size, - const std::string &activation) { + const paddle::DataType& permute_input_dtype, + const paddle::DataType& tokens_expert_prefix_sum_dtype, + const paddle::DataType& up_gate_proj_weight_dtype, + const paddle::DataType& down_proj_weight_dtype, + const paddle::optional& up_gate_proj_bias_dtype, + const paddle::optional& up_gate_proj_scale_dtype, + const paddle::optional& down_proj_scale_dtype, + const paddle::optional& down_proj_in_scale_dtype, + const std::string& quant_method, + const bool used_in_ep_low_latency, + const int estimate_total_token_nums, + const int hadamard_block_size, + const std::string& activation) { if (quant_method == "w4a8" || quant_method == "w4afp8") { return {up_gate_proj_scale_dtype.get()}; } else { @@ -540,15 +579,15 @@ std::vector MoeExpertFFNInferDtype( * 2. SwiGLU activation function * 3. Second linear transformation (down_proj) with optional quantization * - * Supports multiple quantization methods including weight-only int4/int8 and w4a8 quantization. + * Supports multiple quantization methods including weight-only int4/int8 and + * w4a8 quantization. * * Inputs: * - permute_input: Permuted input tensor organized by expert * Shape: [total_tokens * top_k, hidden_size] * dtype: bfloat16/float16 (or int8 for w4a8) - * - tokens_expert_prefix_sum: Prefix sum array of token counts per expert for group_gemm - * Shape: [num_experts] - * dtype: int64 + * - tokens_expert_prefix_sum: Prefix sum array of token counts per expert for + * group_gemm Shape: [num_experts] dtype: int64 * - up_gate_proj_weight: First FFN layer weights * Shape: [num_experts, inter_size * 2, hidden_size] * dtype: Same as input (unquantized) or int8 (quantized) @@ -564,8 +603,8 @@ std::vector MoeExpertFFNInferDtype( * - down_proj_scale: Quantization scales for second FFN layer * Shape: [num_experts, hidden_size] * dtype: Same as input - * - down_proj_in_scale: Optional input scales for second FFN layer (w4a8 only) - * dtype: float32 + * - down_proj_in_scale: Optional input scales for second FFN layer (w4a8 + * only) dtype: float32 * - expert_idx_per_token: Optional expert indices per token (w4a8 only) * Shape: [total_tokens] * dtype: int64 @@ -577,7 +616,8 @@ std::vector MoeExpertFFNInferDtype( * * Attributes: * - quant_method: Quantization method to use - * Options: "none", "weight_only_int4", "weight_only_int8", "w4a8" + * Options: "none", "weight_only_int4", "weight_only_int8", + * "w4a8" * - used_in_ep_low_latency: Whether running in low latency mode * Affects activation function implementation * - estimate_total_token_nums: estimate total token numbers @@ -598,7 +638,11 @@ PD_BUILD_STATIC_OP(moe_expert_ffn) paddle::Optional("down_proj_in_scale"), paddle::Optional("expert_idx_per_token")}) .Outputs({"output_tensor"}) - .Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool", "estimate_total_token_nums:int", "hadamard_block_size:int", "activation:std::string"}) + .Attrs({"quant_method:std::string", + "used_in_ep_low_latency:bool", + "estimate_total_token_nums:int", + "hadamard_block_size:int", + "activation:std::string"}) .SetKernelFn(PD_KERNEL(MoeExpertFFN)) .SetInferShapeFn(PD_INFER_SHAPE(MoeExpertFFNInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertFFNInferDtype)); diff --git a/custom_ops/gpu_ops/moe/template_config.json b/custom_ops/gpu_ops/moe/template_config.json new file mode 100644 index 0000000000..9e2302110e --- /dev/null +++ b/custom_ops/gpu_ops/moe/template_config.json @@ -0,0 +1,26 @@ +{ + "moe_fast_hardamard_impl": { + "name": "moe_fast_hardamard_impl", + "function_name": "MoeFastHardamardImplWrapper", + "impl_file": "moe_fast_hardamard_impl.cuh", + "template_params": [ + "T", + "OutT", + "kLogN", + "VecSize", + "kNChunks", + "kThreads", + "UseDiagonalBlockMatrix" + ], + "dispatch_params": {}, + "data_types": [ + ["phi::dtype::float16", "phi::dtype::float16", "float16_float16"], + ["phi::dtype::float16", "int8_t", "float16_int8"], + ["phi::dtype::bfloat16", "phi::dtype::bfloat16", "bfloat16_bfloat16"], + ["phi::dtype::bfloat16", "int8_t", "bfloat16_int8"] + ], + "max_instances_per_file": 16, + "file_prefix": "moe_fast_hardamard_impl_", + "function_signature": "template void {function_name}{template_args}(\n const T *x,\n const int64_t *expert_idx_per_token,\n const int64_t *recv_expert_count,\n const T *shift,\n const T *smooth,\n const float* quant_scales,\n const int quant_round_type,\n const float quant_max_bound,\n const float quant_min_bound,\n const int64_t token_num,\n const int64_t dim,\n const int num_max_tokens_per_expert,\n bool used_in_ep_low_latency,\n OutT* out,\n cudaStream_t stream);\n\n" + } +} diff --git a/custom_ops/gpu_ops/multi_head_latent_attention.cu b/custom_ops/gpu_ops/multi_head_latent_attention.cu index 6e804f3ebd..a08e7c4e87 100644 --- a/custom_ops/gpu_ops/multi_head_latent_attention.cu +++ b/custom_ops/gpu_ops/multi_head_latent_attention.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "append_attn/multi_head_latent_attention_kernel.h" +#include "append_attn/decoder_mla_attention_kernel.h" #include "helper.h" #include "mla_attn/batch_mla_with_paged_kv_cache.h" @@ -66,10 +66,12 @@ std::vector MultiHeadLatentAttentionKernel( // int chunk_size = decoder_chunk_size_cpu.data()[0]; // - const bool mla_use_tensorcore = true; //get_mla_use_tensorcore(); + const bool mla_use_tensorcore = true; // get_mla_use_tensorcore(); auto sm_version = GetSMVersion(); if ((speculate_decoder || mla_use_tensorcore) && sm_version < 90) { - PD_THROW("Please use speculate_decoder=0 and FLAGS_mla_use_tensorcore=0 when sm < 90."); + PD_THROW( + "Please use speculate_decoder=0 and FLAGS_mla_use_tensorcore=0 when sm " + "< 90."); } auto main_stream = query.stream(); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 863163f303..c82dfe411e 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -381,7 +381,9 @@ elif paddle.is_compiled_with_cuda(): if cc >= 80: # append_attention - os.system("python gpu_ops/append_attn/autogen_template_instantiation.py") + os.system( + "python utils/auto_gen_template_instantiation.py --config gpu_ops/append_attn/template_config.json --output gpu_ops/append_attn/template_instantiation/autogen" + ) sources += ["gpu_ops/append_attention.cu"] sources += find_end_files("gpu_ops/append_attn", ".cu") # mla @@ -394,6 +396,9 @@ elif paddle.is_compiled_with_cuda(): nvcc_compile_args += ["-DENABLE_BF16"] # moe os.system("python gpu_ops/moe/moe_wna16_marlin_utils/generate_kernels.py") + os.system( + "python utils/auto_gen_template_instantiation.py --config gpu_ops/moe/template_config.json --output gpu_ops/moe/template_instantiation/autogen" + ) sources += find_end_files("gpu_ops/cutlass_kernels/moe_gemm/", ".cu") sources += find_end_files("gpu_ops/cutlass_kernels/w4a8_moe/", ".cu") sources += find_end_files("gpu_ops/moe/", ".cu") diff --git a/custom_ops/gpu_ops/append_attn/autogen_template_instantiation.py b/custom_ops/utils/auto_gen_template_instantiation.py similarity index 73% rename from custom_ops/gpu_ops/append_attn/autogen_template_instantiation.py rename to custom_ops/utils/auto_gen_template_instantiation.py index 1a2f27a878..4288afbb4d 100644 --- a/custom_ops/gpu_ops/append_attn/autogen_template_instantiation.py +++ b/custom_ops/utils/auto_gen_template_instantiation.py @@ -15,6 +15,7 @@ import argparse import json +import shutil from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, List, Optional, Tuple @@ -65,6 +66,10 @@ class UniversalTemplateInstantiator: f"Configuration '{config.name}' has T or OutT in template_params but no data_types configured" ) + # Skip validation for special handled functions + if config.name == "moe_fast_hardamard_impl": + return + special_params = {"T", "OutT", "NUM_WARP_Q"} for param_name in config.template_params: if param_name not in special_params and param_name not in config.dispatch_params: @@ -112,10 +117,20 @@ class UniversalTemplateInstantiator: return f"<{', '.join(template_args_parts)}>" - def _generate_function_signature(self, config: TemplateConfig, template_args: str) -> str: + def _generate_function_signature( + self, config: TemplateConfig, template_args: str, t_in: str = "", t_out: str = "" + ) -> str: """Generate function signature.""" if config.function_signature: - return config.function_signature.format(function_name=config.function_name, template_args=template_args) + signature = config.function_signature.format( + function_name=config.function_name, template_args=template_args + ) + # Replace T and OutT with actual types if provided + if t_in: + signature = signature.replace("const T *", f"const {t_in} *") + if t_out: + signature = signature.replace("OutT*", f"{t_out}*") + return signature else: raise ValueError(f"Function signature not found for {config.name}") @@ -133,25 +148,73 @@ class UniversalTemplateInstantiator: ) -> str: """Generate template instantiation.""" template_args = self._build_template_args(config, t_in, t_out, params) - return self._generate_function_signature(config, template_args) + return self._generate_function_signature(config, template_args, t_in, t_out) + + def _clean_output_directory(self, output_dir: str): + """Clean output directory before generating new files.""" + output_path = Path(output_dir) + if output_path.exists(): + shutil.rmtree(output_path) + output_path.mkdir(parents=True, exist_ok=True) def generate_combinations_for_type(self, config: TemplateConfig, t_in: str, t_out: str) -> List[Dict[str, Any]]: """Generate parameter combinations for specific type.""" combinations = [] - def _generate_recursive( - params_dict: Dict[str, List[Any]], current_params: Dict[str, Any], param_names: List[str] - ): - if not param_names: - combinations.append(current_params.copy()) - return + if config.name == "moe_fast_hardamard_impl": + combinations = self._generate_moe_hardamard_combinations(config, t_in, t_out) + else: - param_name = param_names[0] - for value in params_dict[param_name]: - current_params[param_name] = value - _generate_recursive(params_dict, current_params, param_names[1:]) + def _generate_recursive( + params_dict: Dict[str, List[Any]], current_params: Dict[str, Any], param_names: List[str] + ): + if not param_names: + combinations.append(current_params.copy()) + return + + param_name = param_names[0] + for value in params_dict[param_name]: + current_params[param_name] = value + _generate_recursive(params_dict, current_params, param_names[1:]) + + _generate_recursive(config.dispatch_params, {}, list(config.dispatch_params.keys())) + + return combinations + + def _generate_moe_hardamard_combinations( + self, config: TemplateConfig, t_in: str, t_out: str + ) -> List[Dict[str, Any]]: + """Generate combinations for MoeFastHardamardImplWrapper based on code logic.""" + combinations = [] + + for vec_size in [1, 2, 4, 8, 16]: + for log_n in [7, 8, 9, 10]: + combinations.append( + {"kLogN": log_n, "VecSize": vec_size, "kNChunks": 1, "kThreads": 128, "UseDiagonalBlockMatrix": 1} + ) + + for log_n in [7, 8, 9, 10]: + vec_size = (1 << log_n) // 128 + combinations.append( + {"kLogN": log_n, "VecSize": vec_size, "kNChunks": 28, "kThreads": 128, "UseDiagonalBlockMatrix": 0} + ) + combinations.append( + {"kLogN": log_n, "VecSize": vec_size, "kNChunks": 36, "kThreads": 128, "UseDiagonalBlockMatrix": 0} + ) + + for log_n in [11, 12, 13, 14]: + vec_size = 8 + n_chunks = (1 << log_n) // (128 * vec_size) + combinations.append( + { + "kLogN": log_n, + "VecSize": vec_size, + "kNChunks": n_chunks, + "kThreads": 128, + "UseDiagonalBlockMatrix": 0, + } + ) - _generate_recursive(config.dispatch_params, {}, list(config.dispatch_params.keys())) return combinations def split_combinations(self, combinations: List[Dict[str, Any]], max_per_file: int) -> List[List[Dict[str, Any]]]: @@ -186,7 +249,7 @@ class UniversalTemplateInstantiator: config = self.configs[function_name] output_path = Path(output_dir) - output_path.mkdir(exist_ok=True) + output_path.mkdir(parents=True, exist_ok=True) if not config.data_types: data_types = [("", "", "")] @@ -206,6 +269,7 @@ class UniversalTemplateInstantiator: def generate_all(self, output_dir: str): """Generate all configured function types.""" + self._clean_output_directory(output_dir) for function_name in self.configs.keys(): print(f"Generating template instantiations for {function_name}...") self.generate_for_function_type(function_name, output_dir) @@ -219,14 +283,12 @@ def main(): "--config", "-c", type=str, - default="gpu_ops/append_attn/template_config.json", help="Configuration file path (JSON format)", ) parser.add_argument( "--output", "-o", type=str, - default="gpu_ops/append_attn/template_instantiation/autogen", help="Output directory", )