diff --git a/.clang-format b/.clang-format index a4de8e7be8..4bb0ebed18 100644 --- a/.clang-format +++ b/.clang-format @@ -26,4 +26,5 @@ BinPackParameters: false BinPackArguments: false IncludeBlocks: Preserve IncludeIsMainSourceRegex: (\.cu)$ +SortIncludes: false ... diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index eeb50d4201..1f5791e398 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,7 @@ exclude: | (?x)^( - dockerfiles/.+ + dockerfiles/.+| + custom_ops/third_party/.+ )$ default_install_hook_types: - pre-commit diff --git a/custom_ops/gpu_ops/append_attn/decode_attention_func.cuh b/custom_ops/gpu_ops/append_attn/decode_attention_func.cuh index 4ff4a02298..8f7b096e6b 100644 --- a/custom_ops/gpu_ops/append_attn/decode_attention_func.cuh +++ b/custom_ops/gpu_ops/append_attn/decode_attention_func.cuh @@ -42,13 +42,10 @@ struct softmax_state_t { } } - __device__ __forceinline__ softmax_state_t() { - init(); - } + __device__ __forceinline__ softmax_state_t() { init(); } - __device__ __forceinline__ void merge(const AlignedVector& other_o, - T other_m, - T other_d) { + __device__ __forceinline__ void merge( + const AlignedVector& other_o, T other_m, T other_d) { // using kType = typename cascade_attn_nv_type2_traits::type; T m_prev = m, d_prev = d; m = m_prev > other_m ? m_prev : other_m; @@ -63,13 +60,11 @@ struct softmax_state_t { } __device__ __forceinline__ void normalize() { - #pragma unroll for (size_t i = 0; i < vec_size; ++i) { o[i] /= d; } } - }; template @@ -102,65 +97,79 @@ struct softmax_state_ts { } } - __device__ __forceinline__ softmax_state_ts() { - init(); - } + __device__ __forceinline__ softmax_state_ts() { init(); } __device__ __forceinline__ void normalize(const uint32_t tile_id) { - #pragma unroll for (size_t i = 0; i < vec_size; i++) { o[tile_id][i] /= d; } } - }; -template -__device__ __forceinline__ void produce_kv(CacheT *smem, - CacheT *kv_base_gptr, - const int * block_table_smem, - const uint32_t seq_offset_gmem, - const uint32_t seq_offset_smem, - const uint32_t kv_head_idx, - const uint32_t kv_num_heads, - const uint32_t tidx, - const uint32_t chunk_start, - const uint32_t chunk_end) { +template +__device__ __forceinline__ void produce_kv(CacheT* smem, + CacheT* kv_base_gptr, + const int* block_table_smem, + const uint32_t seq_offset_gmem, + const uint32_t seq_offset_smem, + const uint32_t kv_head_idx, + const uint32_t kv_num_heads, + const uint32_t tidx, + const uint32_t chunk_start, + const uint32_t chunk_end) { int block_id = __ldg(&block_table_smem[seq_offset_gmem / BLOCK_SIZE]); if (block_id < 0) { block_id = 0; } const uint32_t block_offset = seq_offset_gmem % BLOCK_SIZE; // 8/16 T/int8 each time - const uint32_t k_offset_base = ((block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE + block_offset) * HEAD_DIM_QK; + const uint32_t k_offset_base = + ((block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE + block_offset) * + HEAD_DIM_QK; const uint32_t smem_offset_base = seq_offset_smem * HEAD_DIM_QK; - for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) { + for (uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) { pred_load<128, PrefetchMode::kPrefetch, fill_mode, CacheT>( - smem + smem_offset_base + vid * CACHE_VEC_SIZE, - kv_base_gptr + k_offset_base + vid * CACHE_VEC_SIZE, - seq_offset_gmem < chunk_end - ); + smem + smem_offset_base + vid * CACHE_VEC_SIZE, + kv_base_gptr + k_offset_base + vid * CACHE_VEC_SIZE, + seq_offset_gmem < chunk_end); } } -template -__device__ __forceinline__ void compute_qk(const T* cu_q_smem, - const CacheT* k_smem, - const uint32_t kv_idx_base, - const uint32_t stage_idx, - const uint32_t iter_base, - const uint32_t iter_bound, - const uint32_t tidx, - const uint32_t gid, - const float scale, - float *s, - softmax_state_ts& st) { +template +__device__ __forceinline__ void compute_qk( + const T* cu_q_smem, + const CacheT* k_smem, + const uint32_t kv_idx_base, + const uint32_t stage_idx, + const uint32_t iter_base, + const uint32_t iter_bound, + const uint32_t tidx, + const uint32_t gid, + const float scale, + float* s, + softmax_state_ts& st) { const CacheT* smem; AlignedVector q_vec; AlignedVector k_vec; float m_prev = st.m; - // smem = base_smem + (stage_idx * DEAL_EACH_TIME + zid * tile_size) * HEAD_DIM; + // smem = base_smem + (stage_idx * DEAL_EACH_TIME + zid * tile_size) * + // HEAD_DIM; smem = k_smem + stage_idx * DEAL_EACH_TIME * HEAD_DIM; #pragma unroll for (uint32_t j = 0; j < DEAL_EACH_TIME; ++j) { @@ -171,7 +180,7 @@ __device__ __forceinline__ void compute_qk(const T* cu_q_smem, s[j] = 0.f; } #pragma unroll - for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) { + for (uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) { Load(cu_q_smem + vid * vec_size, &q_vec); Load(smem + j * HEAD_DIM + vid * vec_size, &k_vec); for (uint32_t i = 0; i < vec_size; ++i) { @@ -211,20 +220,29 @@ __device__ __forceinline__ void compute_qk(const T* cu_q_smem, } } -template -__device__ __forceinline__ void compute_sv(const float *s, - const CacheT *base_v_smem, - const uint32_t stage_idx, - const uint32_t iter_base, - const uint32_t iter_bound, - const uint32_t tidx, - softmax_state_ts& st) { +template +__device__ __forceinline__ void compute_sv( + const float* s, + const CacheT* base_v_smem, + const uint32_t stage_idx, + const uint32_t iter_base, + const uint32_t iter_bound, + const uint32_t tidx, + softmax_state_ts& st) { const CacheT* v_smem; AlignedVector v_vec; #pragma unroll for (int j = 0; (j < DEAL_EACH_TIME) && (iter_base + j < iter_bound); ++j) { - v_smem = base_v_smem + stage_idx * DEAL_EACH_TIME * HEAD_DIM_QK + j * HEAD_DIM_QK; - for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) { + v_smem = base_v_smem + stage_idx * DEAL_EACH_TIME * HEAD_DIM_QK + + j * HEAD_DIM_QK; + for (uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) { Load(v_smem + vid * vec_size, &v_vec); uint32_t tile_id = vid / bdx; #pragma unroll diff --git a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.h b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.h index 8872761956..5095ef1062 100644 --- a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.h +++ b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.h @@ -41,4 +41,5 @@ void DecoderWriteCacheWithRoPEKernel( paddle::Tensor* key_cache_out, paddle::Tensor* value_cache_out, const paddle::optional& q_norm_weight, - const paddle::optional& k_norm_weight, const float rms_norm_eps); + const paddle::optional& k_norm_weight, + const float rms_norm_eps); diff --git a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_kernel.h b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_kernel.h index 5ea9ce213f..2349a0e5e1 100644 --- a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_kernel.h +++ b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_kernel.h @@ -56,46 +56,53 @@ void EncoderWriteCacheWithRopeKernel( auto head_dim = meta_data.head_dims; bool is_scale_channel_wise = false; int rotary_dim = head_dim; - if (cache_k_scale && cache_k_scale.get().dims()[0] == head_dim * kv_num_heads) { + if (cache_k_scale && + cache_k_scale.get().dims()[0] == head_dim * kv_num_heads) { is_scale_channel_wise = true; } - if (rotary_embs){ - rotary_dim = rotary_embs.get().dims()[rotary_embs.get().dims().size()-1] * 2; - if(rotary_dim < head_dim){ - if (!use_neox_style || q_norm_weight || k_norm_weight || num_heads == kv_num_heads || is_scale_channel_wise){ + if (rotary_embs) { + rotary_dim = + rotary_embs.get().dims()[rotary_embs.get().dims().size() - 1] * 2; + if (rotary_dim < head_dim) { + if (!use_neox_style || q_norm_weight || k_norm_weight || + num_heads == kv_num_heads || is_scale_channel_wise) { PADDLE_THROW(phi::errors::Fatal( - "partial_rotary_factor < 1.0 only supports use_neox_rotary_style=True, q_norm_weight/k_norm_weight) is None, GQA and is_scale_channel_wise=false.")); + "partial_rotary_factor < 1.0 only supports " + "use_neox_rotary_style=True, q_norm_weight/k_norm_weight) is None, " + "GQA and is_scale_channel_wise=false.")); } } } if (q_norm_weight && k_norm_weight) { - if (num_heads != kv_num_heads && !is_scale_channel_wise && !use_neox_style) { + if (num_heads != kv_num_heads && !is_scale_channel_wise && + !use_neox_style) { gqa_rotary_qk_norm_variable( - qkv_out->data(), - qkv.data(), - qkv_out_scales ? qkv_out_scales.get().data() : nullptr, - qkv_biases ? qkv_biases.get().data() : nullptr, - rotary_embs.get().data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens_encoder.data(), - seq_lens_decoder.data(), - token_num, - num_heads, - kv_num_heads, - max_seq_len, - rope_3d ? rotary_embs.get().dims()[3] : rotary_embs.get().dims()[2], - head_dim, - stream, - use_neox_style, - rope_3d, - q_norm_weight ? q_norm_weight.get().data() : nullptr, - k_norm_weight ? k_norm_weight.get().data() : nullptr, - rms_norm_eps); + qkv_out->data(), + qkv.data(), + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? qkv_biases.get().data() : nullptr, + rotary_embs.get().data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + token_num, + num_heads, + kv_num_heads, + max_seq_len, + rope_3d ? rotary_embs.get().dims()[3] : rotary_embs.get().dims()[2], + head_dim, + stream, + use_neox_style, + rope_3d, + q_norm_weight ? q_norm_weight.get().data() : nullptr, + k_norm_weight ? k_norm_weight.get().data() : nullptr, + rms_norm_eps); } else { PD_THROW( - "gqa_rotary_qk_norm_variable only support gqa mode. channel wise scale and neox style are not supported"); + "gqa_rotary_qk_norm_variable only support gqa mode. channel wise " + "scale and neox style are not supported"); } } else { if (num_heads == kv_num_heads) { @@ -120,49 +127,48 @@ void EncoderWriteCacheWithRopeKernel( } else { if (!is_scale_channel_wise) { gqa_rotary_qk_variable( - qkv_out->data(), - qkv.data(), - qkv_out_scales ? qkv_out_scales.get().data() : nullptr, - qkv_biases ? qkv_biases.get().data() : nullptr, - rotary_embs.get().data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens_encoder.data(), - seq_lens_decoder.data(), - token_num, - num_heads, - kv_num_heads, - max_seq_len, - rope_3d ? rotary_embs.get().dims()[3] : rotary_embs.get().dims()[2], - head_dim, - rotary_dim, - stream, - use_neox_style, - rope_3d); + qkv_out->data(), + qkv.data(), + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? qkv_biases.get().data() : nullptr, + rotary_embs.get().data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + token_num, + num_heads, + kv_num_heads, + max_seq_len, + rope_3d ? rotary_embs.get().dims()[3] : rotary_embs.get().dims()[2], + head_dim, + rotary_dim, + stream, + use_neox_style, + rope_3d); } else { gqa_rotary_qk_quant_variable( - qkv_out->data(), - qkv.data(), - qkv_out_scales ? qkv_out_scales.get().data() : nullptr, - qkv_biases ? qkv_biases.get().data() : nullptr, - cache_k_scale ? cache_k_scale.get().data() : nullptr, - cache_v_scale ? cache_v_scale.get().data() : nullptr, - rotary_embs.get().data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens_encoder.data(), - seq_lens_decoder.data(), - token_num, - num_heads, - kv_num_heads, - max_seq_len, - rotary_embs.get().dims()[2], - head_dim, - stream, - use_neox_style, - rope_3d); + qkv_out->data(), + qkv.data(), + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? qkv_biases.get().data() : nullptr, + cache_k_scale ? cache_k_scale.get().data() : nullptr, + cache_v_scale ? cache_v_scale.get().data() : nullptr, + rotary_embs.get().data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + token_num, + num_heads, + kv_num_heads, + max_seq_len, + rotary_embs.get().dims()[2], + head_dim, + stream, + use_neox_style, + rope_3d); } - } } const uint32_t block_size = meta_data.block_size; @@ -178,7 +184,9 @@ void EncoderWriteCacheWithRopeKernel( stream, key_cache_out, value_cache_out); - } else if (cache_quant_type_str == "cache_int8" or cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") { + } else if (cache_quant_type_str == "cache_int8" or + cache_quant_type_str == "cache_fp8" or + cache_quant_type_str == "block_wise_fp8") { DISPATCH_HEAD_DIM( head_dim, HEAD_DIM, {DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, { CascadeAppendWriteCacheKVC8QKV( @@ -234,23 +242,29 @@ void EncoderWriteCacheWithRopeKernel( "cache_int4_zp]"); } - const char* fmt_write_cache_completed_signal_str = std::getenv("FLAGS_fmt_write_cache_completed_signal"); - const char* FLAGS_use_pd_disaggregation_per_chunk = std::getenv("FLAGS_use_pd_disaggregation_per_chunk"); + const char* fmt_write_cache_completed_signal_str = + std::getenv("FLAGS_fmt_write_cache_completed_signal"); + const char* FLAGS_use_pd_disaggregation_per_chunk = + std::getenv("FLAGS_use_pd_disaggregation_per_chunk"); if (fmt_write_cache_completed_signal_str && (std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 || std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) { - if (FLAGS_use_pd_disaggregation_per_chunk && - (std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "true") == 0 || - std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "1") == 0)) { - cudaLaunchHostFunc(qkv.stream(), - &(RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_per_query), - (void*)nullptr); - } else { - if (kv_signal_data) { - cudaLaunchHostFunc(qkv.stream(), - &RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise, - (void*)(const_cast(kv_signal_data.get().data()))); - } + if (FLAGS_use_pd_disaggregation_per_chunk && + (std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "true") == 0 || + std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "1") == 0)) { + cudaLaunchHostFunc( + qkv.stream(), + &(RemoteCacheKvIpc:: + save_cache_kv_complete_signal_layerwise_per_query), + (void*)nullptr); + } else { + if (kv_signal_data) { + cudaLaunchHostFunc( + qkv.stream(), + &RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise, + (void*)(const_cast( + kv_signal_data.get().data()))); } + } } } diff --git a/custom_ops/gpu_ops/append_attn/mem_util.cuh b/custom_ops/gpu_ops/append_attn/mem_util.cuh index fb735be7ae..25c7a62388 100644 --- a/custom_ops/gpu_ops/append_attn/mem_util.cuh +++ b/custom_ops/gpu_ops/append_attn/mem_util.cuh @@ -66,10 +66,10 @@ __device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) { #ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16); - memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 16); + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 16); } else { memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16); - memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 16); + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 16); } #else if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { @@ -100,19 +100,23 @@ __device__ __forceinline__ void pred_load_128b(T* smem_ptr, int src_in_bytes = predicate ? 16 : 0; if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16); - memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, src_in_bytes); + memcpy(__cvta_shared_to_generic(smem_int_ptr), + (void*)gmem_ptr, + src_in_bytes); } else { memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16); - memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, src_in_bytes); + memcpy(__cvta_shared_to_generic(smem_int_ptr), + (void*)gmem_ptr, + src_in_bytes); } } else { if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { if (predicate) { - memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 16); + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 16); } } else { if (predicate) { - memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 16); + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 16); } } } @@ -169,10 +173,11 @@ __device__ __forceinline__ void pred_load_64b(T* smem_ptr, if constexpr (fill_mode == SharedMemFillMode::kFillZero) { int src_in_bytes = predicate ? 8 : 0; memset(__cvta_shared_to_generic(smem_int_ptr), 0, 8); - memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, src_in_bytes); + memcpy( + __cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, src_in_bytes); } else { if (predicate) { - memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 8); + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 8); } } #else @@ -207,10 +212,11 @@ __device__ __forceinline__ void pred_load_32b(T* smem_ptr, if constexpr (fill_mode == SharedMemFillMode::kFillZero) { int src_in_bytes = predicate ? 4 : 0; memset(__cvta_shared_to_generic(smem_int_ptr), 0, 4); - memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, src_in_bytes); + memcpy( + __cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, src_in_bytes); } else { if (predicate) { - memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 4); + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 4); } } #else @@ -275,7 +281,6 @@ struct smem_t { template __device__ __forceinline__ smem_t(T* base) : base((b128_t*)base) {} - template static __device__ __forceinline__ uint32_t get_permuted_offset(uint32_t i, uint32_t j) { diff --git a/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh b/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh index 2efcb7a8c6..ec5b428bda 100644 --- a/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh +++ b/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh @@ -20,10 +20,10 @@ template __global__ void decode_absorb_cache_kernel( const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512 - const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64 - T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size, - // nope_size] - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64 + T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size, + // nope_size] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -62,26 +62,25 @@ __global__ void decode_absorb_cache_kernel( const int block_idx = block_table_now[write_seq_id / block_size]; const int block_offset = write_seq_id % block_size; - if (bias < nope_hidden_size) { // pe + if (bias < nope_hidden_size) { // pe const uint32_t inner_bias = bias; const uint32_t hi = inner_bias / nope_size; const uint32_t h_bias = inner_bias % nope_size; - const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size + - hi * block_size * all_size + - block_offset * all_size + h_bias; - const uint32_t ori_idx = - start_token_idx * nope_hidden_size + inner_bias; + const uint32_t tgt_idx = + block_idx * kv_num_heads * block_size * all_size + + hi * block_size * all_size + block_offset * all_size + h_bias; + const uint32_t ori_idx = start_token_idx * nope_hidden_size + inner_bias; Load(&kv_nope[ori_idx], &src_vec); Store(src_vec, &kv_cache[tgt_idx]); } else { const uint32_t inner_bias = bias - nope_hidden_size; const uint32_t hi = inner_bias / pe_size; const uint32_t h_bias = inner_bias % pe_size; - const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size + - hi * block_size * all_size + - block_offset * all_size + nope_size + h_bias; - const uint32_t ori_idx = - start_token_idx * pe_hidden_size + inner_bias; + const uint32_t tgt_idx = + block_idx * kv_num_heads * block_size * all_size + + hi * block_size * all_size + block_offset * all_size + nope_size + + h_bias; + const uint32_t ori_idx = start_token_idx * pe_hidden_size + inner_bias; Load(&kv_pe[ori_idx], &src_vec); Store(src_vec, &kv_cache[tgt_idx]); } @@ -91,10 +90,10 @@ __global__ void decode_absorb_cache_kernel( template __global__ void speculate_decode_absorb_cache_kernel( const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512 - const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64 - T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size, - // nope_size] - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64 + T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size, + // nope_size] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] const int* __restrict__ batch_id_per_token, const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] @@ -125,8 +124,7 @@ __global__ void speculate_decode_absorb_cache_kernel( if (seq_lens[ori_bi] == 0) continue; const int bias = linear_index % hidden_size; const int start_token_idx = cu_seqlens_q[ori_bi]; - const int write_seq_id = - seq_lens[ori_bi] + token_id - start_token_idx; + const int write_seq_id = seq_lens[ori_bi] + token_id - start_token_idx; if (write_seq_id == 0) continue; const int* block_table_now = nullptr; @@ -145,26 +143,25 @@ __global__ void speculate_decode_absorb_cache_kernel( token_id, cu_seqlens_q[ori_bi]); } - if (bias < nope_hidden_size) { // pe + if (bias < nope_hidden_size) { // pe const uint32_t inner_bias = bias; const uint32_t hi = inner_bias / nope_size; const uint32_t h_bias = inner_bias % nope_size; - const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size + - hi * block_size * all_size + - block_offset * all_size + h_bias; - const uint32_t ori_idx = - token_id * nope_hidden_size + inner_bias; + const uint32_t tgt_idx = + block_idx * kv_num_heads * block_size * all_size + + hi * block_size * all_size + block_offset * all_size + h_bias; + const uint32_t ori_idx = token_id * nope_hidden_size + inner_bias; Load(&kv_nope[ori_idx], &src_vec); Store(src_vec, &kv_cache[tgt_idx]); } else { const uint32_t inner_bias = bias - nope_hidden_size; const uint32_t hi = inner_bias / pe_size; const uint32_t h_bias = inner_bias % pe_size; - const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size + - hi * block_size * all_size + - block_offset * all_size + nope_size + h_bias; - const uint32_t ori_idx = - token_id * pe_hidden_size + inner_bias; + const uint32_t tgt_idx = + block_idx * kv_num_heads * block_size * all_size + + hi * block_size * all_size + block_offset * all_size + nope_size + + h_bias; + const uint32_t ori_idx = token_id * pe_hidden_size + inner_bias; Load(&kv_pe[ori_idx], &src_vec); Store(src_vec, &kv_cache[tgt_idx]); } @@ -174,10 +171,10 @@ __global__ void speculate_decode_absorb_cache_kernel( template __global__ void prefill_absorb_cache_kernel( const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512 - const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64 - T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size, - // nope_size] - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64 + T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size, + // nope_size] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] const int* __restrict__ batch_id_per_token, const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] @@ -206,33 +203,33 @@ __global__ void prefill_absorb_cache_kernel( const uint32_t bias = linear_index % hidden_size; const uint32_t ori_bi = batch_id_per_token[token_idx]; if (seq_lens[ori_bi] == 0) continue; - const uint32_t ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; + const uint32_t ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int* block_table_now = nullptr; block_table_now = block_tables + ori_bi * max_blocks_per_seq; const uint32_t block_idx = block_table_now[ori_seq_id / block_size]; const uint32_t block_offset = ori_seq_id % block_size; - if (bias < nope_hidden_size) { // pe + if (bias < nope_hidden_size) { // pe const uint32_t inner_bias = bias; const uint32_t hi = inner_bias / nope_size; const uint32_t h_bias = inner_bias % nope_size; - const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size + - hi * block_size * all_size + - block_offset * all_size + h_bias; - const uint32_t ori_idx = - token_idx * nope_hidden_size + inner_bias; + const uint32_t tgt_idx = + block_idx * kv_num_heads * block_size * all_size + + hi * block_size * all_size + block_offset * all_size + h_bias; + const uint32_t ori_idx = token_idx * nope_hidden_size + inner_bias; Load(&kv_nope[ori_idx], &src_vec); Store(src_vec, &kv_cache[tgt_idx]); } else { const uint32_t inner_bias = bias - nope_hidden_size; const uint32_t hi = inner_bias / pe_size; const uint32_t h_bias = inner_bias % pe_size; - const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size + - hi * block_size * all_size + - block_offset * all_size + nope_size + h_bias; - const uint32_t ori_idx = - token_idx * pe_hidden_size + inner_bias; + const uint32_t tgt_idx = + block_idx * kv_num_heads * block_size * all_size + + hi * block_size * all_size + block_offset * all_size + nope_size + + h_bias; + const uint32_t ori_idx = token_idx * pe_hidden_size + inner_bias; Load(&kv_pe[ori_idx], &src_vec); Store(src_vec, &kv_cache[tgt_idx]); } diff --git a/custom_ops/gpu_ops/append_attn/multiquery_decoder_attention_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_decoder_attention_impl.cuh index 759aa7fbd6..2eb6d6bde5 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_decoder_attention_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_decoder_attention_impl.cuh @@ -16,37 +16,40 @@ #include "decode_attention_func.cuh" #include "multiquery_decoder_attention_kernel.h" -#define CHECK(call) \ -do \ -{ \ - const cudaError_t error_code = call; \ - if (error_code != cudaSuccess) \ - { \ - printf("CUDA Error:\n"); \ - printf(" File: %s\n", __FILE__); \ - printf(" Line %d:\n", __LINE__); \ - printf(" Error code:%d\n", error_code); \ - printf(" Error text:%s\n", cudaGetErrorString(error_code)); \ - exit(1); \ - } \ -}while(0) +#define CHECK(call) \ + do { \ + const cudaError_t error_code = call; \ + if (error_code != cudaSuccess) { \ + printf("CUDA Error:\n"); \ + printf(" File: %s\n", __FILE__); \ + printf(" Line %d:\n", __LINE__); \ + printf(" Error code:%d\n", error_code); \ + printf(" Error text:%s\n", cudaGetErrorString(error_code)); \ + exit(1); \ + } \ + } while (0) -template -__global__ void merge_varlen_multi_chunks_v2_kernel(const T * __restrict__ multi_out, // [bsz, num_chunks, num_heads, head_dim] - const T * __restrict__ multi_m, // [bsz, num_chunks, num_heads] - const T * __restrict__ multi_d, // [bsz, num_chunks, num_heads] - const int * __restrict__ seq_lens_q, - const int * __restrict__ seq_lens_kv, - const int * __restrict__ cu_seqlens_q, - const T * __restrict__ shift_bias, // [q_num_heads * HEAD_DIM] - const T * __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] - OutT * __restrict__ out, // [token_num, num_heads, head_dim] - const float in_scale, - const int num_chunks, - const int chunk_size, - const int max_seq_len, - const int num_heads, - const int head_dim) { +template +__global__ void merge_varlen_multi_chunks_v2_kernel( + const T *__restrict__ multi_out, // [bsz, num_chunks, num_heads, head_dim] + const T *__restrict__ multi_m, // [bsz, num_chunks, num_heads] + const T *__restrict__ multi_d, // [bsz, num_chunks, num_heads] + const int *__restrict__ seq_lens_q, + const int *__restrict__ seq_lens_kv, + const int *__restrict__ cu_seqlens_q, + const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + OutT *__restrict__ out, // [token_num, num_heads, head_dim] + const float in_scale, + const int num_chunks, + const int chunk_size, + const int max_seq_len, + const int num_heads, + const int head_dim) { const int vid = threadIdx.x, ty = threadIdx.y; const int qid = blockIdx.x, hid = blockIdx.y; const int seq_len_q = seq_lens_q[qid]; @@ -68,12 +71,12 @@ __global__ void merge_varlen_multi_chunks_v2_kernel(const T * __restrict__ multi if constexpr (std::is_same::value) { #pragma unroll for (int i = 0; i < vec_size / 2; ++i) { - *((half2*)(&res_vec) + i) = make_half2(0, 0); + *((half2 *)(&res_vec) + i) = make_half2(0, 0); } } else if constexpr (std::is_same::value) { #pragma unroll for (int i = 0; i < vec_size / 2; ++i) { - *((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0); + *((nv_bfloat162 *)(&res_vec) + i) = make_bfloat162(0, 0); } } T m; @@ -92,7 +95,8 @@ __global__ void merge_varlen_multi_chunks_v2_kernel(const T * __restrict__ multi const T m_now = multi_m[offset]; const T d_now = multi_d[offset]; m = m_prev > m_now ? m_prev : m_now; - offset = (qid * num_chunks * num_heads + i * num_heads + hid) * head_dim + vid * vec_size; + offset = (qid * num_chunks * num_heads + i * num_heads + hid) * head_dim + + vid * vec_size; Load(&multi_out[offset], &load_vec); const T scale1 = hexp(m_prev - m), scale2 = hexp(m_now - m); d = d * scale1 + d_now * scale2; @@ -124,30 +128,47 @@ __global__ void merge_varlen_multi_chunks_v2_kernel(const T * __restrict__ multi for (int i = 0; i < vec_size; ++i) { out_vec[i] = static_cast(st.o[i]); } - Store(out_vec, &out[(start_token_ids * num_heads + hid) * head_dim + vid * vec_size]); + Store( + out_vec, + &out[(start_token_ids * num_heads + hid) * head_dim + vid * vec_size]); } -template -__global__ void multi_query_decode_attention_kernel(T * __restrict__ q, // [token_num, num_heads, head_dim] - CacheT * __restrict__ cache_k, // [max_block_num, num_heads, block_size, head_dim] - CacheT * __restrict__ cache_v, - const T * __restrict__ shift_bias, // [q_num_heads * HEAD_DIM] - const T * __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] - const int * __restrict__ seq_lens_q, - const int * __restrict__ seq_lens_kv, - const int * __restrict__ cu_seqlens_q, - const int * __restrict__ block_table, // [bsz, block_num_per_seq] - const int max_seq_len, - const int max_dec_len, - const int max_block_num_per_seq, - const float scale, - const float in_scale, - const uint32_t chunk_size, - T * __restrict__ tmp_workspace, // [batch_size, num_chunks, num_heads, head_dim] - T * __restrict__ tmp_m, // [batch_size, num_chunks, num_heads] - T * __restrict__ tmp_d, // [batch_size, num_chunks, num_heads] - OutT * __restrict__ out) { +template +__global__ void multi_query_decode_attention_kernel( + T *__restrict__ q, // [token_num, num_heads, head_dim] + CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size, + // head_dim] + CacheT *__restrict__ cache_v, + const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + const int *__restrict__ seq_lens_q, + const int *__restrict__ seq_lens_kv, + const int *__restrict__ cu_seqlens_q, + const int *__restrict__ block_table, // [bsz, block_num_per_seq] + const int max_seq_len, + const int max_dec_len, + const int max_block_num_per_seq, + const float scale, + const float in_scale, + const uint32_t chunk_size, + T *__restrict__ tmp_workspace, // [batch_size, num_chunks, num_heads, + // head_dim] + T *__restrict__ tmp_m, // [batch_size, num_chunks, num_heads] + T *__restrict__ tmp_d, // [batch_size, num_chunks, num_heads] + OutT *__restrict__ out) { const uint32_t bidx = blockIdx.x, kv_head_idx = blockIdx.z; const uint32_t bid = bidx, gid = threadIdx.y; const uint32_t tidx = threadIdx.x; @@ -167,9 +188,9 @@ __global__ void multi_query_decode_attention_kernel(T * __restrict__ q, // [toke if (q_len <= 0) { return; } - uint32_t kv_len = seq_lens_kv[bid]; // !!!!!!!! + uint32_t kv_len = seq_lens_kv[bid]; // !!!!!!!! if (kv_len <= 0) { - return; + return; } kv_len += q_len; const uint32_t num_chunk_this_seq = div_up(kv_len, chunk_size); @@ -180,23 +201,24 @@ __global__ void multi_query_decode_attention_kernel(T * __restrict__ q, // [toke } const uint32_t chunk_start = partition_kv ? chunk_id * chunk_size : 0; - const uint32_t chunk_end = partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; + const uint32_t chunk_end = + partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; const uint32_t chunk_len = chunk_end - chunk_start; extern __shared__ uint8_t smem[]; const T *q_now = q + (q_start_idx * q_num_heads + q_head_idx) * HEAD_DIM_QK; - T *q_smem = reinterpret_cast(smem); // [HEAD_DIM_QK * sizeof(T)] + T *q_smem = reinterpret_cast(smem); // [HEAD_DIM_QK * sizeof(T)] T *cu_q_smem = q_smem + gid * HEAD_DIM_QK; #pragma unroll - for(uint32_t vid = tidx; vid < num_vec_per_head_qk; vid += bdx) { - ((float4*)(&cu_q_smem[vid * VEC_SIZE]))[0] = ((float4*)(&q_now[vid * VEC_SIZE]))[0]; - + for (uint32_t vid = tidx; vid < num_vec_per_head_qk; vid += bdx) { + ((float4 *)(&cu_q_smem[vid * VEC_SIZE]))[0] = + ((float4 *)(&q_now[vid * VEC_SIZE]))[0]; } __syncthreads(); using VecT = AlignedVector; VecT q_vec; #pragma unroll - for(uint32_t vid = tidx; vid < num_vec_per_head_qk; vid += bdx) { + for (uint32_t vid = tidx; vid < num_vec_per_head_qk; vid += bdx) { Load(cu_q_smem + vid * VEC_SIZE, &q_vec); for (uint32_t i = 0; i < VEC_SIZE; ++i) { q_vec[i] *= scale; @@ -204,8 +226,8 @@ __global__ void multi_query_decode_attention_kernel(T * __restrict__ q, // [toke Store(q_vec, cu_q_smem + vid * VEC_SIZE); } - - CacheT *kv_smem = reinterpret_cast(smem + GROUP_SIZE * HEAD_DIM_QK * sizeof(CacheT)); + CacheT *kv_smem = reinterpret_cast(smem + GROUP_SIZE * HEAD_DIM_QK * + sizeof(CacheT)); uint32_t stage_idx = 0; constexpr int loop_times = DEAL_EACH_TIME / bdy; #pragma unroll @@ -214,24 +236,27 @@ __global__ void multi_query_decode_attention_kernel(T * __restrict__ q, // [toke for (int j = 0; j < loop_times; ++j) { const uint32_t k_seq_offset = i * DEAL_EACH_TIME + j * bdy + gid; const uint32_t k_seq_id = chunk_start + k_seq_offset; - produce_kv( - kv_smem, - cache_k, - block_table_now, - k_seq_id, - k_seq_offset, - kv_head_idx, - kv_num_heads, - tidx, - chunk_start, - chunk_end - ); + produce_kv(kv_smem, + cache_k, + block_table_now, + k_seq_id, + k_seq_offset, + kv_head_idx, + kv_num_heads, + tidx, + chunk_start, + chunk_end); } commit_group(); stage_idx = (stage_idx + 1) % NUM_STAGES; } - softmax_state_ts st; float s[DEAL_EACH_TIME]; @@ -240,48 +265,55 @@ __global__ void multi_query_decode_attention_kernel(T * __restrict__ q, // [toke wait_group(); __syncthreads(); // compute qk - compute_qk( - cu_q_smem, - kv_smem, - chunk_start + iter * DEAL_EACH_TIME, - stage_idx, - iter * DEAL_EACH_TIME, - chunk_len, - tidx, - gid, - scale, - s, - st - ); + compute_qk(cu_q_smem, + kv_smem, + chunk_start + iter * DEAL_EACH_TIME, + stage_idx, + iter * DEAL_EACH_TIME, + chunk_len, + tidx, + gid, + scale, + s, + st); __syncthreads(); // compute sv - compute_sv( - s, - kv_smem, - stage_idx, - iter * DEAL_EACH_TIME, - chunk_len, - tidx, - st - ); + compute_sv( + s, kv_smem, stage_idx, iter * DEAL_EACH_TIME, chunk_len, tidx, st); __syncthreads(); #pragma unroll for (int j = 0; j < loop_times; ++j) { const uint32_t k_seq_offset = j * bdy + gid; - produce_kv( - kv_smem, - cache_k, - block_table_now, - chunk_start + k_seq_offset + (iter + NUM_STAGES) * DEAL_EACH_TIME, - stage_idx * DEAL_EACH_TIME + k_seq_offset, - kv_head_idx, - kv_num_heads, - tidx, - chunk_start, - chunk_end - ); + produce_kv( + kv_smem, + cache_k, + block_table_now, + chunk_start + k_seq_offset + (iter + NUM_STAGES) * DEAL_EACH_TIME, + stage_idx * DEAL_EACH_TIME + k_seq_offset, + kv_head_idx, + kv_num_heads, + tidx, + chunk_start, + chunk_end); } commit_group(); stage_idx = (stage_idx + 1) % NUM_STAGES; @@ -290,45 +322,59 @@ __global__ void multi_query_decode_attention_kernel(T * __restrict__ q, // [toke __syncthreads(); // normize if not partition_kv - for(uint32_t vid = tidx; vid < num_vec_per_head_v; vid += bdx) { + for (uint32_t vid = tidx; vid < num_vec_per_head_v; vid += bdx) { const uint32_t tile_id = vid / bdx; if (!partition_kv || num_chunk_this_seq == 1) { st.normalize(tile_id); } if (partition_kv && num_chunk_this_seq > 1) { - const uint32_t head_idx = (bid * num_chunks + chunk_id) * q_num_heads + q_head_idx; - Store(st.o[tile_id], tmp_workspace + head_idx * HEAD_DIM_V + vid * VEC_SIZE); + const uint32_t head_idx = + (bid * num_chunks + chunk_id) * q_num_heads + q_head_idx; + Store( + st.o[tile_id], + tmp_workspace + head_idx * HEAD_DIM_V + vid * VEC_SIZE); tmp_m[head_idx] = st.m; tmp_d[head_idx] = st.d; } else { - Store(st.o[tile_id], out + (q_write_idx * q_num_heads + q_head_idx) * HEAD_DIM_V + vid * VEC_SIZE); + Store( + st.o[tile_id], + out + (q_write_idx * q_num_heads + q_head_idx) * HEAD_DIM_V + + vid * VEC_SIZE); } } } - -template +template void MultiQueryDecoderAttention( - const AppendAttnMetaData& meta_data, - cudaStream_t &stream, - const paddle::Tensor &q, - const paddle::Tensor &cache_k, // [max_block_num, num_kv_heads, block_size, head_dim] - const paddle::Tensor &cache_v, // [num_kv_heads, head_dim] - const paddle::optional& attn_mask, - const paddle::optional& shift_bias, - const paddle::optional& smooth_weight, - const paddle::Tensor &seq_lens_q, - const paddle::Tensor &seq_lens_kv, - const paddle::Tensor &batch_id_per_token, - const paddle::Tensor &cu_seqlens_q, - const paddle::Tensor &block_table, - const int max_seq_len, - const int max_dec_len, - const float rope_scale, - const float rope_theta, - const float softmax_scale, - const float in_scale, - paddle::Tensor *out) { + const AppendAttnMetaData &meta_data, + cudaStream_t &stream, + const paddle::Tensor &q, + const paddle::Tensor + &cache_k, // [max_block_num, num_kv_heads, block_size, head_dim] + const paddle::Tensor &cache_v, // [num_kv_heads, head_dim] + const paddle::optional &attn_mask, + const paddle::optional &shift_bias, + const paddle::optional &smooth_weight, + const paddle::Tensor &seq_lens_q, + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &block_table, + const int max_seq_len, + const int max_dec_len, + const float rope_scale, + const float rope_theta, + const float softmax_scale, + const float in_scale, + paddle::Tensor *out) { using NV_TYPE = typename cascade_attn_type_traits::type; auto num_heads = meta_data.q_num_heads; @@ -338,8 +384,8 @@ void MultiQueryDecoderAttention( auto max_block_num_per_seq = meta_data.max_blocks_per_seq; constexpr int num_stages = NUM_STAGE; - constexpr int vec_size = 16 / sizeof(T); // 8 16 32 - constexpr int cache_vec_size = 128 / cache_bytes; // 8 16 32 + constexpr int vec_size = 16 / sizeof(T); // 8 16 32 + constexpr int cache_vec_size = 128 / cache_bytes; // 8 16 32 constexpr int blockxc = HEAD_DIM_QK / cache_vec_size; constexpr int num_vec_per_head = HEAD_DIM_QK / vec_size; constexpr int blockx = num_vec_per_head < 32 ? num_vec_per_head : 32; @@ -349,12 +395,25 @@ void MultiQueryDecoderAttention( constexpr int num_threads = blockx * blocky; - auto splitkv_kernel = multi_query_decode_attention_kernel; + auto splitkv_kernel = multi_query_decode_attention_kernel; uint32_t cache_smem_bytes = 0; const T *shift_bias_ptr = shift_bias ? shift_bias.get().data() : nullptr; - const T *smooth_weight_ptr = smooth_weight ? smooth_weight.get().data() : nullptr; + const T *smooth_weight_ptr = + smooth_weight ? smooth_weight.get().data() : nullptr; cache_smem_bytes = num_stages * DEAL_EACH_TIME * HEAD_DIM_QK * sizeof(T); const uint32_t chunk_size = get_max_partition_size(bsz); @@ -363,51 +422,64 @@ void MultiQueryDecoderAttention( if (smem_size >= 48 * 1024) { cudaFuncSetAttribute( - splitkv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + splitkv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); } const int dev_id = 0; int sm_count; int act_blocks_per_sm; cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &act_blocks_per_sm, splitkv_kernel, num_threads, smem_size); + &act_blocks_per_sm, splitkv_kernel, num_threads, smem_size); assert(act_blocks_per_sm > 1); const int num_blocks_per_wave = sm_count * act_blocks_per_sm; const int num_blocks_need = gridx * num_chunks * kv_num_heads; const int max_num_chunks = div_up(num_blocks_per_wave, num_blocks_need); - const float ratio = static_cast(num_blocks_need) / static_cast(num_blocks_per_wave); + const float ratio = static_cast(num_blocks_need) / + static_cast(num_blocks_per_wave); dim3 grids(gridx, num_chunks, kv_num_heads); dim3 blocks(blockx, blocky); if (num_chunks <= 1) { - auto no_splitkv_kernel = multi_query_decode_attention_kernel; + auto no_splitkv_kernel = multi_query_decode_attention_kernel; if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute( - no_splitkv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + cudaFuncSetAttribute(no_splitkv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); } no_splitkv_kernel<<>>( - reinterpret_cast(const_cast(q.data())), - reinterpret_cast(const_cast(cache_k.data())), - reinterpret_cast(const_cast(cache_v.data())), - reinterpret_cast(const_cast(shift_bias_ptr)), - reinterpret_cast(const_cast(smooth_weight_ptr)), - seq_lens_q.data(), - seq_lens_kv.data(), - cu_seqlens_q.data(), - block_table.data(), - max_seq_len, - max_dec_len, - max_block_num_per_seq, - softmax_scale, - in_scale, - chunk_size, - nullptr, - nullptr, - nullptr, - reinterpret_cast(const_cast(out->data())) - ); + reinterpret_cast(const_cast(q.data())), + reinterpret_cast(const_cast(cache_k.data())), + reinterpret_cast(const_cast(cache_v.data())), + reinterpret_cast(const_cast(shift_bias_ptr)), + reinterpret_cast(const_cast(smooth_weight_ptr)), + seq_lens_q.data(), + seq_lens_kv.data(), + cu_seqlens_q.data(), + block_table.data(), + max_seq_len, + max_dec_len, + max_block_num_per_seq, + softmax_scale, + in_scale, + chunk_size, + nullptr, + nullptr, + nullptr, + reinterpret_cast(const_cast(out->data()))); // CHECK(cudaGetLastError()); // CHECK(cudaDeviceSynchronize()); @@ -417,34 +489,33 @@ void MultiQueryDecoderAttention( tmp_workspace = allocator->Allocate( phi::SizeOf(q.dtype()) * static_cast(bsz * num_chunks * num_heads * HEAD_DIM_V)); - tmp_m = allocator->Allocate( - phi::SizeOf(q.dtype()) * - static_cast(bsz * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(q.dtype()) * - static_cast(bsz * num_chunks * num_heads)); + tmp_m = + allocator->Allocate(phi::SizeOf(q.dtype()) * + static_cast(bsz * num_chunks * num_heads)); + tmp_d = + allocator->Allocate(phi::SizeOf(q.dtype()) * + static_cast(bsz * num_chunks * num_heads)); splitkv_kernel<<>>( - reinterpret_cast(const_cast(q.data())), - reinterpret_cast(const_cast(cache_k.data())), - reinterpret_cast(const_cast(cache_v.data())), - reinterpret_cast(const_cast(shift_bias_ptr)), - reinterpret_cast(const_cast(smooth_weight_ptr)), - seq_lens_q.data(), - seq_lens_kv.data(), - cu_seqlens_q.data(), - block_table.data(), - max_seq_len, - max_dec_len, - max_block_num_per_seq, - softmax_scale, - in_scale, - chunk_size, - reinterpret_cast(tmp_workspace->ptr()), - reinterpret_cast(tmp_m->ptr()), - reinterpret_cast(tmp_d->ptr()), - reinterpret_cast(const_cast(out->data())) - ); + reinterpret_cast(const_cast(q.data())), + reinterpret_cast(const_cast(cache_k.data())), + reinterpret_cast(const_cast(cache_v.data())), + reinterpret_cast(const_cast(shift_bias_ptr)), + reinterpret_cast(const_cast(smooth_weight_ptr)), + seq_lens_q.data(), + seq_lens_kv.data(), + cu_seqlens_q.data(), + block_table.data(), + max_seq_len, + max_dec_len, + max_block_num_per_seq, + softmax_scale, + in_scale, + chunk_size, + reinterpret_cast(tmp_workspace->ptr()), + reinterpret_cast(tmp_m->ptr()), + reinterpret_cast(tmp_d->ptr()), + reinterpret_cast(const_cast(out->data()))); // CHECK(cudaGetLastError()); // CHECK(cudaDeviceSynchronize()); @@ -452,23 +523,27 @@ void MultiQueryDecoderAttention( constexpr int bdy = 256 / mblockx; dim3 grids_merge(bsz, num_heads); dim3 blocks_merge(mblockx, bdy); - merge_varlen_multi_chunks_v2_kernel<<>>( - reinterpret_cast(tmp_workspace->ptr()), - reinterpret_cast(tmp_m->ptr()), - reinterpret_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - cu_seqlens_q.data(), - reinterpret_cast(const_cast(shift_bias_ptr)), - reinterpret_cast(const_cast(smooth_weight_ptr)), - reinterpret_cast(const_cast(out->data())), - in_scale, - num_chunks, - chunk_size, - max_seq_len, - num_heads, - HEAD_DIM_V - ); + merge_varlen_multi_chunks_v2_kernel + <<>>( + reinterpret_cast(tmp_workspace->ptr()), + reinterpret_cast(tmp_m->ptr()), + reinterpret_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + cu_seqlens_q.data(), + reinterpret_cast(const_cast(shift_bias_ptr)), + reinterpret_cast(const_cast(smooth_weight_ptr)), + reinterpret_cast(const_cast(out->data())), + in_scale, + num_chunks, + chunk_size, + max_seq_len, + num_heads, + HEAD_DIM_V); } // CHECK(cudaGetLastError()); // CHECK(cudaDeviceSynchronize()); diff --git a/custom_ops/gpu_ops/append_attn/multiquery_decoder_attention_kernel.h b/custom_ops/gpu_ops/append_attn/multiquery_decoder_attention_kernel.h index a1f8aa7cc8..457f383e5e 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_decoder_attention_kernel.h +++ b/custom_ops/gpu_ops/append_attn/multiquery_decoder_attention_kernel.h @@ -15,25 +15,34 @@ #include "decode_attention_func.cuh" -template +template void MultiQueryDecoderAttention( - const AppendAttnMetaData& meta_data, - cudaStream_t &stream, - const paddle::Tensor &q, - const paddle::Tensor &cache_k, // [max_block_num, num_kv_heads, block_size, head_dim] - const paddle::Tensor &cache_v, // [num_kv_heads, head_dim] - const paddle::optional& attn_mask, - const paddle::optional& shift_bias, - const paddle::optional& smooth_weight, - const paddle::Tensor &seq_lens_q, - const paddle::Tensor &seq_lens_kv, - const paddle::Tensor &batch_id_per_token, - const paddle::Tensor &cu_seqlens_q, - const paddle::Tensor &block_table, - const int max_seq_len, - const int max_dec_len, - const float rope_scale, - const float rope_theta, - const float softmax_scale, - const float in_scale, - paddle::Tensor *out); + const AppendAttnMetaData &meta_data, + cudaStream_t &stream, + const paddle::Tensor &q, + const paddle::Tensor + &cache_k, // [max_block_num, num_kv_heads, block_size, head_dim] + const paddle::Tensor &cache_v, // [num_kv_heads, head_dim] + const paddle::optional &attn_mask, + const paddle::optional &shift_bias, + const paddle::optional &smooth_weight, + const paddle::Tensor &seq_lens_q, + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &block_table, + const int max_seq_len, + const int max_dec_len, + const float rope_scale, + const float rope_theta, + const float softmax_scale, + const float in_scale, + paddle::Tensor *out); diff --git a/custom_ops/gpu_ops/append_attn/utils.cuh b/custom_ops/gpu_ops/append_attn/utils.cuh index 35af00822f..fa5b3bca17 100644 --- a/custom_ops/gpu_ops/append_attn/utils.cuh +++ b/custom_ops/gpu_ops/append_attn/utils.cuh @@ -27,7 +27,7 @@ struct AppendAttnMetaData { int head_dims; int head_dims_v; int max_blocks_per_seq; - const int *mask_offset = nullptr; + const int* mask_offset = nullptr; }; __forceinline__ __host__ __device__ int div_up(int a, int b) { @@ -110,29 +110,33 @@ __device__ __forceinline__ uint32_t sub_if_greater_or_zero(uint32_t x, /******************************FASTER CAST*********************************/ -inline __device__ static void convert_fp8(__nv_bfloat16* result, const uint32_t& source) { - +inline __device__ static void convert_fp8(__nv_bfloat16* result, + const uint32_t& source) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) - uint32_t dest0; - uint32_t dest1; - asm volatile( \ - "{\n" \ - ".reg .b16 lo, hi;\n" \ - "mov.b32 {lo, hi}, %2;\n" \ - "cvt.rn.f16x2.e4m3x2 %0, lo;\n" \ - "cvt.rn.f16x2.e4m3x2 %1, hi;\n" \ - "}\n" : "=r"(dest0), "=r"(dest1) : "r"(source)); + uint32_t dest0; + uint32_t dest1; + asm volatile( + "{\n" + ".reg .b16 lo, hi;\n" + "mov.b32 {lo, hi}, %2;\n" + "cvt.rn.f16x2.e4m3x2 %0, lo;\n" + "cvt.rn.f16x2.e4m3x2 %1, hi;\n" + "}\n" + : "=r"(dest0), "=r"(dest1) + : "r"(source)); - ((nv_bfloat162*)(result))[0] = __float22bfloat162_rn(__half22float2(((half2*)(&dest0))[0])); - ((nv_bfloat162*)(result))[1] = __float22bfloat162_rn(__half22float2(((half2*)(&dest1))[0])); + ((nv_bfloat162*)(result))[0] = + __float22bfloat162_rn(__half22float2(((half2*)(&dest0))[0])); + ((nv_bfloat162*)(result))[1] = + __float22bfloat162_rn(__half22float2(((half2*)(&dest1))[0])); #else - printf("Do not support fp8 in arch < 890\n"); - asm("trap;"); + printf("Do not support fp8 in arch < 890\n"); + asm("trap;"); #endif - } -inline __device__ static void convert_fp8(half* result, const uint32_t& source) { +inline __device__ static void convert_fp8(half* result, + const uint32_t& source) { printf("Do not support fp8 to half although it's very easy.\n"); } @@ -301,8 +305,8 @@ __forceinline__ __host__ __device__ void vec_cast( #define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \ switch (head_dim) { \ - case 64: { \ - constexpr size_t HEAD_DIM = 64; \ + case 64: { \ + constexpr size_t HEAD_DIM = 64; \ __VA_ARGS__ \ break; \ } \ @@ -385,9 +389,8 @@ __forceinline__ __host__ __device__ void vec_cast( PD_THROW("not support the cache_type: ", cache_type); \ } - #define DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME, ...) \ - if (deal_each_time == 32) { \ + if (deal_each_time == 32) { \ constexpr size_t DEAL_EACH_TIME = 32; \ __VA_ARGS__ \ } else if (deal_each_time == 64) { \ @@ -404,7 +407,7 @@ __forceinline__ __host__ __device__ void vec_cast( } else if (num_threads == 256) { \ constexpr size_t NUM_THREADS = 256; \ __VA_ARGS__ \ - } else { \ + } else { \ PD_THROW("not support the num_threads", num_threads); \ } @@ -456,7 +459,7 @@ __forceinline__ __host__ __device__ void vec_cast( } #define DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \ - if (group_size == 8) { \ + if (group_size == 8) { \ constexpr size_t GROUP_SIZE = 8; \ __VA_ARGS__ \ } else if (group_size == 16) { \ @@ -538,9 +541,11 @@ inline HOSTDEVICE T roundWithTiesToEven(T x) { : xUpper); } - template -__host__ __device__ __forceinline__ uint8_t QuantToC8(const T scale, const T value, const float max_bound, const float min_bound) { +__host__ __device__ __forceinline__ uint8_t QuantToC8(const T scale, + const T value, + const float max_bound, + const float min_bound) { uint8_t eight_bits; float quant_value; if constexpr (is_need_kv_quant) { @@ -572,8 +577,8 @@ __host__ __device__ __forceinline__ uint8_t QuantToC8(const T scale, const T val return eight_bits; } - -template inline __device__ static void convert_c8(T * result, const uint32_t& source){ +template +inline __device__ static void convert_c8(T* result, const uint32_t& source) { if constexpr (IsFP8) { convert_fp8(result, source); } else { @@ -583,12 +588,12 @@ template inline __device__ static void convert_c8(T * re constexpr int kWarpSize = 32; -template +template inline __device__ void WelfordCombine1(T b_m2, T* m2) { *m2 += b_m2; } -template +template __inline__ __device__ void WelfordWarpReduce(T thread_m2, T* m2) { *m2 = thread_m2; for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) { @@ -597,7 +602,7 @@ __inline__ __device__ void WelfordWarpReduce(T thread_m2, T* m2) { } } -template +template __inline__ __device__ void WelfordWarpAllReduce(T thread_m2, T* m2) { WelfordWarpReduce(thread_m2, m2); } diff --git a/custom_ops/gpu_ops/beam_search_softmax.cu b/custom_ops/gpu_ops/beam_search_softmax.cu index a12f2ca17b..e78d7000e9 100644 --- a/custom_ops/gpu_ops/beam_search_softmax.cu +++ b/custom_ops/gpu_ops/beam_search_softmax.cu @@ -36,8 +36,8 @@ namespace cub = hipcub; static constexpr int kBlockSizeForSmallBeamWidth = 256; static constexpr int kMaxVocabPartForStage1FastKernel = 128; -#define CASE_K(K) \ - case K: \ +#define CASE_K(K) \ + case K: \ invokeTopKSoftMaxLauncher( \ params, beam_group_idx, stream); \ break @@ -1336,24 +1336,25 @@ adding while op and without affecting the speed. Use a 'fake inplace' method here. Not elegant but useful ︸_︸. *****/ -std::vector BeamSearchSoftmax(const paddle::Tensor &logits, - const paddle::Tensor &seq_lens, - const paddle::Tensor &stop_flags, // inplace - const paddle::Tensor &end_ids, - const paddle::Tensor &step_ids, - const paddle::Tensor &max_dec_lens, - const paddle::Tensor &block_tables, // inplace - const paddle::Tensor &cum_scores, // inplace - const paddle::Tensor &beam_cache_ids, // inplace - const paddle::Tensor &beam_hyps, // inplace - const paddle::Tensor &beam_hyps_score, // inplace - const paddle::Tensor &beam_finished, // inplace - const paddle::Tensor &beam_width, - const paddle::Tensor &beam_group_num, - const paddle::Tensor &length_penalty, - const paddle::Tensor &diversity_penalty, - bool fuse_softmax, - bool early_stop) { +std::vector BeamSearchSoftmax( + const paddle::Tensor &logits, + const paddle::Tensor &seq_lens, + const paddle::Tensor &stop_flags, // inplace + const paddle::Tensor &end_ids, + const paddle::Tensor &step_ids, + const paddle::Tensor &max_dec_lens, + const paddle::Tensor &block_tables, // inplace + const paddle::Tensor &cum_scores, // inplace + const paddle::Tensor &beam_cache_ids, // inplace + const paddle::Tensor &beam_hyps, // inplace + const paddle::Tensor &beam_hyps_score, // inplace + const paddle::Tensor &beam_finished, // inplace + const paddle::Tensor &beam_width, + const paddle::Tensor &beam_group_num, + const paddle::Tensor &length_penalty, + const paddle::Tensor &diversity_penalty, + bool fuse_softmax, + bool early_stop) { std::vector logits_shape = logits.shape(); // logits_shape auto cu_stream = logits.stream(); @@ -1380,43 +1381,43 @@ std::vector BeamSearchSoftmax(const paddle::Tensor &logits, const int end_ids_len = end_ids.dims()[0]; const int beam_group_size = beam_width_scalar / beam_group_num_scalar; - auto next_tokens = paddle::full({logits_shape[0], 1}, 0, end_ids.type(), - paddle::GPUPlace()); + auto next_tokens = + paddle::full({logits_shape[0], 1}, 0, end_ids.type(), paddle::GPUPlace()); - auto parent_ids = paddle::full({logits_shape[0], 1}, 0, end_ids.type(), - paddle::GPUPlace()); + auto parent_ids = + paddle::full({logits_shape[0], 1}, 0, end_ids.type(), paddle::GPUPlace()); - auto cum_scores_ori = paddle::empty(cum_scores.shape(), logits.type(), - paddle::GPUPlace()); + auto cum_scores_ori = + paddle::empty(cum_scores.shape(), logits.type(), paddle::GPUPlace()); - auto beam_cache_ids_ori = paddle::empty(beam_cache_ids.shape(), end_ids.type(), - paddle::GPUPlace()); + auto beam_cache_ids_ori = + paddle::empty(beam_cache_ids.shape(), end_ids.type(), paddle::GPUPlace()); - auto block_tables_ori = paddle::empty(block_tables.shape(), end_ids.type(), - paddle::GPUPlace()); + auto block_tables_ori = + paddle::empty(block_tables.shape(), end_ids.type(), paddle::GPUPlace()); cudaMemcpyAsync(cum_scores_ori.mutable_data(), cum_scores.data(), - sizeof(float)*cum_scores.numel(), + sizeof(float) * cum_scores.numel(), cudaMemcpyDeviceToDevice, cu_stream); cudaMemcpyAsync(beam_cache_ids_ori.mutable_data(), beam_cache_ids.data(), - sizeof(int)*beam_cache_ids.numel(), + sizeof(int) * beam_cache_ids.numel(), cudaMemcpyDeviceToDevice, cu_stream); cudaMemcpyAsync(block_tables_ori.mutable_data(), block_tables.data(), - sizeof(int)*block_tables.numel(), + sizeof(int) * block_tables.numel(), cudaMemcpyDeviceToDevice, cu_stream); const int tmp_size = batch_size * beam_group_size * beam_group_size * 2; - auto tmp_topk_id = paddle::full({tmp_size}, 0, end_ids.type(), - paddle::GPUPlace()); + auto tmp_topk_id = + paddle::full({tmp_size}, 0, end_ids.type(), paddle::GPUPlace()); - auto tmp_topk_val = paddle::full({tmp_size}, 0.0, logits.type(), - paddle::GPUPlace()); + auto tmp_topk_val = + paddle::full({tmp_size}, 0.0, logits.type(), paddle::GPUPlace()); BeamSearchParams params; params.batch_size = batch_size; @@ -1449,7 +1450,8 @@ std::vector BeamSearchSoftmax(const paddle::Tensor &logits, params.block_tables_out = const_cast(block_tables.data()); params.cum_scores_out = const_cast(cum_scores.data()); params.beam_hyps_out = const_cast(beam_hyps.data()); - params.beam_hyps_score_out = const_cast(beam_hyps_score.data()); + params.beam_hyps_score_out = + const_cast(beam_hyps_score.data()); params.beam_finished = const_cast(beam_finished.data()); params.stop_flags = const_cast(stop_flags.data()); @@ -1470,8 +1472,8 @@ std::vector BeamSearchSoftmax(const paddle::Tensor &logits, const int workspace_size = tmp_id_val_size * 2 + tmp_stage1_to_stage2_size; - auto wsp_buffer_tensor = paddle::full({workspace_size}, 0, logits.type(), - paddle::GPUPlace()); + auto wsp_buffer_tensor = + paddle::full({workspace_size}, 0, logits.type(), paddle::GPUPlace()); params.tmp_ids = reinterpret_cast(wsp_buffer_tensor.data()); params.tmp_vals = wsp_buffer_tensor.data() + tmp_id_val_size; @@ -1480,11 +1482,9 @@ std::vector BeamSearchSoftmax(const paddle::Tensor &logits, for (int beam_group_idx = 0; beam_group_idx < beam_group_num_scalar; ++beam_group_idx) { if (beam_group_num_scalar == 1) { - invokeTopkSoftMax( - ¶ms, beam_group_idx, cu_stream); + invokeTopkSoftMax(¶ms, beam_group_idx, cu_stream); } else { - invokeTopkSoftMax( - ¶ms, beam_group_idx, cu_stream); + invokeTopkSoftMax(¶ms, beam_group_idx, cu_stream); } } updateBeamSearchParams(¶ms, cu_stream); @@ -1492,54 +1492,66 @@ std::vector BeamSearchSoftmax(const paddle::Tensor &logits, } std::vector> BeamSearchSoftmaxShape( - const std::vector &logits, - const std::vector &seq_lens, - const std::vector &stop_flags, // inplace - const std::vector &end_ids, - const std::vector &step_ids, - const std::vector &max_dec_lens, - const std::vector &block_tables, // inplace - const std::vector &cum_scores, // inplace - const std::vector &beam_cache_ids, // inplace - const std::vector &beam_hyps, // inplace - const std::vector &beam_hyps_score, // inplace - const std::vector &beam_finished, // inplace - const std::vector &beam_width, - const std::vector &beam_group_num, - const std::vector &length_penalty, - const std::vector &diversity_penalty) { - std::vector next_tokens = {logits[0],1}; - std::vector parent_ids = {logits[0],1}; - return {next_tokens,parent_ids}; + const std::vector &logits, + const std::vector &seq_lens, + const std::vector &stop_flags, // inplace + const std::vector &end_ids, + const std::vector &step_ids, + const std::vector &max_dec_lens, + const std::vector &block_tables, // inplace + const std::vector &cum_scores, // inplace + const std::vector &beam_cache_ids, // inplace + const std::vector &beam_hyps, // inplace + const std::vector &beam_hyps_score, // inplace + const std::vector &beam_finished, // inplace + const std::vector &beam_width, + const std::vector &beam_group_num, + const std::vector &length_penalty, + const std::vector &diversity_penalty) { + std::vector next_tokens = {logits[0], 1}; + std::vector parent_ids = {logits[0], 1}; + return {next_tokens, parent_ids}; } std::vector BeamSearchSoftmaxDtype( - const paddle::DataType &logits, - const paddle::DataType &seq_lens, - const paddle::DataType &stop_flags, // inplace - const paddle::DataType &end_ids, - const paddle::DataType &step_ids, - const paddle::DataType &max_dec_lens, - const paddle::DataType &block_tables, // inplace - const paddle::DataType &cum_scores, // inplace - const paddle::DataType &beam_cache_ids, // inplace - const paddle::DataType &beam_hyps, // inplace - const paddle::DataType &beam_hyps_score, // inplace - const paddle::DataType &beam_finished, // inplace - const paddle::DataType &beam_width, - const paddle::DataType &beam_group_num, - const paddle::DataType &length_penalty, - const paddle::DataType &diversity_penalty) { - return {paddle::DataType::INT32, paddle::DataType::INT32}; + const paddle::DataType &logits, + const paddle::DataType &seq_lens, + const paddle::DataType &stop_flags, // inplace + const paddle::DataType &end_ids, + const paddle::DataType &step_ids, + const paddle::DataType &max_dec_lens, + const paddle::DataType &block_tables, // inplace + const paddle::DataType &cum_scores, // inplace + const paddle::DataType &beam_cache_ids, // inplace + const paddle::DataType &beam_hyps, // inplace + const paddle::DataType &beam_hyps_score, // inplace + const paddle::DataType &beam_finished, // inplace + const paddle::DataType &beam_width, + const paddle::DataType &beam_group_num, + const paddle::DataType &length_penalty, + const paddle::DataType &diversity_penalty) { + return {paddle::DataType::INT32, paddle::DataType::INT32}; } PD_BUILD_STATIC_OP(beam_search_softmax) - .Inputs({"logits", "seq_lens", "stop_flags", "end_ids", "step_ids", "max_dec_lens", "block_tables" - , "cum_scores", "beam_cache_ids", "beam_hyps", "beam_hyps_score", "beam_finished" - , "beam_width", "beam_group_num", "length_penalty", "diversity_penalty"}) + .Inputs({"logits", + "seq_lens", + "stop_flags", + "end_ids", + "step_ids", + "max_dec_lens", + "block_tables", + "cum_scores", + "beam_cache_ids", + "beam_hyps", + "beam_hyps_score", + "beam_finished", + "beam_width", + "beam_group_num", + "length_penalty", + "diversity_penalty"}) .Outputs({"next_tokens", "parent_ids"}) - .Attrs({"fuse_softmax: bool", - "early_stop: bool"}) + .Attrs({"fuse_softmax: bool", "early_stop: bool"}) .SetKernelFn(PD_KERNEL(BeamSearchSoftmax)) .SetInferShapeFn(PD_INFER_SHAPE(BeamSearchSoftmaxShape)) .SetInferDtypeFn(PD_INFER_DTYPE(BeamSearchSoftmaxDtype)); diff --git a/custom_ops/gpu_ops/common/configManager.h b/custom_ops/gpu_ops/common/configManager.h index d0bb751e97..960df1259b 100644 --- a/custom_ops/gpu_ops/common/configManager.h +++ b/custom_ops/gpu_ops/common/configManager.h @@ -22,87 +22,96 @@ #include class ConfigManager { -public: - static ConfigManager& get_instance(const std::string& config_path = "fastdeploy_op_configs.json") { - static ConfigManager instance(config_path); - return instance; + public: + static ConfigManager& get_instance( + const std::string& config_path = "fastdeploy_op_configs.json") { + static ConfigManager instance(config_path); + return instance; + } + + std::string get_best_config(const std::string& op_name, + const size_t m, + const size_t n, + const size_t k) { + initialize(); + std::string mnk_string = op_name + "-" + std::to_string(update_m(m)) + "x" + + std::to_string(n) + "x" + std::to_string(k); + if (configs_.contains(mnk_string)) { + return configs_.at(mnk_string); } + return ""; + } - std::string get_best_config(const std::string& op_name, const size_t m, const size_t n, const size_t k) { - initialize(); - std::string mnk_string = op_name + "-" + - std::to_string(update_m(m)) + "x" + std::to_string(n) + "x" + std::to_string(k); - if (configs_.contains(mnk_string)) { - return configs_.at(mnk_string); - } - return ""; + int64_t update_m(const size_t m) { + size_t new_m = m; + if (m < 4) { + return m; + } else if (m < 16) { + return (m + 3) / 4 * 4; + } else if (m < 64) { + return (m + 15) / 16 * 16; + } else if (m < 256) { + return (m + 31) / 32 * 32; + } else if (m < 512) { + return (m + 63) / 64 * 64; + } else if (m < 1024) { + return (m + 127) / 128 * 128; + } else if (m < 8192) { + return (m + 1023) / 1024 * 1024; + } else if (m < 32768) { + return (m + 4095) / 4096 * 4096; + } else { + return 32768; } + } - int64_t update_m(const size_t m) { - size_t new_m = m; - if (m < 4) { - return m; - } else if (m < 16) { - return (m + 3) / 4 * 4; - } else if (m < 64) { - return (m + 15) / 16 * 16; - } else if (m < 256) { - return (m + 31) / 32 * 32; - } else if (m < 512) { - return (m + 63) / 64 * 64; - } else if (m < 1024) { - return (m + 127) / 128 * 128; - } else if (m < 8192) { - return (m + 1023) / 1024 * 1024; - } else if (m < 32768) { - return (m + 4095) / 4096 * 4096; - } else { - return 32768; - } + void update(const std::string& op_name, + const size_t m, + const size_t n, + const size_t k, + const std::string& config) { + initialize(); + std::string mnk_string = op_name + "-" + std::to_string(update_m(m)) + "x" + + std::to_string(n) + "x" + std::to_string(k); + configs_[mnk_string] = config; + } + + void print() const { + std::cout << configs_.dump(4) << std::endl; // Pretty print with 4 spaces + } + + ~ConfigManager() { + std::ofstream file(config_path_); + if (file.is_open()) { + file << configs_.dump(4); // Pretty print with 4 spaces + file.close(); } + } - void update(const std::string& op_name, const size_t m, const size_t n, const size_t k, const std::string& config) { - initialize(); - std::string mnk_string = op_name + "-" + - std::to_string(update_m(m)) + "x" + std::to_string(n) + "x" + std::to_string(k); - configs_[mnk_string] = config; + private: + void initialize() { + if (initialized_) return; + std::ifstream file(config_path_); + if (file.is_open()) { + try { + file >> configs_; + } catch (const std::exception& e) { + std::cerr << "Error reading configs from " << config_path_ << " : " + << e.what() << std::endl; + configs_ = nlohmann::json::object(); // Create an empty JSON object + } + file.close(); + } else { + configs_ = nlohmann::json::object(); // Create an empty JSON object } + initialized_ = true; + } - void print() const { - std::cout << configs_.dump(4) << std::endl; // Pretty print with 4 spaces - } + ConfigManager(const std::string& config_path) : config_path_(config_path) {} + ConfigManager(const ConfigManager&) = delete; + ConfigManager& operator=(const ConfigManager&) = delete; - ~ConfigManager() { - std::ofstream file(config_path_); - if (file.is_open()) { - file << configs_.dump(4); // Pretty print with 4 spaces - file.close(); - } - } - -private: - void initialize() { - if (initialized_) return; - std::ifstream file(config_path_); - if (file.is_open()) { - try { - file >> configs_; - } catch (const std::exception& e) { - std::cerr << "Error reading configs from " << config_path_ << " : " << e.what() << std::endl; - configs_ = nlohmann::json::object(); // Create an empty JSON object - } - file.close(); - } else { - configs_ = nlohmann::json::object(); // Create an empty JSON object - } - initialized_ = true; - } - - ConfigManager(const std::string& config_path) : config_path_(config_path) {} - ConfigManager(const ConfigManager&) = delete; - ConfigManager& operator=(const ConfigManager&) = delete; - - nlohmann::json configs_; - std::string config_path_; - bool initialized_{false}; + nlohmann::json configs_; + std::string config_path_; + bool initialized_{false}; }; diff --git a/custom_ops/gpu_ops/common/cudaUtils.h b/custom_ops/gpu_ops/common/cudaUtils.h index 9bbd1f6e80..7123e33ebb 100644 --- a/custom_ops/gpu_ops/common/cudaUtils.h +++ b/custom_ops/gpu_ops/common/cudaUtils.h @@ -16,18 +16,18 @@ #include #include "paddle/phi/core/enforce.h" -namespace common -{ +namespace common { -inline int getSMVersion() -{ - int device{-1}; - PADDLE_ENFORCE_GPU_SUCCESS(cudaGetDevice(&device)); - int sm_major = 0; - int sm_minor = 0; - PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device)); - return sm_major * 10 + sm_minor; +inline int getSMVersion() { + int device{-1}; + PADDLE_ENFORCE_GPU_SUCCESS(cudaGetDevice(&device)); + int sm_major = 0; + int sm_minor = 0; + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceGetAttribute( + &sm_major, cudaDevAttrComputeCapabilityMajor, device)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceGetAttribute( + &sm_minor, cudaDevAttrComputeCapabilityMinor, device)); + return sm_major * 10 + sm_minor; } -} +} // namespace common diff --git a/custom_ops/gpu_ops/common/quantization.h b/custom_ops/gpu_ops/common/quantization.h index e6a74760bb..433e795311 100644 --- a/custom_ops/gpu_ops/common/quantization.h +++ b/custom_ops/gpu_ops/common/quantization.h @@ -20,312 +20,240 @@ #include #include -namespace common -{ +namespace common { -class QuantMode -{ - // [WARNING] KEEP BELOW DEFINITION IN SYNC WITH tensorrt_llm/quantization/mode.py -public: - using BaseType = std::uint32_t; +class QuantMode { + // [WARNING] KEEP BELOW DEFINITION IN SYNC WITH + // tensorrt_llm/quantization/mode.py + public: + using BaseType = std::uint32_t; - explicit constexpr QuantMode(BaseType value) noexcept - : mValue{value} - { + explicit constexpr QuantMode(BaseType value) noexcept : mValue{value} {} + + QuantMode() noexcept = default; + + constexpr QuantMode(QuantMode const&) noexcept = default; + + constexpr QuantMode& operator=(QuantMode const& other) noexcept = default; + + static constexpr QuantMode none() noexcept { return QuantMode(BaseType(0)); } + + static constexpr QuantMode int4Weights() noexcept { + return QuantMode(BaseType(1u) << 0); + } + + static constexpr QuantMode int8Weights() noexcept { + return QuantMode(BaseType(1u) << 1); + } + + static constexpr QuantMode activations() noexcept { + return QuantMode(BaseType(1u) << 2); + } + + static constexpr QuantMode perChannelScaling() noexcept { + return QuantMode(BaseType(1u) << 3); + } + + static constexpr QuantMode perTokenScaling() noexcept { + return QuantMode(BaseType(1u) << 4); + } + + static constexpr QuantMode perGroupScaling() noexcept { + return QuantMode(BaseType(1u) << 5); + } + + static constexpr QuantMode int8KvCache() noexcept { + return QuantMode(BaseType(1u) << 6); + } + + static constexpr QuantMode fp8KvCache() noexcept { + return QuantMode(BaseType(1u) << 7); + } + + static constexpr QuantMode fp8Qdq() noexcept { + return QuantMode(BaseType(1u) << 8); + } + + static constexpr QuantMode fp8RowWise() noexcept { + return QuantMode(BaseType(1u) << 3 | BaseType(1u) << 4 | BaseType(1u) << 9); + } + + constexpr BaseType value() const noexcept { return mValue; } + + constexpr bool isSet(QuantMode const& mode) const noexcept { + return (mValue & mode.value()) == mode.value(); + } + + constexpr bool hasInt4Weights() const noexcept { + return isSet(int4Weights()); + } + + constexpr bool hasInt8Weights() const noexcept { + return isSet(int8Weights()); + } + + constexpr bool hasActivations() const noexcept { + return isSet(activations()); + } + + constexpr bool hasPerChannelScaling() const noexcept { + return isSet(perChannelScaling()); + } + + constexpr bool hasPerTokenScaling() const noexcept { + return isSet(perTokenScaling()); + } + + constexpr bool hasPerGroupScaling() const noexcept { + return isSet(perGroupScaling()); + } + + constexpr bool hasStaticActivationScaling() const noexcept { + return !hasPerTokenScaling(); + } + + constexpr bool hasInt8KvCache() const noexcept { + return isSet(int8KvCache()); + } + + constexpr bool hasFp8KvCache() const noexcept { return isSet(fp8KvCache()); } + + constexpr bool hasFp8Qdq() const noexcept { return isSet(fp8Qdq()); } + + constexpr bool hasFp8RowWise() const noexcept { return isSet(fp8RowWise()); } + + constexpr bool hasKvCacheQuant() const noexcept { + return hasInt8KvCache() || hasFp8KvCache(); + } + + static constexpr QuantMode fromDescription(bool quantizeWeights = false, + bool quantizeActivations = false, + bool perToken = false, + bool perChannel = false, + bool perGroup = false, + bool useInt4Weights = false, + bool useInt8KvCache = false, + bool useFp8KvCache = false, + bool useFp8Qdq = false, + bool useFp8RowWise = false) { + QuantMode quantMode{}; + if (quantizeWeights) { + if (useInt4Weights) + quantMode += int4Weights(); + else + quantMode += int8Weights(); } - QuantMode() noexcept = default; - - constexpr QuantMode(QuantMode const&) noexcept = default; - - constexpr QuantMode& operator=(QuantMode const& other) noexcept = default; - - static constexpr QuantMode none() noexcept - { - return QuantMode(BaseType(0)); + if (quantizeActivations) { + quantMode += activations(); } - static constexpr QuantMode int4Weights() noexcept - { - return QuantMode(BaseType(1u) << 0); + if (perChannel) { + quantMode += QuantMode::perChannelScaling(); + } + if (perToken) { + quantMode += QuantMode::perTokenScaling(); + } + if (perGroup) { + quantMode += QuantMode::perGroupScaling(); } - static constexpr QuantMode int8Weights() noexcept - { - return QuantMode(BaseType(1u) << 1); + if (useInt8KvCache) { + quantMode += int8KvCache(); } - static constexpr QuantMode activations() noexcept - { - return QuantMode(BaseType(1u) << 2); + if (useFp8KvCache) { + quantMode += fp8KvCache(); } - static constexpr QuantMode perChannelScaling() noexcept - { - return QuantMode(BaseType(1u) << 3); + if (useFp8Qdq) { + quantMode += fp8Qdq(); } - static constexpr QuantMode perTokenScaling() noexcept - { - return QuantMode(BaseType(1u) << 4); + if (useFp8RowWise) { + quantMode += fp8RowWise(); } - static constexpr QuantMode perGroupScaling() noexcept - { - return QuantMode(BaseType(1u) << 5); + return quantMode; + } + + static constexpr QuantMode useSmoothQuant(bool perToken = false, + bool perChannel = false) { + return fromDescription(true, true, perToken, perChannel); + } + + static constexpr QuantMode useWeightOnly(bool useInt4Weights = false, + bool perGroup = false) { + return fromDescription(true, false, false, false, perGroup, useInt4Weights); + } + + static const QuantMode fromQuantAlgo( + std::optional quantAlgo = std::nullopt, + std::optional kvCacheQuantAlgo = std::nullopt) { + QuantMode quantMode{}; + if (quantAlgo == "W8A16") { + quantMode = useWeightOnly(false, false); + } else if (quantAlgo == "W4A16") { + quantMode = useWeightOnly(true, false); + } else if (quantAlgo == "W4A16_AWQ") { + quantMode = useWeightOnly(true, true); + } else if (quantAlgo == "W4A8_AWQ") { + quantMode = useWeightOnly(true, true); + } else if (quantAlgo == "W4A16_GPTQ") { + quantMode = useWeightOnly(true, true); + } else if (quantAlgo == "W8A8_SQ_PER_CHANNEL") { + quantMode = useSmoothQuant(false, true); + } else if (quantAlgo == "W8A8_SQ_PER_TENSOR_PLUGIN") { + quantMode = useSmoothQuant(false, false); + } else if (quantAlgo == "W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN") { + quantMode = useSmoothQuant(true, true); + } else if (quantAlgo == "W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN") { + quantMode = useSmoothQuant(false, true); + } else if (quantAlgo == "W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN") { + quantMode = useSmoothQuant(true, false); + } else if (quantAlgo == "FP8") { + quantMode = fromDescription( + false, false, false, false, false, false, false, false, true); + } else if (quantAlgo == "FP8_ROWWISE") { + quantMode = fromDescription( + false, false, true, true, false, false, false, false, false, true); } - static constexpr QuantMode int8KvCache() noexcept - { - return QuantMode(BaseType(1u) << 6); + if (kvCacheQuantAlgo == "INT8") { + quantMode += int8KvCache(); + } else if (kvCacheQuantAlgo == "FP8") { + quantMode += fp8KvCache(); } - static constexpr QuantMode fp8KvCache() noexcept - { - return QuantMode(BaseType(1u) << 7); - } + return quantMode; + } - static constexpr QuantMode fp8Qdq() noexcept - { - return QuantMode(BaseType(1u) << 8); - } + constexpr QuantMode operator+(QuantMode const& other) const noexcept { + return QuantMode(mValue | other.mValue); + } - static constexpr QuantMode fp8RowWise() noexcept - { - return QuantMode(BaseType(1u) << 3 | BaseType(1u) << 4 | BaseType(1u) << 9); - } + constexpr QuantMode& operator+=(QuantMode const& other) noexcept { + return *this = *this + other; + } - constexpr BaseType value() const noexcept - { - return mValue; - } + constexpr QuantMode operator-(QuantMode const& other) const noexcept { + return QuantMode(mValue & ~other.mValue); + } - constexpr bool isSet(QuantMode const& mode) const noexcept - { - return (mValue & mode.value()) == mode.value(); - } + constexpr QuantMode& operator-=(QuantMode const& other) noexcept { + return *this = *this - other; + } - constexpr bool hasInt4Weights() const noexcept - { - return isSet(int4Weights()); - } + constexpr bool operator==(QuantMode const& other) const noexcept { + return mValue == other.mValue; + } - constexpr bool hasInt8Weights() const noexcept - { - return isSet(int8Weights()); - } + constexpr bool operator!=(QuantMode const& other) const noexcept { + return !(*this == other); + } - constexpr bool hasActivations() const noexcept - { - return isSet(activations()); - } - - constexpr bool hasPerChannelScaling() const noexcept - { - return isSet(perChannelScaling()); - } - - constexpr bool hasPerTokenScaling() const noexcept - { - return isSet(perTokenScaling()); - } - - constexpr bool hasPerGroupScaling() const noexcept - { - return isSet(perGroupScaling()); - } - - constexpr bool hasStaticActivationScaling() const noexcept - { - return !hasPerTokenScaling(); - } - - constexpr bool hasInt8KvCache() const noexcept - { - return isSet(int8KvCache()); - } - - constexpr bool hasFp8KvCache() const noexcept - { - return isSet(fp8KvCache()); - } - - constexpr bool hasFp8Qdq() const noexcept - { - return isSet(fp8Qdq()); - } - - constexpr bool hasFp8RowWise() const noexcept - { - return isSet(fp8RowWise()); - } - - constexpr bool hasKvCacheQuant() const noexcept - { - return hasInt8KvCache() || hasFp8KvCache(); - } - - static constexpr QuantMode fromDescription(bool quantizeWeights = false, bool quantizeActivations = false, - bool perToken = false, bool perChannel = false, bool perGroup = false, bool useInt4Weights = false, - bool useInt8KvCache = false, bool useFp8KvCache = false, bool useFp8Qdq = false, bool useFp8RowWise = false) - { - QuantMode quantMode{}; - if (quantizeWeights) - { - if (useInt4Weights) - quantMode += int4Weights(); - else - quantMode += int8Weights(); - } - - if (quantizeActivations) - { - quantMode += activations(); - } - - if (perChannel) - { - quantMode += QuantMode::perChannelScaling(); - } - if (perToken) - { - quantMode += QuantMode::perTokenScaling(); - } - if (perGroup) - { - quantMode += QuantMode::perGroupScaling(); - } - - if (useInt8KvCache) - { - quantMode += int8KvCache(); - } - - if (useFp8KvCache) - { - quantMode += fp8KvCache(); - } - - if (useFp8Qdq) - { - quantMode += fp8Qdq(); - } - - if (useFp8RowWise) - { - quantMode += fp8RowWise(); - } - - return quantMode; - } - - static constexpr QuantMode useSmoothQuant(bool perToken = false, bool perChannel = false) - { - return fromDescription(true, true, perToken, perChannel); - } - - static constexpr QuantMode useWeightOnly(bool useInt4Weights = false, bool perGroup = false) - { - return fromDescription(true, false, false, false, perGroup, useInt4Weights); - } - - static const QuantMode fromQuantAlgo( - std::optional quantAlgo = std::nullopt, std::optional kvCacheQuantAlgo = std::nullopt) - { - QuantMode quantMode{}; - if (quantAlgo == "W8A16") - { - quantMode = useWeightOnly(false, false); - } - else if (quantAlgo == "W4A16") - { - quantMode = useWeightOnly(true, false); - } - else if (quantAlgo == "W4A16_AWQ") - { - quantMode = useWeightOnly(true, true); - } - else if (quantAlgo == "W4A8_AWQ") - { - quantMode = useWeightOnly(true, true); - } - else if (quantAlgo == "W4A16_GPTQ") - { - quantMode = useWeightOnly(true, true); - } - else if (quantAlgo == "W8A8_SQ_PER_CHANNEL") - { - quantMode = useSmoothQuant(false, true); - } - else if (quantAlgo == "W8A8_SQ_PER_TENSOR_PLUGIN") - { - quantMode = useSmoothQuant(false, false); - } - else if (quantAlgo == "W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN") - { - quantMode = useSmoothQuant(true, true); - } - else if (quantAlgo == "W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN") - { - quantMode = useSmoothQuant(false, true); - } - else if (quantAlgo == "W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN") - { - quantMode = useSmoothQuant(true, false); - } - else if (quantAlgo == "FP8") - { - quantMode = fromDescription(false, false, false, false, false, false, false, false, true); - } - else if (quantAlgo == "FP8_ROWWISE") - { - quantMode = fromDescription(false, false, true, true, false, false, false, false, false, true); - } - - if (kvCacheQuantAlgo == "INT8") - { - quantMode += int8KvCache(); - } - else if (kvCacheQuantAlgo == "FP8") - { - quantMode += fp8KvCache(); - } - - return quantMode; - } - - constexpr QuantMode operator+(QuantMode const& other) const noexcept - { - return QuantMode(mValue | other.mValue); - } - - constexpr QuantMode& operator+=(QuantMode const& other) noexcept - { - return *this = *this + other; - } - - constexpr QuantMode operator-(QuantMode const& other) const noexcept - { - return QuantMode(mValue & ~other.mValue); - } - - constexpr QuantMode& operator-=(QuantMode const& other) noexcept - { - return *this = *this - other; - } - - constexpr bool operator==(QuantMode const& other) const noexcept - { - return mValue == other.mValue; - } - - constexpr bool operator!=(QuantMode const& other) const noexcept - { - return !(*this == other); - } - -private: - BaseType mValue{0}; + private: + BaseType mValue{0}; }; -} // namespace common +} // namespace common diff --git a/custom_ops/gpu_ops/cuda_multiprocess.h b/custom_ops/gpu_ops/cuda_multiprocess.h index c4b3c84109..c8a138e133 100644 --- a/custom_ops/gpu_ops/cuda_multiprocess.h +++ b/custom_ops/gpu_ops/cuda_multiprocess.h @@ -52,35 +52,36 @@ namespace cub = hipcub; #define GPU(str) cuda##str #endif -#define checkCudaErrors(call) \ - do { \ - GPU(Error_t) err = call; \ - if (err != GPU(Success)) { \ - printf("CUDA error at %s %d: %s\n", \ - __FILE__, \ - __LINE__, \ - GPU(GetErrorString)(err)); \ - exit(EXIT_FAILURE); \ - } \ - } while (0) +#define checkCudaErrors(call) \ + do { \ + GPU(Error_t) err = call; \ + if (err != GPU(Success)) { \ + printf("CUDA error at %s %d: %s\n", \ + __FILE__, \ + __LINE__, \ + GPU(GetErrorString)(err)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) typedef struct shmStruct_st { - size_t nprocesses; - GPU(IpcMemHandle_t) memHandle; + size_t nprocesses; + GPU(IpcMemHandle_t) memHandle; } shmStruct; typedef struct sharedMemoryInfo_st { - void *addr; - size_t size; + void *addr; + size_t size; #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) - HANDLE shmHandle; + HANDLE shmHandle; #else - int shmFd; + int shmFd; #endif } sharedMemoryInfo; - -inline int sharedMemoryOpen(const char *name, size_t sz, sharedMemoryInfo *info) { +inline int sharedMemoryOpen(const char *name, + size_t sz, + sharedMemoryInfo *info) { #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) info->size = sz; diff --git a/custom_ops/gpu_ops/cutlass_extensions/arch/copy_red_global.hpp b/custom_ops/gpu_ops/cutlass_extensions/arch/copy_red_global.hpp index 61a41031bf..4b99d9651b 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/arch/copy_red_global.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/arch/copy_red_global.hpp @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once @@ -38,315 +39,331 @@ // Config -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDACC_VER_MAJOR__ >= 10)) +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && \ + (__CUDACC_VER_MAJOR__ >= 10)) #define CUTE_ARCH_RED_F16_SM70_ENABLED #endif -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && \ + (__CUDACC_VER_MAJOR__ >= 12)) #define CUTE_ARCH_RED_VEC_SM90_ENABLED #define CUTE_ARCH_RED_BF16_SM90_ENABLED #endif -namespace cute -{ +namespace cute { ////////////////////////////////// // Wrapper around CUDA's atomicAdd ////////////////////////////////// template -struct TypedAtomicAdd -{ - using SRegisters = T[1]; - using DRegisters = T[1]; +struct TypedAtomicAdd { + using SRegisters = T[1]; + using DRegisters = T[1]; - CUTE_HOST_DEVICE static constexpr void copy(T const& src, T& dst) - { - atomicAdd(&dst, src); - } + CUTE_HOST_DEVICE static constexpr void copy(T const& src, T& dst) { + atomicAdd(&dst, src); + } }; template -struct Copy_Traits> -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; +struct Copy_Traits> { + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout::value>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout::value>>>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout::value>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout::value>>>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; }; ////////////////////////////////// // F16 ADD PTX ////////////////////////////////// -struct SM70_RED_ADD_NOFTZ_F16 -{ - using SRegisters = uint16_t[1]; - using DRegisters = uint16_t[1]; +struct SM70_RED_ADD_NOFTZ_F16 { + using SRegisters = uint16_t[1]; + using DRegisters = uint16_t[1]; - CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst) - { + CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst) { #if defined(CUTE_ARCH_RED_F16_SM70_ENABLED) - asm volatile("red.global.add.noftz.f16 [%0], %1;\n" ::"l"(&gmem_dst), "h"(src0)); + asm volatile("red.global.add.noftz.f16 [%0], %1;\n" ::"l"(&gmem_dst), + "h"(src0)); #else - CUTE_INVALID_CONTROL_PATH("Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED."); + CUTE_INVALID_CONTROL_PATH( + "Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED."); #endif - } + } }; template <> -struct Copy_Traits -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; +struct Copy_Traits { + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; }; -struct SM70_RED_ADD_NOFTZ_F16x2 -{ - using SRegisters = uint32_t[1]; - using DRegisters = uint32_t[1]; +struct SM70_RED_ADD_NOFTZ_F16x2 { + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; - CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst) - { + CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst) { #if defined(CUTE_ARCH_RED_F16_SM70_ENABLED) - asm volatile("red.global.add.noftz.f16x2 [%0], %1;\n" ::"l"(&gmem_dst), "r"(src0)); + asm volatile("red.global.add.noftz.f16x2 [%0], %1;\n" ::"l"(&gmem_dst), + "r"(src0)); #else - CUTE_INVALID_CONTROL_PATH("Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED."); + CUTE_INVALID_CONTROL_PATH( + "Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED."); #endif - } + } }; template <> -struct Copy_Traits -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; +struct Copy_Traits { + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; }; -struct SM90_RED_ADD_NOFTZ_F16x2_V2 -{ - using SRegisters = uint32_t[2]; - using DRegisters = uint64_t[1]; +struct SM90_RED_ADD_NOFTZ_F16x2_V2 { + using SRegisters = uint32_t[2]; + using DRegisters = uint64_t[1]; - CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t const& src1, uint64_t& gmem_dst) - { + CUTE_HOST_DEVICE static void copy(uint32_t const& src0, + uint32_t const& src1, + uint64_t& gmem_dst) { #if defined(CUTE_ARCH_RED_VEC_SM90_ENABLED) - asm volatile("red.global.add.noftz.v2.f16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1)); + asm volatile( + "red.global.add.noftz.v2.f16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), + "r"(src0), + "r"(src1)); #else - CUTE_INVALID_CONTROL_PATH("Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED."); + CUTE_INVALID_CONTROL_PATH( + "Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED."); #endif - } + } }; template <> -struct Copy_Traits -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; +struct Copy_Traits { + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; }; -struct SM90_RED_ADD_NOFTZ_F16x2_V4 -{ - using SRegisters = uint32_t[4]; - using DRegisters = uint128_t[1]; +struct SM90_RED_ADD_NOFTZ_F16x2_V4 { + using SRegisters = uint32_t[4]; + using DRegisters = uint128_t[1]; - CUTE_HOST_DEVICE static void copy( - uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, uint128_t& gmem_dst) - { + CUTE_HOST_DEVICE static void copy(uint32_t const& src0, + uint32_t const& src1, + uint32_t const& src2, + uint32_t const& src3, + uint128_t& gmem_dst) { #if defined(CUTE_ARCH_RED_VEC_SM90_ENABLED) - asm volatile("red.global.add.noftz.v4.f16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1), - "r"(src2), "r"(src3)); + asm volatile( + "red.global.add.noftz.v4.f16x2 [%0], {%1, %2, %3, %4};\n" ::"l"( + &gmem_dst), + "r"(src0), + "r"(src1), + "r"(src2), + "r"(src3)); #else - CUTE_INVALID_CONTROL_PATH("Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED."); + CUTE_INVALID_CONTROL_PATH( + "Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED."); #endif - } + } }; template <> -struct Copy_Traits -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; +struct Copy_Traits { + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; }; ////////////////////////////////// // BF16 ADD PTX ////////////////////////////////// -struct SM90_RED_ADD_NOFTZ_BF16 -{ - using SRegisters = uint16_t[1]; - using DRegisters = uint16_t[1]; +struct SM90_RED_ADD_NOFTZ_BF16 { + using SRegisters = uint16_t[1]; + using DRegisters = uint16_t[1]; - CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst) - { + CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst) { #if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED) - asm volatile("red.global.add.noftz.bf16 [%0], %1;\n" ::"l"(&gmem_dst), "h"(src0)); + asm volatile("red.global.add.noftz.bf16 [%0], %1;\n" ::"l"(&gmem_dst), + "h"(src0)); #else - CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED."); + CUTE_INVALID_CONTROL_PATH( + "Trying to use red.global.bf16 without " + "CUTE_ARCH_RED_BF16_SM90_ENABLED."); #endif - } + } }; template <> -struct Copy_Traits -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; +struct Copy_Traits { + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; }; ////////////////////////////////// -struct SM90_RED_ADD_NOFTZ_BF16x2 -{ - using SRegisters = uint32_t[1]; - using DRegisters = uint32_t[1]; +struct SM90_RED_ADD_NOFTZ_BF16x2 { + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; - CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst) - { + CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst) { #if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED) - asm volatile("red.global.add.noftz.bf16x2 [%0], %1;\n" ::"l"(&gmem_dst), "r"(src0)); + asm volatile("red.global.add.noftz.bf16x2 [%0], %1;\n" ::"l"(&gmem_dst), + "r"(src0)); #else - CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED."); + CUTE_INVALID_CONTROL_PATH( + "Trying to use red.global.bf16 without " + "CUTE_ARCH_RED_BF16_SM90_ENABLED."); #endif - } + } }; template <> -struct Copy_Traits -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; +struct Copy_Traits { + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; }; ////////////////////////////////// -struct SM90_RED_ADD_NOFTZ_BF16x2_V2 -{ - using SRegisters = uint32_t[2]; - using DRegisters = uint64_t[1]; +struct SM90_RED_ADD_NOFTZ_BF16x2_V2 { + using SRegisters = uint32_t[2]; + using DRegisters = uint64_t[1]; - CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t const& src1, uint64_t& gmem_dst) - { + CUTE_HOST_DEVICE static void copy(uint32_t const& src0, + uint32_t const& src1, + uint64_t& gmem_dst) { #if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED) - asm volatile("red.global.add.noftz.v2.bf16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1)); + asm volatile( + "red.global.add.noftz.v2.bf16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), + "r"(src0), + "r"(src1)); #else - CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED."); + CUTE_INVALID_CONTROL_PATH( + "Trying to use red.global.bf16 without " + "CUTE_ARCH_RED_BF16_SM90_ENABLED."); #endif - } + } }; template <> -struct Copy_Traits -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; +struct Copy_Traits { + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; }; ////////////////////////////////// -struct SM90_RED_ADD_NOFTZ_BF16x2_V4 -{ - using SRegisters = uint32_t[4]; - using DRegisters = uint128_t[1]; +struct SM90_RED_ADD_NOFTZ_BF16x2_V4 { + using SRegisters = uint32_t[4]; + using DRegisters = uint128_t[1]; - CUTE_HOST_DEVICE static void copy( - uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, uint128_t& gmem_dst) - { + CUTE_HOST_DEVICE static void copy(uint32_t const& src0, + uint32_t const& src1, + uint32_t const& src2, + uint32_t const& src3, + uint128_t& gmem_dst) { #if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED) - asm volatile("red.global.add.noftz.v4.bf16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1), - "r"(src2), "r"(src3)); + asm volatile( + "red.global.add.noftz.v4.bf16x2 [%0], {%1, %2, %3, %4};\n" ::"l"( + &gmem_dst), + "r"(src0), + "r"(src1), + "r"(src2), + "r"(src3)); #else - CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED."); + CUTE_INVALID_CONTROL_PATH( + "Trying to use red.global.bf16 without " + "CUTE_ARCH_RED_BF16_SM90_ENABLED."); #endif - } + } }; template <> -struct Copy_Traits -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; +struct Copy_Traits { + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; }; ////////////////////////////////// -} // end namespace cute +} // end namespace cute diff --git a/custom_ops/gpu_ops/cutlass_extensions/arch/memory_copy_sm80.h b/custom_ops/gpu_ops/cutlass_extensions/arch/memory_copy_sm80.h index a9975c0138..0e9fa1b151 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/arch/memory_copy_sm80.h +++ b/custom_ops/gpu_ops/cutlass_extensions/arch/memory_copy_sm80.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ @@ -59,8 +60,9 @@ template < bool GlobalToShared = true> struct copy; -/// Initiates an asynchronous copy from global memory to shared memory. Rather than predicate -/// the entire transfer, zeros are written to SMEM if the guard predicate is false. +/// Initiates an asynchronous copy from global memory to shared memory. Rather +/// than predicate the entire transfer, zeros are written to SMEM if the guard +/// predicate is false. /// /// cp.async /// @@ -72,7 +74,8 @@ template < bool GlobalToShared = true> struct copy_zfill; -/// Blocks until all but previous cp.async.commit_group operations have committed. +/// Blocks until all but previous cp.async.commit_group operations have +/// committed. /// /// cp.async /// @@ -86,11 +89,11 @@ template < /// Size of the access in bytes int SizeInBytes> struct copy { - /// Copy CUTLASS_DEVICE copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { - cp_async(smem_ptr, global_ptr, pred_guard); + cp_async( + smem_ptr, global_ptr, pred_guard); } }; @@ -99,15 +102,15 @@ template < /// Size of the access in bytes int SizeInBytes> struct copy { - /// Copy CUTLASS_DEVICE copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { - using AccessType = Array; + using AccessType = Array; - if (pred_guard) { - *static_cast(smem_ptr) = *static_cast(global_ptr); - } + if (pred_guard) { + *static_cast(smem_ptr) = + *static_cast(global_ptr); + } } }; @@ -116,11 +119,11 @@ template < /// Size of the access in bytes int SizeInBytes> struct copy_zfill { - /// Copy with zero fill CUTLASS_DEVICE copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) { - cp_async_zfill(smem_ptr, global_ptr, pred_guard); + cp_async_zfill( + smem_ptr, global_ptr, pred_guard); } }; @@ -129,20 +132,19 @@ template < /// Size of the access in bytes int SizeInBytes> struct copy_zfill { - /// Copy with zero fill CUTLASS_DEVICE copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) { - using AccessType = Array; + using AccessType = Array; - if (pred_guard) { - *static_cast(smem_ptr) = *static_cast(global_ptr); - } - else { - AccessType zeros; - zeros.clear(); - *static_cast(smem_ptr) = zeros; - } + if (pred_guard) { + *static_cast(smem_ptr) = + *static_cast(global_ptr); + } else { + AccessType zeros; + zeros.clear(); + *static_cast(smem_ptr) = zeros; + } } }; @@ -153,11 +155,11 @@ template < /// Size of the access in bytes int SizeInBytes> struct copy { - /// Copy CUTLASS_DEVICE copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { - cp_async(smem_ptr, global_ptr, pred_guard); + cp_async( + smem_ptr, global_ptr, pred_guard); } }; @@ -166,15 +168,15 @@ template < /// Size of the access in bytes int SizeInBytes> struct copy { - /// Copy CUTLASS_DEVICE copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { - using AccessType = Array; + using AccessType = Array; - if (pred_guard) { - *static_cast(smem_ptr) = *static_cast(global_ptr); - } + if (pred_guard) { + *static_cast(smem_ptr) = + *static_cast(global_ptr); + } } }; @@ -183,11 +185,11 @@ template < /// Size of the access in bytes int SizeInBytes> struct copy_zfill { - /// Copy with zero fill CUTLASS_DEVICE copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { - cp_async_zfill(smem_ptr, global_ptr, pred_guard); + cp_async_zfill( + smem_ptr, global_ptr, pred_guard); } }; @@ -196,31 +198,29 @@ template < /// Size of the access in bytes int SizeInBytes> struct copy_zfill { - /// Copy with zero fill CUTLASS_DEVICE copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { - using AccessType = Array; + using AccessType = Array; - if (pred_guard) { - *static_cast(smem_ptr) = *static_cast(global_ptr); - } - else { - AccessType zeros; - zeros.clear(); - *static_cast(smem_ptr) = zeros; - } + if (pred_guard) { + *static_cast(smem_ptr) = + *static_cast(global_ptr); + } else { + AccessType zeros; + zeros.clear(); + *static_cast(smem_ptr) = zeros; + } } }; -/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block. +/// Establishes an ordering w.r.t previously issued cp.async instructions. Does +/// not block. template -CUTLASS_DEVICE -void copy_fence() {} +CUTLASS_DEVICE void copy_fence() {} template <> -CUTLASS_DEVICE -void copy_fence() { +CUTLASS_DEVICE void copy_fence() { cp_async_fence(); } @@ -229,7 +229,6 @@ void copy_fence() { /// Partial specialization template struct copy_wait { - CUTLASS_DEVICE copy_wait() {} }; @@ -237,7 +236,6 @@ struct copy_wait { /// Partial specialization template struct copy_wait { - CUTLASS_DEVICE copy_wait() { cp_async_wait(); } }; diff --git a/custom_ops/gpu_ops/cutlass_extensions/arch/mma.h b/custom_ops/gpu_ops/cutlass_extensions/arch/mma.h index 2362da4f7f..2ab2981518 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/arch/mma.h +++ b/custom_ops/gpu_ops/cutlass_extensions/arch/mma.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file @@ -37,10 +38,8 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace arch -{ +namespace cutlass { +namespace arch { // Tag which triggers MMA which will trigger struct OpMultiplyAddDequantizeInterleavedBToA; @@ -52,8 +51,8 @@ struct OpMultiplyAddDequantizeInterleavedBToA; split out the template below into OpMultiplyAddDequantizeInterleavedBToA along with the quantization op before instantiating the GEMM pieces. - Note that this is somewhat of a hack, but it SIGNIFICANTLY reduces the amount of - code we need to duplicate. + Note that this is somewhat of a hack, but it SIGNIFICANTLY reduces the amount + of code we need to duplicate. */ struct OpMultiplyAddDequantizeInterleavedBToA_percol_scale; struct OpMultiplyAddDequantizeInterleavedBToA_fine_scale; @@ -61,60 +60,59 @@ struct OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias; // The default just forwards the original operator template -struct TagOperator -{ - using TaggedOperator = MmaOp; +struct TagOperator { + using TaggedOperator = MmaOp; }; // Specializations below attach more information to the operator template <> -struct TagOperator -{ - using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale; +struct TagOperator { + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale; }; template <> -struct TagOperator -{ - using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scale; +struct TagOperator { + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scale; }; template <> -struct TagOperator -{ - using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias; +struct TagOperator { + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias; }; -// Here we instantiate some structs to "detag" the tagged operator. It splits it back to the original -// operator + the extra information. If no extra info was tagged, the dequant op per column scaling -// as a default. +// Here we instantiate some structs to "detag" the tagged operator. It splits it +// back to the original operator + the extra information. If no extra info was +// tagged, the dequant op per column scaling as a default. template -struct DetagOperator -{ - using Operator = TaggedMmaOp; - static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY; +struct DetagOperator { + using Operator = TaggedMmaOp; + static constexpr WeightOnlyQuantOp QuantOp = + WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY; }; template <> -struct DetagOperator -{ - using Operator = OpMultiplyAddDequantizeInterleavedBToA; - static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY; +struct DetagOperator { + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = + WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY; }; template <> -struct DetagOperator -{ - using Operator = OpMultiplyAddDequantizeInterleavedBToA; - static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; +struct DetagOperator { + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = + WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; }; template <> -struct DetagOperator -{ - using Operator = OpMultiplyAddDequantizeInterleavedBToA; - static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; +struct DetagOperator { + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = + WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; }; -} // namespace arch -} // namespace cutlass +} // namespace arch +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/compute_occupancy.h b/custom_ops/gpu_ops/cutlass_extensions/compute_occupancy.h index 29ee976691..7f7cce414e 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/compute_occupancy.h +++ b/custom_ops/gpu_ops/cutlass_extensions/compute_occupancy.h @@ -20,66 +20,65 @@ #include "cutlass/device_kernel.h" #include "common/cudaUtils.h" -namespace cutlass_extensions -{ +namespace cutlass_extensions { template -inline int compute_occupancy_for_kernel() -{ +inline int compute_occupancy_for_kernel() { + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); - int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); - - if (smem_size > (48 << 10)) - { - cudaFuncAttributes attr; - int device = 0; - int max_smem_per_block = 0; - PADDLE_ENFORCE_GPU_SUCCESS(cudaGetDevice(&device)); - PADDLE_ENFORCE_GPU_SUCCESS( - cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); - if constexpr (enable_cutlass_3x) - { - PADDLE_ENFORCE_GPU_SUCCESS(cudaFuncGetAttributes(&attr, cutlass::device_kernel)); - } - else - { - PADDLE_ENFORCE_GPU_SUCCESS(cudaFuncGetAttributes(&attr, cutlass::Kernel)); - } - if (smem_size + attr.sharedSizeBytes >= static_cast(max_smem_per_block)) - { - // This should mean that - // cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) - // wouldn't work. In that case, we return an occupancy of 0. This will cause the heuristic to ignore this - // configuration. - return 0; - } - - if constexpr (enable_cutlass_3x) - { - PADDLE_ENFORCE_GPU_SUCCESS(cudaFuncSetAttribute( - cutlass::device_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - else - { - PADDLE_ENFORCE_GPU_SUCCESS(cudaFuncSetAttribute( - cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } + if (smem_size > (48 << 10)) { + cudaFuncAttributes attr; + int device = 0; + int max_smem_per_block = 0; + PADDLE_ENFORCE_GPU_SUCCESS(cudaGetDevice(&device)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); + if constexpr (enable_cutlass_3x) { + PADDLE_ENFORCE_GPU_SUCCESS( + cudaFuncGetAttributes(&attr, cutlass::device_kernel)); + } else { + PADDLE_ENFORCE_GPU_SUCCESS( + cudaFuncGetAttributes(&attr, cutlass::Kernel)); + } + if (smem_size + attr.sharedSizeBytes >= + static_cast(max_smem_per_block)) { + // This should mean that + // cudaFuncSetAttribute(cutlass::Kernel, + // cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) wouldn't work. + // In that case, we return an occupancy of 0. This will cause the + // heuristic to ignore this configuration. + return 0; } - int max_active_blocks = -1; - if constexpr (enable_cutlass_3x) - { - PADDLE_ENFORCE_GPU_SUCCESS( - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, cutlass::device_kernel, - 128 * (GemmKernel::NumLoadWarpGroups + GemmKernel::NumMmaWarpGroups), smem_size)); - } - else - { - PADDLE_ENFORCE_GPU_SUCCESS(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, cutlass::Kernel, GemmKernel::kThreadCount, smem_size)); + if constexpr (enable_cutlass_3x) { + PADDLE_ENFORCE_GPU_SUCCESS( + cudaFuncSetAttribute(cutlass::device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + } else { + PADDLE_ENFORCE_GPU_SUCCESS( + cudaFuncSetAttribute(cutlass::Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); } + } - return max_active_blocks; + int max_active_blocks = -1; + if constexpr (enable_cutlass_3x) { + PADDLE_ENFORCE_GPU_SUCCESS(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + cutlass::device_kernel, + 128 * (GemmKernel::NumLoadWarpGroups + GemmKernel::NumMmaWarpGroups), + smem_size)); + } else { + PADDLE_ENFORCE_GPU_SUCCESS(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + cutlass::Kernel, + GemmKernel::kThreadCount, + smem_size)); + } + + return max_active_blocks; } -} // namespace cutlass_extensions +} // namespace cutlass_extensions diff --git a/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp b/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp index 3e5aa4b038..60c0dc6302 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp @@ -47,7 +47,8 @@ // breaks when moving scales to the CPU. // -// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp +// adapted from: +// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp #pragma once diff --git a/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp b/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp index fa1df1fb1e..dbc6a718f3 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp @@ -47,7 +47,8 @@ // breaks when moving scales to the CPU. // -// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp +// adapted from: +// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp #pragma once diff --git a/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp b/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp index 7b56c3c1ac..aee2fbaeb2 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp @@ -47,7 +47,8 @@ // breaks when moving scales to the CPU. // -// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp +// adapted from: +// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp #pragma once diff --git a/custom_ops/gpu_ops/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp b/custom_ops/gpu_ops/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp index 513d3741fb..42db1bd158 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp @@ -1,4 +1,5 @@ -// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp +// adapted from: +// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp #pragma once @@ -24,31 +25,41 @@ using namespace cute; */ template struct ScaledEpilogueBase { -protected: + protected: using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; template using ColOrScalarLoad = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< - OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; + OutputTileThreadMap, + T, + Stride, Int<0>, Int<0>>>; template using RowOrScalarLoad = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< - OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + OutputTileThreadMap, + T, + Stride, Int<1>, Int<0>>>; template using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast< - OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; + OutputTileThreadMap, + T, + Stride, Int<0>, Int<0>>>; template using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast< - OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + OutputTileThreadMap, + T, + Stride, Int<1>, Int<0>>>; template using RowOrZeroLoad = cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast< - OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + OutputTileThreadMap, + T, + Stride, Int<1>, Int<0>>>; // This utility function constructs the arguments for the load descriptors // from a tensor. It can handle both row and column, as well as row/column or @@ -56,15 +67,11 @@ protected: template static auto args_from_tensor(paddle::Tensor const &tensor) { using Arguments = typename Descriptor::Arguments; - auto *data_ptr = static_cast(const_cast( - tensor.data())); - if constexpr (std::is_same_v> || - std::is_same_v>) { + auto *data_ptr = static_cast(const_cast(tensor.data())); + if constexpr (std::is_same_v> || + std::is_same_v>) { return Arguments{data_ptr, tensor.numel() != 1}; - } - else { + } else { // it would technically work but no use case as data_ptr is never nullptr static_assert(!std::is_same_v>); return Arguments{data_ptr}; @@ -102,24 +109,28 @@ protected: template struct ScaledEpilogue : private ScaledEpilogueBase { -private: + private: using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; using ScaleB = typename SUPER::template RowOrScalarLoad; using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, + cutlass::multiplies, + float, + float, cutlass::FloatRoundStyle::round_to_nearest>; using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT; using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, ElementD, float, + cutlass::multiplies, + ElementD, + float, cutlass::FloatRoundStyle::round_to_nearest>; -public: + public: using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT; using ArgumentType = typename EVTCompute::Arguments; @@ -146,26 +157,30 @@ public: template struct ScaledEpilogueBias : protected ScaledEpilogueBase { -protected: + protected: using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; using ScaleB = typename SUPER::template RowOrScalarLoad; using Bias = typename SUPER::template RowLoad; using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, + cutlass::multiplies, + float, + float, cutlass::FloatRoundStyle::round_to_nearest>; using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT; using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiply_add, ElementD, float, + cutlass::multiply_add, + ElementD, + float, cutlass::FloatRoundStyle::round_to_nearest>; -public: - using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT; + public: + using EVTCompute = cutlass::epilogue::threadblock:: + Sm80EVT; using ArgumentType = typename EVTCompute::Arguments; static ArgumentType prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales, @@ -190,7 +205,7 @@ public: template struct ScaledEpilogueBiasAzp : protected ScaledEpilogueBase { -private: + private: using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; @@ -202,35 +217,40 @@ private: // Compute float(accum - azp_adj), both operands are int32_t using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::minus, float, int32_t, + cutlass::minus, + float, + int32_t, cutlass::FloatRoundStyle::round_to_nearest>; using EVTComputeAzp = cutlass::epilogue::threadblock::Sm80EVT; using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, + cutlass::multiplies, + float, + float, cutlass::FloatRoundStyle::round_to_nearest>; - using EVTComputeScaleB = - cutlass::epilogue::threadblock::Sm80EVT; + using EVTComputeScaleB = cutlass::epilogue::threadblock:: + Sm80EVT; using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiply_add, ElementD, float, + cutlass::multiply_add, + ElementD, + float, cutlass::FloatRoundStyle::round_to_nearest>; -public: - using EVTCompute = - cutlass::epilogue::threadblock::Sm80EVT; + public: + using EVTCompute = cutlass::epilogue::threadblock:: + Sm80EVT; using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType - prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales, - paddle::Tensor const &azp_adj, - paddle::optional const &bias) { + static ArgumentType prepare_args( + paddle::Tensor const &a_scales, + paddle::Tensor const &b_scales, + paddle::Tensor const &azp_adj, + paddle::optional const &bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); @@ -257,7 +277,7 @@ public: template struct ScaledEpilogueBiasAzpToken : protected ScaledEpilogueBase { -private: + private: using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; @@ -272,7 +292,9 @@ private: // Compute azp * azp_adj using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, int32_t, int32_t, + cutlass::multiplies, + int32_t, + int32_t, cutlass::FloatRoundStyle::round_to_nearest>; using EVTComputeAzp = @@ -280,35 +302,41 @@ private: // Compute float(accum - azp*azp_adj), all operands are int32_t using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::minus, float, int32_t, + cutlass::minus, + float, + int32_t, cutlass::FloatRoundStyle::round_to_nearest>; using EVTComputeAcc = cutlass::epilogue::threadblock::Sm80EVT; using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, + cutlass::multiplies, + float, + float, cutlass::FloatRoundStyle::round_to_nearest>; - using EVTComputeScaleB = - cutlass::epilogue::threadblock::Sm80EVT; + using EVTComputeScaleB = cutlass::epilogue::threadblock:: + Sm80EVT; using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiply_add, ElementD, float, + cutlass::multiply_add, + ElementD, + float, cutlass::FloatRoundStyle::round_to_nearest>; -public: - using EVTCompute = - cutlass::epilogue::threadblock::Sm80EVT; + public: + using EVTCompute = cutlass::epilogue::threadblock:: + Sm80EVT; using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType - prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales, - paddle::Tensor const &azp_adj, paddle::Tensor const &azp, - paddle::optional const &bias) { + static ArgumentType prepare_args( + paddle::Tensor const &a_scales, + paddle::Tensor const &b_scales, + paddle::Tensor const &azp_adj, + paddle::Tensor const &azp, + paddle::optional const &bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); @@ -324,4 +352,4 @@ public: } }; -}; // namespace fastdeploy::c2x +}; // namespace fastdeploy::c2x diff --git a/custom_ops/gpu_ops/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/custom_ops/gpu_ops/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp index 90a5e342b5..abb73ce84c 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -1,4 +1,5 @@ -// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +// adapted from: +// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp #pragma once @@ -24,24 +25,28 @@ namespace fastdeploy::c3x { using namespace cute; -template struct identity { +template +struct identity { CUTLASS_HOST_DEVICE T operator()(T lhs) const { return lhs; } }; template struct TrivialEpilogue { -private: + private: using Accum = cutlass::epilogue::fusion::Sm90AccFetch; using Compute = cutlass::epilogue::fusion::Sm90Compute< - cutlass::epilogue::thread::Identity, ElementD, ElementAcc, + cutlass::epilogue::thread::Identity, + ElementD, + ElementAcc, cutlass::FloatRoundStyle::round_to_nearest>; -public: + public: using EVTCompute = cutlass::epilogue::fusion::Sm90EVT; using ArgumentType = typename EVTCompute::Arguments; - template static ArgumentType prepare_args(Args... args) { + template + static ArgumentType prepare_args(Args... args) { return {}; } }; @@ -52,38 +57,60 @@ public: */ template struct ScaledEpilogueBase { -protected: + protected: using Accum = cutlass::epilogue::fusion::Sm90AccFetch; template using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast< - 0 /*Stages*/, TileShape, T, Stride, Int<0>, Int<0>>>; + 0 /*Stages*/, + TileShape, + T, + Stride, Int<0>, Int<0>>>; template using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< - 0 /*Stages*/, TileShape, T, Stride, Int<1>, Int<0>>>; + 0 /*Stages*/, + TileShape, + T, + Stride, Int<1>, Int<0>>>; // Don't want to support nullptr by default template using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< - 0 /*Stages*/, TileShape, T, T, Stride, Int<0>, Int<0>>, - 128 / sizeof_bits_v, EnableNullPtr>; + 0 /*Stages*/, + TileShape, + T, + T, + Stride, Int<0>, Int<0>>, + 128 / sizeof_bits_v, + EnableNullPtr>; // Don't want to support nullptr by default template using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< - 0 /*Stages*/, TileShape, T, T, Stride, Int<1>, Int<0>>, - 128 / sizeof_bits_v, EnableNullPtr>; + 0 /*Stages*/, + TileShape, + T, + T, + Stride, Int<1>, Int<0>>, + 128 / sizeof_bits_v, + EnableNullPtr>; template using ColOrScalarLoadArray = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcastArray< - 0 /*Stages*/, TileShape, T, Stride, Int<0>, Int<0>>>; + 0 /*Stages*/, + TileShape, + T, + Stride, Int<0>, Int<0>>>; template using RowOrScalarLoadArray = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcastArray< - 0 /*Stages*/, TileShape, T, Stride, Int<1>, Int<0>>>; + 0 /*Stages*/, + TileShape, + T, + Stride, Int<1>, Int<0>>>; // This utility function constructs the arguments for the load descriptors // from a tensor. It can handle both row and column, as well as row/column or @@ -142,24 +169,28 @@ protected: template struct ScaledEpilogue : private ScaledEpilogueBase { -private: + private: using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; using ScaleB = typename SUPER::template RowOrScalarLoad; using Compute0 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, + cutlass::multiplies, + float, + float, cutlass::FloatRoundStyle::round_to_nearest>; using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT; using Compute1 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, ElementD, float, + cutlass::multiplies, + ElementD, + float, cutlass::FloatRoundStyle::round_to_nearest>; -public: + public: using EVTCompute = cutlass::epilogue::fusion::Sm90EVT; using ArgumentType = typename EVTCompute::Arguments; @@ -186,7 +217,7 @@ public: template struct ScaledEpilogueBias : private ScaledEpilogueBase { -private: + private: using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; @@ -194,17 +225,21 @@ private: using Bias = typename SUPER::template RowLoad; using Compute0 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, + cutlass::multiplies, + float, + float, cutlass::FloatRoundStyle::round_to_nearest>; using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT; using Compute1 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiply_add, ElementD, float, + cutlass::multiply_add, + ElementD, + float, cutlass::FloatRoundStyle::round_to_nearest>; -public: + public: using EVTCompute = cutlass::epilogue::fusion::Sm90EVT; @@ -229,7 +264,7 @@ public: template struct ScaledEpilogueColumnBias : private ScaledEpilogueBase { -private: + private: using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; @@ -237,17 +272,21 @@ private: using Bias = typename SUPER::template ColLoad; using Compute0 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, + cutlass::multiplies, + float, + float, cutlass::FloatRoundStyle::round_to_nearest>; using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT; using Compute1 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiply_add, ElementD, float, + cutlass::multiply_add, + ElementD, + float, cutlass::FloatRoundStyle::round_to_nearest>; -public: + public: using EVTCompute = cutlass::epilogue::fusion::Sm90EVT; @@ -275,7 +314,7 @@ public: template struct ScaledEpilogueBiasAzp : private ScaledEpilogueBase { -private: + private: using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; @@ -287,33 +326,39 @@ private: // Compute float(accum - azp_adj), both operands are int32_t using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< - cutlass::minus, float, int32_t, + cutlass::minus, + float, + int32_t, cutlass::FloatRoundStyle::round_to_nearest>; using EVTComputeAzp = cutlass::epilogue::fusion::Sm90EVT; using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, + cutlass::multiplies, + float, + float, cutlass::FloatRoundStyle::round_to_nearest>; using EVTComputeScaleB = cutlass::epilogue::fusion::Sm90EVT; using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiply_add, ElementD, float, + cutlass::multiply_add, + ElementD, + float, cutlass::FloatRoundStyle::round_to_nearest>; -public: - using EVTCompute = - cutlass::epilogue::fusion::Sm90EVT; + public: + using EVTCompute = cutlass::epilogue::fusion:: + Sm90EVT; using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType - prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales, - paddle::Tensor const &azp_adj, - paddle::optional const &bias) { + static ArgumentType prepare_args( + paddle::Tensor const &a_scales, + paddle::Tensor const &b_scales, + paddle::Tensor const &azp_adj, + paddle::optional const &bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); @@ -340,7 +385,7 @@ public: template struct ScaledEpilogueBiasAzpToken : private ScaledEpilogueBase { -private: + private: using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; @@ -355,7 +400,9 @@ private: // Compute azp * azp_adj using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, int32_t, int32_t, + cutlass::multiplies, + int32_t, + int32_t, cutlass::FloatRoundStyle::round_to_nearest>; using EVTComputeAzp = @@ -363,33 +410,40 @@ private: // Compute float(accum - azp*azp_adj), all operands are int32_t using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute< - cutlass::minus, float, int32_t, + cutlass::minus, + float, + int32_t, cutlass::FloatRoundStyle::round_to_nearest>; using EVTComputeAcc = cutlass::epilogue::fusion::Sm90EVT; using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, + cutlass::multiplies, + float, + float, cutlass::FloatRoundStyle::round_to_nearest>; using EVTComputeScaleB = cutlass::epilogue::fusion::Sm90EVT; using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiply_add, ElementD, float, + cutlass::multiply_add, + ElementD, + float, cutlass::FloatRoundStyle::round_to_nearest>; -public: - using EVTCompute = - cutlass::epilogue::fusion::Sm90EVT; + public: + using EVTCompute = cutlass::epilogue::fusion:: + Sm90EVT; using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType - prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales, - paddle::Tensor const &azp_adj, paddle::Tensor const &azp, - paddle::optional const &bias) { + static ArgumentType prepare_args( + paddle::Tensor const &a_scales, + paddle::Tensor const &b_scales, + paddle::Tensor const &azp_adj, + paddle::Tensor const &azp, + paddle::optional const &bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); @@ -414,24 +468,28 @@ public: template struct ScaledEpilogueArray : private ScaledEpilogueBase { -private: + private: using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoadArray; using ScaleB = typename SUPER::template RowOrScalarLoadArray; using Compute0 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, + cutlass::multiplies, + float, + float, cutlass::FloatRoundStyle::round_to_nearest>; using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT; using Compute1 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, ElementD, float, + cutlass::multiplies, + ElementD, + float, cutlass::FloatRoundStyle::round_to_nearest>; -public: + public: using EVTCompute = cutlass::epilogue::fusion::Sm90EVT; using ArgumentType = typename EVTCompute::Arguments; @@ -441,7 +499,8 @@ public: static ArgumentType prepare_args(float const *const *a_scales_ptr, float const *const *b_scales_ptr, - bool a_col_broadcast, bool b_row_broadcast) { + bool a_col_broadcast, + bool b_row_broadcast) { auto a_args = SUPER::template args_from_tensor( a_scales_ptr, a_col_broadcast); auto b_args = SUPER::template args_from_tensor( @@ -452,4 +511,4 @@ public: } }; -}; // namespace fastdeploy::c3x +}; // namespace fastdeploy::c3x diff --git a/custom_ops/gpu_ops/cutlass_extensions/epilogue/thread/fused_activations.h b/custom_ops/gpu_ops/cutlass_extensions/epilogue/thread/fused_activations.h index f3c622b88a..0aef590a10 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/epilogue/thread/fused_activations.h +++ b/custom_ops/gpu_ops/cutlass_extensions/epilogue/thread/fused_activations.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,18 +18,20 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file - \brief Functor performing linear combination with a maximum operation used by epilogues. + \brief Functor performing linear combination with a maximum operation used by + epilogues. */ #pragma once @@ -46,60 +48,53 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace epilogue -{ -namespace thread -{ +namespace cutlass { +namespace epilogue { +namespace thread { ///////////////////////////////////////////////////////////////////////////////////////////////// -__forceinline__ __device__ float copysignf_pos(float a, float b) -{ - float r; - r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); - return r; +__forceinline__ __device__ float copysignf_pos(float a, float b) { + float r; + r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); + return r; } -__forceinline__ __device__ float tanh_opt(float x) -{ +__forceinline__ __device__ float tanh_opt(float x) { #if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750) - float const exp_val = -1.f * fabs(2 * x); - return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); + float const exp_val = -1.f * fabs(2 * x); + return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); #else - return fast_tanh(x); + return fast_tanh(x); #endif } ///////////////////////////////////////////////////////////////////////////////////////////////// template <> -struct GELU_taylor -{ - static bool const kIsHeavy = true; +struct GELU_taylor { + static bool const kIsHeavy = true; - CUTLASS_DEVICE - float operator()(float const& z) const - { + CUTLASS_DEVICE + float operator()(float const& z) const { + float k0 = float(0.7978845608028654); + float k1 = float(0.044715); - float k0 = float(0.7978845608028654); - float k1 = float(0.044715); + return float( + cutlass::constants::half() * z * + (cutlass::constants::one() + + tanh_opt(k0 * z * (cutlass::constants::one() + k1 * z * z)))); + } - return float(cutlass::constants::half() * z - * (cutlass::constants::one() + tanh_opt(k0 * z * (cutlass::constants::one() + k1 * z * z)))); - } + using Params = LinearCombinationGenericParams; - using Params = LinearCombinationGenericParams; - - CUTLASS_DEVICE - float operator()(float const& scalar, Params const& params_) const - { - return this->operator()(scalar); - } + CUTLASS_DEVICE + float operator()(float const& scalar, Params const& params_) const { + return this->operator()(scalar); + } }; -} // namespace thread -} // namespace epilogue -} // namespace cutlass +} // namespace thread +} // namespace epilogue +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h b/custom_ops/gpu_ops/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h index aeec5e5d0b..031f3e7048 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h +++ b/custom_ops/gpu_ops/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,20 +18,23 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file - \brief Epilogue visitor for threadblock scoped INT8 GEMMs that uses one scaling factor per row, and one per column. + \brief Epilogue visitor for threadblock scoped INT8 GEMMs that uses one + scaling factor per row, and one per column. - original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h + original file: + 3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h */ @@ -46,305 +49,312 @@ #include "cutlass/numeric_conversion.h" #include "common/quantization.h" -namespace cutlass -{ -namespace epilogue -{ -namespace threadblock -{ +namespace cutlass { +namespace epilogue { +namespace threadblock { -template -class EpilogueVisitorPerRowPerCol -{ -public: - using ThreadblockShape = ThreadblockShape_; - static int const kThreadCount = ThreadCount; +template +class EpilogueVisitorPerRowPerCol { + public: + using ThreadblockShape = ThreadblockShape_; + static int const kThreadCount = ThreadCount; - using ScaleTileIterator = ScaleTileIterator_; - using OutputTileIterator = OutputTileIterator_; - using ElementwiseFunctor = ElementwiseFunctor_; + using ScaleTileIterator = ScaleTileIterator_; + using OutputTileIterator = OutputTileIterator_; + using ElementwiseFunctor = ElementwiseFunctor_; - static int const kIterations = OutputTileIterator::kIterations; - static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + static int const kIterations = OutputTileIterator::kIterations; + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; - using ElementOutput = typename OutputTileIterator::Element; - using LayoutOutput = cutlass::layout::RowMajor; - using ElementAccumulator = ElementAccumulator_; + using ElementOutput = typename OutputTileIterator::Element; + using LayoutOutput = cutlass::layout::RowMajor; + using ElementAccumulator = ElementAccumulator_; - using AlphaScaleElementType = typename ScaleTileIterator::Element; + using AlphaScaleElementType = typename ScaleTileIterator::Element; - using ElementCompute = ElementCompute_; - using AccumulatorFragment = Array; - using ComputeFragment = Array; - using OutputVector = Array; + using ElementCompute = ElementCompute_; + using AccumulatorFragment = Array; + using ComputeFragment = Array; + using OutputVector = Array; - static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth; - static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1); + static int const kThreadsPerRow = + OutputTileIterator::ThreadMap::Detail::kAccessWidth; + static bool const kHasMultiStepsInRow = + (OutputTileIterator::ThreadMap::Iterations::kColumn > 1); - /// Argument structure - struct Arguments - { + /// Argument structure + struct Arguments { + typename ElementwiseFunctor::Params elementwise; + int64_t batch_stride_alpha; + int64_t batch_stride_C; + int64_t batch_stride_D; - typename ElementwiseFunctor::Params elementwise; - int64_t batch_stride_alpha; - int64_t batch_stride_C; - int64_t batch_stride_D; + // + // Methods + // + Arguments() : batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {} - // - // Methods - // - Arguments() - : batch_stride_alpha(0) - , batch_stride_C(0) - , batch_stride_D(0) - { - } + Arguments(typename ElementwiseFunctor::Params elementwise_) + : elementwise(elementwise_), + batch_stride_alpha(0), + batch_stride_C(0), + batch_stride_D(0) {} - Arguments(typename ElementwiseFunctor::Params elementwise_) - : elementwise(elementwise_) - , batch_stride_alpha(0) - , batch_stride_C(0) - , batch_stride_D(0) - { - } + Arguments(typename ElementwiseFunctor::Params elementwise_, + int64_t batch_stride_alpha_, + int64_t batch_stride_C_, + int64_t batch_stride_D_) + : elementwise(elementwise_), + batch_stride_alpha(batch_stride_alpha_), + batch_stride_C(batch_stride_C_), + batch_stride_D(batch_stride_D_) {} + }; - Arguments(typename ElementwiseFunctor::Params elementwise_, int64_t batch_stride_alpha_, - int64_t batch_stride_C_, int64_t batch_stride_D_) - : elementwise(elementwise_) - , batch_stride_alpha(batch_stride_alpha_) - , batch_stride_C(batch_stride_C_) - , batch_stride_D(batch_stride_D_) - { - } - }; + struct Params { + typename ElementwiseFunctor::Params elementwise; + int64_t batch_stride_alpha; + int64_t batch_stride_C; + int64_t batch_stride_D; - struct Params - { + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() {} - typename ElementwiseFunctor::Params elementwise; - int64_t batch_stride_alpha; - int64_t batch_stride_C; - int64_t batch_stride_D; + CUTLASS_HOST_DEVICE + Params(Arguments const& args) + : elementwise(args.elementwise), + batch_stride_alpha(args.batch_stride_alpha), + batch_stride_C(args.batch_stride_C), + batch_stride_D(args.batch_stride_D) {} + }; - // - // Methods - // - CUTLASS_HOST_DEVICE - Params() {} + /// Shared storage + struct SharedStorage {}; - CUTLASS_HOST_DEVICE - Params(Arguments const& args) - : elementwise(args.elementwise) - , batch_stride_alpha(args.batch_stride_alpha) - , batch_stride_C(args.batch_stride_C) - , batch_stride_D(args.batch_stride_D) - { - } - }; + private: + Params const& params_; + SharedStorage& shared_storage_; + MatrixCoord extent_; + MatrixCoord extent_real_; + ElementwiseFunctor elementwise_; - /// Shared storage - struct SharedStorage - { - }; + bool const per_token_quant_; + bool const per_channel_quant_; -private: - Params const& params_; - SharedStorage& shared_storage_; - MatrixCoord extent_; - MatrixCoord extent_real_; - ElementwiseFunctor elementwise_; + AlphaScaleElementType* ptr_alpha_row_; + AlphaScaleElementType* ptr_alpha_col_; + ScaleTileIterator iterator_alpha_col_; + OutputTileIterator iterator_C_; + OutputTileIterator iterator_D_; - bool const per_token_quant_; - bool const per_channel_quant_; + AlphaScaleElementType element_alpha_row_ = 1.0f; + AlphaScaleElementType element_alpha_col_ = 1.0f; + typename ScaleTileIterator::Fragment fragment_alpha_col_; + typename OutputTileIterator::Fragment fragment_C_; + typename OutputTileIterator::Fragment fragment_D_; - AlphaScaleElementType* ptr_alpha_row_; - AlphaScaleElementType* ptr_alpha_col_; - ScaleTileIterator iterator_alpha_col_; - OutputTileIterator iterator_C_; - OutputTileIterator iterator_D_; + ElementAccumulator beta_; - AlphaScaleElementType element_alpha_row_ = 1.0f; - AlphaScaleElementType element_alpha_col_ = 1.0f; - typename ScaleTileIterator::Fragment fragment_alpha_col_; - typename OutputTileIterator::Fragment fragment_C_; - typename OutputTileIterator::Fragment fragment_D_; + int column_offset_; - ElementAccumulator beta_; + MatrixCoord thread_offset_; - int column_offset_; + public: + CUTLASS_DEVICE + EpilogueVisitorPerRowPerCol( + Params const& params, + SharedStorage& shared_storage, + cutlass::MatrixCoord const& problem_size, + int thread_idx, + int warp_idx, + int lane_idx, + typename ScaleTileIterator::Params params_alpha_col, + typename OutputTileIterator::Params params_C, + typename OutputTileIterator::Params params_D, + common::QuantMode quant_option, + AlphaScaleElementType* ptr_alpha_row, + AlphaScaleElementType* ptr_alpha_col, + typename OutputTileIterator::Element* ptr_C, + typename OutputTileIterator::Element* ptr_D, + cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, + 0), + int column_offset = 0, + cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, + 0)) + : params_(params), + shared_storage_(shared_storage), + extent_(problem_size), + elementwise_(params.elementwise), + per_token_quant_(quant_option.hasPerTokenScaling()), + per_channel_quant_(quant_option.hasPerChannelScaling()), + ptr_alpha_row_(ptr_alpha_row), + ptr_alpha_col_(ptr_alpha_col), + iterator_alpha_col_(params_alpha_col, + ptr_alpha_col, + problem_size, + thread_idx, + threadblock_offset), + iterator_C_( + params_C, ptr_C, problem_size, thread_idx, threadblock_offset), + iterator_D_( + params_D, ptr_D, problem_size, thread_idx, threadblock_offset), + extent_real_(problem_size_real) { + beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr + : params.elementwise.beta); - MatrixCoord thread_offset_; - -public: - CUTLASS_DEVICE - EpilogueVisitorPerRowPerCol(Params const& params, SharedStorage& shared_storage, - cutlass::MatrixCoord const& problem_size, int thread_idx, int warp_idx, int lane_idx, - typename ScaleTileIterator::Params params_alpha_col, typename OutputTileIterator::Params params_C, - typename OutputTileIterator::Params params_D, common::QuantMode quant_option, AlphaScaleElementType* ptr_alpha_row, - AlphaScaleElementType* ptr_alpha_col, typename OutputTileIterator::Element* ptr_C, - typename OutputTileIterator::Element* ptr_D, - cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0), int column_offset = 0, - cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0)) - : params_(params) - , shared_storage_(shared_storage) - , extent_(problem_size) - , elementwise_(params.elementwise) - , per_token_quant_(quant_option.hasPerTokenScaling()) - , per_channel_quant_(quant_option.hasPerChannelScaling()) - , ptr_alpha_row_(ptr_alpha_row) - , ptr_alpha_col_(ptr_alpha_col) - , iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset) - , iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset) - , iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset) - , extent_real_(problem_size_real) - { - beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta); - - if (beta_ == ElementAccumulator()) - { - iterator_C_.clear_mask(); - } - - if (!per_channel_quant_ && (ptr_alpha_col_ != nullptr)) - { - element_alpha_col_ = *ptr_alpha_col_; - } - - if (!per_token_quant_ && (ptr_alpha_row_ != nullptr)) - { - element_alpha_row_ = *ptr_alpha_row_; - } + if (beta_ == ElementAccumulator()) { + iterator_C_.clear_mask(); } - /// Helper to indicate split-K behavior - CUTLASS_DEVICE - void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme - int split_k_slices) - { ///< Total number of split-K slices + if (!per_channel_quant_ && (ptr_alpha_col_ != nullptr)) { + element_alpha_col_ = *ptr_alpha_col_; } - /// Called to set the batch index - CUTLASS_DEVICE - void set_batch_index(int batch_idx) - { - iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha); - iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C); - iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D); + if (!per_token_quant_ && (ptr_alpha_row_ != nullptr)) { + element_alpha_row_ = *ptr_alpha_row_; + } + } + + /// Helper to indicate split-K behavior + CUTLASS_DEVICE + void set_k_partition( + int split_k_index, ///< Index of this threadblock within split-K + ///< partitioned scheme + int split_k_slices) { ///< Total number of split-K slices + } + + /// Called to set the batch index + CUTLASS_DEVICE + void set_batch_index(int batch_idx) { + iterator_alpha_col_.add_pointer_offset(batch_idx * + params_.batch_stride_alpha); + iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C); + iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D); + } + + /// Called at the start of the epilogue just before iterating over accumulator + /// slices + CUTLASS_DEVICE + void begin_epilogue() { + if (per_channel_quant_) { + iterator_alpha_col_.load(fragment_alpha_col_); + } + } + + /// Called at the start of one step before starting accumulator exchange + CUTLASS_DEVICE + void begin_step(int step_idx) { + fragment_D_.clear(); + fragment_C_.clear(); + + if (elementwise_.kScale != + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { + iterator_C_.load(fragment_C_); + ++iterator_C_; + } + } + + /// Called at the start of a row + CUTLASS_DEVICE + void begin_row(int row_idx) { + // load alpha_row in begin_step only when per token(row) scaling is used + if (per_token_quant_) { + int thread_offset_row = + iterator_D_.thread_start_row() + + OutputTileIterator::ThreadMap::iteration_offset(row_idx).row(); + + arch::global_load( + element_alpha_row_, + ptr_alpha_row_ + thread_offset_row, + thread_offset_row < extent_.row()); + } + } + + /// Called after accumulators have been exchanged for each accumulator vector + CUTLASS_DEVICE + void visit(int iter_idx, + int row_idx, + int column_idx, + int frag_idx, + AccumulatorFragment const& accum) { + NumericArrayConverter + source_converter; + + ComputeFragment result = source_converter(accum); + if (per_channel_quant_) { + ComputeFragment alpha_col = + reinterpret_cast(&fragment_alpha_col_)[column_idx]; + result = per_token_channel_scale_accumulator_( + result, alpha_col, element_alpha_row_); + } else { + result = per_token_scale_accumulator_( + result, element_alpha_col_, element_alpha_row_); } - /// Called at the start of the epilogue just before iterating over accumulator slices - CUTLASS_DEVICE - void begin_epilogue() - { - if (per_channel_quant_) - { - iterator_alpha_col_.load(fragment_alpha_col_); - } + // Convert to the output + NumericArrayConverter + output_converter; + OutputVector& output = + reinterpret_cast(&fragment_D_)[frag_idx]; + output = output_converter(result); + } + + /// Called at the end of a row + CUTLASS_DEVICE + void end_row(int row_idx) {} + + /// Called after all accumulator elements have been visited + CUTLASS_DEVICE + void end_step(int step_idx) { + iterator_D_.store(fragment_D_); + ++iterator_D_; + } + + /// Called after all steps have been completed + CUTLASS_DEVICE + void end_epilogue() {} + + private: + CUTLASS_DEVICE + ComputeFragment per_token_channel_scale_accumulator_( + ComputeFragment const& accum, + ComputeFragment const& scale_col, + AlphaScaleElementType const& scale_row) { + ComputeFragment result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ComputeFragment::kElements; ++i) { + result[i] = accum[i] * (scale_col[i] * scale_row); } - /// Called at the start of one step before starting accumulator exchange - CUTLASS_DEVICE - void begin_step(int step_idx) - { - fragment_D_.clear(); - fragment_C_.clear(); + return result; + } - if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) - { - iterator_C_.load(fragment_C_); - ++iterator_C_; - } + CUTLASS_DEVICE + ComputeFragment per_token_scale_accumulator_( + ComputeFragment const& accum, + AlphaScaleElementType const& scale_col, + AlphaScaleElementType const& scale_row) { + ComputeFragment result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ComputeFragment::kElements; ++i) { + result[i] = accum[i] * (scale_col * scale_row); } - /// Called at the start of a row - CUTLASS_DEVICE - void begin_row(int row_idx) - { - // load alpha_row in begin_step only when per token(row) scaling is used - if (per_token_quant_) - { - int thread_offset_row - = iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(row_idx).row(); - - arch::global_load( - element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row()); - } - } - - /// Called after accumulators have been exchanged for each accumulator vector - CUTLASS_DEVICE - void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const& accum) - { - - NumericArrayConverter source_converter; - - ComputeFragment result = source_converter(accum); - if (per_channel_quant_) - { - ComputeFragment alpha_col = reinterpret_cast(&fragment_alpha_col_)[column_idx]; - result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_); - } - else - { - result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_); - } - - // Convert to the output - NumericArrayConverter output_converter; - OutputVector& output = reinterpret_cast(&fragment_D_)[frag_idx]; - output = output_converter(result); - } - - /// Called at the end of a row - CUTLASS_DEVICE - void end_row(int row_idx) {} - - /// Called after all accumulator elements have been visited - CUTLASS_DEVICE - void end_step(int step_idx) - { - - iterator_D_.store(fragment_D_); - ++iterator_D_; - } - - /// Called after all steps have been completed - CUTLASS_DEVICE - void end_epilogue() {} - -private: - CUTLASS_DEVICE - ComputeFragment per_token_channel_scale_accumulator_( - ComputeFragment const& accum, ComputeFragment const& scale_col, AlphaScaleElementType const& scale_row) - { - - ComputeFragment result; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < ComputeFragment::kElements; ++i) - { - result[i] = accum[i] * (scale_col[i] * scale_row); - } - - return result; - } - - CUTLASS_DEVICE - ComputeFragment per_token_scale_accumulator_( - ComputeFragment const& accum, AlphaScaleElementType const& scale_col, AlphaScaleElementType const& scale_row) - { - - ComputeFragment result; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < ComputeFragment::kElements; ++i) - { - result[i] = accum[i] * (scale_col * scale_row); - } - - return result; - } + return result; + } }; -} // namespace threadblock -} // namespace epilogue -} // namespace cutlass +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h b/custom_ops/gpu_ops/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h index 6f26d79017..2f89cb3f21 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h +++ b/custom_ops/gpu_ops/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,23 +18,26 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. - The epilogue rearranges the result of a matrix product through shared memory to match canonical - tensor layouts in global memory. Epilogues support conversion and reduction operations. + The epilogue rearranges the result of a matrix product through shared memory + to match canonical tensor layouts in global memory. Epilogues support + conversion and reduction operations. - original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h + original file: + 3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h */ @@ -80,35 +83,45 @@ //////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace epilogue -{ -namespace threadblock -{ +namespace cutlass { +namespace epilogue { +namespace threadblock { //////////////////////////////////////////////////////////////////////////////// -namespace detail -{ +namespace detail { -/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared memory bank conflicts. -template -struct DefaultIteratorsTensorOp -{ - using WarpTileIterator - = cutlass::epilogue::warp::TileIteratorTensorOpMixed; +/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared +/// memory bank conflicts. +template +struct DefaultIteratorsTensorOp { + using WarpTileIterator = + cutlass::epilogue::warp::TileIteratorTensorOpMixed; - using SharedLoadIterator - = cutlass::epilogue::threadblock::SharedLoadIteratorMixed; + using SharedLoadIterator = cutlass::epilogue::threadblock:: + SharedLoadIteratorMixed; - static int const kFragmentsPerIteration = 2; + static int const kFragmentsPerIteration = 2; }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace detail +} // namespace detail ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -116,167 +129,159 @@ struct DefaultIteratorsTensorOp -class SharedLoadIteratorMixed -{ -public: - using ThreadMap = ThreadMap_; - using Shape = typename ThreadMap::Shape; +template +class SharedLoadIteratorMixed { + public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; - using Element = int32_t; + using Element = int32_t; - using Layout = layout::RowMajor; - using TensorRef = TensorRef; - using ConstTensorRef = typename TensorRef::ConstTensorRef; + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - using TensorCoord = MatrixCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; - static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; - static int const kAlignment = ThreadMap::kElementsPerAccess * sizeof_bits::value / 8; + static int const kAlignment = + ThreadMap::kElementsPerAccess * sizeof_bits::value / 8; - static int const kThreads = ThreadMap::kThreads; + static int const kThreads = ThreadMap::kThreads; - /// Fragment object - using Fragment = Array; + /// Fragment object + using Fragment = + Array; - /// Memory access size - using AccessType = AlignedArray; + /// Memory access size + using AccessType = + AlignedArray; - /// Vector type used for SMEM loads - using LoadType = AlignedArray::value, ThreadMap::kElementsPerAccess), - const_min(16, kAlignment)>; + /// Vector type used for SMEM loads + using LoadType = AlignedArray::value, + ThreadMap::kElementsPerAccess), + const_min(16, kAlignment)>; - static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements; + static int const kLoadsPerAccess = + AccessType::kElements / LoadType::kElements; -private: - // - // Data members - // + private: + // + // Data members + // - /// Byte-level pointer - LoadType const* pointers_[kLoadsPerAccess]; + /// Byte-level pointer + LoadType const* pointers_[kLoadsPerAccess]; - /// Stride along adjacent rows in units of LoadType - int stride_; + /// Stride along adjacent rows in units of LoadType + int stride_; -public: - // - // Methods - // + public: + // + // Methods + // - /// Constructor - CUTLASS_DEVICE - SharedLoadIteratorMixed(TensorRef ref, int thread_idx) - : stride_((ref.stride(0) / LoadType::kElements)) - { + /// Constructor + CUTLASS_DEVICE + SharedLoadIteratorMixed(TensorRef ref, int thread_idx) + : stride_((ref.stride(0) / LoadType::kElements)) { + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); - TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); + // Initialize pointers + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] = reinterpret_cast(ref.data()); - // Initialize pointers - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kLoadsPerAccess; ++i) - { - pointers_[i] = reinterpret_cast(ref.data()); + int col_idx = + (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess; + int bank_offset = + (col_idx * int(sizeof(LoadType)) / 128) % kLoadsPerAccess; - int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess; - int bank_offset = (col_idx * int(sizeof(LoadType)) / 128) % kLoadsPerAccess; + col_idx += (bank_offset + i) % kLoadsPerAccess; - col_idx += (bank_offset + i) % kLoadsPerAccess; - - pointers_[i] += thread_offset.row() * stride_ + col_idx; - } + pointers_[i] += thread_offset.row() * stride_ + col_idx; } + } - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) - { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kLoadsPerAccess; ++i) - { - pointers_[i] += pointer_offset / LoadType::kElements; - } + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] += pointer_offset / LoadType::kElements; } + } - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const& offset) - { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kLoadsPerAccess; ++i) - { - pointers_[i] - += offset.row() * Shape::kRow * stride_ + offset.column() * Shape::kColumn / LoadType::kElements; - } + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& offset) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] += offset.row() * Shape::kRow * stride_ + + offset.column() * Shape::kColumn / LoadType::kElements; } + } - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment& frag, Index pointer_offset) const - { - + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) const { + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { CUTLASS_PRAGMA_UNROLL - for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) - { + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int row_ptr_offset = row * ThreadMap::Delta::kRow * stride_ + + group * ThreadMap::Delta::kGroup * stride_ + + cluster * ThreadMap::Delta::kCluster * stride_ + + pointer_offset / LoadType::kElements; + + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + LoadType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + int frag_idx = + frag_row_idx * ThreadMap::Iterations::kColumn + column; CUTLASS_PRAGMA_UNROLL - for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) - { + for (int v = 0; v < kLoadsPerAccess; ++v) { + int vector_idx = (column * ThreadMap::Delta::kColumn / + kElementsPerAccess * kLoadsPerAccess); - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) - { + LoadType const* memory_pointer = pointers_[v] + row_ptr_offset; - int row_ptr_offset = row * ThreadMap::Delta::kRow * stride_ - + group * ThreadMap::Delta::kGroup * stride_ + cluster * ThreadMap::Delta::kCluster * stride_ - + pointer_offset / LoadType::kElements; - - int frag_row_idx - = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - - LoadType* frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) - { - - int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < kLoadsPerAccess; ++v) - { - - int vector_idx - = (column * ThreadMap::Delta::kColumn / kElementsPerAccess * kLoadsPerAccess); - - LoadType const* memory_pointer = pointers_[v] + row_ptr_offset; - - frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[vector_idx]; - } - } - } + frag_ptr[frag_idx * kLoadsPerAccess + v] = + memory_pointer[vector_idx]; } + } } + } } + } - /// Loads a fragment - CUTLASS_DEVICE - void load(Fragment& frag) const - { - - load_with_pointer_offset(frag, 0); - } + /// Loads a fragment + CUTLASS_DEVICE + void load(Fragment& frag) const { load_with_pointer_offset(frag, 0); } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace threadblock -} // namespace epilogue -} // namespace cutlass +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass //////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/epilogue_helpers.h b/custom_ops/gpu_ops/cutlass_extensions/epilogue_helpers.h index 6ed5b9b920..ca8209d59c 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/epilogue_helpers.h +++ b/custom_ops/gpu_ops/cutlass_extensions/epilogue_helpers.h @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,9 +17,10 @@ /** * @file epilogue_helpers.h * - * This file includes types for the epilogues. The empty structs exist so we can signal to template - * code the type of epilogue we want to run, and let the underlying code specify the details such as - * element types, accumulator type and elements per vector access. + * This file includes types for the epilogues. The empty structs exist so we can + * signal to template code the type of epilogue we want to run, and let the + * underlying code specify the details such as element types, accumulator type + * and elements per vector access. * */ @@ -33,107 +34,161 @@ // #include "cutlass/epilogue/fusion/operations.hpp" -namespace cutlass_extensions -{ +namespace cutlass_extensions { -struct EpilogueOpBiasSilu -{ +struct EpilogueOpBiasSilu {}; + +struct EpilogueOpBiasReLU {}; + +struct EpilogueOpBiasFtGelu {}; + +struct EpilogueOpBias {}; + +struct EpilogueOpDefaultSilu {}; + +struct EpilogueOpDefaultReLU {}; + +struct EpilogueOpDefaultFtGelu {}; + +struct EpilogueOpDefault {}; + +template +struct Epilogue { + static_assert(sizeof(ElementType) == 0, "Unrecognized Epilogue Tag"); }; -struct EpilogueOpBiasReLU -{ +constexpr auto BiasScaleMode = + cutlass::epilogue::thread::ScaleType::NoBetaScaling; + +template +struct Epilogue { + using Op = + cutlass::epilogue::thread::LinearCombinationSilu; }; -struct EpilogueOpBiasFtGelu -{ +template +struct Epilogue { + using Op = + cutlass::epilogue::thread::LinearCombinationRelu; }; -struct EpilogueOpBias -{ +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationGeneric< + cutlass::epilogue::thread::GELU_taylor, + ElementType, + ElementsPerVectorAccess, + ElementAccumulator, + ElementAccumulator, + BiasScaleMode, + cutlass::FloatRoundStyle::round_to_nearest, + true>; }; -struct EpilogueOpDefaultSilu -{ -}; - -struct EpilogueOpDefaultReLU -{ -}; - -struct EpilogueOpDefaultFtGelu -{ -}; - -struct EpilogueOpDefault -{ -}; - -template -struct Epilogue -{ - static_assert(sizeof(ElementType) == 0, "Unrecognized Epilogue Tag"); -}; - -constexpr auto BiasScaleMode = cutlass::epilogue::thread::ScaleType::NoBetaScaling; - -template -struct Epilogue -{ - using Op = cutlass::epilogue::thread::LinearCombinationSilu; -}; - -template -struct Epilogue -{ - using Op = cutlass::epilogue::thread::LinearCombinationRelu; -}; - -template -struct Epilogue -{ - using Op = cutlass::epilogue::thread::LinearCombinationGeneric; -}; - -template -struct Epilogue -{ - using Op = cutlass::epilogue::thread::LinearCombination; +template +struct Epilogue { + using Op = + cutlass::epilogue::thread::LinearCombination; }; constexpr auto DefaultScaleMode = cutlass::epilogue::thread::ScaleType::Default; -template -struct Epilogue -{ - using Op = cutlass::epilogue::thread::LinearCombinationSilu; +template +struct Epilogue { + using Op = + cutlass::epilogue::thread::LinearCombinationSilu; }; -template -struct Epilogue -{ - using Op = cutlass::epilogue::thread::LinearCombinationRelu; +template +struct Epilogue { + using Op = + cutlass::epilogue::thread::LinearCombinationRelu; }; -template -struct Epilogue -{ - using Op = cutlass::epilogue::thread::LinearCombinationGeneric; +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationGeneric< + cutlass::epilogue::thread::GELU_taylor, + ElementType, + ElementsPerVectorAccess, + ElementAccumulator, + ElementAccumulator, + DefaultScaleMode, + cutlass::FloatRoundStyle::round_to_nearest, + true>; }; -template -struct Epilogue -{ - using Op = cutlass::epilogue::thread::LinearCombination; +template +struct Epilogue { + using Op = + cutlass::epilogue::thread::LinearCombination; }; -} // namespace cutlass_extensions +} // namespace cutlass_extensions diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/collective_builder.hpp b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/collective_builder.hpp index d327eb18ae..5d5b99b5b2 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/collective_builder.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/collective_builder.hpp @@ -21,7 +21,6 @@ #include "cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp" - ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm::collective { @@ -29,19 +28,17 @@ namespace cutlass::gemm::collective { ///////////////////////////////////////////////////////////////////////////////////////////////// // GMMA_TMA_WS_SS (BlockScaled Builders) -template < - class ElementA, - class GmemLayoutATag, - int AlignmentA, - class ElementB, - class GmemLayoutBTag, - int AlignmentB, - class ElementAccumulator, - class TileShape_MNK, - class ClusterShape_MNK, - class StageCountType, - int ScaleGranularityM -> +template struct CollectiveBuilder< arch::Sm90, arch::OpClassTensorOp, @@ -55,82 +52,124 @@ struct CollectiveBuilder< TileShape_MNK, ClusterShape_MNK, StageCountType, - KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum, - cute::enable_if_t< - not detail::is_use_rmem_A()> -> { - using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum; + KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum< + ScaleGranularityM>, + cute::enable_if_t()>> { + using KernelScheduleType = + KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum< + ScaleGranularityM>; static_assert(is_static::value); static_assert(is_static::value); #ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED - static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); + static_assert(cutlass::detail::dependent_false, + "Unsupported Toolkit for SM90 Collective Builder\n"); #endif - static_assert(detail::is_aligned(), + static_assert(detail::is_aligned(), "Should meet TMA alignment requirement\n"); - static constexpr bool IsArrayOfPointersGemm = (cute::is_any_of_v); + static constexpr bool IsArrayOfPointersGemm = + (cute::is_any_of_v); static constexpr bool IsFP8Input = detail::is_input_fp8(); static_assert((!IsFP8Input || !IsArrayOfPointersGemm), - "KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is only compatible with FP8 Blocked Scaled version right now."); + "KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is " + "only compatible with FP8 Blocked Scaled version right now."); // For fp32 types, map to tf32 MMA value type - using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>; - using ElementBMma = cute::conditional_t, tfloat32_t, ElementB>; + using ElementAMma = cute:: + conditional_t, tfloat32_t, ElementA>; + using ElementBMma = cute:: + conditional_t, tfloat32_t, ElementB>; - static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); - static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + static constexpr cute::GMMA::Major GmmaMajorA = + detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = + detail::gmma_ss_tag_to_major_B(); - static constexpr bool IsCooperative = cute::is_any_of_v>; + static constexpr bool IsCooperative = cute::is_any_of_v< + KernelScheduleType, + KernelTmaWarpSpecializedCooperative, + KernelPtrArrayTmaWarpSpecializedCooperative, + KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum< + ScaleGranularityM>>; using AtomLayoutMNK = cute::conditional_t>, Layout>>; + Layout>, + Layout>>; - using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< - ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); + using TiledMma = decltype(cute::make_tiled_mma( + cute::GMMA::ss_op_selector(), + AtomLayoutMNK{})); - using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); - using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom( + shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom( + shape<0>(ClusterShape_MNK{}))); - using SmemLayoutAtomA = decltype(detail::ss_smem_selector< - GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutAtomB = decltype(detail::ss_smem_selector< - GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomA = + decltype(detail::ss_smem_selector(TileShape_MNK{})), + decltype(cute::get<2>( + TileShape_MNK{}))>()); + using SmemLayoutAtomB = + decltype(detail::ss_smem_selector(TileShape_MNK{})), + decltype(cute::get<2>( + TileShape_MNK{}))>()); - static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0; + static constexpr size_t TensorMapStorage = + IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ + : 0; static constexpr int KernelSmemCarveout = static_cast(TensorMapStorage); - static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); - using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8; + static constexpr int PipelineStages = + detail::compute_stage_count_or_override(StageCountType{}); + using DispatchPolicy = + MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8< + PipelineStages, + ClusterShape_MNK, + KernelScheduleType, + ScaleGranularityM>; using SmemCopyAtomA = void; using SmemCopyAtomB = void; - using CollectiveOp = CollectiveMma< - DispatchPolicy, - TileShape_MNK, - ElementA, - TagToStrideA_t, - ElementB, - TagToStrideB_t, - TiledMma, - GmemTiledCopyA, - SmemLayoutAtomA, - SmemCopyAtomA, - cute::identity, - GmemTiledCopyB, - SmemLayoutAtomB, - SmemCopyAtomB, - cute::identity - >; + using CollectiveOp = CollectiveMma, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity>; }; - ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::gemm::collective diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/collective_builder_gated.hpp b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/collective_builder_gated.hpp index 227aee50fe..e3f7e03ef9 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/collective_builder_gated.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/collective_builder_gated.hpp @@ -39,12 +39,23 @@ namespace cutlass::gemm::collective { ///////////////////////////////////////////////////////////////////////////////////////////////// -template class Activation, - bool SwapAB = false, class Enable = void> +template + class Activation, + bool SwapAB = false, + class Enable = void> struct CollectiveBuilderGated { static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters."); @@ -52,7 +63,7 @@ struct CollectiveBuilderGated { ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace cutlass::gemm::collective +} // namespace cutlass::gemm::collective ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/collective_mma_gated.hpp b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/collective_mma_gated.hpp index 56849ee56f..d2d06fed31 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/collective_mma_gated.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/collective_mma_gated.hpp @@ -39,12 +39,23 @@ namespace cutlass::gemm::collective { ///////////////////////////////////////////////////////////////////////////////////////////////// -template class Activation, + template + class Activation, bool SwapAB = false> struct CollectiveMmaGated { static_assert(cutlass::detail::dependent_false, @@ -53,7 +64,7 @@ struct CollectiveMmaGated { ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace cutlass::gemm::collective +} // namespace cutlass::gemm::collective ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/fp8_accumulation.hpp b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/fp8_accumulation.hpp index 0a530e5c14..b22492fd9e 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/fp8_accumulation.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/fp8_accumulation.hpp @@ -12,17 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -// adapted from: https://github.com/soundOfDestiny/cutlass/blob/a4208aa6958864923505cade9c63eb2a6daf16e5/include/cutlass/gemm/collective/fp8_accumulation.hpp +// adapted from: +// https://github.com/soundOfDestiny/cutlass/blob/a4208aa6958864923505cade9c63eb2a6daf16e5/include/cutlass/gemm/collective/fp8_accumulation.hpp /*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -34,14 +35,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ @@ -61,24 +63,26 @@ namespace cutlass::gemm::collective { -template < - class EngineAccum, - class LayoutAccum> +template struct GmmaFP8AccumulationWithScale { using TensorAccum = cute::Tensor; using ElementAccumulator = typename EngineAccum::value_type; - static_assert(is_static::value, "Accumulator Layout should be static"); - static_assert(is_rmem::value , "Accumulator tensor must be rmem resident."); + static_assert(is_static::value, + "Accumulator Layout should be static"); + static_assert(is_rmem::value, + "Accumulator tensor must be rmem resident."); -private: - TensorAccum& accum_; + private: + TensorAccum &accum_; TensorAccum accum_temp_; - uint32_t accum_promotion_interval_; // defines the max num of executed MMAs after which accum should be promoted. - uint32_t mma_count_per_mainloop_iteration_; // num of MMAs per k_tile of mainloop - uint32_t mma_count_; // current executed MMAs - uint32_t reset_accum_flag_; // accum needs to be zeroed or not. + uint32_t accum_promotion_interval_; // defines the max num of executed MMAs + // after which accum should be promoted. + uint32_t + mma_count_per_mainloop_iteration_; // num of MMAs per k_tile of mainloop + uint32_t mma_count_; // current executed MMAs + uint32_t reset_accum_flag_; // accum needs to be zeroed or not. // promote or `add` the partial accumulators to main accumulator (FADD). CUTLASS_DEVICE @@ -90,18 +94,20 @@ private: } } - // `multiply` scale the partial accumulators and `add` to main accumulator (FFMA). - template < - class EngineScale, - class LayoutScale> - CUTLASS_DEVICE - void scale_core(const cute::Tensor &scale) { + // `multiply` scale the partial accumulators and `add` to main accumulator + // (FFMA). + template + CUTLASS_DEVICE void scale_core( + const cute::Tensor &scale) { using TensorScale = cute::Tensor; - static_assert(is_static::value, "Scale Layout should be static"); - static_assert(is_rmem::value , "Scale tensor must be rmem resident."); + static_assert(is_static::value, + "Scale Layout should be static"); + static_assert(is_rmem::value, + "Scale tensor must be rmem resident."); - static_assert(LayoutAccum{}.shape() == LayoutScale{}.shape(), "Accumulator and scale must have same shape."); + static_assert(LayoutAccum{}.shape() == LayoutScale{}.shape(), + "Accumulator and scale must have same shape."); warpgroup_wait<0>(); CUTLASS_PRAGMA_UNROLL @@ -110,18 +116,16 @@ private: } } -public: + public: CUTLASS_DEVICE - GmmaFP8AccumulationWithScale( - TensorAccum &accum, - uint32_t accum_promotion_interval, - uint32_t mma_count_per_mainloop_iteration) + GmmaFP8AccumulationWithScale(TensorAccum &accum, + uint32_t accum_promotion_interval, + uint32_t mma_count_per_mainloop_iteration) : accum_(accum), accum_promotion_interval_(accum_promotion_interval), mma_count_per_mainloop_iteration_(mma_count_per_mainloop_iteration), mma_count_(0), - reset_accum_flag_(0) - { + reset_accum_flag_(0) { accum_temp_ = cute::make_fragment_like(accum); } @@ -130,32 +134,31 @@ public: // CUTLASS_DEVICE - TensorAccum& operator()() { - return accum_temp_; - } + TensorAccum &operator()() { return accum_temp_; } /// prepare the MMA accumulators when initialization or zeroing is required. CUTLASS_DEVICE - bool prepare_if_needed() { - return reset_accum_flag_; - } + bool prepare_if_needed() { return reset_accum_flag_; } // // Methods (for FADD version) // - /// promote (add) the results from the MMA accumulators to main accumulator if needed. + /// promote (add) the results from the MMA accumulators to main accumulator if + /// needed. CUTLASS_DEVICE void promote_if_needed() { mma_count_ += mma_count_per_mainloop_iteration_; - reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); + reset_accum_flag_ = + __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); if (reset_accum_flag_) { promote_core(); mma_count_ = 0; } } - /// promote (add) the residue results from the MMA accumulators to main accumulator if needed. + /// promote (add) the residue results from the MMA accumulators to main + /// accumulator if needed. CUTLASS_DEVICE void promote_residue_if_needed() { if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) { @@ -167,30 +170,29 @@ public: // Methods (for FFMA version) // - /// scale (multiply_add) the results from the MMA accumulators to main accumulator if needed. - template < - class EngineScale, - class LayoutScale> - CUTLASS_DEVICE - void scale_if_needed(const cute::Tensor &scale) { + /// scale (multiply_add) the results from the MMA accumulators to main + /// accumulator if needed. + template + CUTLASS_DEVICE void scale_if_needed( + const cute::Tensor &scale) { mma_count_ += mma_count_per_mainloop_iteration_; - reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); + reset_accum_flag_ = + __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); if (reset_accum_flag_) { scale_core(scale); mma_count_ = 0; } } - /// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed. - template < - class EngineScale, - class LayoutScale> - CUTLASS_DEVICE - void scale_residue_if_needed(const cute::Tensor &scale) { + /// scale (multiply_add) the residue results from the MMA accumulators to main + /// accumulator if needed. + template + CUTLASS_DEVICE void scale_residue_if_needed( + const cute::Tensor &scale) { if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) { scale_core(scale); } } }; -} // namespace cutlass::gemm::collective +} // namespace cutlass::gemm::collective diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp index 8ff14a2a49..f335ec2d39 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp @@ -53,18 +53,43 @@ using namespace cute; ///////////////////////////////////////////////////////////////////////////////////////////////// // WarpSpecialized Mainloop -template class Activation_, bool SwapAB_> + template + class Activation_, + bool SwapAB_> struct CollectiveMmaGated< MainloopSm90TmaGmmaWarpSpecialized, - TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, - GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, - GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_, + Activation_, SwapAB_> { static constexpr bool isGated = true; static constexpr bool SwapAB = SwapAB_; @@ -93,7 +118,8 @@ struct CollectiveMmaGated< using Activation = Activation_; using ElementAux = cute::conditional_t; - using ValTypeAux = cute::conditional_t; using MainloopPipeline = cutlass::PipelineTmaAsync; @@ -118,16 +144,20 @@ struct CollectiveMmaGated< // Tile along modes in a way that maximizes the TMA box size. using SmemLayoutA = decltype(tile_to_shape( SmemLayoutAtomA{}, - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), + make_shape(shape<0>(TileShape{}), + shape<2>(TileShape{}), Int{}), conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), - Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + Step<_2, _1, _3>, + Step<_1, _2, _3>>{})); using SmemLayoutB = decltype(tile_to_shape( SmemLayoutAtomB{}, - make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), + make_shape(shape<1>(TileShape{}), + shape<2>(TileShape{}), Int{}), conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), - Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + Step<_2, _1, _3>, + Step<_1, _2, _3>>{})); using SmemLayoutAux = cute::conditional_t; static_assert(DispatchPolicy::Stages >= 2, @@ -151,10 +181,12 @@ struct CollectiveMmaGated< static constexpr bool ConvertF32toTF32A = cute::is_same_v; static constexpr bool ConvertF32toTF32B = cute::is_same_v; using InternalElementA = - cute::conditional_t>>; using InternalElementB = - cute::conditional_t>>; using InternalElementAux = cute::conditional_t; @@ -195,18 +227,22 @@ struct CollectiveMmaGated< using TMA_A = decltype(make_tma_copy( GmemTiledCopyA{}, make_tensor(static_cast(nullptr), - repeat_like(StrideA{}, int32_t(0)), StrideA{}), + repeat_like(StrideA{}, int32_t(0)), + StrideA{}), SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), - size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + size<1>( + ClusterShape{}))); // mcast along N mode for this M load, if any // Assumption: StrideB is congruent with Problem_NK using TMA_B = decltype(make_tma_copy( GmemTiledCopyB{}, make_tensor(static_cast(nullptr), - repeat_like(StrideB{}, int32_t(0)), StrideB{}), + repeat_like(StrideB{}, int32_t(0)), + StrideB{}), SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), - size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + size<0>( + ClusterShape{}))); // mcast along M mode for this N load, if any using TMA_Aux = cute::conditional_t; TMA_A tma_load_a; TMA_B tma_load_b; @@ -220,9 +256,10 @@ struct CollectiveMmaGated< // template - static constexpr Params - to_underlying_arguments(ProblemShape const &problem_shape, - Arguments const &args, void *workspace) { + static constexpr Params to_underlying_arguments( + ProblemShape const &problem_shape, + Arguments const &args, + void *workspace) { (void)workspace; // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is @@ -238,36 +275,44 @@ struct CollectiveMmaGated< Tensor tensor_b = make_tensor(ptr_B0, make_layout(make_shape(N, K, L), args.dB)); typename Params::TMA_A tma_load_a = make_tma_copy( - GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_, _, cute::Int<0>{}), + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), - size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any typename Params::TMA_B tma_load_b = make_tma_copy( - GmemTiledCopyB{}, tensor_b, SmemLayoutB{}(_, _, cute::Int<0>{}), + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), - size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any if constexpr (SwapAB) { auto ptr_Aux = reinterpret_cast(args.ptr_B1); Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA)); typename Params::TMA_Aux tma_load_aux = make_tma_copy( - GmemTiledCopyA{}, tensor_aux, SmemLayoutA{}(_, _, cute::Int<0>{}), + GmemTiledCopyA{}, + tensor_aux, + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), size<1>( - ClusterShape{})); // mcast along N mode for this M load, if any - return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, - args.scale_d1}; + ClusterShape{})); // mcast along N mode for this M load, if any + return { + tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1}; } else { auto ptr_Aux = reinterpret_cast(args.ptr_B1); Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB)); typename Params::TMA_Aux tma_load_aux = make_tma_copy( - GmemTiledCopyB{}, tensor_aux, SmemLayoutB{}(_, _, cute::Int<0>{}), + GmemTiledCopyB{}, + tensor_aux, + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), size<0>( - ClusterShape{})); // mcast along M mode for this N load, if any - return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, - args.scale_d1}; + ClusterShape{})); // mcast along M mode for this N load, if any + return { + tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1}; } } @@ -293,8 +338,9 @@ struct CollectiveMmaGated< cute::make_shape(N, K, L), StrideB{}); if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the " - "minimum alignment requirements for TMA.\n"); + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Problem Size doesn't meet the " + "minimum alignment requirements for TMA.\n"); } return implementable; } @@ -342,49 +388,64 @@ struct CollectiveMmaGated< // TMA requires special handling of strides to deal with coord codomain // mapping Represent the full tensors -- get these from TMA Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor( - make_shape(M, K, L)); // (m,k,l) + make_shape(M, K, L)); // (m,k,l) Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor( - make_shape(N, K, L)); // (n,k,l) + make_shape(N, K, L)); // (n,k,l) // Make tiled views, defer the slice - Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), - Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) - Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), - Step{}); // (BLK_N,BLK_K,n,k,l) + Tensor gA_mkl = local_tile(mA_mkl, + TileShape{}, + make_coord(_, _, _), + Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, + TileShape{}, + make_coord(_, _, _), + Step{}); // (BLK_N,BLK_K,n,k,l) if constexpr (SwapAB) { Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor( - make_shape(M, K, L)); // (m,k,l) - Tensor gAux_xkl = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), - Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + make_shape(M, K, L)); // (m,k,l) + Tensor gAux_xkl = local_tile(mAux_xkl, + TileShape{}, + make_coord(_, _, _), + Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); } else { Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor( - make_shape(N, K, L)); // (n,k,l) - Tensor gAux_xkl = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), - Step{}); // (BLK_N,BLK_K,n,k,l) + make_shape(N, K, L)); // (n,k,l) + Tensor gAux_xkl = local_tile(mAux_xkl, + TileShape{}, + make_coord(_, _, _), + Step{}); // (BLK_N,BLK_K,n,k,l) return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); } } /// Perform a collective-scoped matrix multiply-accumulate /// Producer Perspective - template - CUTLASS_DEVICE void - load(Params const &mainloop_params, MainloopPipeline pipeline, - PipelineState smem_pipe_write, - cute::tuple const &load_inputs, - BlockCoord const &blk_coord, KTileIterator k_tile_iter, int k_tile_count, - int thread_idx, uint32_t block_rank_in_cluster, - TensorStorage &shared_tensors) { + CUTLASS_DEVICE void load( + Params const &mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const &load_inputs, + BlockCoord const &blk_coord, + KTileIterator k_tile_iter, + int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage &shared_tensors) { int lane_predicate = cute::elect_one_sync(); if (lane_predicate) { Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), - SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), - SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{}); @@ -412,17 +473,17 @@ struct CollectiveMmaGated< cluster_local_block_id.x); // Partition the inputs based on the current block coordinates. auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; - Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) - Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k) + Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k) Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord) : gAux_xkl(_, _, n_coord, _, l_coord); // Applies the mapping from block_tma_a - Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) - Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) - Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) - Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) Tensor tAuxgAux = block_tma_aux.partition_S(gAux); Tensor tAuxsAux = block_tma_aux.partition_D(sAux); @@ -435,18 +496,18 @@ struct CollectiveMmaGated< // Maps the tile -> block, value if constexpr (cute::is_same_v) { auto block_layout = - Layout{}; // (m,n) -> - // block_id + Layout{}; // (m,n) -> + // block_id for (int n = 0; n < size<1>(block_layout); ++n) { - mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, - n, Int<0>{})); + mcast_mask_a |= (uint16_t(1) << block_layout( + cluster_local_block_id.x, n, Int<0>{})); } } if constexpr (cute::is_same_v) { auto block_layout = - Layout{}; // (m,n) -> - // block_id + Layout{}; // (m,n) -> + // block_id for (int m = 0; m < size<0>(block_layout); ++m) { mcast_mask_b |= (uint16_t(1) << block_layout( m, cluster_local_block_id.y, Int<0>{})); @@ -475,11 +536,14 @@ struct CollectiveMmaGated< int write_stage = smem_pipe_write.index(); copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), - tAgA(_, _, _, *k_tile_iter), tAsA(_, _, _, write_stage)); + tAgA(_, _, _, *k_tile_iter), + tAsA(_, _, _, write_stage)); copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), - tBgB(_, _, _, *k_tile_iter), tBsB(_, _, _, write_stage)); + tBgB(_, _, _, *k_tile_iter), + tBsB(_, _, _, write_stage)); copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux), - tAuxgAux(_, _, _, *k_tile_iter), tAuxsAux(_, _, _, write_stage)); + tAuxgAux(_, _, _, *k_tile_iter), + tAuxsAux(_, _, _, write_stage)); ++k_tile_iter; // Advance smem_pipe_write @@ -508,10 +572,14 @@ struct CollectiveMmaGated< /// Perform a collective-scoped matrix multiply-accumulate /// Consumer Perspective template - CUTLASS_DEVICE void - mma(MainloopPipeline pipeline, PipelineState smem_pipe_read, - FrgTensorC &accum0, FrgTensorC &accum1, int k_tile_count, int thread_idx, - TensorStorage &shared_tensors, Params const &mainloop_params) { + CUTLASS_DEVICE void mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC &accum0, + FrgTensorC &accum1, + int k_tile_count, + int thread_idx, + TensorStorage &shared_tensors, + Params const &mainloop_params) { static_assert(is_rmem::value, "C tensor must be rmem resident."); static_assert(cute::rank(SmemLayoutA{}) == 3, @@ -528,9 +596,9 @@ struct CollectiveMmaGated< "smem sourced instructions."); Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), - SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), - SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{}); @@ -541,12 +609,12 @@ struct CollectiveMmaGated< TiledMma tiled_mma; auto thread_mma = tiled_mma.get_thread_slice(thread_idx); - Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) // Allocate "fragments/descriptors" - Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) auto tCsAux = [&]() -> auto { if constexpr (SwapAB) { @@ -554,34 +622,36 @@ struct CollectiveMmaGated< } else { return thread_mma.partition_B(sAux); } - }(); + } + (); auto tCrAux = [&]() -> auto { if constexpr (SwapAB) { return thread_mma.make_fragment_A(tCsAux); } else { return thread_mma.make_fragment_B(tCsAux); } - }(); - - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N - CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K - CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE - if constexpr (SwapAB) { - CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N - CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K - CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE - } else { - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N - CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K - CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE } - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + (); + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + if constexpr (SwapAB) { + CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE + } else { + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE + } + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE CUTE_STATIC_ASSERT_V(Int{} == - size<2>(sAux)); // PIPE + size<2>(sAux)); // PIPE // // PIPELINED MAIN LOOP @@ -613,14 +683,20 @@ struct CollectiveMmaGated< CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { // (V,M,K) x (V,N,K) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), - tCrB(_, _, k_block, read_stage), accum0); + cute::gemm(tiled_mma, + tCrA(_, _, k_block, read_stage), + tCrB(_, _, k_block, read_stage), + accum0); if constexpr (SwapAB) { - cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), - tCrB(_, _, k_block, read_stage), accum1); + cute::gemm(tiled_mma, + tCrAux(_, _, k_block, read_stage), + tCrB(_, _, k_block, read_stage), + accum1); } else { - cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), - tCrAux(_, _, k_block, read_stage), accum1); + cute::gemm(tiled_mma, + tCrA(_, _, k_block, read_stage), + tCrAux(_, _, k_block, read_stage), + accum1); } tiled_mma.accumulate_ = GMMA::ScaleOut::One; } @@ -654,14 +730,20 @@ struct CollectiveMmaGated< CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { // (V,M,K) x (V,N,K) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), - tCrB(_, _, k_block, read_stage), accum0); + cute::gemm(tiled_mma, + tCrA(_, _, k_block, read_stage), + tCrB(_, _, k_block, read_stage), + accum0); if constexpr (SwapAB) { - cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), - tCrB(_, _, k_block, read_stage), accum1); + cute::gemm(tiled_mma, + tCrAux(_, _, k_block, read_stage), + tCrB(_, _, k_block, read_stage), + accum1); } else { - cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), - tCrAux(_, _, k_block, read_stage), accum1); + cute::gemm(tiled_mma, + tCrA(_, _, k_block, read_stage), + tCrAux(_, _, k_block, read_stage), + accum1); } tiled_mma.accumulate_ = GMMA::ScaleOut::One; } @@ -699,8 +781,9 @@ struct CollectiveMmaGated< warpgroup_wait<0>(); for (int count = 0; count < prologue_mma_count; ++count) { - pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, - // done _computing_ on it + pipeline.consumer_release( + smem_pipe_release); // UNLOCK smem_pipe_release, + // done _computing_ on it ++smem_pipe_release; } } @@ -708,6 +791,6 @@ struct CollectiveMmaGated< ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace cutlass::gemm::collective +} // namespace cutlass::gemm::collective ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp index 76ffbdb2e6..c34ad242e2 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp @@ -55,18 +55,43 @@ using namespace cute; ///////////////////////////////////////////////////////////////////////////////////////////////// // WarpSpecialized Mainloop -template class Activation_, bool SwapAB_> + template + class Activation_, + bool SwapAB_> struct CollectiveMmaGated< MainloopSm90TmaGmmaWarpSpecializedFP8, - TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, - GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, - GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_, + Activation_, SwapAB_> { static constexpr bool isGated = true; static constexpr bool SwapAB = SwapAB_; @@ -74,9 +99,9 @@ struct CollectiveMmaGated< // // Type Aliases // - using DispatchPolicy = - MainloopSm90TmaGmmaWarpSpecializedFP8; + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedFP8; using TileShape = TileShape_; using ElementA = ElementA_; using StrideA = StrideA_; @@ -96,7 +121,8 @@ struct CollectiveMmaGated< using Activation = Activation_; using ElementAux = cute::conditional_t; - using ValTypeAux = cute::conditional_t; using MainloopPipeline = cutlass::PipelineTmaAsync; @@ -121,16 +147,20 @@ struct CollectiveMmaGated< // Tile along modes in a way that maximizes the TMA box size. using SmemLayoutA = decltype(tile_to_shape( SmemLayoutAtomA{}, - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), + make_shape(shape<0>(TileShape{}), + shape<2>(TileShape{}), Int{}), conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), - Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + Step<_2, _1, _3>, + Step<_1, _2, _3>>{})); using SmemLayoutB = decltype(tile_to_shape( SmemLayoutAtomB{}, - make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), + make_shape(shape<1>(TileShape{}), + shape<2>(TileShape{}), Int{}), conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), - Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + Step<_2, _1, _3>, + Step<_1, _2, _3>>{})); using SmemLayoutAux = cute::conditional_t; static_assert(DispatchPolicy::Stages >= 2, @@ -184,18 +214,22 @@ struct CollectiveMmaGated< using TMA_A = decltype(make_tma_copy( GmemTiledCopyA{}, make_tensor(static_cast(nullptr), - repeat_like(StrideA{}, int32_t(0)), StrideA{}), + repeat_like(StrideA{}, int32_t(0)), + StrideA{}), SmemLayoutA{}(_, _, 0), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), - size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + size<1>( + ClusterShape{}))); // mcast along N mode for this M load, if any // Assumption: StrideB is congruent with Problem_NK using TMA_B = decltype(make_tma_copy( GmemTiledCopyB{}, make_tensor(static_cast(nullptr), - repeat_like(StrideB{}, int32_t(0)), StrideB{}), + repeat_like(StrideB{}, int32_t(0)), + StrideB{}), SmemLayoutB{}(_, _, 0), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), - size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + size<0>( + ClusterShape{}))); // mcast along M mode for this N load, if any using TMA_Aux = cute::conditional_t; TMA_A tma_load_a; TMA_B tma_load_b; @@ -210,9 +244,10 @@ struct CollectiveMmaGated< // template - static constexpr Params - to_underlying_arguments(ProblemShape const &problem_shape, - Arguments const &args, void *workspace) { + static constexpr Params to_underlying_arguments( + ProblemShape const &problem_shape, + Arguments const &args, + void *workspace) { (void)workspace; // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is @@ -228,35 +263,51 @@ struct CollectiveMmaGated< Tensor tensor_b = make_tensor(ptr_B0, make_layout(make_shape(N, K, L), args.dB)); typename Params::TMA_A tma_load_a = make_tma_copy( - GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_, _, cute::Int<0>{}), + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), - size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any typename Params::TMA_B tma_load_b = make_tma_copy( - GmemTiledCopyB{}, tensor_b, SmemLayoutB{}(_, _, cute::Int<0>{}), + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), - size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any if constexpr (SwapAB) { auto ptr_Aux = reinterpret_cast(args.ptr_B1); Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA)); typename Params::TMA_Aux tma_load_aux = make_tma_copy( - GmemTiledCopyA{}, tensor_aux, SmemLayoutA{}(_, _, cute::Int<0>{}), + GmemTiledCopyA{}, + tensor_aux, + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), size<1>( - ClusterShape{})); // mcast along N mode for this M load, if any - return {tma_load_a, tma_load_b, tma_load_aux, - args.scale_d0, args.scale_d1, args.mma_promotion_interval}; + ClusterShape{})); // mcast along N mode for this M load, if any + return {tma_load_a, + tma_load_b, + tma_load_aux, + args.scale_d0, + args.scale_d1, + args.mma_promotion_interval}; } else { auto ptr_Aux = reinterpret_cast(args.ptr_B1); Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB)); typename Params::TMA_Aux tma_load_aux = make_tma_copy( - GmemTiledCopyB{}, tensor_aux, SmemLayoutB{}(_, _, cute::Int<0>{}), + GmemTiledCopyB{}, + tensor_aux, + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), size<0>( - ClusterShape{})); // mcast along M mode for this N load, if any - return {tma_load_a, tma_load_b, tma_load_aux, - args.scale_d0, args.scale_d1, args.mma_promotion_interval}; + ClusterShape{})); // mcast along M mode for this N load, if any + return {tma_load_a, + tma_load_b, + tma_load_aux, + args.scale_d0, + args.scale_d1, + args.mma_promotion_interval}; } } @@ -285,8 +336,9 @@ struct CollectiveMmaGated< implementable = implementable && (args.mma_promotion_interval % 4 == 0); if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the " - "minimum alignment requirements for TMA.\n"); + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Problem Size doesn't meet the " + "minimum alignment requirements for TMA.\n"); } return implementable; } @@ -333,49 +385,64 @@ struct CollectiveMmaGated< // TMA requires special handling of strides to deal with coord codomain // mapping Represent the full tensors -- get these from TMA Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor( - make_shape(M, K, L)); // (m,k,l) + make_shape(M, K, L)); // (m,k,l) Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor( - make_shape(N, K, L)); // (n,k,l) + make_shape(N, K, L)); // (n,k,l) // Make tiled views, defer the slice - Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), - Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) - Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), - Step{}); // (BLK_N,BLK_K,n,k,l) + Tensor gA_mkl = local_tile(mA_mkl, + TileShape{}, + make_coord(_, _, _), + Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, + TileShape{}, + make_coord(_, _, _), + Step{}); // (BLK_N,BLK_K,n,k,l) if constexpr (SwapAB) { Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor( - make_shape(M, K, L)); // (m,k,l) - Tensor gAux_xkl = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), - Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + make_shape(M, K, L)); // (m,k,l) + Tensor gAux_xkl = local_tile(mAux_xkl, + TileShape{}, + make_coord(_, _, _), + Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); } else { Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor( - make_shape(N, K, L)); // (n,k,l) - Tensor gAux_xkl = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), - Step{}); // (BLK_N,BLK_K,n,k,l) + make_shape(N, K, L)); // (n,k,l) + Tensor gAux_xkl = local_tile(mAux_xkl, + TileShape{}, + make_coord(_, _, _), + Step{}); // (BLK_N,BLK_K,n,k,l) return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); } } /// Perform a collective-scoped matrix multiply-accumulate /// Producer Perspective - template - CUTLASS_DEVICE void - load(Params const &mainloop_params, MainloopPipeline pipeline, - PipelineState smem_pipe_write, - cute::tuple const &load_inputs, - BlockCoord const &blk_coord, KTileIterator k_tile_iter, int k_tile_count, - int thread_idx, uint32_t block_rank_in_cluster, - TensorStorage &shared_tensors) { + CUTLASS_DEVICE void load( + Params const &mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const &load_inputs, + BlockCoord const &blk_coord, + KTileIterator k_tile_iter, + int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage &shared_tensors) { int lane_predicate = cute::elect_one_sync(); if (lane_predicate) { Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), - SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), - SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{}); @@ -403,17 +470,17 @@ struct CollectiveMmaGated< // Partition the inputs based on the current block coordinates. auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; - Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) - Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k) + Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k) Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord) : gAux_xkl(_, _, n_coord, _, l_coord); // Applies the mapping from block_tma_a - Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) - Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) - Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) - Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) Tensor tAuxgAux = block_tma_aux.partition_S(gAux); Tensor tAuxsAux = block_tma_aux.partition_D(sAux); @@ -426,18 +493,18 @@ struct CollectiveMmaGated< // Maps the tile -> block, value if constexpr (cute::is_same_v) { auto block_layout = - Layout{}; // (m,n) -> - // block_id + Layout{}; // (m,n) -> + // block_id for (int n = 0; n < size<1>(block_layout); ++n) { - mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, - n, Int<0>{})); + mcast_mask_a |= (uint16_t(1) << block_layout( + cluster_local_block_id.x, n, Int<0>{})); } } if constexpr (cute::is_same_v) { auto block_layout = - Layout{}; // (m,n) -> - // block_id + Layout{}; // (m,n) -> + // block_id for (int m = 0; m < size<0>(block_layout); ++m) { mcast_mask_b |= (uint16_t(1) << block_layout( m, cluster_local_block_id.y, Int<0>{})); @@ -466,11 +533,14 @@ struct CollectiveMmaGated< int write_stage = smem_pipe_write.index(); copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), - tAgA(_, _, _, *k_tile_iter), tAsA(_, _, _, write_stage)); + tAgA(_, _, _, *k_tile_iter), + tAsA(_, _, _, write_stage)); copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), - tBgB(_, _, _, *k_tile_iter), tBsB(_, _, _, write_stage)); + tBgB(_, _, _, *k_tile_iter), + tBsB(_, _, _, write_stage)); copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux), - tAuxgAux(_, _, _, *k_tile_iter), tAuxsAux(_, _, _, write_stage)); + tAuxgAux(_, _, _, *k_tile_iter), + tAuxsAux(_, _, _, write_stage)); ++k_tile_iter; // Advance smem_pipe_write @@ -499,11 +569,14 @@ struct CollectiveMmaGated< /// Perform a collective-scoped matrix multiply-accumulate /// Consumer Perspective template - CUTLASS_DEVICE void - mma(MainloopPipeline pipeline, PipelineState smem_pipe_read, - FrgTensorC &accum0, FrgTensorC &accum1, int k_tile_count, int thread_idx, - TensorStorage &shared_tensors, Params const &mainloop_params) { - + CUTLASS_DEVICE void mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC &accum0, + FrgTensorC &accum1, + int k_tile_count, + int thread_idx, + TensorStorage &shared_tensors, + Params const &mainloop_params) { static_assert(is_rmem::value, "C tensor must be rmem resident."); static_assert(cute::rank(SmemLayoutA{}) == 3, @@ -518,9 +591,9 @@ struct CollectiveMmaGated< "smem sourced instructions."); Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), - SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), - SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{}); @@ -531,12 +604,12 @@ struct CollectiveMmaGated< TiledMma tiled_mma; auto thread_mma = tiled_mma.get_thread_slice(thread_idx); - Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) // Allocate "fragments/descriptors" - Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) auto tCsAux = [&]() -> auto { if constexpr (SwapAB) { @@ -544,34 +617,36 @@ struct CollectiveMmaGated< } else { return thread_mma.partition_B(sAux); } - }(); + } + (); auto tCrAux = [&]() -> auto { if constexpr (SwapAB) { return thread_mma.make_fragment_A(tCsAux); } else { return thread_mma.make_fragment_B(tCsAux); } - }(); - - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N - CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K - CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE - if constexpr (SwapAB) { - CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N - CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K - CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE - } else { - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N - CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K - CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE } - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + (); + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + if constexpr (SwapAB) { + CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE + } else { + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE + } + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE CUTE_STATIC_ASSERT_V(Int{} == - size<2>(sAux)); // PIPE + size<2>(sAux)); // PIPE // // PIPELINED MAIN LOOP @@ -611,14 +686,20 @@ struct CollectiveMmaGated< CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { // (V,M,K) x (V,N,K) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), - tCrB(_, _, k_block, read_stage), accumulation0()); + cute::gemm(tiled_mma, + tCrA(_, _, k_block, read_stage), + tCrB(_, _, k_block, read_stage), + accumulation0()); if constexpr (SwapAB) { - cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), - tCrB(_, _, k_block, read_stage), accumulation1()); + cute::gemm(tiled_mma, + tCrAux(_, _, k_block, read_stage), + tCrB(_, _, k_block, read_stage), + accumulation1()); } else { - cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), - tCrAux(_, _, k_block, read_stage), accumulation1()); + cute::gemm(tiled_mma, + tCrA(_, _, k_block, read_stage), + tCrAux(_, _, k_block, read_stage), + accumulation1()); } tiled_mma.accumulate_ = GMMA::ScaleOut::One; } @@ -659,14 +740,20 @@ struct CollectiveMmaGated< CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { // (V,M,K) x (V,N,K) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), - tCrB(_, _, k_block, read_stage), accumulation0()); + cute::gemm(tiled_mma, + tCrA(_, _, k_block, read_stage), + tCrB(_, _, k_block, read_stage), + accumulation0()); if constexpr (SwapAB) { - cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), - tCrB(_, _, k_block, read_stage), accumulation1()); + cute::gemm(tiled_mma, + tCrAux(_, _, k_block, read_stage), + tCrB(_, _, k_block, read_stage), + accumulation1()); } else { - cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), - tCrAux(_, _, k_block, read_stage), accumulation1()); + cute::gemm(tiled_mma, + tCrA(_, _, k_block, read_stage), + tCrAux(_, _, k_block, read_stage), + accumulation1()); } tiled_mma.accumulate_ = GMMA::ScaleOut::One; } @@ -681,8 +768,9 @@ struct CollectiveMmaGated< accumulation0.promote_if_needed(); accumulation1.promote_if_needed(); - pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, - // done _computing_ on it + pipeline.consumer_release( + smem_pipe_release); // UNLOCK smem_pipe_release, + // done _computing_ on it // Advance smem_pipe_read and smem_pipe_release ++smem_pipe_read; @@ -710,8 +798,9 @@ struct CollectiveMmaGated< warpgroup_wait<0>(); for (int count = 0; count < prologue_mma_count; ++count) { - pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, - // done _computing_ on it + pipeline.consumer_release( + smem_pipe_release); // UNLOCK smem_pipe_release, + // done _computing_ on it ++smem_pipe_release; } } @@ -719,6 +808,6 @@ struct CollectiveMmaGated< ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace cutlass::gemm::collective +} // namespace cutlass::gemm::collective ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp index be1f9747e7..837d65ae54 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp @@ -12,18 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Adapt from https://github.com/vllm-project/vllm/blob/v0.7.2/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp -// Adapted (Heavily) from: https://github.com/soundOfDestiny/cutlass/blob/9d997ce0dea4c5fa1a617db6b7ff29aa9235822c/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp +// Adapt from +// https://github.com/vllm-project/vllm/blob/v0.7.2/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp +// Adapted (Heavily) from: +// https://github.com/soundOfDestiny/cutlass/blob/9d997ce0dea4c5fa1a617db6b7ff29aa9235822c/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp /*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -35,14 +37,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ @@ -73,46 +76,52 @@ using namespace cute; ///////////////////////////////////////////////////////////////////////////////////////////////// // WarpSpecialized Mainloop -template < - int Stages, - class ClusterShape, - class KernelSchedule, - int ScaleGranularityM_, - class TileShape_, - class ElementA_, - class StrideA_, - class ElementB_, - class StrideB_, - class TiledMma_, - class GmemTiledCopyA_, - class SmemLayoutAtomA_, - class SmemCopyAtomA_, - class TransformA_, - class GmemTiledCopyB_, - class SmemLayoutAtomB_, - class SmemCopyAtomB_, - class TransformB_> -struct CollectiveMma< - MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8, - TileShape_, - ElementA_, - StrideA_, - ElementB_, - StrideB_, - TiledMma_, - GmemTiledCopyA_, - SmemLayoutAtomA_, - SmemCopyAtomA_, - TransformA_, - GmemTiledCopyB_, - SmemLayoutAtomB_, - SmemCopyAtomB_, - TransformB_> -{ +template +struct CollectiveMma, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> { // // Type Aliases // - using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8; + using DispatchPolicy = + MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8< + Stages, + ClusterShape, + KernelSchedule, + ScaleGranularityM_>; using TileShape = TileShape_; using ElementA = ElementA_; using StrideA = StrideA_; @@ -139,55 +148,91 @@ struct CollectiveMma< // Two threads per CTA are producers (1 for operand tile and 32 for scales) static constexpr int NumProducerThreadEvents = 33; - static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_; - static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; + static constexpr int ScaleGranularityM = + ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_; + static constexpr int ScaleMsPerTile = + size<0>(TileShape{}) / ScaleGranularityM; - static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); - static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, + "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); - static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); - static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, + "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); - static_assert((size<0>(TileShape{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M."); + static_assert( + (size<0>(TileShape{}) % ScaleGranularityM) == 0, + "FP8 scaling granularity must evenly divide tile shape along M."); // Tile along modes in a way that maximizes the TMA box size. using SmemLayoutA = decltype(tile_to_shape( SmemLayoutAtomA{}, - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), - cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + make_shape(shape<0>(TileShape{}), + shape<2>(TileShape{}), + Int{}), + cute::conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), + Step<_2, _1, _3>, + Step<_1, _2, _3>>{})); using SmemLayoutB = decltype(tile_to_shape( SmemLayoutAtomB{}, - make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), - cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + make_shape(shape<1>(TileShape{}), + shape<2>(TileShape{}), + Int{}), + cute::conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), + Step<_2, _1, _3>, + Step<_1, _2, _3>>{})); // Block scaling gmem-to-smem copy atom - using SmemBlockScalingCopyAtomA = Copy_Atom, ElementBlockScale>; - using SmemBlockScalingCopyAtomB = Copy_Atom, ElementBlockScale>; + using SmemBlockScalingCopyAtomA = + Copy_Atom, + ElementBlockScale>; + using SmemBlockScalingCopyAtomB = + Copy_Atom, + ElementBlockScale>; // Block scaling smem layout - using SmemLayoutScaleA = Layout, Int>>; - using SmemLayoutScaleB = Layout>, Stride<_1>>; // `ScaleNsPerTile` is always 1. + using SmemLayoutScaleA = + Layout, Int>>; + using SmemLayoutScaleB = Layout>, + Stride<_1>>; // `ScaleNsPerTile` is always 1. - static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); - static_assert(cute::is_base_of::value && - cute::is_base_of::value, - "MMA atom must source both A and B operand from smem_desc for this mainloop."); - static_assert(cute::is_same_v || cute::is_same_v, - "GmemTiledCopy - invalid SM90 TMA copy atom specified."); - static_assert(cute::is_same_v || cute::is_same_v, - "GmemTiledCopy - invalid SM90 TMA copy atom specified."); - static_assert(cute::is_same_v, - "ElementAccumulator and ElementBlockScale should be same datatype"); + static_assert(DispatchPolicy::Stages >= 2, + "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for " + "this mainloop."); + static_assert(cute::is_same_v || + cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || + cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert( + cute::is_same_v, + "ElementAccumulator and ElementBlockScale should be same datatype"); - struct SharedStorage - { + struct SharedStorage { struct TensorStorage : cute::aligned_struct<128> { - cute::array_aligned> smem_A; // mxk - cute::array_aligned> smem_B; // nxk - cute::array_aligned> smem_scale_A; // ScaleMsPerTile x k - cute::array_aligned> smem_scale_B; // 1xk + cute::array_aligned> + smem_A; // mxk + cute::array_aligned> + smem_B; // nxk + cute::array_aligned> + smem_scale_A; // ScaleMsPerTile x k + cute::array_aligned> + smem_scale_B; // 1xk } tensors; using PipelineStorage = typename MainloopPipeline::SharedStorage; @@ -211,15 +256,19 @@ struct CollectiveMma< // Assumption: StrideA is congruent with Problem_MK using TMA_A = decltype(make_tma_copy_A_sm90( GmemTiledCopyA{}, - make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), - SmemLayoutA{}(_,_,0), + make_tensor(static_cast(nullptr), + repeat_like(StrideA{}, int32_t(0)), + StrideA{}), + SmemLayoutA{}(_, _, 0), TileShape{}, ClusterShape{})); // Assumption: StrideB is congruent with Problem_NK using TMA_B = decltype(make_tma_copy_B_sm90( GmemTiledCopyB{}, - make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), - SmemLayoutB{}(_,_,0), + make_tensor(static_cast(nullptr), + repeat_like(StrideB{}, int32_t(0)), + StrideB{}), + SmemLayoutB{}(_, _, 0), TileShape{}, ClusterShape{})); TMA_A tma_load_a; @@ -237,103 +286,128 @@ struct CollectiveMma< // template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - (void) workspace; + static constexpr Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + (void)workspace; - // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is + // only rank-3 (MNK) auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M,N,K,L] = problem_shape_MNKL; + auto [M, N, K, L] = problem_shape_MNKL; auto ptr_A = reinterpret_cast(args.ptr_A); auto ptr_B = reinterpret_cast(args.ptr_B); - Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); - Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); - typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90( - GmemTiledCopyA{}, - tensor_a, - SmemLayoutA{}(_,_,cute::Int<0>{}), - TileShape{}, - ClusterShape{}); - typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90( - GmemTiledCopyB{}, - tensor_b, - SmemLayoutB{}(_,_,cute::Int<0>{}), - TileShape{}, - ClusterShape{}); + Tensor tensor_a = + make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA)); + Tensor tensor_b = + make_tensor(ptr_B, make_layout(make_shape(N, K, L), args.dB)); + typename Params::TMA_A tma_load_a = + make_tma_copy_A_sm90(GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_, _, cute::Int<0>{}), + TileShape{}, + ClusterShape{}); + typename Params::TMA_B tma_load_b = + make_tma_copy_B_sm90(GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_, _, cute::Int<0>{}), + TileShape{}, + ClusterShape{}); uint32_t transaction_bytes_mk = TmaTransactionBytesMK; uint32_t transaction_bytes_nk = TmaTransactionBytesNK; uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk; - return { - tma_load_a, - tma_load_b, - transaction_bytes, - transaction_bytes_mk, - transaction_bytes_nk, - args.ptr_scale_A, - args.ptr_scale_B - }; + return {tma_load_a, + tma_load_b, + transaction_bytes, + transaction_bytes_mk, + transaction_bytes_nk, + args.ptr_scale_A, + args.ptr_scale_B}; } - template - static bool - can_implement( - ProblemShape const& problem_shape, - [[maybe_unused]] Arguments const& args) { + template + static bool can_implement(ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { constexpr int tma_alignment_bits = 128; auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M,N,K,L] = problem_shape_MNKL; + auto [M, N, K, L] = problem_shape_MNKL; bool implementable = true; - constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); - constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + constexpr int min_tma_aligned_elements_A = + tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = + implementable && + cutlass::detail::check_alignment( + cute::make_shape(M, K, L), StrideA{}); + constexpr int min_tma_aligned_elements_B = + tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = + implementable && + cutlass::detail::check_alignment( + cute::make_shape(N, K, L), StrideB{}); if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment " + "requirements for TMA.\n"); } return implementable; } static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; static constexpr int K_PIPE_MMAS = 1; - static constexpr uint32_t TmaTransactionBytesMK = - cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)); - static constexpr uint32_t TmaTransactionBytesNK = - cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)); - static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + static constexpr uint32_t TmaTransactionBytesMK = cutlass::bits_to_bytes( + size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * + static_cast(sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytesNK = cutlass::bits_to_bytes( + size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * + static_cast(sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytes = + TmaTransactionBytesMK + TmaTransactionBytesNK; - /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best + /// performance CUTLASS_DEVICE - static void prefetch_tma_descriptors(Params const& mainloop_params) - { - cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); - cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor( + mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor( + mainloop_params.tma_load_b.get_tma_descriptor()); } /// Set up the data needed by this collective for load and mma. - /// Returns a tuple of tensors. The collective and the kernel layer have the contract - /// Returned tuple must contain at least two elements, with the first two elements being: - /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) - /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// Returns a tuple of tensors. The collective and the kernel layer have the + /// contract Returned tuple must contain at least two elements, with the first + /// two elements being: gA_mkl - The tma tensor, A after a local tile so it + /// has shape (BLK_M,BLK_K,m,k,l) gB_nkl - The tma tensor, B after a local + /// tile so it has shape (BLK_N,BLK_K,n,k,l) template - CUTLASS_DEVICE auto - load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, + Params const& mainloop_params) const { using X = Underscore; // Separate out problem shape for convenience - auto [M,N,K,L] = problem_shape_MNKL; + auto [M, N, K, L] = problem_shape_MNKL; - // TMA requires special handling of strides to deal with coord codomain mapping - // Represent the full tensors -- get these from TMA - Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) - Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + // TMA requires special handling of strides to deal with coord codomain + // mapping Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor( + make_shape(M, K, L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor( + make_shape(N, K, L)); // (n,k,l) // Make tiled views, defer the slice - Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) - Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + Tensor gA_mkl = local_tile(mA_mkl, + TileShape{}, + make_coord(_, _, _), + Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, + TileShape{}, + make_coord(_, _, _), + Step{}); // (BLK_N,BLK_K,n,k,l) constexpr auto scales_m = Int{}; auto tM = get<2>(gA_mkl.shape()); @@ -341,84 +415,103 @@ struct CollectiveMma< auto tK = get<3>(gA_mkl.shape()); // Make the tiled views of scale tensors - auto scaleA_shape = make_shape(M / ScaleGranularityM, tK, L); // (scale_m,k,l) - auto scaleA_layout = make_ordered_layout(scaleA_shape, Step<_0, _1, _2>{}); - auto scaleB_shape = make_shape(tN, tK, L); // (n,k,l) + auto scaleA_shape = + make_shape(M / ScaleGranularityM, tK, L); // (scale_m,k,l) + auto scaleA_layout = make_ordered_layout(scaleA_shape, Step<_0, _1, _2>{}); + auto scaleB_shape = make_shape(tN, tK, L); // (n,k,l) auto scaleB_layout = make_ordered_layout(scaleB_shape, Step<_1, _0, _2>{}); - // Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and - // gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl. - Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (scale_m,k,l) - Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,k,l) + // Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the + // `m` host and gScaleA_mkl and gScaleB_nkl in `g` global memory are same as + // mScaleA_mkl and mScaleB_nkl. + Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), + scaleA_layout); // (scale_m,k,l) + Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), + scaleB_layout); // (n,k,l) return cute::make_tuple(gA_mkl, gB_nkl, mScaleA_mkl, mScaleB_nkl); } /// Perform a collective-scoped matrix multiply-accumulate /// Producer Perspective - template < - class TensorA, class TensorB, - class TensorScaleA, class TensorScaleB, - class KTileIterator, class BlockCoord - > - CUTLASS_DEVICE void - load( + template + CUTLASS_DEVICE void load( Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write, - cute::tuple const& load_inputs, + cute::tuple const& + load_inputs, BlockCoord const& blk_coord, - KTileIterator k_tile_iter, int k_tile_count, + KTileIterator k_tile_iter, + int k_tile_count, int thread_idx, uint32_t block_rank_in_cluster, TensorStorage& shared_tensors) { int lane_predicate = cute::elect_one_sync(); // Blockscaling: Tma loads for load_input and CpAsync for load_scale - Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) - Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k) - Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k) + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), + SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), + SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sScaleA = + make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), + SmemLayoutScaleA{}); // (ScaleMsPerTile,k) + Tensor sScaleB = + make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), + SmemLayoutScaleB{}); // (k) // // Prepare the TMA loads for A and B // constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); - uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, + block_rank_in_cluster / cluster_shape_x}; Tensor gA_mkl = get<0>(load_inputs); Tensor gB_nkl = get<1>(load_inputs); - auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); - auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + auto block_tma_a = + mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = + mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); // Partition the inputs based on the current block coordinates. auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; - Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k) - - // Block scaling: load_scale has scaling tensors in global memory which are not tiled + // Block scaling: load_scale has scaling tensors in global memory which are + // not tiled Tensor mScaleA_mkl = get<2>(load_inputs); Tensor mScaleB_nkl = get<3>(load_inputs); auto scales_m = get<0>(mScaleA_mkl.shape()); Tensor cScaleA_mkl = make_identity_tensor(mScaleA_mkl.shape()); - Tensor gScaleA = local_tile( - mScaleA_mkl, make_tile(Int{}), - make_coord(m_coord,_,l_coord)); // (ScaleMsPerTile,k,1) - Tensor cScaleA = local_tile( - cScaleA_mkl, make_tile(Int{}), - make_coord(m_coord,_,l_coord)); - Tensor gScaleB = mScaleB_nkl(n_coord,_,l_coord); // (1,k,1) + Tensor gScaleA = + local_tile(mScaleA_mkl, + make_tile(Int{}), + make_coord(m_coord, _, l_coord)); // (ScaleMsPerTile,k,1) + Tensor cScaleA = local_tile(cScaleA_mkl, + make_tile(Int{}), + make_coord(m_coord, _, l_coord)); + Tensor gScaleB = mScaleB_nkl(n_coord, _, l_coord); // (1,k,1) // TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128 - TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{}, - Layout>{}, Layout>{}); // (1,1,1) + TiledCopy scale_copy_a = + make_tiled_copy(SmemBlockScalingCopyAtomA{}, + Layout>{}, + Layout>{}); // (1,1,1) TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{}, - Layout>{}, Layout>{}); // (1,1,1) + Layout>{}, + Layout>{}); // (1,1,1) ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x); ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x); @@ -430,11 +523,11 @@ struct CollectiveMma< Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB); // Applies the mapping from block_tma_a - Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) - Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) - Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) - Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) uint16_t mcast_mask_a = 0; uint16_t mcast_mask_b = 0; @@ -442,30 +535,34 @@ struct CollectiveMma< // Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors // Maps the tile -> block, value if constexpr (cute::is_same_v) { - auto block_layout = Layout{}; // (m,n) -> block_id + auto block_layout = + Layout{}; // (m,n) -> block_id for (int n = 0; n < size<1>(block_layout); ++n) { - mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + mcast_mask_a |= (uint16_t(1) << block_layout( + cluster_local_block_id.x, n, Int<0>{})); } } if constexpr (cute::is_same_v) { - auto block_layout = Layout{}; // (m,n) -> block_id + auto block_layout = + Layout{}; // (m,n) -> block_id for (int m = 0; m < size<0>(block_layout); ++m) { - mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + mcast_mask_b |= (uint16_t(1) << block_layout( + m, cluster_local_block_id.y, Int<0>{})); } } // Allocate predicate tensors for a_scales (since we can't guarantee that // all scales are valid, since we could have a partial tiles along M) - Tensor tApA_ScaleA = make_tensor(shape(tAsA_ScaleA(_,_,0))); - #pragma unroll + Tensor tApA_ScaleA = make_tensor(shape(tAsA_ScaleA(_, _, 0))); +#pragma unroll for (int i = 0; i < size(tApA_ScaleA); ++i) { tApA_ScaleA(i) = get<0>(tAcA_ScaleA(i)) < scales_m; } // Mainloop CUTLASS_PRAGMA_NO_UNROLL - for ( ; k_tile_count > 0; --k_tile_count) { + for (; k_tile_count > 0; --k_tile_count) { // LOCK smem_pipe_write for _writing_ pipeline.producer_acquire(smem_pipe_write); @@ -477,13 +574,25 @@ struct CollectiveMma< BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); // Copy operands A and B from global memory to shared memory - if (lane_predicate) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); - if (lane_predicate) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + if (lane_predicate) + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), + tAgA(_, _, _, *k_tile_iter), + tAsA(_, _, _, write_stage)); + if (lane_predicate) + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), + tBgB(_, _, _, *k_tile_iter), + tBsB(_, _, _, write_stage)); // Copy scale tensors from global memory to shared memory - copy_if(scale_copy_a, tApA_ScaleA, tAgA_ScaleA(_,_,*k_tile_iter), tAsA_ScaleA(_,_,write_stage)); - copy(scale_copy_b, tBgB_ScaleB(_,*k_tile_iter), tBsB_ScaleB(_,write_stage)); - pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc); + copy_if(scale_copy_a, + tApA_ScaleA, + tAgA_ScaleA(_, _, *k_tile_iter), + tAsA_ScaleA(_, _, write_stage)); + copy(scale_copy_b, + tBgB_ScaleB(_, *k_tile_iter), + tBsB_ScaleB(_, write_stage)); + pipeline.producer_commit(smem_pipe_write, + cutlass::arch::cpasync_barrier_arrive_noinc); ++k_tile_iter; @@ -493,10 +602,8 @@ struct CollectiveMma< } /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster - CUTLASS_DEVICE void - load_tail( - MainloopPipeline pipeline, - PipelineState smem_pipe_write) { + CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, + PipelineState smem_pipe_write) { int lane_predicate = cute::elect_one_sync(); // Issue the epilogue waits @@ -513,37 +620,46 @@ struct CollectiveMma< /// Perform a collective-scoped matrix multiply-accumulate /// Consumer Perspective - template < - class FrgTensorC - > - CUTLASS_DEVICE void - mma(MainloopPipeline pipeline, - PipelineState smem_pipe_read, - FrgTensorC& accum, - int k_tile_count, - int thread_idx, - TensorStorage& shared_tensors, - Params const& mainloop_params) { - - - static_assert(is_rmem::value, "C tensor must be rmem resident."); - static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); - static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + template + CUTLASS_DEVICE void mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + static_assert(is_rmem::value, + "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, + "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, + "Smem layout must be rank 3."); static_assert(cute::is_void_v, - "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + "SM90 GMMA mainloops cannot have a non-void copy atom for " + "smem sourced instructions."); static_assert(cute::is_void_v, - "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + "SM90 GMMA mainloops cannot have a non-void copy atom for " + "smem sourced instructions."); - Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), + SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), + SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) // Block scaling - Tensor sScaleAViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), - Layout< - Shape, Int>, cute::tuple_element_t<1, TileShape>, Int>, - Stride, _0, Int> - >{}); // ((ScaleGranularityM,ScaleMsPerTile),n,k) - Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k) + Tensor sScaleAViewAsC = make_tensor( + cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), + Layout< + Shape, Int>, + cute::tuple_element_t<1, TileShape>, + Int>, + Stride< + Stride<_0, _1>, + _0, + Int>>{}); // ((ScaleGranularityM,ScaleMsPerTile),n,k) + Tensor sScaleB = + make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), + SmemLayoutScaleB{}); // (k) // // Define C accumulators and A/B partitioning @@ -551,52 +667,68 @@ struct CollectiveMma< // Layout of warp group to thread mapping - static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and - stride<0>(typename TiledMma::BLayout{}) == 0 and - size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and - size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, - "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + static_assert( + stride<0>(typename TiledMma::ALayout{}) == 0 and + stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be " + "NumThreadsPerWarpGroup"); constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; - Layout warp_group_thread_layout = make_layout(Int{}, - Int{}); + Layout warp_group_thread_layout = + make_layout(Int{}, Int{}); - int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + int warp_group_idx = + __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); TiledMma tiled_mma; - auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); + auto thread_mma = + tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); - Tensor tCsScaleAViewAsC = tiled_mma.get_slice(thread_idx).partition_C(sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above is correct when partitioning A and B, but it is not correct when partitioning C. + Tensor tCsScaleAViewAsC = + tiled_mma.get_slice(thread_idx) + .partition_C( + sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above + // is correct when partitioning A and B, but + // it is not correct when partitioning C. - Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) // Allocate "fragments/descriptors" - Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N - CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K - CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE // // PIPELINED MAIN LOOP // - static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), - "ERROR : Incorrect number of MMAs in flight"); + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); // We release buffers to producer warps(dma load) with some mmas in flight PipelineState smem_pipe_release = smem_pipe_read; // Per block scale values for operand A and B - using RegLayoutScaleAViewAsC = decltype(make_layout_like(tCsScaleAViewAsC(_, _, _, 0).layout())); // `make_layout_like` makes a compact layout. - using RegLayoutScaleAEssential = decltype(filter_zeros(RegLayoutScaleAViewAsC{}.stride(), RegLayoutScaleAViewAsC{}.shape())); // an interface to traverse the underlying storage for the compact layout mentioned above + using RegLayoutScaleAViewAsC = decltype(make_layout_like( + tCsScaleAViewAsC(_, _, _, 0) + .layout())); // `make_layout_like` makes a compact layout. + using RegLayoutScaleAEssential = decltype(filter_zeros( + RegLayoutScaleAViewAsC{}.stride(), + RegLayoutScaleAViewAsC{} + .shape())); // an interface to traverse the underlying storage for + // the compact layout mentioned above - Tensor tCrScaleAViewAsC = make_tensor(RegLayoutScaleAViewAsC{}); // (MMA,MMA_M,MMA_N) + Tensor tCrScaleAViewAsC = make_tensor( + RegLayoutScaleAViewAsC{}); // (MMA,MMA_M,MMA_N) ElementBlockScale scale_b; // Prologue GMMAs @@ -604,12 +736,16 @@ struct CollectiveMma< tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - GmmaFP8AccumulationWithScale accumulation(accum, size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{}), size<2>(tCrA)); + GmmaFP8AccumulationWithScale accumulation( + accum, + size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{}), + size<2>(tCrA)); warpgroup_fence_operand(accumulation()); CUTLASS_PRAGMA_UNROLL - for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) - { - // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; + --k_tile_prologue) { + // WAIT on smem_pipe_read until its data are available (phase bit flips + // from rdPhaseBit value) auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); pipeline.consumer_wait(smem_pipe_read, barrier_token); @@ -623,11 +759,16 @@ struct CollectiveMma< scale_b = sScaleB[read_stage]; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { - tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{})); + tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC( + _, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{})); } if constexpr (ScaleMsPerTile == 1) { static_assert(size(RegLayoutScaleAEssential{}) == 1); - tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`. + tCrScaleAViewAsC.data()[0] = + __shfl_sync(0xffffffff, + tCrScaleAViewAsC.data()[0] * scale_b, + 0); // `tCrScaleAViewAsC.data()[0]` are all same in a + // warp group when `ScaleMsPerTile == 1`. } else { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { @@ -640,7 +781,10 @@ struct CollectiveMma< CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { // (V,M,K) x (V,N,K) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); + cute::gemm(tiled_mma, + tCrA(_, _, k_block, read_stage), + tCrB(_, _, k_block, read_stage), + accumulation()); tiled_mma.accumulate_ = GMMA::ScaleOut::One; } warpgroup_commit_batch(); @@ -656,9 +800,9 @@ struct CollectiveMma< k_tile_count -= prologue_mma_count; CUTLASS_PRAGMA_NO_UNROLL - for ( ; k_tile_count > 0; --k_tile_count) - { - // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + for (; k_tile_count > 0; --k_tile_count) { + // WAIT on smem_pipe_read until its data are available (phase bit flips + // from rdPhaseBit value) auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); pipeline.consumer_wait(smem_pipe_read, barrier_token); @@ -668,15 +812,21 @@ struct CollectiveMma< int read_stage = smem_pipe_read.index(); - // Load per block scale values from shared memory to registers (at most twice per block along M and exactly once per block along N) + // Load per block scale values from shared memory to registers (at most + // twice per block along M and exactly once per block along N) scale_b = sScaleB[read_stage]; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { - tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{})); + tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC( + _, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{})); } if constexpr (ScaleMsPerTile == 1) { static_assert(size(RegLayoutScaleAEssential{}) == 1); - tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`. + tCrScaleAViewAsC.data()[0] = + __shfl_sync(0xffffffff, + tCrScaleAViewAsC.data()[0] * scale_b, + 0); // `tCrScaleAViewAsC.data()[0]` are all same in a + // warp group when `ScaleMsPerTile == 1`. } else { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { @@ -694,19 +844,25 @@ struct CollectiveMma< CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { // (V,M,K) x (V,N,K) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); + cute::gemm(tiled_mma, + tCrA(_, _, k_block, read_stage), + tCrB(_, _, k_block, read_stage), + accumulation()); tiled_mma.accumulate_ = GMMA::ScaleOut::One; } warpgroup_commit_batch(); - /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to + /// ensure smem_pipe_write is consumed warpgroup_wait(); warpgroup_fence_operand(accumulation()); // Block scale the accumulators with reg tensor `tCrScaleAViewAsC` accumulation.scale_if_needed(tCrScaleAViewAsC); - pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + pipeline.consumer_release( + smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on + // it // Advance smem_pipe_read and smem_pipe_release ++smem_pipe_read; @@ -719,8 +875,9 @@ struct CollectiveMma< } /// Perform a Consumer Epilogue to release all buffers - CUTLASS_DEVICE void - mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, + PipelineState smem_pipe_release, + int k_tile_count) { // Prologue GMMAs int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); k_tile_count -= prologue_mma_count; @@ -731,7 +888,9 @@ struct CollectiveMma< warpgroup_wait<0>(); for (int count = 0; count < prologue_mma_count; ++count) { - pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + pipeline.consumer_release( + smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on + // it ++smem_pipe_release; } } @@ -739,6 +898,6 @@ struct CollectiveMma< ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace cutlass::gemm::collective +} // namespace cutlass::gemm::collective ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/device/gemm_universal_base_compat.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/device/gemm_universal_base_compat.h index 2edd5a228b..ae5ff6decc 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/device/gemm_universal_base_compat.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/device/gemm_universal_base_compat.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,20 +18,21 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file - \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and - batched array variants. + \brief The universal GEMM accommodates serial reductions, parallel reductions, + batched strided, and batched array variants. */ #pragma once @@ -54,385 +55,363 @@ //////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace device -{ +namespace cutlass { +namespace gemm { +namespace device { ///////////////////////////////////////////////////////////////////////////////////////////////// /* - This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088) - It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs - and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs. + This is the device layer from CUTLASS 2.10 (SHA - + cc85b64cf676c45f98a17e3a47c0aafcf817f088) It is replicated here since we + needed to duplicate kernel level APIs for mixed dtype GEMMs and SmoothQuant. + The newer device layer is not compatible with these older kernel level APIs. - Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support - that feature at the moment. + Note: While CUTLASS 3.x supports stream-k, none of the kernels in the + extensions folder support that feature at the moment. */ template -class GemmUniversalBaseCompat -{ -public: - using GemmKernel = GemmKernel_; - using ThreadblockShape = typename GemmKernel::Mma::Shape; +class GemmUniversalBaseCompat { + public: + using GemmKernel = GemmKernel_; + using ThreadblockShape = typename GemmKernel::Mma::Shape; - using ElementA = typename GemmKernel::ElementA; - using LayoutA = typename GemmKernel::LayoutA; - using TensorRefA = TensorRef; - static ComplexTransform const kTransformA = GemmKernel::kTransformA; + using ElementA = typename GemmKernel::ElementA; + using LayoutA = typename GemmKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = GemmKernel::kTransformA; - using ElementB = typename GemmKernel::ElementB; - using LayoutB = typename GemmKernel::LayoutB; - using TensorRefB = TensorRef; - static ComplexTransform const kTransformB = GemmKernel::kTransformB; + using ElementB = typename GemmKernel::ElementB; + using LayoutB = typename GemmKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = GemmKernel::kTransformB; - using ElementC = typename GemmKernel::ElementC; - using LayoutC = typename GemmKernel::LayoutC; - using TensorRefC = TensorRef; - using TensorRefD = TensorRef; + using ElementC = typename GemmKernel::ElementC; + using LayoutC = typename GemmKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; - using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC; + using ElementAccumulator = + typename GemmKernel::Mma::Policy::Operator::ElementC; - using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; - using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; - using Operator = typename GemmKernel::Operator; + using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; + using Operator = typename GemmKernel::Operator; - /// Argument structure - using Arguments = typename GemmKernel::Arguments; + /// Argument structure + using Arguments = typename GemmKernel::Arguments; -protected: - /// Kernel parameters object - typename GemmKernel::Params params_; + protected: + /// Kernel parameters object + typename GemmKernel::Params params_; -protected: - /// Private helper to obtain the grid dimensions with fix-up for split-K - static void get_grid_shape_(gemm::GemmCoord& grid_tiled_shape, int& gemm_k_size, Arguments const& args) - { + protected: + /// Private helper to obtain the grid dimensions with fix-up for split-K + static void get_grid_shape_(gemm::GemmCoord& grid_tiled_shape, + int& gemm_k_size, + Arguments const& args) { + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; - // Determine grid shape - ThreadblockSwizzle threadblock_swizzle; + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.batch_count); - grid_tiled_shape = threadblock_swizzle.get_tiled_shape( - args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); + gemm_k_size = args.problem_size.k(); - gemm_k_size = args.problem_size.k(); + if (args.mode == GemmUniversalMode::kGemm || + args.mode == GemmUniversalMode::kGemmSplitKParallel) { + int const kAlignK = + const_max(const_max(128 / sizeof_bits::value, + 128 / sizeof_bits::value), + 1); - if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) - { + gemm_k_size = + round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); - int const kAlignK - = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + if (gemm_k_size) { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } + } + } - gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); + public: + /// Constructs the GEMM. + GemmUniversalBaseCompat() {} - if (gemm_k_size) - { - grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); - } - } + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const& args) { + // Determine grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + ThreadblockSwizzle threadblock_swizzle; + dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape); + + uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1); + + if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax)) { + return Status::kErrorInvalidProblem; } -public: - /// Constructs the GEMM. - GemmUniversalBaseCompat() {} + return GemmKernel::can_implement(args); + } - /// Determines whether the GEMM can execute the given problem. - static Status can_implement(Arguments const& args) - { + /// Gets the workspace size + static size_t get_workspace_size(Arguments const& args) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()"); - // Determine grid shape - cutlass::gemm::GemmCoord grid_tiled_shape; - int gemm_k_size = 0; + size_t workspace_bytes = 0; - get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + // Determine grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; - ThreadblockSwizzle threadblock_swizzle; - dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape); + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); - uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1); - - if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax)) - { - - return Status::kErrorInvalidProblem; - } - - return GemmKernel::can_implement(args); + if (args.mode == GemmUniversalMode::kGemmSplitKParallel) { + // Split-K parallel always requires a temporary workspace + workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * + size_t(grid_tiled_shape.k()); + } else if (args.mode == GemmUniversalMode::kGemm && + grid_tiled_shape.k() > 1) { + // Serial split-K only requires a temporary workspace if the number of + // partitions along the GEMM K dimension is greater than one. + workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * + size_t(grid_tiled_shape.n()); } - /// Gets the workspace size - static size_t get_workspace_size(Arguments const& args) - { + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); - CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()"); + workspace_bytes += + GemmKernel::get_extra_workspace_size(args, grid_tiled_shape); - size_t workspace_bytes = 0; + return workspace_bytes; + } - // Determine grid shape - cutlass::gemm::GemmCoord grid_tiled_shape; - int gemm_k_size = 0; + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const& args) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()"); - get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + ThreadblockSwizzle threadblock_swizzle; - if (args.mode == GemmUniversalMode::kGemmSplitKParallel) - { + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; - // Split-K parallel always requires a temporary workspace - workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k()); - } - else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) - { + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape); - // Serial split-K only requires a temporary workspace if the number of partitions along the - // GEMM K dimension is greater than one. - workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); - } + CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n" + << " result = {" << result + << "}"); - CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + return result; + } - workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape); + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()"); - return workspace_bytes; - } + int max_active_blocks = -1; + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); - /// Computes the grid shape - static dim3 get_grid_shape(Arguments const& args) - { + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); - CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()"); + if (smem_size <= (48 << 10)) { + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + Kernel, + GemmKernel::kThreadCount, + smem_size); - ThreadblockSwizzle threadblock_swizzle; + if (result == cudaSuccess) { + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + } else { + // Query assuming zero shared memory then compute occupancy limit based on + // SMEM + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, Kernel, GemmKernel::kThreadCount, 0); - cutlass::gemm::GemmCoord grid_tiled_shape; - int gemm_k_size = 0; - - get_grid_shape_(grid_tiled_shape, gemm_k_size, args); - dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape); - - CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n" - << " result = {" << result << "}"); - - return result; - } - - /// Computes the maximum number of active blocks per multiprocessor - static int maximum_active_blocks(int smem_capacity = -1) - { - - CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()"); - - int max_active_blocks = -1; - int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); - - CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); - - if (smem_size <= (48 << 10)) - { - - cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, Kernel, GemmKernel::kThreadCount, smem_size); - - if (result == cudaSuccess) - { - CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); - return max_active_blocks; - } - } - else - { - - // Query assuming zero shared memory then compute occupancy limit based on SMEM - cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, Kernel, GemmKernel::kThreadCount, 0); - - if (result != cudaSuccess) - { - - CUTLASS_TRACE_HOST( - " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result)); - - return -1; - } - - if (smem_capacity < 0) - { - int device_idx = 0; - result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) - { - return -1; - } - - cudaDeviceProp properties; - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) - { - return -1; - } - - smem_capacity = static_cast(properties.sharedMemPerMultiprocessor); - } - - int occupancy = std::min(max_active_blocks, smem_capacity / smem_size); - - CUTLASS_TRACE_HOST(" occupancy: " << occupancy); - - return occupancy; - } - - CUTLASS_TRACE_HOST(" returning internal error"); + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " + << cudaGetErrorString(result)); return -1; - } + } - /// Initializes GEMM state from arguments. - Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) - { + if (smem_capacity < 0) { + int device_idx = 0; + result = cudaGetDevice(&device_idx); - CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace " - << workspace << ", stream: " << (stream ? "non-null" : "null")); - - size_t workspace_bytes = get_workspace_size(args); - - CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); - - if (workspace_bytes) - { - - if (!workspace) - { - CUTLASS_TRACE_HOST(" error: device workspace must not be null"); - - return Status::kErrorWorkspaceNull; - } - - if (args.mode == GemmUniversalMode::kGemm) - { - CUTLASS_TRACE_HOST(" clearing device workspace"); - cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream); - - if (result != cudaSuccess) - { - CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); - - return Status::kErrorInternal; - } - } + if (result != cudaSuccess) { + return -1; } - // Get CUDA grid shape - cutlass::gemm::GemmCoord grid_tiled_shape; - int gemm_k_size = 0; + cudaDeviceProp properties; + result = cudaGetDeviceProperties(&properties, device_idx); - get_grid_shape_(grid_tiled_shape, gemm_k_size, args); - - // Initialize the Params structure - params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast(workspace)); - - // Specify shared memory capacity for kernel. - int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); - - if (smem_size >= (48 << 10)) - { - cudaError_t result - = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - - if (result != cudaSuccess) - { - return Status::kErrorInternal; - } + if (result != cudaSuccess) { + return -1; } - return Status::kSuccess; + smem_capacity = static_cast(properties.sharedMemPerMultiprocessor); + } + + int occupancy = std::min(max_active_blocks, smem_capacity / smem_size); + + CUTLASS_TRACE_HOST(" occupancy: " << occupancy); + + return occupancy; } - /// Lightweight update given a subset of arguments - Status update(Arguments const& args, void* workspace = nullptr) - { + CUTLASS_TRACE_HOST(" returning internal error"); - CUTLASS_TRACE_HOST("GemmUniversalBaseCompat()::update() - workspace: " << workspace); + return -1; + } - size_t workspace_bytes = get_workspace_size(args); + /// Initializes GEMM state from arguments. + Status initialize(Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace " + << workspace + << ", stream: " << (stream ? "non-null" : "null")); - if (workspace_bytes && !workspace) - { - return Status::kErrorWorkspaceNull; + size_t workspace_bytes = get_workspace_size(args); + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + if (workspace_bytes) { + if (!workspace) { + CUTLASS_TRACE_HOST(" error: device workspace must not be null"); + + return Status::kErrorWorkspaceNull; + } + + if (args.mode == GemmUniversalMode::kGemm) { + CUTLASS_TRACE_HOST(" clearing device workspace"); + cudaError_t result = + cudaMemsetAsync(workspace, 0, workspace_bytes, stream); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " + << cudaGetErrorString(result)); + + return Status::kErrorInternal; } - - params_.update(args, workspace); - - return Status::kSuccess; + } } - /// Runs the kernel using initialized state. - Status run(cudaStream_t stream = nullptr) - { - CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()"); + // Get CUDA grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; - // - // Configure grid and block dimensions - // + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); - ThreadblockSwizzle threadblock_swizzle; + // Initialize the Params structure + params_ = typename GemmKernel::Params( + args, grid_tiled_shape, gemm_k_size, static_cast(workspace)); - dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); - dim3 block(GemmKernel::kThreadCount, 1, 1); + // Specify shared memory capacity for kernel. + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); - int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + if (smem_size >= (48 << 10)) { + cudaError_t result = + cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); - // - // Launch kernel - // - - CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block << "), SMEM: " << smem_size << " bytes"); - - // Launch - cutlass::Kernel<<>>(params_); - - // - // Query for errors - // - cudaError_t result = cudaGetLastError(); - - if (result != cudaSuccess) - { - CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); - return Status::kErrorInternal; - } - - return Status::kSuccess; + if (result != cudaSuccess) { + return Status::kErrorInternal; + } } - /// Runs the kernel using initialized state. - Status operator()(cudaStream_t stream = nullptr) - { - return run(stream); + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST( + "GemmUniversalBaseCompat()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; } - /// Runs the kernel using initialized state. - Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) - { + params_.update(args, workspace); - Status status = initialize(args, workspace, stream); + return Status::kSuccess; + } - if (status == Status::kSuccess) - { - status = run(stream); - } + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()"); - return status; + // + // Configure grid and block dimensions + // + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + // + // Launch kernel + // + + CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block + << "), SMEM: " << smem_size << " bytes"); + + // Launch + cutlass::Kernel<<>>(params_); + + // + // Query for errors + // + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" grid launch failed with error " + << cudaGetErrorString(result)); + return Status::kErrorInternal; } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { return run(stream); } + + /// Runs the kernel using initialized state. + Status operator()(Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace device -} // namespace gemm -} // namespace cutlass +} // namespace device +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/device/splitk_gemm_grouped.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/device/splitk_gemm_grouped.h index bfd3666b9c..02859ee0d1 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/device/splitk_gemm_grouped.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/device/splitk_gemm_grouped.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! @@ -55,488 +56,479 @@ //////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace device -{ +namespace cutlass { +namespace gemm { +namespace device { ///////////////////////////////////////////////////////////////////////////////////////////////// template -__global__ void splitkReduction(T_OUT** out_tensor, const T_IN* in_tensor, GemmCoord const* problem_sizes, int splitk, - int64_t* splitk_buffer_offsets) -{ - // in_tensor: [problem_idx, k_partition, hidden_size] - // Note that different requests of in_tensor might have different hidden_size (=m*n) - // so, we need to use splitk_buffer_offsets. - // out_tensor: problem_idx * [hidden_size] +__global__ void splitkReduction(T_OUT** out_tensor, + const T_IN* in_tensor, + GemmCoord const* problem_sizes, + int splitk, + int64_t* splitk_buffer_offsets) { + // in_tensor: [problem_idx, k_partition, hidden_size] + // Note that different requests of in_tensor might have different + // hidden_size (=m*n) so, we need to use splitk_buffer_offsets. + // out_tensor: problem_idx * [hidden_size] - int const problem_idx = blockIdx.y; - GemmCoord problem = problem_sizes[problem_idx]; - int const hidden_size = problem.m() * problem.n(); - const T_IN* in_tensor_ = in_tensor + splitk_buffer_offsets[problem_idx] * splitk; - T_OUT* out_tensor_ = out_tensor[problem_idx]; + int const problem_idx = blockIdx.y; + GemmCoord problem = problem_sizes[problem_idx]; + int const hidden_size = problem.m() * problem.n(); + const T_IN* in_tensor_ = + in_tensor + splitk_buffer_offsets[problem_idx] * splitk; + T_OUT* out_tensor_ = out_tensor[problem_idx]; - for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < hidden_size; i += blockDim.x * gridDim.x) - { - float sum = 0.0f; - for (int k_idx = 0; k_idx < splitk; k_idx++) - { - sum += (float) in_tensor_[k_idx * hidden_size + i]; - } - out_tensor_[i] = (T_OUT) (sum); + for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < hidden_size; + i += blockDim.x * gridDim.x) { + float sum = 0.0f; + for (int k_idx = 0; k_idx < splitk; k_idx++) { + sum += (float)in_tensor_[k_idx * hidden_size + i]; } + out_tensor_[i] = (T_OUT)(sum); + } } /// GEMM Grouped template -class BaseSplitkGrouped -{ -public: - using BaseKernel = BaseKernel_; +class BaseSplitkGrouped { + public: + using BaseKernel = BaseKernel_; - using ElementA = typename BaseKernel::ElementA; - using LayoutA = typename BaseKernel::LayoutA; - using TensorRefA = TensorRef; - static ComplexTransform const kTransformA = BaseKernel::kTransformA; - static int const kAlignmentA = BaseKernel::kAlignmentA; + using ElementA = typename BaseKernel::ElementA; + using LayoutA = typename BaseKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = BaseKernel::kTransformA; + static int const kAlignmentA = BaseKernel::kAlignmentA; - using ElementB = typename BaseKernel::ElementB; - using LayoutB = typename BaseKernel::LayoutB; - using TensorRefB = TensorRef; - static ComplexTransform const kTransformB = BaseKernel::kTransformB; - static int const kAlignmentB = BaseKernel::kAlignmentB; + using ElementB = typename BaseKernel::ElementB; + using LayoutB = typename BaseKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = BaseKernel::kTransformB; + static int const kAlignmentB = BaseKernel::kAlignmentB; - using ElementC = typename BaseKernel::ElementC; - using LayoutC = typename BaseKernel::LayoutC; - using TensorRefC = TensorRef; - using TensorRefD = TensorRef; - static int const kAlignmentC = BaseKernel::kAlignmentC; + using ElementC = typename BaseKernel::ElementC; + using LayoutC = typename BaseKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + static int const kAlignmentC = BaseKernel::kAlignmentC; - using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC; + using ElementAccumulator = + typename BaseKernel::Mma::Policy::Operator::ElementC; - using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp; - using ThreadblockSwizzle = typename threadblock::GemmSplitKHorizontalThreadblockSwizzle; + using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp; + using ThreadblockSwizzle = + typename threadblock::GemmSplitKHorizontalThreadblockSwizzle; - using Operator = typename BaseKernel::Operator; - using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator; + using Operator = typename BaseKernel::Operator; + using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator; - using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; - using MathOperator = typename WarpMmaOperator::MathOperator; - using OperatorClass = typename WarpMmaOperator::OperatorClass; - using ArchTag = typename WarpMmaOperator::ArchTag; - using ThreadblockShape = typename BaseKernel::Mma::Shape; - using WarpShape = typename BaseKernel::WarpShape; - using InstructionShape = typename BaseKernel::InstructionShape; - static int const kStages = BaseKernel::Mma::kStages; + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename WarpMmaOperator::MathOperator; + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + using ThreadblockShape = typename BaseKernel::Mma::Shape; + using WarpShape = typename BaseKernel::WarpShape; + using InstructionShape = typename BaseKernel::InstructionShape; + static int const kStages = BaseKernel::Mma::kStages; - /// Argument structure - using Arguments = typename BaseKernel::Arguments; + /// Argument structure + using Arguments = typename BaseKernel::Arguments; - using ProblemInfo = typename BaseKernel::ProblemVisitor::ProblemInfo; + using ProblemInfo = typename BaseKernel::ProblemVisitor::ProblemInfo; -protected: - /// Kernel parameters object - typename BaseKernel::Params gemm_params_; + protected: + /// Kernel parameters object + typename BaseKernel::Params gemm_params_; -private: - /// Get the number of tiles across all problems in a group - static int32_t group_tile_count(cutlass::gemm::GemmCoord const* problem_sizes_ptr, int problem_count) - { - int32_t tiles = 0; - for (int32_t i = 0; i < problem_count; ++i) - { - cutlass::gemm::GemmCoord problem = problem_sizes_ptr[i]; - BaseKernel::ProblemVisitor::possibly_transpose_problem(problem); - tiles += problem_tile_count(problem); - } - return tiles; + private: + /// Get the number of tiles across all problems in a group + static int32_t group_tile_count( + cutlass::gemm::GemmCoord const* problem_sizes_ptr, int problem_count) { + int32_t tiles = 0; + for (int32_t i = 0; i < problem_count; ++i) { + cutlass::gemm::GemmCoord problem = problem_sizes_ptr[i]; + BaseKernel::ProblemVisitor::possibly_transpose_problem(problem); + tiles += problem_tile_count(problem); + } + return tiles; + } + + /// Copy from `data` to `workspace` + Status copy_to_workspace(void* workspace, void* data, size_t bytes) { + cudaError_t cuda_error = + cudaMemcpy(workspace, data, bytes, cudaMemcpyHostToDevice); + if (cuda_error != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + cuda_error = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaMemcpy() returned error " + << cudaGetErrorString(cuda_error)); + return Status::kErrorInternal; } - /// Copy from `data` to `workspace` - Status copy_to_workspace(void* workspace, void* data, size_t bytes) - { - cudaError_t cuda_error = cudaMemcpy(workspace, data, bytes, cudaMemcpyHostToDevice); - if (cuda_error != cudaSuccess) - { - // Call cudaGetLastError() to clear the error bit - cuda_error = cudaGetLastError(); - CUTLASS_TRACE_HOST(" cudaMemcpy() returned error " << cudaGetErrorString(cuda_error)); - return Status::kErrorInternal; - } + return Status::kSuccess; + } - return Status::kSuccess; + /// Precomputes scheduling information for the grouped GEMM + Status precompute(Arguments const& args, + int32_t tile_count, + void* workspace) { + size_t workspace_bytes = get_workspace_size(args); + std::vector host_workspace(workspace_bytes); + BaseKernel::ProblemVisitor::host_precompute(args.host_problem_sizes, + args.problem_count, + args.threadblock_count, + (void*)host_workspace.data()); + return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes); + } + + /// Reorder `data` according to `indices` + template + static void reorder_array(T* data, std::vector const& indices) { + // For now, simply create a copy of the data and then copy over to the + // original. + std::vector copy(indices.size()); + for (size_t i = 0; i < indices.size(); ++i) { + copy.at(i) = data[indices[i]]; } - /// Precomputes scheduling information for the grouped GEMM - Status precompute(Arguments const& args, int32_t tile_count, void* workspace) - { - size_t workspace_bytes = get_workspace_size(args); - std::vector host_workspace(workspace_bytes); - BaseKernel::ProblemVisitor::host_precompute( - args.host_problem_sizes, args.problem_count, args.threadblock_count, (void*) host_workspace.data()); - return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes); + memcpy(data, copy.data(), indices.size() * sizeof(T)); + } + + public: + /// Constructs the GEMM. + BaseSplitkGrouped() {} + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const& args) { + return BaseKernel::can_implement(args); + } + + /// Get the number of tiles in a problem + static int32_t problem_tile_count(cutlass::gemm::GemmCoord const& problem) { + auto grid = BaseKernel::ProblemVisitor::grid_shape(problem); + return BaseKernel::ProblemVisitor::tile_count(grid); + } + + /// Get the number of tiles across all problems in a group + static int32_t group_tile_count(Arguments const& args) { + if (args.host_problem_sizes == nullptr) { + CUTLASS_TRACE_HOST("Received nullptr for `args.host_problem_sizes"); + return -1; } - /// Reorder `data` according to `indices` - template - static void reorder_array(T* data, std::vector const& indices) - { - // For now, simply create a copy of the data and then copy over to the original. - std::vector copy(indices.size()); - for (size_t i = 0; i < indices.size(); ++i) - { - copy.at(i) = data[indices[i]]; - } + return group_tile_count(args.host_problem_sizes, args.problem_count); + } - memcpy(data, copy.data(), indices.size() * sizeof(T)); + /// Gets the workspace size + static size_t get_workspace_size(Arguments const& args) { + size_t total_mn = 0; + for (int i = 0; i < args.problem_count; i++) { + total_mn += + args.host_problem_sizes[i].m() * args.host_problem_sizes[i].n(); + } + size_t workSpaceSize = + total_mn * sizeof(ElementAccumulator) * args.split_k_slices; + + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { + workSpaceSize += BaseKernel::ProblemVisitor::get_workspace_size( + args.host_problem_sizes, args.problem_count, args.threadblock_count); + } + return workSpaceSize; + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const& args) { + return dim3(args.threadblock_count, 1, 1); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + CUTLASS_TRACE_HOST("BaseSplitkGrouped::maximum_active_blocks()"); + + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); + + cudaError_t result; + if (smem_size > (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " + << cudaGetErrorString(result)); + return -1; + } } -public: - /// Constructs the GEMM. - BaseSplitkGrouped() {} + int max_active_blocks = -1; + result = + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, + Kernel, + BaseKernel::kThreadCount, + smem_size); - /// Determines whether the GEMM can execute the given problem. - static Status can_implement(Arguments const& args) - { - - return BaseKernel::can_implement(args); + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " + << cudaGetErrorString(result)); + return -1; } - /// Get the number of tiles in a problem - static int32_t problem_tile_count(cutlass::gemm::GemmCoord const& problem) - { - auto grid = BaseKernel::ProblemVisitor::grid_shape(problem); - return BaseKernel::ProblemVisitor::tile_count(grid); + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Sorts each pointer passed in according to the indices that sort + /// `problem_sizes_ptr` in descending order of problem-K dimension. + static void sort_problems(int problem_count, + cutlass::gemm::GemmCoord* problem_sizes_ptr, + int64_t* lda_host_ptr, + int64_t* ldb_host_ptr, + int64_t* ldc_host_ptr, + int64_t* ldd_host_ptr, + int64_t* offset_A_ptr, + int64_t* offset_B_ptr, + int64_t* offset_C_ptr, + int64_t* offset_D_ptr) { + std::vector indices(problem_count); + std::iota(indices.begin(), indices.end(), 0); + std::stable_sort(indices.begin(), + indices.end(), + [&problem_sizes_ptr](size_t i, size_t j) { + return problem_sizes_ptr[i].k() > + problem_sizes_ptr[j].k(); + }); + + reorder_array(problem_sizes_ptr, indices); + reorder_array(lda_host_ptr, indices); + reorder_array(ldb_host_ptr, indices); + reorder_array(ldc_host_ptr, indices); + reorder_array(ldd_host_ptr, indices); + reorder_array(offset_A_ptr, indices); + reorder_array(offset_B_ptr, indices); + reorder_array(offset_C_ptr, indices); + reorder_array(offset_D_ptr, indices); + } + + /// Computes the number of threadblocks to launch for the grouped kernel + static int sufficient( + cutlass::gemm::GemmCoord const* problem_sizes_ptr = nullptr, + int problem_count = 0, + int available_sm_count = -1) { + // Determine the number of blocks that would be launched to fill up a single + // wave on the GPU with each SM having maximum occupancy. + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " + << cudaGetErrorString(result)); + return 0; } - /// Get the number of tiles across all problems in a group - static int32_t group_tile_count(Arguments const& args) - { - if (args.host_problem_sizes == nullptr) - { - CUTLASS_TRACE_HOST("Received nullptr for `args.host_problem_sizes"); - return -1; - } - - return group_tile_count(args.host_problem_sizes, args.problem_count); + int multiprocessor_count; + result = cudaDeviceGetAttribute( + &multiprocessor_count, cudaDevAttrMultiProcessorCount, device_idx); + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " + << cudaGetErrorString(result)); + return 0; } - /// Gets the workspace size - static size_t get_workspace_size(Arguments const& args) - { - size_t total_mn = 0; - for (int i = 0; i < args.problem_count; i++) - { - total_mn += args.host_problem_sizes[i].m() * args.host_problem_sizes[i].n(); - } - size_t workSpaceSize = total_mn * sizeof(ElementAccumulator) * args.split_k_slices; - - if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) - { - workSpaceSize += BaseKernel::ProblemVisitor::get_workspace_size( - args.host_problem_sizes, args.problem_count, args.threadblock_count); - } - return workSpaceSize; + bool override_sm_count = + (available_sm_count < 0 || available_sm_count > multiprocessor_count); + if (override_sm_count) { + available_sm_count = multiprocessor_count; } - /// Computes the grid shape - static dim3 get_grid_shape(Arguments const& args) - { - - return dim3(args.threadblock_count, 1, 1); + int max_active_blocks = maximum_active_blocks(); + if (max_active_blocks <= 0) { + return 0; } - /// Computes the maximum number of active blocks per multiprocessor - static int maximum_active_blocks(int smem_capacity = -1) - { + int occupancy_based_block_count = available_sm_count * max_active_blocks; - CUTLASS_TRACE_HOST("BaseSplitkGrouped::maximum_active_blocks()"); - - int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); - - CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); - - cudaError_t result; - if (smem_size > (48 << 10)) - { - result = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - - if (result != cudaSuccess) - { - // Call cudaGetLastError() to clear the error bit - result = cudaGetLastError(); - CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result)); - return -1; - } - } - - int max_active_blocks = -1; - result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, Kernel, BaseKernel::kThreadCount, smem_size); - - if (result != cudaSuccess) - { - // Call cudaGetLastError() to clear the error bit - result = cudaGetLastError(); - CUTLASS_TRACE_HOST( - " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result)); - return -1; - } - - CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); - return max_active_blocks; + if (problem_sizes_ptr == nullptr || problem_count == 0) { + return occupancy_based_block_count; } - /// Sorts each pointer passed in according to the indices that sort - /// `problem_sizes_ptr` in descending order of problem-K dimension. - static void sort_problems(int problem_count, cutlass::gemm::GemmCoord* problem_sizes_ptr, int64_t* lda_host_ptr, - int64_t* ldb_host_ptr, int64_t* ldc_host_ptr, int64_t* ldd_host_ptr, int64_t* offset_A_ptr, - int64_t* offset_B_ptr, int64_t* offset_C_ptr, int64_t* offset_D_ptr) - { - std::vector indices(problem_count); - std::iota(indices.begin(), indices.end(), 0); - std::stable_sort(indices.begin(), indices.end(), - [&problem_sizes_ptr](size_t i, size_t j) { return problem_sizes_ptr[i].k() > problem_sizes_ptr[j].k(); }); + int total_tiles = group_tile_count(problem_sizes_ptr, problem_count); - reorder_array(problem_sizes_ptr, indices); - reorder_array(lda_host_ptr, indices); - reorder_array(ldb_host_ptr, indices); - reorder_array(ldc_host_ptr, indices); - reorder_array(ldd_host_ptr, indices); - reorder_array(offset_A_ptr, indices); - reorder_array(offset_B_ptr, indices); - reorder_array(offset_C_ptr, indices); - reorder_array(offset_D_ptr, indices); + // If the group contains a single problem, launching the exact number of + // threadblocks needed to cover the problem minimizes the work performed + // per threadblock in finding the next tile to compute. We return + // total_tiles unless the user has provided the SM count. + if (problem_count == 1 && override_sm_count) { + return total_tiles; } - /// Computes the number of threadblocks to launch for the grouped kernel - static int sufficient( - cutlass::gemm::GemmCoord const* problem_sizes_ptr = nullptr, int problem_count = 0, int available_sm_count = -1) - { - // Determine the number of blocks that would be launched to fill up a single - // wave on the GPU with each SM having maximum occupancy. - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - if (result != cudaSuccess) - { - // Call cudaGetLastError() to clear the error bit - result = cudaGetLastError(); - CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " << cudaGetErrorString(result)); - return 0; - } + // Choose between the full wave of threadblocks and the tile count. If there + // are fewer tiles in the group than threadblocks in the full wave, only + // some threadblocks will be assigned tiles. Those threadblocks + // which are not assigned tiles still need to perform the work of iterating + // through problem sizes to determine that they have no work to do. This + // competes for cycles with those threadblocks that are assigned tiles to + // compute. + return std::min(total_tiles, occupancy_based_block_count); + } - int multiprocessor_count; - result = cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device_idx); - if (result != cudaSuccess) - { - CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " << cudaGetErrorString(result)); - return 0; - } + /// Initializes GEMM state from arguments. + Status initialize(Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("BaseSplitkGrouped::initialize() - workspace " + << workspace + << ", stream: " << (stream ? "non-null" : "null")); - bool override_sm_count = (available_sm_count < 0 || available_sm_count > multiprocessor_count); - if (override_sm_count) - { - available_sm_count = multiprocessor_count; - } + // Workspace + size_t workspace_bytes = get_workspace_size(args); - int max_active_blocks = maximum_active_blocks(); - if (max_active_blocks <= 0) - { - return 0; - } - - int occupancy_based_block_count = available_sm_count * max_active_blocks; - - if (problem_sizes_ptr == nullptr || problem_count == 0) - { - return occupancy_based_block_count; - } - - int total_tiles = group_tile_count(problem_sizes_ptr, problem_count); - - // If the group contains a single problem, launching the exact number of - // threadblocks needed to cover the problem minimizes the work performed - // per threadblock in finding the next tile to compute. We return total_tiles - // unless the user has provided the SM count. - if (problem_count == 1 && override_sm_count) - { - return total_tiles; - } - - // Choose between the full wave of threadblocks and the tile count. If there - // are fewer tiles in the group than threadblocks in the full wave, only - // some threadblocks will be assigned tiles. Those threadblocks - // which are not assigned tiles still need to perform the work of iterating through - // problem sizes to determine that they have no work to do. This competes for cycles - // with those threadblocks that are assigned tiles to compute. - return std::min(total_tiles, occupancy_based_block_count); + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; } - /// Initializes GEMM state from arguments. - Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) - { - - CUTLASS_TRACE_HOST("BaseSplitkGrouped::initialize() - workspace " - << workspace << ", stream: " << (stream ? "non-null" : "null")); - - // Workspace - size_t workspace_bytes = get_workspace_size(args); - - if (workspace_bytes && !workspace) - { - return Status::kErrorWorkspaceNull; - } - - if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) - { - int32_t tile_count = group_tile_count(args); - Status status = precompute(args, tile_count, workspace); - if (status != Status::kSuccess) - { - return status; - } - - gemm_params_ = typename BaseKernel::Params(args, workspace, tile_count); - } - else - { - gemm_params_ = typename BaseKernel::Params(args, workspace); - } - - // Specify shared memory capacity for kernel. - int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); - - if (smem_size >= (48 << 10)) - { - cudaError_t result - = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - - if (result != cudaSuccess) - { - return Status::kErrorInternal; - } - } - - return Status::kSuccess; - } - - /// Lightweight update given a subset of arguments - Status update(Arguments const& args, void* workspace = nullptr) - { - - size_t workspace_bytes = get_workspace_size(args); - - if (workspace_bytes && !workspace) - { - return Status::kErrorWorkspaceNull; - } - - if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) - { - int32_t tile_count = group_tile_count(args); - Status status = precompute(args, tile_count, workspace); - if (status != Status::kSuccess) - { - return status; - } - - gemm_params_.update(args, workspace, tile_count); - } - else - { - gemm_params_.update(args, workspace); - } - - return Status::kSuccess; - } - - /// Runs the kernel using initialized state. - Status run(cudaStream_t stream = nullptr) - { - if (!gemm_params_.problem_visitor.problem_count) - { - return Status::kSuccess; - } - - // - // Launch kernel - // - - // Launch splitk grouped gemm - { - dim3 grid(gemm_params_.threadblock_count, 1, gemm_params_.split_k_slices); - dim3 block(BaseKernel::kThreadCount, 1, 1); - - int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); - cutlass::Kernel<<>>(gemm_params_); - - cudaError_t result = cudaGetLastError(); - - if (result != cudaSuccess) - { - CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); - return Status::kErrorInternal; - } - } - - // Launch splitkReduction - { - dim3 grid(32, gemm_params_.problem_visitor.problem_count); - dim3 block(256); - splitkReduction<<>>(gemm_params_.ptr_D, gemm_params_.ptr_D_split, - gemm_params_.problem_visitor.problem_sizes, gemm_params_.split_k_slices, - gemm_params_.splitk_buffer_offsets); - - cudaError_t result = cudaGetLastError(); - - if (result != cudaSuccess) - { - CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); - return Status::kErrorInternal; - } - } - - return Status::kSuccess; - } - - /// Runs the kernel using initialized state. - Status operator()(cudaStream_t stream = nullptr) - { - return run(stream); - } - - /// Initializes and runs the kernel. - Status operator()(Arguments const& args, void* workspace, cudaStream_t stream = nullptr) - { - - Status status = initialize(args, workspace, stream); - - if (status == Status::kSuccess) - { - status = run(stream); - } - + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { + int32_t tile_count = group_tile_count(args); + Status status = precompute(args, tile_count, workspace); + if (status != Status::kSuccess) { return status; + } + + gemm_params_ = typename BaseKernel::Params(args, workspace, tile_count); + } else { + gemm_params_ = typename BaseKernel::Params(args, workspace); } + + // Specify shared memory capacity for kernel. + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + cudaError_t result = + cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const& args, void* workspace = nullptr) { + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; + } + + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { + int32_t tile_count = group_tile_count(args); + Status status = precompute(args, tile_count, workspace); + if (status != Status::kSuccess) { + return status; + } + + gemm_params_.update(args, workspace, tile_count); + } else { + gemm_params_.update(args, workspace); + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + if (!gemm_params_.problem_visitor.problem_count) { + return Status::kSuccess; + } + + // + // Launch kernel + // + + // Launch splitk grouped gemm + { + dim3 grid(gemm_params_.threadblock_count, 1, gemm_params_.split_k_slices); + dim3 block(BaseKernel::kThreadCount, 1, 1); + + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + cutlass::Kernel + <<>>(gemm_params_); + + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" grid launch failed with error " + << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + // Launch splitkReduction + { + dim3 grid(32, gemm_params_.problem_visitor.problem_count); + dim3 block(256); + splitkReduction<<>>( + gemm_params_.ptr_D, + gemm_params_.ptr_D_split, + gemm_params_.problem_visitor.problem_sizes, + gemm_params_.split_k_slices, + gemm_params_.splitk_buffer_offsets); + + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" grid launch failed with error " + << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { return run(stream); } + + /// Initializes and runs the kernel. + Status operator()(Arguments const& args, + void* workspace, + cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// /// GEMM Grouped template -class SplitkGemmGrouped : public BaseSplitkGrouped -{ -public: - using GemmKernel = GemmKernel_; +class SplitkGemmGrouped : public BaseSplitkGrouped { + public: + using GemmKernel = GemmKernel_; }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace device -} // namespace gemm -} // namespace cutlass +} // namespace device +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/dispatch_policy.hpp b/custom_ops/gpu_ops/cutlass_extensions/gemm/dispatch_policy.hpp index f4cf0bf420..c276c4e9c0 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/dispatch_policy.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/dispatch_policy.hpp @@ -30,7 +30,8 @@ struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum // n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp // specialized dynamic schedule For FP8 kernels with Block Scaling -template , +template , class KernelSchedule = KernelTmaWarpSpecialized, int ScaleGranularityM = 0 // `ScaleGranularityM` specifies scaling granularity along M, @@ -38,7 +39,8 @@ template , // granularity is `size<0>(TileShape_MNK{})` along M. > struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8 - : MainloopSm90TmaGmmaWarpSpecialized { static_assert( cute::is_same_v< diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h index 3f83470270..0d25460e66 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,164 +27,190 @@ #include "cutlass_extensions/arch/mma.h" #include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ +namespace cutlass { +namespace gemm { +namespace kernel { template -struct MixedGemmArchTraits -{ - static_assert(dependent_false, "Unrecognised parameterization"); +struct MixedGemmArchTraits { + static_assert(dependent_false, "Unrecognised parameterization"); }; template -struct MixedGemmArchTraits -{ - static constexpr int Stages = 2; - using OperatorClass = cutlass::arch::OpClassSimt; - using AccType = float; - using LayoutB = cutlass::layout::ColumnMajor; +struct MixedGemmArchTraits { + static constexpr int Stages = 2; + using OperatorClass = cutlass::arch::OpClassSimt; + using AccType = float; + using LayoutB = cutlass::layout::ColumnMajor; - static constexpr int ElementsPerAccessA = 1; - static constexpr int ElementsPerAccessB = 1; - static constexpr int ElementsPerAccessC = 1; - static constexpr int ThreadblockK = 8; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + static constexpr int ElementsPerAccessA = 1; + static constexpr int ElementsPerAccessB = 1; + static constexpr int ElementsPerAccessC = 1; + static constexpr int ThreadblockK = 8; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using Operator = cutlass::arch::OpMultiplyAdd; + using Operator = cutlass::arch::OpMultiplyAdd; }; // ========================= Volta Traits =========================== // Volta will always dequantize after the global memory load. // This will instantiate any HMMA tensorcore kernels for Volta. -// Note that volta does not have native bfloat support so weights and activations will be casted to fp16 -// and compute will happen in fp16 then will be converted for bf16 output. +// Note that volta does not have native bfloat support so weights and +// activations will be casted to fp16 and compute will happen in fp16 then will +// be converted for bf16 output. template -struct MixedGemmArchTraits::value - || cutlass::platform::is_same::value>::type> -{ -private: - using LayoutDetails = LayoutDetailsB; +struct MixedGemmArchTraits< + TypeA, + TypeB, + cutlass::arch::Sm70, + typename cutlass::platform::enable_if< + cutlass::platform::is_same::value || + cutlass::platform::is_same::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; -public: - static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; - using OperatorClass = cutlass::arch::OpClassTensorOp; - using AccType = float; - using LayoutB = typename LayoutDetails::Layout; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; - static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; - static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; - static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; - using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + static constexpr int ElementsPerAccessA = + 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = + 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; - using Operator = typename LayoutDetails::Operator; + using Operator = typename LayoutDetails::Operator; }; // ======================= Turing Traits ============================== -// Note that turing does not have native bfloat support so weights and activations will be casted to fp16 -// and compute will happen in fp16 then will be converted for bf16 output. +// Note that turing does not have native bfloat support so weights and +// activations will be casted to fp16 and compute will happen in fp16 then will +// be converted for bf16 output. template -struct MixedGemmArchTraits::value - || cutlass::platform::is_same::value>::type> -{ -private: - using LayoutDetails = LayoutDetailsB; +struct MixedGemmArchTraits< + TypeA, + TypeB, + cutlass::arch::Sm75, + typename cutlass::platform::enable_if< + cutlass::platform::is_same::value || + cutlass::platform::is_same::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; -public: - static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; - using OperatorClass = cutlass::arch::OpClassTensorOp; - using AccType = float; - using LayoutB = typename LayoutDetails::Layout; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; - static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; - static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; - static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + static constexpr int ElementsPerAccessA = + 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = + 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; - using Operator = typename LayoutDetails::Operator; + using Operator = typename LayoutDetails::Operator; }; // ======================= Ampere Traits ============================== template -struct MixedGemmArchTraits::value - || cutlass::platform::is_same::value>::type> -{ -private: - using LayoutDetails = LayoutDetailsB; +struct MixedGemmArchTraits< + TypeA, + TypeB, + cutlass::arch::Sm80, + typename cutlass::platform::enable_if< + cutlass::platform::is_same::value || + cutlass::platform::is_same::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; -public: - static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; - using OperatorClass = cutlass::arch::OpClassTensorOp; - using AccType = float; - using LayoutB = typename LayoutDetails::Layout; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; - static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; - static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; - static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + static constexpr int ElementsPerAccessA = + 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = + 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; - using Operator = typename LayoutDetails::Operator; + using Operator = typename LayoutDetails::Operator; }; // ======================= Ada Traits ============================== template -struct MixedGemmArchTraits::value - || cutlass::platform::is_same::value>::type> -{ -private: - using LayoutDetails = LayoutDetailsB; +struct MixedGemmArchTraits< + TypeA, + TypeB, + cutlass::arch::Sm89, + typename cutlass::platform::enable_if< + cutlass::platform::is_same::value || + cutlass::platform::is_same::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; -public: - static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; - using OperatorClass = cutlass::arch::OpClassTensorOp; - using AccType = float; - using LayoutB = typename LayoutDetails::Layout; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; - static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; - static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; - static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits::value>; + static constexpr int ElementsPerAccessA = + 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = + 128 / cutlass::sizeof_bits::value; + using InstructionShape = + cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits::value>; - using Operator = typename LayoutDetails::Operator; + using Operator = typename LayoutDetails::Operator; }; // FP8 A/B = fp8, C/D = fp32 template -struct MixedGemmArchTraits::value - || cutlass::platform::is_same::value>::type> -{ -private: - using LayoutDetails = LayoutDetailsB; +struct MixedGemmArchTraits< + TypeA, + TypeB, + cutlass::arch::Sm89, + typename cutlass::platform::enable_if< + cutlass::platform::is_same::value || + cutlass::platform::is_same::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; -public: - static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; - using OperatorClass = cutlass::arch::OpClassTensorOp; - using AccType = float; - // be careful, TypeC should align with HopperGroupedGemmInput::OutputTypeAdaptor_t - using TypeC = __nv_bfloat16; - using LayoutB = typename LayoutDetails::Layout; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + // be careful, TypeC should align with + // HopperGroupedGemmInput::OutputTypeAdaptor_t + using TypeC = __nv_bfloat16; + using LayoutB = typename LayoutDetails::Layout; - static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; - static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; - static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits::value>; + static constexpr int ElementsPerAccessA = + 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = + 128 / cutlass::sizeof_bits::value; + using InstructionShape = + cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits::value>; - using Operator = typename LayoutDetails::Operator; + using Operator = typename LayoutDetails::Operator; }; -} // namespace kernel -} // namespace gemm -} // namespace cutlass +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/default_int8_traits.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/default_int8_traits.h index 3fd722994e..beab7fd253 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/default_int8_traits.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/default_int8_traits.h @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,36 +22,30 @@ #include "cutlass/gemm/gemm.h" #include "cutlass/layout/matrix.h" -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ +namespace cutlass { +namespace gemm { +namespace kernel { template -struct Int8GemmArchTraits -{ - using OperatorClass = cutlass::arch::OpClassSimt; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; +struct Int8GemmArchTraits { + using OperatorClass = cutlass::arch::OpClassSimt; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; }; // ======================= Turing Traits ============================== template <> -struct Int8GemmArchTraits -{ - using OperatorClass = cutlass::arch::OpClassTensorOp; - using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; +struct Int8GemmArchTraits { + using OperatorClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; }; // ======================= Ampere Traits ============================== template <> -struct Int8GemmArchTraits -{ - using OperatorClass = cutlass::arch::OpClassTensorOp; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; +struct Int8GemmArchTraits { + using OperatorClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; }; -} // namespace kernel -} // namespace gemm -} // namespace cutlass +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h index b83f6d76e0..f8ee5bdab9 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,19 +18,21 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file - \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. + \brief Template for a pipelined GEMM kernel. Does not compute batching or + support split-K. */ #pragma once @@ -46,523 +48,560 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ +namespace cutlass { +namespace gemm { +namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace detail -{ +namespace detail { template inline constexpr bool dependent_false_v = false; } -template -struct GemmFpAIntB -{ +template +struct GemmFpAIntB { + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static bool const kSplitKSerial = SplitKSerial; - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - static bool const kSplitKSerial = SplitKSerial; + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Mma::LayoutC; + using ElementScale = ElementC; - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename Mma::LayoutC; - using ElementScale = ElementC; + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformA; - static ComplexTransform const kTransformA = Mma::kTransformA; - static ComplexTransform const kTransformB = Mma::kTransformA; + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; - // Type definitions about the mainloop. - using Operator = typename Mma::Operator; - using OperatorClass = typename Mma::Operator::OperatorClass; - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename Mma::Operator::Shape; - using InstructionShape = typename Mma::Policy::Operator::InstructionShape; - using ArchTag = typename Mma::ArchTag; + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = + Epilogue::OutputTileIterator::kElementsPerAccess; - static int const kStages = Mma::kStages; - static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; - static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; + static constexpr int kInterleave = + Mma::IteratorB::Shape::kRow / Mma::Shape::kK; - static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; + /// Parameters structure + struct Arguments { + GemmUniversalMode mode = GemmUniversalMode::kGemm; - /// Parameters structure - struct Arguments - { - GemmUniversalMode mode = GemmUniversalMode::kGemm; + cutlass::gemm::GemmCoord problem_size; + int group_size; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorScale::TensorRef ref_scale; + typename Mma::IteratorScale::TensorRef ref_zero; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::TensorRef ref_D; - cutlass::gemm::GemmCoord problem_size; - int group_size; - typename Mma::IteratorA::TensorRef ref_A; - typename Mma::IteratorB::TensorRef ref_B; - typename Mma::IteratorScale::TensorRef ref_scale; - typename Mma::IteratorScale::TensorRef ref_zero; - typename Epilogue::OutputTileIterator::TensorRef ref_C; - typename Epilogue::OutputTileIterator::TensorRef ref_D; + // Control serial split-k + int batch_count; - // Control serial split-k - int batch_count; + typename EpilogueOutputOp::Params output_op; - typename EpilogueOutputOp::Params output_op; + // For gather+scatter operations + int const* gather_A_indices; + int const* gather_B_indices; + int const* scatter_D_indices; - // For gather+scatter operations - int const* gather_A_indices; - int const* gather_B_indices; - int const* scatter_D_indices; - - // Included so we can use Gemm Universal - int batch_stride_D = 0; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Arguments() {} - - CUTLASS_HOST_DEVICE - Arguments(cutlass::gemm::GemmCoord const& problem_size, int const group_size, - typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, - typename Mma::IteratorScale::TensorRef ref_scale, typename Mma::IteratorScale::TensorRef ref_zero, - typename Epilogue::OutputTileIterator::TensorRef ref_C, - typename Epilogue::OutputTileIterator::TensorRef ref_D, int serial_split_k_factor, - typename EpilogueOutputOp::Params output_op = typename EpilogueOutputOp::Params(), - int const* gather_A_indices = nullptr, int const* gather_B_indices = nullptr, - int const* scatter_D_indices = nullptr) - : problem_size(problem_size) - , group_size(group_size) - , ref_A(ref_A) - , ref_B(ref_B) - , ref_scale(ref_scale) - , ref_zero(ref_zero) - , ref_C(ref_C) - , ref_D(ref_D) - , batch_count(serial_split_k_factor) - , output_op(output_op) - , gather_A_indices(gather_A_indices) - , gather_B_indices(gather_B_indices) - , scatter_D_indices(scatter_D_indices) - { - } - }; - - /// Parameters structure - struct Params - { - cutlass::gemm::GemmCoord problem_size; - int group_size; - cutlass::gemm::GemmCoord grid_tiled_shape; - int swizzle_log_tile; - typename Mma::IteratorA::Params params_A; - typename Mma::IteratorA::TensorRef ref_A; - typename Mma::IteratorB::Params params_B; - typename Mma::IteratorB::TensorRef ref_B; - typename Mma::IteratorScale::Params params_scale; - typename Mma::IteratorScale::TensorRef ref_scale; - typename Mma::IteratorScale::TensorRef ref_zero; - typename Epilogue::OutputTileIterator::Params params_C; - typename Epilogue::OutputTileIterator::TensorRef ref_C; - typename Epilogue::OutputTileIterator::Params params_D; - typename Epilogue::OutputTileIterator::TensorRef ref_D; - typename EpilogueOutputOp::Params output_op; - int* semaphore; - int gemm_k_size; - // For gather+scatter operations - int const* gather_A_indices; - int const* gather_B_indices; - int const* scatter_D_indices; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Params() - : swizzle_log_tile(0) - , semaphore(0) - , gemm_k_size(0) - { - } - - CUTLASS_HOST_DEVICE - Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape, int const gemm_k_size, - void* workspace = nullptr) - : problem_size(args.problem_size) - , group_size(args.group_size) - , grid_tiled_shape(grid_tiled_shape) - , swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)) - , params_A(args.ref_A.layout()) - , ref_A(args.ref_A) - , params_B(args.ref_B.layout()) - , ref_B(args.ref_B) - , params_scale(args.ref_scale.layout()) - , ref_scale(args.ref_scale) - , ref_zero(args.ref_zero) - , params_C(args.ref_C.layout()) - , ref_C(args.ref_C) - , params_D(args.ref_D.layout()) - , ref_D(args.ref_D) - , output_op(args.output_op) - , semaphore(static_cast(workspace)) - , gemm_k_size(gemm_k_size) - , gather_A_indices(args.gather_A_indices) - , gather_B_indices(args.gather_B_indices) - , scatter_D_indices(args.scatter_D_indices) - { - } - }; - - /// Shared memory storage structure - union SharedStorage - { - typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; - }; + // Included so we can use Gemm Universal + int batch_stride_D = 0; // // Methods // CUTLASS_HOST_DEVICE - GemmFpAIntB() {} + Arguments() {} - /// Determines whether kernel satisfies alignment - static Status can_implement(Arguments const& args) - { - static int const kAlignmentA - = (platform::is_same>::value) ? 32 - : (platform::is_same>::value) + CUTLASS_HOST_DEVICE + Arguments(cutlass::gemm::GemmCoord const& problem_size, + int const group_size, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::IteratorScale::TensorRef ref_scale, + typename Mma::IteratorScale::TensorRef ref_zero, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D, + int serial_split_k_factor, + typename EpilogueOutputOp::Params output_op = + typename EpilogueOutputOp::Params(), + int const* gather_A_indices = nullptr, + int const* gather_B_indices = nullptr, + int const* scatter_D_indices = nullptr) + : problem_size(problem_size), + group_size(group_size), + ref_A(ref_A), + ref_B(ref_B), + ref_scale(ref_scale), + ref_zero(ref_zero), + ref_C(ref_C), + ref_D(ref_D), + batch_count(serial_split_k_factor), + output_op(output_op), + gather_A_indices(gather_A_indices), + gather_B_indices(gather_B_indices), + scatter_D_indices(scatter_D_indices) {} + }; + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + int group_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::Params params_B; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorScale::Params params_scale; + typename Mma::IteratorScale::TensorRef ref_scale; + typename Mma::IteratorScale::TensorRef ref_zero; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + typename EpilogueOutputOp::Params output_op; + int* semaphore; + int gemm_k_size; + // For gather+scatter operations + int const* gather_A_indices; + int const* gather_B_indices; + int const* scatter_D_indices; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() : swizzle_log_tile(0), semaphore(0), gemm_k_size(0) {} + + CUTLASS_HOST_DEVICE + Params(Arguments const& args, + cutlass::gemm::GemmCoord const& grid_tiled_shape, + int const gemm_k_size, + void* workspace = nullptr) + : problem_size(args.problem_size), + group_size(args.group_size), + grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + params_A(args.ref_A.layout()), + ref_A(args.ref_A), + params_B(args.ref_B.layout()), + ref_B(args.ref_B), + params_scale(args.ref_scale.layout()), + ref_scale(args.ref_scale), + ref_zero(args.ref_zero), + params_C(args.ref_C.layout()), + ref_C(args.ref_C), + params_D(args.ref_D.layout()), + ref_D(args.ref_D), + output_op(args.output_op), + semaphore(static_cast(workspace)), + gemm_k_size(gemm_k_size), + gather_A_indices(args.gather_A_indices), + gather_B_indices(args.gather_B_indices), + scatter_D_indices(args.scatter_D_indices) {} + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + GemmFpAIntB() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(Arguments const& args) { + static int const kAlignmentA = + (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) ? 64 : Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB - = (platform::is_same>::value) ? 32 - : (platform::is_same>::value) + static int const kAlignmentB = + (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) ? 64 : Mma::IteratorB::AccessType::kElements; - static int const kAlignmentScale = Mma::IteratorScale::AccessType::kElements; + static int const kAlignmentScale = + Mma::IteratorScale::AccessType::kElements; - static int const kAlignmentC = (platform::is_same>::value) + static int const kAlignmentC = + (platform::is_same>::value) ? 32 - : (platform::is_same>::value) + : (platform::is_same>::value) ? 64 : Epilogue::OutputTileIterator::kElementsPerAccess; - if (!TensorRef_aligned(args.ref_A, kAlignmentA)) - { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(args.ref_B, kAlignmentB)) - { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(args.ref_scale, kAlignmentScale)) - { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(args.ref_zero, kAlignmentScale)) - { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(args.ref_C, kAlignmentC)) - { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(args.ref_D, kAlignmentC)) - { - return Status::kErrorMisalignedOperand; - } - - if (!args.ref_scale.good()) - { - return Status::kErrorNotSupported; - } - - if constexpr (hasZero(Mma::QuantOp)) - { - if (!args.ref_zero.good()) - { - return Status::kErrorNotSupported; - } - } - else - { - if (args.ref_zero.good()) - { - return Status::kErrorNotSupported; - } - } - - if constexpr (isFinegrained(Mma::QuantOp)) - { - if (args.group_size != 64 && args.group_size != 128) - { - return Status::kErrorNotSupported; - } - } - - return Status::kSuccess; + if (!TensorRef_aligned(args.ref_A, kAlignmentA)) { + return Status::kErrorMisalignedOperand; } - static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) - { - - return 0; + if (!TensorRef_aligned(args.ref_B, kAlignmentB)) { + return Status::kErrorMisalignedOperand; } - // Initializes the fine grained scale+bias iterator. Needed since the fine grained iterator - // has a different constructor signature than a regular cutlass iterator - template = true> - CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params, - typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero, - typename IteratorScale::TensorCoord extent, int thread_id, - typename IteratorScale::TensorCoord const& threadblock_offset, int group_size) - { - - return IteratorScale(params, pointer_scale, pointer_zero, extent, thread_id, threadblock_offset, group_size); + if (!TensorRef_aligned(args.ref_scale, kAlignmentScale)) { + return Status::kErrorMisalignedOperand; } - template = true> - CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params, - typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero, - typename IteratorScale::TensorCoord extent, int thread_id, - typename IteratorScale::TensorCoord const& threadblock_offset, int group_size) - { - - return IteratorScale(params, pointer_scale, extent, thread_id, threadblock_offset); + if (!TensorRef_aligned(args.ref_zero, kAlignmentScale)) { + return Status::kErrorMisalignedOperand; } - CUTLASS_DEVICE - void run_kernel_(Params const& params, SharedStorage& shared_storage) - { - using LayoutB = typename Mma::IteratorB::Layout; - static_assert(platform::is_same::value && kInterleave == 1 - || platform::is_same::value && kInterleave >= 1, - "B must be row major/col major OR col major interleaved."); - - // Compute threadblock location - ThreadblockSwizzle threadblock_swizzle; - - cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // Early exit if CTA is out of range - if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() - || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) - { - - return; - } - - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{ - threadblock_tile_offset.m() * Mma::Shape::kM, - threadblock_tile_offset.k() * params.gemm_k_size, - }; - - cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave, - threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave}; - - typename MatrixCoord::Index fg_row_offset = threadblock_tile_offset.k() * params.gemm_k_size / 64; - typename MatrixCoord::Index scale_row_offset = isFinegrained(Mma::QuantOp) ? fg_row_offset : 0; - cutlass::MatrixCoord tb_offset_scale{scale_row_offset, threadblock_tile_offset.n() * Mma::Shape::kN}; - - // Problem size is a function of threadblock index in the K dimension - int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size); - - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; - - // Compute position within threadblock - int thread_idx = threadIdx.x; - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A(params.params_A, params.ref_A.data(), - {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A, params.gather_A_indices); - - typename Mma::IteratorB iterator_B(params.params_B, params.ref_B.data(), - {problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, thread_idx, tb_offset_B, - params.gather_B_indices); - - typename MatrixCoord::Index scale_row_extent = isFinegrained(Mma::QuantOp) ? problem_size_k / 64 : 1; - typename Mma::IteratorScale iterator_scale = initialize_scale( - params.params_scale, params.ref_scale.data(), params.ref_zero.data(), - {scale_row_extent, params.problem_size.n()}, thread_idx, tb_offset_scale, params.group_size); - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - int lane_idx = threadIdx.x % 32; - - // - // Main loop - // - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx); - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - - if (!kSplitKSerial || gemm_k_iterations > 0) - { - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators); - } - - // - // Epilogue - // - - EpilogueOutputOp output_op(params.output_op); - - // - // Masked tile iterators constructed from members - // - - threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // assume identity swizzle - MatrixCoord threadblock_offset( - threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN); - - int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); - - // Construct the semaphore. - Semaphore semaphore(params.semaphore + block_idx, thread_idx); - - // If performing a reduction via split-K, fetch the initial synchronization - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) - { - - // Fetch the synchronization lock initially but do not block. - semaphore.fetch(); - - // Indicate which position in a serial reduction the output operator is currently updating - output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); - } - - // Tile iterator loading from source tensor. - typename Epilogue::OutputTileIterator iterator_C(params.params_C, params.ref_C.data(), params.problem_size.mn(), - thread_idx, threadblock_offset, params.scatter_D_indices); - - // Tile iterator writing to destination tensor. - typename Epilogue::OutputTileIterator iterator_D(params.params_D, params.ref_D.data(), params.problem_size.mn(), - thread_idx, threadblock_offset, params.scatter_D_indices); - - Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); - - // Wait on the semaphore - this latency may have been covered by iterator construction - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) - { - - // For subsequent threadblocks, the source matrix is held in the 'D' tensor. - if (threadblock_tile_offset.k()) - { - iterator_C = iterator_D; - } - - semaphore.wait(threadblock_tile_offset.k()); - } - - // Execute the epilogue operator to update the destination tensor. - epilogue(output_op, iterator_D, accumulators, iterator_C); - - // - // Release the semaphore - // - - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) - { - - int lock = 0; - if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) - { - - // The final threadblock resets the semaphore for subsequent grids. - lock = 0; - } - else - { - // Otherwise, the semaphore is incremented - lock = threadblock_tile_offset.k() + 1; - } - - semaphore.release(lock); - } + if (!TensorRef_aligned(args.ref_C, kAlignmentC)) { + return Status::kErrorMisalignedOperand; } - template - CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) - { - if constexpr (platform::is_same::value) - { - run_kernel_(params, shared_storage); - } - else - { - CUTLASS_NOT_IMPLEMENTED(); - } + if (!TensorRef_aligned(args.ref_D, kAlignmentC)) { + return Status::kErrorMisalignedOperand; } - /* - To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond - to the ArchTag of the cutlass kernel operator. - */ - /// Executes one GEMM - CUTLASS_DEVICE - void operator()(Params const& params, SharedStorage& shared_storage) - { + if (!args.ref_scale.good()) { + return Status::kErrorNotSupported; + } + + if constexpr (hasZero(Mma::QuantOp)) { + if (!args.ref_zero.good()) { + return Status::kErrorNotSupported; + } + } else { + if (args.ref_zero.good()) { + return Status::kErrorNotSupported; + } + } + + if constexpr (isFinegrained(Mma::QuantOp)) { + if (args.group_size != 64 && args.group_size != 128) { + return Status::kErrorNotSupported; + } + } + + return Status::kSuccess; + } + + static size_t get_extra_workspace_size( + Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) { + return 0; + } + + // Initializes the fine grained scale+bias iterator. Needed since the fine + // grained iterator has a different constructor signature than a regular + // cutlass iterator + template = true> + CUTLASS_DEVICE static IteratorScale initialize_scale( + typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, + typename IteratorScale::Pointer pointer_zero, + typename IteratorScale::TensorCoord extent, + int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, + int group_size) { + return IteratorScale(params, + pointer_scale, + pointer_zero, + extent, + thread_id, + threadblock_offset, + group_size); + } + + template = true> + CUTLASS_DEVICE static IteratorScale initialize_scale( + typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, + typename IteratorScale::Pointer pointer_zero, + typename IteratorScale::TensorCoord extent, + int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, + int group_size) { + return IteratorScale( + params, pointer_scale, extent, thread_id, threadblock_offset); + } + + CUTLASS_DEVICE + void run_kernel_(Params const& params, SharedStorage& shared_storage) { + using LayoutB = typename Mma::IteratorB::Layout; + static_assert(platform::is_same::value && + kInterleave == 1 || + platform::is_same::value && + kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + return; + } + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.k() * params.gemm_k_size, + }; + + cutlass::MatrixCoord tb_offset_B{ + threadblock_tile_offset.k() * params.gemm_k_size * kInterleave, + threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave}; + + typename MatrixCoord::Index fg_row_offset = + threadblock_tile_offset.k() * params.gemm_k_size / 64; + typename MatrixCoord::Index scale_row_offset = + isFinegrained(Mma::QuantOp) ? fg_row_offset : 0; + cutlass::MatrixCoord tb_offset_scale{ + scale_row_offset, threadblock_tile_offset.n() * Mma::Shape::kN}; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k = + min(params.problem_size.k(), + (threadblock_tile_offset.k() + 1) * params.gemm_k_size); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = + (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / + Mma::Shape::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + params.ref_A.data(), + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A, + params.gather_A_indices); + + typename Mma::IteratorB iterator_B( + params.params_B, + params.ref_B.data(), + {problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, + thread_idx, + tb_offset_B, + params.gather_B_indices); + + typename MatrixCoord::Index scale_row_extent = + isFinegrained(Mma::QuantOp) ? problem_size_k / 64 : 1; + typename Mma::IteratorScale iterator_scale = + initialize_scale( + params.params_scale, + params.ref_scale.data(), + params.ref_zero.data(), + {scale_row_extent, params.problem_size.n()}, + thread_idx, + tb_offset_scale, + params.group_size); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, + params.group_size, + thread_idx, + warp_idx, + lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + if (!kSplitKSerial || gemm_k_iterations > 0) { + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + iterator_scale, + accumulators); + } + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN); + + int block_idx = threadblock_tile_offset.m() + + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // If performing a reduction via split-K, fetch the initial synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is + // currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), + params.grid_tiled_shape.k()); + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C(params.params_C, + params.ref_C.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset, + params.scatter_D_indices); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D(params.params_D, + params.ref_D.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset, + params.scatter_D_indices); + + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator + // construction + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + // For subsequent threadblocks, the source matrix is held in the 'D' + // tensor. + if (threadblock_tile_offset.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // + // Release the semaphore + // + + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + } + + template + CUTLASS_DEVICE void run_kernel(Params const& params, + SharedStorage& shared_storage) { + if constexpr (platform::is_same::value) { + run_kernel_(params, shared_storage); + } else { + CUTLASS_NOT_IMPLEMENTED(); + } + } + + /* + To improve compilation speed, we do not compile the device operator if the + CUDA_ARCH does not correspond to the ArchTag of the cutlass kernel + operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) { #if defined(__CUDA_ARCH__) #if (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750) - run_kernel(params, shared_storage); + run_kernel(params, shared_storage); #elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) - run_kernel(params, shared_storage); + run_kernel(params, shared_storage); #elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890) - run_kernel(params, shared_storage); + run_kernel(params, shared_storage); #elif (__CUDA_ARCH__ == 890) - run_kernel(params, shared_storage); + run_kernel(params, shared_storage); #elif (__CUDA_ARCH__ >= 900) - CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels. + CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use + // CUTLASS 3.x kernels. #else - static_assert( - false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."); + static_assert(false, + "Invalid architecture being compiled. Only Volta+ supported " + "in weight-only quantization kernels."); #endif #else - CUTLASS_NOT_IMPLEMENTED(); + CUTLASS_NOT_IMPLEMENTED(); #endif - } + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace kernel -} // namespace gemm -} // namespace cutlass +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp index 15faad26ee..cc143751bc 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp @@ -54,15 +54,16 @@ namespace cutlass::gemm::kernel { * 2.x API type argument order. Template arguments without two names * belong to the 3.x API only. **/ -template + class TileScheduler_ = void, + class Enable = void> class GemmUniversalGated; //////////////////////////////////////////////////////////////////////////////// -} // namespace cutlass::gemm::kernel +} // namespace cutlass::gemm::kernel //////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h index 6c4c578c9c..be80a3feec 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,25 +18,27 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief GEMM kernel to support the epilogue visitor model for customized softmax partial reduction epilogue fusion. - This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once - its usage has been stabilized. For now, it is included in this example to demonstrate - some basic output fusion options. + This source file will likely be moved to `include/cutlass/gemm/kernel/` in + the future once its usage has been stabilized. For now, it is included in + this example to demonstrate some basic output fusion options. - original file: 3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h + original file: + 3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h */ #pragma once @@ -55,533 +57,527 @@ namespace tk = common; ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ +namespace cutlass { +namespace gemm { +namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// -template -struct GemmWithEpilogueVisitor -{ -public: - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueVisitor = typename Epilogue::Visitor; - using ThreadblockSwizzle = ThreadblockSwizzle_; +template +struct GemmWithEpilogueVisitor { + public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueVisitor = typename Epilogue::Visitor; + using ThreadblockSwizzle = ThreadblockSwizzle_; - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using TensorRefA = TensorRef; + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using TensorRefA = TensorRef; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - using TensorRefB = TensorRef; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using TensorRefB = TensorRef; - using ElementCompute = typename EpilogueVisitor::ElementCompute; - using LayoutAlphaCol = cutlass::layout::RowMajor; - using LayoutAlphaRow = cutlass::layout::ColumnMajor; - using TensorRefAlphaCol = TensorRef; - using TensorRefAlphaRow = TensorRef; + using ElementCompute = typename EpilogueVisitor::ElementCompute; + using LayoutAlphaCol = cutlass::layout::RowMajor; + using LayoutAlphaRow = cutlass::layout::ColumnMajor; + using TensorRefAlphaCol = TensorRef; + using TensorRefAlphaRow = TensorRef; - using ElementC = typename EpilogueVisitor::ElementOutput; - using LayoutC = typename Epilogue::Layout; - using TensorRefC = TensorRef; + using ElementC = typename EpilogueVisitor::ElementOutput; + using LayoutC = typename Epilogue::Layout; + using TensorRefC = TensorRef; - static ComplexTransform const kTransformA = Mma::kTransformA; - static ComplexTransform const kTransformB = Mma::kTransformB; - using Operator = typename Mma::Operator; + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; - using OperatorClass = typename Mma::Operator::OperatorClass; - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename Mma::Operator::Shape; - using InstructionShape = typename Mma::Policy::Operator::InstructionShape; - using ArchTag = typename Mma::ArchTag; - using EpilogueOutputOp = - typename Epilogue::Visitor::ElementwiseFunctor; // Define type so GemmUniversalBase doesn't complain + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + using EpilogueOutputOp = typename Epilogue::Visitor:: + ElementwiseFunctor; // Define type so GemmUniversalBase doesn't complain - static int const kStages = Mma::kStages; - static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; - static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; - /// Split-K preserves splits that are 128b aligned - static int const kSplitKAlignment - = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment = const_max( + 128 / sizeof_bits::value, 128 / sizeof_bits::value); + // + // Structures + // + + /// Argument structure + struct Arguments { // - // Structures + // Data members // - /// Argument structure - struct Arguments - { + GemmUniversalMode mode; + GemmCoord problem_size; + int batch_count; - // - // Data members - // + TensorRefA ref_A; + TensorRefB ref_B; + tk::QuantMode quant_option; + TensorRefAlphaCol ref_alpha_col; + TensorRefAlphaRow ref_alpha_row; + TensorRefC ref_C; + TensorRefC ref_D; - GemmUniversalMode mode; - GemmCoord problem_size; - int batch_count; + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_D; - TensorRefA ref_A; - TensorRefB ref_B; - tk::QuantMode quant_option; - TensorRefAlphaCol ref_alpha_col; - TensorRefAlphaRow ref_alpha_row; - TensorRefC ref_C; - TensorRefC ref_D; + typename EpilogueVisitor::Arguments epilogue_visitor; - int64_t batch_stride_A; - int64_t batch_stride_B; - int64_t batch_stride_D; - - typename EpilogueVisitor::Arguments epilogue_visitor; - - // - // Methods - // - - Arguments() - : mode(GemmUniversalMode::kGemm) - , batch_count(1) - { - } - - /// constructs an arguments structure - Arguments(GemmUniversalMode mode_, GemmCoord problem_size_, int batch_count_, TensorRefA ref_A_, - TensorRefB ref_B_, tk::QuantMode quant_option_, TensorRefAlphaCol ref_alpha_col_, - TensorRefAlphaRow ref_alpha_row_, TensorRefC ref_C_, TensorRefC ref_D_, int64_t batch_stride_A_, - int64_t batch_stride_B_, typename EpilogueVisitor::Arguments epilogue_visitor_) - : mode(mode_) - , problem_size(problem_size_) - , batch_count(batch_count_) - , ref_A(ref_A_) - , ref_B(ref_B_) - , quant_option(quant_option_) - , ref_alpha_col(ref_alpha_col_) - , ref_alpha_row(ref_alpha_row_) - , ref_C(ref_C_) - , ref_D(ref_D_) - , batch_stride_A(batch_stride_A_) - , batch_stride_B(batch_stride_B_) - , batch_stride_D(0) - , epilogue_visitor(epilogue_visitor_) - { - } - }; - - // - // Structure for precomputing values in host memory and passing to kernels - // - - /// Parameters structure - struct Params - { - - cutlass::gemm::GemmCoord problem_size; - cutlass::gemm::GemmCoord grid_tiled_shape; - int swizzle_log_tile; - - typename Mma::IteratorA::Params params_A; - typename Mma::IteratorB::Params params_B; - typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_col; - typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_row; - typename EpilogueVisitor::OutputTileIterator::Params params_C; - typename EpilogueVisitor::OutputTileIterator::Params params_D; - - GemmUniversalMode mode; - int batch_count; - int gemm_k_size; - - void* ptr_A; - void* ptr_B; - tk::QuantMode quant_option; - typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_col; - typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_row; - ElementC* ptr_C; - ElementC* ptr_D; - - int64_t batch_stride_A; - int64_t batch_stride_B; - - typename EpilogueVisitor::Params epilogue_visitor; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Params() - : swizzle_log_tile(0) - , params_A(0) - , params_B(0) - , params_alpha_col(0) - , params_C(0) - , params_D(0) - , batch_count(0) - , gemm_k_size(0) - , mode(cutlass::gemm::GemmUniversalMode::kGemm) - , ptr_A(nullptr) - , ptr_B(nullptr) - , ptr_alpha_col(nullptr) - , ptr_alpha_row(nullptr) - , ptr_C(nullptr) - , ptr_D(nullptr) - , batch_stride_A(0) - , batch_stride_B(0) - { - } - - Params( - Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape_, int gemm_k_size_, int* workspace_) - : problem_size(args.problem_size) - , swizzle_log_tile(0) - , params_A(args.ref_A.layout()) - , params_B(args.ref_B.layout()) - , params_alpha_col(args.ref_alpha_col.layout()) - , params_alpha_row(args.ref_alpha_col.layout()) - , params_C(args.ref_C.layout()) - , params_D(args.ref_D.layout()) - , mode(args.mode) - , batch_count(args.batch_count) - , gemm_k_size(args.problem_size.k()) - , ptr_A(args.ref_A.data()) - , ptr_B(args.ref_B.data()) - , quant_option(args.quant_option) - , ptr_alpha_col(args.ref_alpha_col.data()) - , ptr_alpha_row(args.ref_alpha_row.data()) - , ptr_C(args.ref_C.data()) - , ptr_D(args.ref_D.data()) - , batch_stride_A(args.batch_stride_A) - , batch_stride_B(args.batch_stride_B) - , epilogue_visitor(args.epilogue_visitor) - { - - ThreadblockSwizzle threadblock_swizzle; - - grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.problem_size, - {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); - - if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) - { - - int const kAlignK - = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); - - gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); - - if (gemm_k_size) - { - grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); - } - } - - swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); - } - }; - - /// Shared memory storage structure - union SharedStorage - { - - typename Mma::SharedStorage main_loop; - - struct - { - typename Epilogue::SharedStorage epilogue; - typename EpilogueVisitor::SharedStorage visitor; - } epilogue; - }; - -public: // // Methods // - CUTLASS_DEVICE - GemmWithEpilogueVisitor() {} + Arguments() : mode(GemmUniversalMode::kGemm), batch_count(1) {} - /// Determines whether kernel satisfies alignment - static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) - { + /// constructs an arguments structure + Arguments(GemmUniversalMode mode_, + GemmCoord problem_size_, + int batch_count_, + TensorRefA ref_A_, + TensorRefB ref_B_, + tk::QuantMode quant_option_, + TensorRefAlphaCol ref_alpha_col_, + TensorRefAlphaRow ref_alpha_row_, + TensorRefC ref_C_, + TensorRefC ref_D_, + int64_t batch_stride_A_, + int64_t batch_stride_B_, + typename EpilogueVisitor::Arguments epilogue_visitor_) + : mode(mode_), + problem_size(problem_size_), + batch_count(batch_count_), + ref_A(ref_A_), + ref_B(ref_B_), + quant_option(quant_option_), + ref_alpha_col(ref_alpha_col_), + ref_alpha_row(ref_alpha_row_), + ref_C(ref_C_), + ref_D(ref_D_), + batch_stride_A(batch_stride_A_), + batch_stride_B(batch_stride_B_), + batch_stride_D(0), + epilogue_visitor(epilogue_visitor_) {} + }; - CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()"); + // + // Structure for precomputing values in host memory and passing to kernels + // - static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; - static int const kAlignmentC = EpilogueVisitor::OutputTileIterator::kElementsPerAccess; + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; - bool isAMisaligned = false; - bool isBMisaligned = false; - bool isCMisaligned = false; + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_col; + typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_row; + typename EpilogueVisitor::OutputTileIterator::Params params_C; + typename EpilogueVisitor::OutputTileIterator::Params params_D; - if (platform::is_same::value) - { - isAMisaligned = problem_size.k() % kAlignmentA; - } - else if (platform::is_same::value) - { - isAMisaligned = problem_size.m() % kAlignmentA; - } - else if (platform::is_same>::value - || platform::is_same>::value) - { - isAMisaligned = problem_size.k() % kAlignmentA; - } + GemmUniversalMode mode; + int batch_count; + int gemm_k_size; - if (platform::is_same::value) - { - isBMisaligned = problem_size.n() % kAlignmentB; - } - else if (platform::is_same::value) - { - isBMisaligned = problem_size.k() % kAlignmentB; - } - else if (platform::is_same>::value - || platform::is_same>::value) - { - isBMisaligned = problem_size.k() % kAlignmentB; - } + void* ptr_A; + void* ptr_B; + tk::QuantMode quant_option; + typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_col; + typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_row; + ElementC* ptr_C; + ElementC* ptr_D; - if (platform::is_same::value) - { - isCMisaligned = problem_size.n() % kAlignmentC; - } - else if (platform::is_same::value) - { - isCMisaligned = problem_size.m() % kAlignmentC; - } - else if (platform::is_same>::value - || platform::is_same>::value) - { - isCMisaligned = problem_size.n() % kAlignmentC; - } + int64_t batch_stride_A; + int64_t batch_stride_B; - if (isAMisaligned) - { - CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); - return Status::kErrorMisalignedOperand; + typename EpilogueVisitor::Params epilogue_visitor; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : swizzle_log_tile(0), + params_A(0), + params_B(0), + params_alpha_col(0), + params_C(0), + params_D(0), + batch_count(0), + gemm_k_size(0), + mode(cutlass::gemm::GemmUniversalMode::kGemm), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_alpha_col(nullptr), + ptr_alpha_row(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + batch_stride_A(0), + batch_stride_B(0) {} + + Params(Arguments const& args, + cutlass::gemm::GemmCoord const& grid_tiled_shape_, + int gemm_k_size_, + int* workspace_) + : problem_size(args.problem_size), + swizzle_log_tile(0), + params_A(args.ref_A.layout()), + params_B(args.ref_B.layout()), + params_alpha_col(args.ref_alpha_col.layout()), + params_alpha_row(args.ref_alpha_col.layout()), + params_C(args.ref_C.layout()), + params_D(args.ref_D.layout()), + mode(args.mode), + batch_count(args.batch_count), + gemm_k_size(args.problem_size.k()), + ptr_A(args.ref_A.data()), + ptr_B(args.ref_B.data()), + quant_option(args.quant_option), + ptr_alpha_col(args.ref_alpha_col.data()), + ptr_alpha_row(args.ref_alpha_row.data()), + ptr_C(args.ref_C.data()), + ptr_D(args.ref_D.data()), + batch_stride_A(args.batch_stride_A), + batch_stride_B(args.batch_stride_B), + epilogue_visitor(args.epilogue_visitor) { + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.batch_count); + + if (args.mode == GemmUniversalMode::kGemm || + args.mode == GemmUniversalMode::kGemmSplitKParallel) { + int const kAlignK = + const_max(const_max(128 / sizeof_bits::value, + 128 / sizeof_bits::value), + 1); + + gemm_k_size = round_up( + ceil_div(args.problem_size.k(), args.batch_count), kAlignK); + + if (gemm_k_size) { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); } + } - if (isBMisaligned) - { - CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); - return Status::kErrorMisalignedOperand; - } + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + } + }; - if (isCMisaligned) - { - CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); - return Status::kErrorMisalignedOperand; - } + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; - CUTLASS_TRACE_HOST(" returning kSuccess"); + struct { + typename Epilogue::SharedStorage epilogue; + typename EpilogueVisitor::SharedStorage visitor; + } epilogue; + }; - return Status::kSuccess; + public: + // + // Methods + // + + CUTLASS_DEVICE + GemmWithEpilogueVisitor() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) { + CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()"); + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = + EpilogueVisitor::OutputTileIterator::kElementsPerAccess; + + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; + + if (platform::is_same::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } else if (platform::is_same::value) { + isAMisaligned = problem_size.m() % kAlignmentA; + } else if (platform::is_same>::value || + platform::is_same>::value) { + isAMisaligned = problem_size.k() % kAlignmentA; } - static Status can_implement(Arguments const& args) - { - return can_implement(args.problem_size); + if (platform::is_same::value) { + isBMisaligned = problem_size.n() % kAlignmentB; + } else if (platform::is_same::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } else if (platform::is_same>::value || + platform::is_same>::value) { + isBMisaligned = problem_size.k() % kAlignmentB; } - static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) - { - - return 0; + if (platform::is_same::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } else if (platform::is_same::value) { + isCMisaligned = problem_size.m() % kAlignmentC; + } else if (platform::is_same>::value || + platform::is_same>::value) { + isCMisaligned = problem_size.n() % kAlignmentC; } + if (isAMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); + return Status::kErrorMisalignedOperand; + } + + CUTLASS_TRACE_HOST(" returning kSuccess"); + + return Status::kSuccess; + } + + static Status can_implement(Arguments const& args) { + return can_implement(args.problem_size); + } + + static size_t get_extra_workspace_size( + Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) { + return 0; + } + #define SPLIT_K_ENABLED 1 - /// Executes one GEMM - CUTLASS_DEVICE - void run_kernel_(Params const& params, SharedStorage& shared_storage) - { + /// Executes one GEMM + CUTLASS_DEVICE + void run_kernel_(Params const& params, SharedStorage& shared_storage) { + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; - // Compute threadblock location - ThreadblockSwizzle threadblock_swizzle; + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + return; + } - // Early exit if CTA is out of range - if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() - || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) - { + int offset_k = 0; + int problem_size_k = params.problem_size.k(); - return; - } - - int offset_k = 0; - int problem_size_k = params.problem_size.k(); - - ElementA* ptr_A = static_cast(params.ptr_A); - ElementB* ptr_B = static_cast(params.ptr_B); + ElementA* ptr_A = static_cast(params.ptr_A); + ElementB* ptr_B = static_cast(params.ptr_B); #if SPLIT_K_ENABLED - // - // Fetch pointers based on mode. - // - if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel) - { + // + // Fetch pointers based on mode. + // + if (params.mode == GemmUniversalMode::kGemm || + params.mode == GemmUniversalMode::kGemmSplitKParallel) { + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } - if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) - { - - problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; - } - - offset_k = threadblock_tile_offset.k() * params.gemm_k_size; - } - else if (params.mode == GemmUniversalMode::kBatched) - { - ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; - ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; - } - else if (params.mode == GemmUniversalMode::kArray) - { - ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; - ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; - } + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } else if (params.mode == GemmUniversalMode::kBatched) { + ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; + ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; + } else if (params.mode == GemmUniversalMode::kArray) { + ptr_A = static_cast( + params.ptr_A)[threadblock_tile_offset.k()]; + ptr_B = static_cast( + params.ptr_B)[threadblock_tile_offset.k()]; + } #endif - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{ - threadblock_tile_offset.m() * Mma::Shape::kM, - offset_k, - }; + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k, + }; - cutlass::MatrixCoord tb_offset_B{offset_k, threadblock_tile_offset.n() * Mma::Shape::kN}; + cutlass::MatrixCoord tb_offset_B{ + offset_k, threadblock_tile_offset.n() * Mma::Shape::kN}; - // Compute position within threadblock - int thread_idx = threadIdx.x; + // Compute position within threadblock + int thread_idx = threadIdx.x; - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A( - params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A); + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + ptr_A, + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A); - typename Mma::IteratorB iterator_B( - params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B); + typename Mma::IteratorB iterator_B( + params.params_B, + ptr_B, + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B); - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - int lane_idx = threadIdx.x % 32; + int lane_idx = threadIdx.x % 32; - // - // Main loop - // + // + // Main loop + // - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); - typename Mma::FragmentC accumulators; + typename Mma::FragmentC accumulators; - accumulators.clear(); + accumulators.clear(); - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = + (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); - // - // Masked tile iterators constructed from members - // + // + // Masked tile iterators constructed from members + // - threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - // assume identity swizzle - MatrixCoord threadblock_offset( - threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN); + // assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN); - int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + int block_idx = threadblock_tile_offset.m() + + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); - // - // Construct the epilogue visitor - // + // + // Construct the epilogue visitor + // - EpilogueVisitor epilogue_visitor(params.epilogue_visitor, shared_storage.epilogue.visitor, - params.problem_size.mn(), thread_idx, warp_idx, lane_idx, params.params_alpha_col, params.params_C, - params.params_D, params.quant_option, params.ptr_alpha_row, params.ptr_alpha_col, params.ptr_C, - params.ptr_D, threadblock_offset, blockIdx.y * params.problem_size.m()); + EpilogueVisitor epilogue_visitor(params.epilogue_visitor, + shared_storage.epilogue.visitor, + params.problem_size.mn(), + thread_idx, + warp_idx, + lane_idx, + params.params_alpha_col, + params.params_C, + params.params_D, + params.quant_option, + params.ptr_alpha_row, + params.ptr_alpha_col, + params.ptr_C, + params.ptr_D, + threadblock_offset, + blockIdx.y * params.problem_size.m()); - if (params.mode == GemmUniversalMode::kGemm) - { - // Indicate which position in a serial reduction the output operator is currently updating - epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); - } - else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) - { - epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); - } - - // Construct the epilogue - Epilogue epilogue(shared_storage.epilogue.epilogue, thread_idx, warp_idx, lane_idx); - - // Execute the epilogue operator to update the destination tensor. - epilogue(epilogue_visitor, accumulators); + if (params.mode == GemmUniversalMode::kGemm) { + // Indicate which position in a serial reduction the output operator is + // currently updating + epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), + params.grid_tiled_shape.k()); + } else if (params.mode == GemmUniversalMode::kBatched || + params.mode == GemmUniversalMode::kArray) { + epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); } - template - CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) - { - if constexpr (platform::is_same::value) - { - run_kernel_(params, shared_storage); - } - else - { - CUTLASS_NOT_IMPLEMENTED(); - } - } + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue.epilogue, thread_idx, warp_idx, lane_idx); - /* - To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond - to the ArchTag of the cutlass kernel operator. - */ - /// Executes one GEMM - CUTLASS_DEVICE - void operator()(Params const& params, SharedStorage& shared_storage) - { + // Execute the epilogue operator to update the destination tensor. + epilogue(epilogue_visitor, accumulators); + } + + template + CUTLASS_DEVICE void run_kernel(Params const& params, + SharedStorage& shared_storage) { + if constexpr (platform::is_same::value) { + run_kernel_(params, shared_storage); + } else { + CUTLASS_NOT_IMPLEMENTED(); + } + } + + /* + To improve compilation speed, we do not compile the device operator if the + CUDA_ARCH does not correspond to the ArchTag of the cutlass kernel + operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) { #if defined(__CUDA_ARCH__) #if (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 720) - run_kernel(params, shared_storage); + run_kernel(params, shared_storage); #elif (__CUDA_ARCH__ >= 720) && (__CUDA_ARCH__ < 750) - run_kernel(params, shared_storage); + run_kernel(params, shared_storage); #elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) - run_kernel(params, shared_storage); + run_kernel(params, shared_storage); #elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) - run_kernel(params, shared_storage); + run_kernel(params, shared_storage); #elif (__CUDA_ARCH__ >= 900) - // TODO - replace with CUTLASS_NOT_IMPLEMENTED() and upgrade to 3.x kernels. - run_kernel(params, shared_storage); + // TODO - replace with CUTLASS_NOT_IMPLEMENTED() and upgrade to 3.x kernels. + run_kernel(params, shared_storage); #else - static_assert( - false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."); + static_assert(false, + "Invalid architecture being compiled. Only Volta+ supported " + "in weight-only quantization kernels."); #endif #else - CUTLASS_NOT_IMPLEMENTED(); + CUTLASS_NOT_IMPLEMENTED(); #endif - } + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace kernel -} // namespace gemm -} // namespace cutlass +} // namespace kernel +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h index 8f61c6d9c4..56140c81d6 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,9 +15,10 @@ * limitations under the License. */ /* - This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is - quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices - to be consumed by CUTLASS. + This file exists so that we use the same weight layout for MoE grouped gemm + and regular gemm when the weight is quantized. The preprocessing code reads + this template to know how to organize the quantized weight matrices to be + consumed by CUTLASS. Note that for int4, ThreadBlockK MUST be 64. @@ -35,136 +36,172 @@ #include "cutlass_extensions/arch/mma.h" #include "cutlass_extensions/tile_interleaved_layout.h" -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ +namespace cutlass { +namespace gemm { +namespace kernel { template -struct LayoutDetailsB -{ -}; +struct LayoutDetailsB {}; -// Volta specialiations. Volta will dequantize before STS, so we need a different operator +// Volta specialiations. Volta will dequantize before STS, so we need a +// different operator template -struct LayoutDetailsB -{ - static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; - using Layout = layout::ColumnMajor; - static constexpr int ElementsPerAccess = 8; - using Operator = cutlass::arch::OpMultiplyAdd; +struct LayoutDetailsB { + static constexpr int ThreadblockK = + 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 8; + using Operator = cutlass::arch::OpMultiplyAdd; }; -// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks. -// TODO - Switch this to column major for weights since gemms should be more performant. +// Specializations for Turing+ when B is FP16. These are currently only used for +// MoE networks. +// TODO - Switch this to column major for weights since gemms should be more +// performant. template -struct LayoutDetailsB= 75>::type> -{ - static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; - using Layout = layout::ColumnMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; +struct LayoutDetailsB< + TypeA, + half_t, + Arch, + typename platform::enable_if= 75>::type> { + static constexpr int ThreadblockK = + 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = + 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; }; template -struct LayoutDetailsB= 75>::type> -{ - static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; - using Layout = layout::ColumnMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; +struct LayoutDetailsB< + TypeA, + bfloat16_t, + Arch, + typename platform::enable_if= 75>::type> { + static constexpr int ThreadblockK = + 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = + 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; }; template -struct LayoutDetailsB -{ - static constexpr int ThreadblockK = 64; +struct LayoutDetailsB { + static constexpr int ThreadblockK = 64; -private: - static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; - static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + private: + static constexpr int ElementsPerCacheLine = + 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; -public: - using Layout = layout::ColumnMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; - // for fast accumulation - // using Operator = cutlass::arch::OpMultiplyAddFastAccum; + public: + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = + 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; + // for fast accumulation + // using Operator = cutlass::arch::OpMultiplyAddFastAccum; }; -// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA, -// which signals that we want to dequantize after loading from smem. +// Specializations for Turing+ when B is quantized. These can use the operator +// OpMultiplyAddDequantizeInterleavedBToA, which signals that we want to +// dequantize after loading from smem. template struct LayoutDetailsB < TypeA, uint8_t, Arch, - typename platform::enable_if= 75 && Arch::kMinComputeCapability<90>::type> -{ - static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + typename platform::enable_if= 75 && + Arch::kMinComputeCapability<90>::type> { + static constexpr int ThreadblockK = + 128 * 8 / cutlass::sizeof_bits::value; -private: - static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; - static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + private: + static constexpr int ElementsPerCacheLine = + 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; -public: - using Layout = layout::ColumnMajorTileInterleave; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; + public: + using Layout = + layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = + 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; }; template struct LayoutDetailsB < TypeA, uint4b_t, Arch, - typename platform::enable_if= 75 && Arch::kMinComputeCapability<90>::type> -{ - static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + typename platform::enable_if= 75 && + Arch::kMinComputeCapability<90>::type> { + static constexpr int ThreadblockK = + 128 * 8 / cutlass::sizeof_bits::value; -private: - static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; - static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + private: + static constexpr int ElementsPerCacheLine = + 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; -public: - using Layout = layout::ColumnMajorTileInterleave; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; + public: + using Layout = + layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = + 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; }; template -struct LayoutDetailsB= 75>::type> -{ - static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; // 64 +struct LayoutDetailsB< + TypeA, + uint2b_t, + Arch, + typename platform::enable_if= 75>::type> { + static constexpr int ThreadblockK = + 128 * 8 / cutlass::sizeof_bits::value; // 64 -private: - static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; - static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; // 8 + private: + static constexpr int ElementsPerCacheLine = + 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = + ElementsPerCacheLine / ThreadblockK; // 8 -public: - // using Layout = layout::ColumnMajor; - // static constexpr int ElementsPerAccess = 16; // at least 4-bytes - using Layout = layout::ColumnMajorTileInterleave; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; // 64 - using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; + public: + // using Layout = layout::ColumnMajor; + // static constexpr int ElementsPerAccess = 16; // at least 4-bytes + using Layout = + layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = + 128 / cutlass::sizeof_bits::value; // 64 + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; }; template -struct LayoutDetailsB= 90>::type> -{ - static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; - using Layout = layout::ColumnMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; +struct LayoutDetailsB< + TypeA, + uint8_t, + Arch, + typename platform::enable_if= 90>::type> { + static constexpr int ThreadblockK = + 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = + 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; }; template -struct LayoutDetailsB= 90>::type> -{ - static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; - using Layout = layout::ColumnMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; +struct LayoutDetailsB< + TypeA, + uint4b_t, + Arch, + typename platform::enable_if= 90>::type> { + static constexpr int ThreadblockK = + 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = + 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; }; -} // namespace kernel -} // namespace gemm -} // namespace cutlass +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/moe_problem_visitor.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/moe_problem_visitor.h index b9126e3500..a2376e6b8b 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/moe_problem_visitor.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/moe_problem_visitor.h @@ -160,11 +160,11 @@ struct BaseMoeProblemVisitor { CUTLASS_HOST_DEVICE cutlass::gemm::GemmCoord problem_size(int idx) const { - int64_t gemm_m = 0; if (params.total_rows < 0) { - const int64_t prev_problem_row = idx == 0 ? 0 : params.last_row_for_problem[idx - 1]; + const int64_t prev_problem_row = + idx == 0 ? 0 : params.last_row_for_problem[idx - 1]; const int64_t current_problem_row = params.last_row_for_problem[idx]; gemm_m = current_problem_row - prev_problem_row; } else { diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp index 843529cde5..efb90a5e7c 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp @@ -53,15 +53,20 @@ namespace cutlass::gemm::kernel { /////////////////////////////////////////////////////////////////////////////// -template +template class GemmUniversalGated< - ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_, + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_, cute::enable_if_t && CollectiveMainloop_::isGated>> { -public: + public: // // Type Aliases // @@ -98,7 +103,9 @@ public: using TileSchedulerTag = TileScheduler_; using TileScheduler = - typename detail::TileSchedulerSelector::Scheduler; using TileSchedulerArguments = typename TileScheduler::Arguments; using TileSchedulerParams = typename TileScheduler::Params; @@ -222,9 +229,14 @@ public: // used, therefore separate reduction will not be enabled. constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); - TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments( - problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, - args.scheduler, scheduler_workspace, NumEpilogueSubTiles); + TileSchedulerParams scheduler = + TileScheduler::to_underlying_arguments(problem_shape_MNKL, + TileShape{}, + ClusterShape{}, + hw_info, + args.scheduler, + scheduler_workspace, + NumEpilogueSubTiles); return {args.mode, problem_shape, @@ -242,8 +254,9 @@ public: (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't " - "meet the requirements.\n"); + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Arguments or Problem Shape don't " + "meet the requirements.\n"); return implementable; } implementable &= @@ -262,7 +275,10 @@ public: workspace_size += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, + args.scheduler, + args.problem_shape, + args.hw_info, + NumMmaWarpGroups, NumEpilogueSubTiles); workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); @@ -273,10 +289,11 @@ public: return workspace_size; } - static cutlass::Status - initialize_workspace(Arguments const &args, void *workspace = nullptr, - cudaStream_t stream = nullptr, - CudaHostAdapter *cuda_adapter = nullptr) { + static cutlass::Status initialize_workspace( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { Status status = Status::kSuccess; uint8_t *workspace_ptr = reinterpret_cast(workspace); size_t workspace_offset = 0; @@ -285,13 +302,20 @@ public: status = TileScheduler::template initialize_workspace( - args.scheduler, workspace_ptr + workspace_offset, stream, - args.problem_shape, args.hw_info, NumMmaWarpGroups, + args.scheduler, + workspace_ptr + workspace_offset, + stream, + args.problem_shape, + args.hw_info, + NumMmaWarpGroups, NumEpilogueSubTiles); workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, + args.scheduler, + args.problem_shape, + args.hw_info, + NumMmaWarpGroups, NumEpilogueSubTiles); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); if (status != Status::kSuccess) { @@ -299,8 +323,11 @@ public: } status = CollectiveEpilogue::initialize_workspace( - args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, - stream, cuda_adapter); + args.problem_shape, + args.epilogue, + workspace_ptr + workspace_offset, + stream, + cuda_adapter); workspace_offset += CollectiveEpilogue::get_workspace_size( args.problem_shape, args.epilogue); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); @@ -323,9 +350,12 @@ public: params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN ? TileScheduler::RasterOrderOptions::AlongN : TileScheduler::RasterOrderOptions::AlongM; - return TileScheduler::get_grid_shape(params.scheduler, params.problem_shape, - TileShape{}, ClusterShape{}, - params.hw_info, args); + return TileScheduler::get_grid_shape(params.scheduler, + params.problem_shape, + TileShape{}, + ClusterShape{}, + params.hw_info, + args); } static dim3 get_block_shape() { return dim3(MaxThreadsPerBlock, 1, 1); } @@ -337,8 +367,9 @@ public: // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. #if !defined(__CUDA_ARCH_FEAT_SM90_ALL) - printf("ERROR : Arch conditional MMA instruction used without targeting " - "sm90a compute capability. Aborting.\n"); + printf( + "ERROR : Arch conditional MMA instruction used without targeting " + "sm90a compute capability. Aborting.\n"); #else // Preconditions @@ -469,7 +500,7 @@ public: return []() { cute::cluster_wait(); }; } else { __syncthreads(); - return []() {}; // do nothing + return []() {}; // do nothing } }(); @@ -480,7 +511,7 @@ public: // Get the appropriate blocks for this thread block -- potential for thread // block locality TiledMma tiled_mma; - auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) TileScheduler scheduler{params.scheduler}; auto work_tile_info = scheduler.get_current_work(); @@ -540,10 +571,16 @@ public: auto k_tile_iter = cute::make_coord_iterator( idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); - collective_mainloop.load( - params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state, - load_inputs, blk_coord, k_tile_iter, work_k_tile_count, lane_idx, - block_rank_in_cluster, shared_storage.tensors.mainloop); + collective_mainloop.load(params.mainloop, + mainloop_pipeline, + mainloop_pipe_producer_state, + load_inputs, + blk_coord, + k_tile_iter, + work_k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop); // Update starting pipeline state for the next tile mainloop_pipe_producer_state.advance(work_k_tile_count); @@ -555,12 +592,12 @@ public: // Get next work tile work_tile_info = fetch_next_work(work_tile_info, scheduler); - } // Scheduler work fetch loop + } // Scheduler work fetch loop // Make sure all Consumer Warp Groups have been waited upon collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); - } // Mainloop Producer Warp End + } // Mainloop Producer Warp End // Epilogue Producer Warp else if (producer_warp_role == ProducerWarpRole::Epilogue && @@ -579,21 +616,26 @@ public: auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); epi_load_pipe_producer_state = collective_epilogue.load( - epi_load_pipeline, epi_load_pipe_producer_state, - problem_shape_MNKL, blk_shape, blk_coord, tiled_mma, lane_idx, + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + tiled_mma, + lane_idx, shared_storage.tensors.epilogue, work_tile_info.reduction_subtile_idx()); } // Get next work tile work_tile_info = fetch_next_work(work_tile_info, scheduler); - } // Scheduler work fetch loop + } // Scheduler work fetch loop // Make sure all Consumer Warp Groups have been waited upon collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); - } // Epilogue Producer Warp End - } // Producer Warp Group End + } // Epilogue Producer Warp End + } // Producer Warp Group End else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { @@ -618,14 +660,18 @@ public: // // MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead. auto accumulators0 = partition_fragment_C( - tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) + tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) auto accumulators1 = partition_fragment_C( - tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) + tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) if (TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { - collective_mainloop.mma( - mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0, - accumulators1, work_k_tile_count, mma_thread_idx, - shared_storage.tensors.mainloop, params.mainloop); + collective_mainloop.mma(mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators0, + accumulators1, + work_k_tile_count, + mma_thread_idx, + shared_storage.tensors.mainloop, + params.mainloop); // Make sure the math instructions are done and free buffers before // entering the epilogue @@ -641,10 +687,16 @@ public: canonical_warp_group_idx() - NumLoadWarpGroups; // Perform reduction across splits, if needed - TileScheduler::fixup(params.scheduler, work_tile_info, accumulators0, - NumMmaWarpGroups, consumer_warp_group_idx); - TileScheduler::fixup(params.scheduler, work_tile_info, accumulators1, - NumMmaWarpGroups, consumer_warp_group_idx); + TileScheduler::fixup(params.scheduler, + work_tile_info, + accumulators0, + NumMmaWarpGroups, + consumer_warp_group_idx); + TileScheduler::fixup(params.scheduler, + work_tile_info, + accumulators1, + NumMmaWarpGroups, + consumer_warp_group_idx); Activation elt_op; CUTLASS_PRAGMA_UNROLL @@ -657,12 +709,18 @@ public: // Epilogue and write to gD auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = - collective_epilogue.store( - epi_load_pipeline, epi_load_pipe_consumer_state, - epi_store_pipeline, epi_store_pipe_producer_state, - problem_shape_MNKL, blk_shape, blk_coord, accumulators0, - tiled_mma, mma_thread_idx, shared_storage.tensors.epilogue, - work_tile_info.reduction_subtile_idx()); + collective_epilogue.store(epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + accumulators0, + tiled_mma, + mma_thread_idx, + shared_storage.tensors.epilogue, + work_tile_info.reduction_subtile_idx()); epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next; epi_store_pipe_producer_state = epi_store_pipe_producer_state_next; do_store_tail = true; @@ -670,23 +728,24 @@ public: // Get next work tile work_tile_info = fetch_next_work(work_tile_info, scheduler); - } // Scheduler work fetch loop + } // Scheduler work fetch loop if (do_store_tail) { - collective_epilogue.store_tail( - epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline, - epi_store_pipe_producer_state); + collective_epilogue.store_tail(epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state); } - } // Consumer Warp Groups End + } // Consumer Warp Groups End #endif } -private: + private: // Kernel helper function to get next work unit CUTLASS_DEVICE - typename TileScheduler::WorkTileInfo - fetch_next_work(typename TileScheduler::WorkTileInfo &work_tile_info, - TileScheduler &scheduler) const { + typename TileScheduler::WorkTileInfo fetch_next_work( + typename TileScheduler::WorkTileInfo &work_tile_info, + TileScheduler &scheduler) const { // Check whether we should continue on with the current work unit. If this // is the case, the work unit will have been updated in // continue_current_work to reflect the new tile to be computed. @@ -702,4 +761,4 @@ private: /////////////////////////////////////////////////////////////////////////////// -} // namespace cutlass::gemm::kernel +} // namespace cutlass::gemm::kernel diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp index e6cc7de5c6..9609adc32a 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp @@ -56,15 +56,20 @@ namespace cutlass::gemm::kernel { /////////////////////////////////////////////////////////////////////////////// -template +template class GemmUniversalGated< - ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_, + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_, cute::enable_if_t && CollectiveMainloop_::isGated>> { -public: + public: // // Type Aliases // @@ -103,7 +108,9 @@ public: "Ping-pong kernel does not currently support stream-K scheduler."); using TileSchedulerTag = TileScheduler_; using TileScheduler = - typename detail::TileSchedulerSelector::Scheduler; using TileSchedulerArguments = typename TileScheduler::Arguments; using TileSchedulerParams = typename TileScheduler::Params; @@ -236,9 +243,12 @@ public: CollectiveEpilogue::to_underlying_arguments( args.problem_shape, args.epilogue, epilogue_workspace), hw_info, - TileScheduler::to_underlying_arguments( - problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, - args.scheduler, scheduler_workspace)}; + TileScheduler::to_underlying_arguments(problem_shape_MNKL, + TileShape{}, + ClusterShape{}, + hw_info, + args.scheduler, + scheduler_workspace)}; } static bool can_implement(Arguments const &args) { @@ -246,8 +256,9 @@ public: (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't " - "meet the requirements.\n"); + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Arguments or Problem Shape don't " + "meet the requirements.\n"); return implementable; } implementable &= @@ -273,18 +284,23 @@ public: return workspace_size; } - static cutlass::Status - initialize_workspace(Arguments const &args, void *workspace = nullptr, - cudaStream_t stream = nullptr, - CudaHostAdapter *cuda_adapter = nullptr) { + static cutlass::Status initialize_workspace( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { Status status = Status::kSuccess; uint8_t *workspace_ptr = reinterpret_cast(workspace); size_t workspace_offset = 0; status = TileScheduler::template initialize_workspace( - args.scheduler, workspace_ptr + workspace_offset, stream, - args.problem_shape, args.hw_info, NumMmaWarpGroups); + args.scheduler, + workspace_ptr + workspace_offset, + stream, + args.problem_shape, + args.hw_info, + NumMmaWarpGroups); workspace_offset += TileScheduler::template get_workspace_size( @@ -295,8 +311,11 @@ public: } status = CollectiveEpilogue::initialize_workspace( - args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, - stream, cuda_adapter); + args.problem_shape, + args.epilogue, + workspace_ptr + workspace_offset, + stream, + cuda_adapter); workspace_offset += CollectiveEpilogue::get_workspace_size( args.problem_shape, args.epilogue); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); @@ -319,9 +338,12 @@ public: params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN ? TileScheduler::RasterOrderOptions::AlongN : TileScheduler::RasterOrderOptions::AlongM; - return TileScheduler::get_grid_shape(params.scheduler, params.problem_shape, - TileShape{}, ClusterShape{}, - params.hw_info, args); + return TileScheduler::get_grid_shape(params.scheduler, + params.problem_shape, + TileShape{}, + ClusterShape{}, + params.hw_info, + args); } static dim3 get_block_shape() { return dim3(MaxThreadsPerBlock, 1, 1); } @@ -333,8 +355,9 @@ public: // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. #if !defined(__CUDA_ARCH_FEAT_SM90_ALL) - printf("ERROR : Arch conditional MMA instruction used without targeting " - "sm90a compute capability. Aborting.\n"); + printf( + "ERROR : Arch conditional MMA instruction used without targeting " + "sm90a compute capability. Aborting.\n"); #else // Preconditions @@ -437,7 +460,7 @@ public: params_math_wg_order_barrier.group_id = canonical_warp_group_idx() - static_cast(WarpGroupRole::Consumer0); params_math_wg_order_barrier.group_size = - NumThreadsPerWarpGroup; // Number of threads / participants in a group + NumThreadsPerWarpGroup; // Number of threads / participants in a group MathWarpGroupOrderBarrier math_wg_order_barrier( shared_storage.pipelines.math_wg_order, params_math_wg_order_barrier); @@ -464,7 +487,7 @@ public: return []() { cute::cluster_wait(); }; } else { __syncthreads(); - return []() {}; // do nothing + return []() {}; // do nothing } }(); @@ -476,7 +499,7 @@ public: // Get the appropriate blocks for this thread block -- potential for thread // block locality TiledMma tiled_mma; - auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) // In a warp specialized kernel, collectives expose data movement and // compute operations separately @@ -535,10 +558,16 @@ public: auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl)); - collective_mainloop.load( - params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state, - load_inputs, blk_coord, k_tile_iter, k_tile_count, lane_idx, - block_rank_in_cluster, shared_storage.tensors.mainloop); + collective_mainloop.load(params.mainloop, + mainloop_pipeline, + mainloop_pipe_producer_state, + load_inputs, + blk_coord, + k_tile_iter, + k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop); // Update starting pipeline state for the next tile mainloop_pipe_producer_state.advance(k_tile_count); @@ -551,12 +580,12 @@ public: // Get next work tile scheduler.advance_to_next_work(); work_tile_info = scheduler.get_current_work(); - } // Scheduler work fetch loop + } // Scheduler work fetch loop // Make sure all Consumer Warp Groups have been waited upon collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); - } // Mainloop Producer Warp End + } // Mainloop Producer Warp End // Epilogue Producer Warp else if (producer_warp_role == ProducerWarpRole::Epilogue && @@ -570,21 +599,26 @@ public: auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - epi_load_pipe_producer_state = collective_epilogue.load( - epi_load_pipeline, epi_load_pipe_producer_state, - problem_shape_MNKL, blk_shape, blk_coord, tiled_mma, lane_idx, - shared_storage.tensors.epilogue); + epi_load_pipe_producer_state = + collective_epilogue.load(epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + tiled_mma, + lane_idx, + shared_storage.tensors.epilogue); // Get next work tile scheduler.advance_to_next_work(); work_tile_info = scheduler.get_current_work(); - } // Scheduler work fetch loop + } // Scheduler work fetch loop // Make sure all Consumer Warp Groups have been waited upon collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); - } // Epilogue Producer Warp End - } // Producer Warp Group End + } // Epilogue Producer Warp End + } // Producer Warp Group End else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { @@ -602,17 +636,21 @@ public: // Allocate the accumulators for the (M,N) blk_shape Tensor accumulators0 = partition_fragment_C( - tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) + tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) Tensor accumulators1 = partition_fragment_C( - tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) + tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) // Order two Math WG's MMA one after the other, helps hide Epilogue math_wg_order_barrier.wait(); - collective_mainloop.mma( - mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0, - accumulators1, k_tile_count, warp_group_thread_idx, - shared_storage.tensors.mainloop, params.mainloop); + collective_mainloop.mma(mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators0, + accumulators1, + k_tile_count, + warp_group_thread_idx, + shared_storage.tensors.mainloop, + params.mainloop); // Cue for next Math WG's MMA to start math_wg_order_barrier.arrive(); @@ -637,12 +675,17 @@ public: // Epilogue and write to gD auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = - collective_epilogue.store( - epi_load_pipeline, epi_load_pipe_consumer_state, - epi_store_pipeline, epi_store_pipe_producer_state, - problem_shape_MNKL, blk_shape, blk_coord, accumulators0, - tiled_mma, warp_group_thread_idx, - shared_storage.tensors.epilogue); + collective_epilogue.store(epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + accumulators0, + tiled_mma, + warp_group_thread_idx, + shared_storage.tensors.epilogue); // TMA store pipeline wait is only visible to TMA-issuing warp, so for // multiple-consumer kernels we need to wait for all TMA stores to @@ -651,9 +694,10 @@ public: // current consumer. auto [epi_load_pipe_consumer_state_next_, epi_store_pipe_producer_state_next_] = - collective_epilogue.store_tail( - epi_load_pipeline, epi_load_pipe_consumer_state_next, - epi_store_pipeline, epi_store_pipe_producer_state_next); + collective_epilogue.store_tail(epi_load_pipeline, + epi_load_pipe_consumer_state_next, + epi_store_pipeline, + epi_store_pipe_producer_state_next); // Update starting load/store pipeline states for the next tile // state has already been incremented by 1 tile in collective calls, @@ -669,12 +713,12 @@ public: // Get next work tile scheduler.advance_to_next_work(NumMmaWarpGroups); work_tile_info = scheduler.get_current_work(); - } // Scheduler work fetch loop - } // Consumer Warp Groups End + } // Scheduler work fetch loop + } // Consumer Warp Groups End #endif } }; /////////////////////////////////////////////////////////////////////////////// -} // namespace cutlass::gemm::kernel +} // namespace cutlass::gemm::kernel diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h index 5e3531f093..5d68ea26d9 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ @@ -49,446 +50,471 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ +namespace cutlass { +namespace gemm { +namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// -template -struct SplitkGemmGrouped -{ -public: - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; - static bool const kTransposed = Transposed; +template +struct SplitkGemmGrouped { + public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + static bool const kTransposed = Transposed; - // Optional transpose - using MapArguments = kernel::detail::MapArguments; + // Optional transpose + using MapArguments = + kernel::detail::MapArguments; - // Public-facing type definitions related to operand element type, layout, and complex conjugate - // operation. Must interact with the 'kTransposed' notion. - using ElementA = typename MapArguments::ElementA; - using LayoutA = typename MapArguments::LayoutA; - using ElementB = typename MapArguments::ElementB; - using LayoutB = typename MapArguments::LayoutB; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename MapArguments::LayoutC; + // Public-facing type definitions related to operand element type, layout, and + // complex conjugate operation. Must interact with the 'kTransposed' notion. + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename MapArguments::LayoutC; - using ElementFinalOutput = typename MapArguments::ElementA; + using ElementFinalOutput = typename MapArguments::ElementA; - static ComplexTransform const kTransformA = MapArguments::kTransformA; - static ComplexTransform const kTransformB = MapArguments::kTransformB; + static ComplexTransform const kTransformA = MapArguments::kTransformA; + static ComplexTransform const kTransformB = MapArguments::kTransformB; - // Type definitions about the mainloop. - using Operator = typename Mma::Operator; - using OperatorClass = typename Mma::Operator::OperatorClass; - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename Mma::Operator::Shape; - using InstructionShape = typename Mma::Policy::Operator::InstructionShape; - using ArchTag = typename Mma::ArchTag; + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; - static int const kStages = Mma::kStages; - static int const kAlignmentA = MapArguments::kAlignmentA; - static int const kAlignmentB = MapArguments::kAlignmentB; - static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + static int const kStages = Mma::kStages; + static int const kAlignmentA = MapArguments::kAlignmentA; + static int const kAlignmentB = MapArguments::kAlignmentB; + static int const kAlignmentC = + Epilogue::OutputTileIterator::kElementsPerAccess; - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; - using ProblemVisitor - = GemmGroupedProblemVisitor; + using ProblemVisitor = GemmGroupedProblemVisitor; + // + // Structures + // + + /// Argument structure + struct Arguments { // - // Structures + // Data members // - /// Argument structure - struct Arguments - { + GemmCoord* problem_sizes; + int problem_count; + int threadblock_count; - // - // Data members - // + typename EpilogueOutputOp::Params output_op; - GemmCoord* problem_sizes; - int problem_count; - int threadblock_count; + ElementA** ptr_A; + ElementB** ptr_B; + ElementFinalOutput** ptr_C; + ElementFinalOutput** ptr_D; - typename EpilogueOutputOp::Params output_op; + typename LayoutA::Stride::LongIndex* lda; + typename LayoutB::Stride::LongIndex* ldb; + typename LayoutC::Stride::LongIndex* ldc; + typename LayoutC::Stride::LongIndex* ldd; - ElementA** ptr_A; - ElementB** ptr_B; - ElementFinalOutput** ptr_C; - ElementFinalOutput** ptr_D; + // Only used by device-level operator + GemmCoord* host_problem_sizes; - typename LayoutA::Stride::LongIndex* lda; - typename LayoutB::Stride::LongIndex* ldb; - typename LayoutC::Stride::LongIndex* ldc; - typename LayoutC::Stride::LongIndex* ldd; + // splitK + int split_k_slices; + int64_t* splitk_buffer_offsets; - // Only used by device-level operator - GemmCoord* host_problem_sizes; - - // splitK - int split_k_slices; - int64_t* splitk_buffer_offsets; - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - Arguments() - : problem_count(0) - , threadblock_count(0) - , ptr_A(nullptr) - , ptr_B(nullptr) - , ptr_C(nullptr) - , ptr_D(nullptr) - , lda(nullptr) - , ldb(nullptr) - , ldc(nullptr) - , ldd(nullptr) - , host_problem_sizes(nullptr) - , split_k_slices(1) - , splitk_buffer_offsets(nullptr) - { - } - - /// Ctor - CUTLASS_HOST_DEVICE - Arguments(GemmCoord* problem_sizes, int problem_count, int threadblock_count, - typename EpilogueOutputOp::Params output_op, ElementA** ptr_A, ElementB** ptr_B, ElementFinalOutput** ptr_C, - ElementFinalOutput** ptr_D, typename LayoutA::Stride::LongIndex* lda, - typename LayoutB::Stride::LongIndex* ldb, typename LayoutC::Stride::LongIndex* ldc, - typename LayoutC::Stride::LongIndex* ldd, GemmCoord* host_problem_sizes, int split_k_slices, - int64_t* splitk_buffer_offsets) - : problem_sizes(problem_sizes) - , problem_count(problem_count) - , threadblock_count(threadblock_count) - , output_op(output_op) - , ptr_A(ptr_A) - , ptr_B(ptr_B) - , ptr_C(ptr_C) - , ptr_D(ptr_D) - , lda(lda) - , ldb(ldb) - , ldc(ldc) - , ldd(ldd) - , host_problem_sizes(host_problem_sizes) - , split_k_slices(split_k_slices) - , splitk_buffer_offsets(splitk_buffer_offsets) - { - } - }; - - // - // Structure for precomputing values in host memory and passing to kernels - // - - /// Parameters structure - struct Params - { - - typename ProblemVisitor::Params problem_visitor; - int threadblock_count; - - typename EpilogueOutputOp::Params output_op; - - ElementA** ptr_A; - ElementB** ptr_B; - ElementFinalOutput** ptr_C; - ElementFinalOutput** ptr_D; - ElementC* ptr_C_split; - ElementC* ptr_D_split; - - typename LayoutA::Stride::LongIndex* lda; - typename LayoutB::Stride::LongIndex* ldb; - typename LayoutC::Stride::LongIndex* ldc; - typename LayoutC::Stride::LongIndex* ldd; - - // - // Methods - // - - // splitk - GemmCoord grid_tiled_shape; - int swizzle_log_tile; - int gemm_k_size; - GemmCoord* host_problem_sizes; - int split_k_slices; - int64_t* splitk_buffer_offsets; - - CUTLASS_HOST_DEVICE - Params() - : ptr_A(nullptr) - , ptr_B(nullptr) - , ptr_C(nullptr) - , ptr_D(nullptr) - , ptr_C_split(nullptr) - , ptr_D_split(nullptr) - , lda(nullptr) - , ldb(nullptr) - , ldc(nullptr) - , ldd(nullptr) - , swizzle_log_tile(0) - , gemm_k_size(0) - , host_problem_sizes(nullptr) - , split_k_slices(1) - , splitk_buffer_offsets(nullptr) - { - } - - CUTLASS_HOST_DEVICE - Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0) - : problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count) - , host_problem_sizes(args.host_problem_sizes) - , threadblock_count(args.threadblock_count) - , output_op(args.output_op) - , ptr_A(args.ptr_A) - , ptr_B(args.ptr_B) - , ptr_C(args.ptr_C) - , ptr_D(args.ptr_D) - , ptr_C_split((ElementC*) workspace) - , ptr_D_split((ElementC*) workspace) - , lda(args.lda) - , ldb(args.ldb) - , ldc(args.ldc) - , ldd(args.ldd) - , split_k_slices(args.split_k_slices) - , splitk_buffer_offsets(args.splitk_buffer_offsets) - { - // Determine grid shape - ThreadblockSwizzle threadblock_swizzle; - grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.host_problem_sizes[0], - {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.split_k_slices); - swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape); - - // only support same k - int full_gemm_k_iterations = args.host_problem_sizes[0].k() / Mma::Shape::kK; - int gemm_k_iterations = full_gemm_k_iterations / grid_tiled_shape.k(); - - gemm_k_size = gemm_k_iterations * Mma::Shape::kK; - } - - CUTLASS_HOST_DEVICE - void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) - { - - problem_visitor = - typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, workspace, tile_count); - threadblock_count = args.threadblock_count; - output_op = args.output_op; - ptr_A = args.ptr_A; - ptr_B = args.ptr_B; - ptr_C = args.ptr_C; - ptr_D = args.ptr_D; - ptr_C_split = workspace; - ptr_D_split = workspace; - - lda = args.lda; - ldb = args.ldb; - ldc = args.ldc; - ldd = args.ldd; - } - }; - - /// Shared memory storage structure - struct SharedStorage - { - union - { - typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; - } kernel; - - // ProblemVisitor shared storage can't be overlapped with others - typename ProblemVisitor::SharedStorage problem_visitor; - }; - -public: // // Methods // - CUTLASS_DEVICE - SplitkGemmGrouped() {} + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() + : problem_count(0), + threadblock_count(0), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + lda(nullptr), + ldb(nullptr), + ldc(nullptr), + ldd(nullptr), + host_problem_sizes(nullptr), + split_k_slices(1), + splitk_buffer_offsets(nullptr) {} - /// Determines whether kernel satisfies alignment - static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) - { - return Status::kSuccess; + /// Ctor + CUTLASS_HOST_DEVICE + Arguments(GemmCoord* problem_sizes, + int problem_count, + int threadblock_count, + typename EpilogueOutputOp::Params output_op, + ElementA** ptr_A, + ElementB** ptr_B, + ElementFinalOutput** ptr_C, + ElementFinalOutput** ptr_D, + typename LayoutA::Stride::LongIndex* lda, + typename LayoutB::Stride::LongIndex* ldb, + typename LayoutC::Stride::LongIndex* ldc, + typename LayoutC::Stride::LongIndex* ldd, + GemmCoord* host_problem_sizes, + int split_k_slices, + int64_t* splitk_buffer_offsets) + : problem_sizes(problem_sizes), + problem_count(problem_count), + threadblock_count(threadblock_count), + output_op(output_op), + ptr_A(ptr_A), + ptr_B(ptr_B), + ptr_C(ptr_C), + ptr_D(ptr_D), + lda(lda), + ldb(ldb), + ldc(ldc), + ldd(ldd), + host_problem_sizes(host_problem_sizes), + split_k_slices(split_k_slices), + splitk_buffer_offsets(splitk_buffer_offsets) {} + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + + typename EpilogueOutputOp::Params output_op; + + ElementA** ptr_A; + ElementB** ptr_B; + ElementFinalOutput** ptr_C; + ElementFinalOutput** ptr_D; + ElementC* ptr_C_split; + ElementC* ptr_D_split; + + typename LayoutA::Stride::LongIndex* lda; + typename LayoutB::Stride::LongIndex* ldb; + typename LayoutC::Stride::LongIndex* ldc; + typename LayoutC::Stride::LongIndex* ldd; + + // + // Methods + // + + // splitk + GemmCoord grid_tiled_shape; + int swizzle_log_tile; + int gemm_k_size; + GemmCoord* host_problem_sizes; + int split_k_slices; + int64_t* splitk_buffer_offsets; + + CUTLASS_HOST_DEVICE + Params() + : ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + ptr_C_split(nullptr), + ptr_D_split(nullptr), + lda(nullptr), + ldb(nullptr), + ldc(nullptr), + ldd(nullptr), + swizzle_log_tile(0), + gemm_k_size(0), + host_problem_sizes(nullptr), + split_k_slices(1), + splitk_buffer_offsets(nullptr) {} + + CUTLASS_HOST_DEVICE + Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0) + : problem_visitor( + args.problem_sizes, args.problem_count, workspace, tile_count), + host_problem_sizes(args.host_problem_sizes), + threadblock_count(args.threadblock_count), + output_op(args.output_op), + ptr_A(args.ptr_A), + ptr_B(args.ptr_B), + ptr_C(args.ptr_C), + ptr_D(args.ptr_D), + ptr_C_split((ElementC*)workspace), + ptr_D_split((ElementC*)workspace), + lda(args.lda), + ldb(args.ldb), + ldc(args.ldc), + ldd(args.ldd), + split_k_slices(args.split_k_slices), + splitk_buffer_offsets(args.splitk_buffer_offsets) { + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.host_problem_sizes[0], + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape); + + // only support same k + int full_gemm_k_iterations = + args.host_problem_sizes[0].k() / Mma::Shape::kK; + int gemm_k_iterations = full_gemm_k_iterations / grid_tiled_shape.k(); + + gemm_k_size = gemm_k_iterations * Mma::Shape::kK; } - static Status can_implement(Arguments const& args) - { - return Status::kSuccess; + CUTLASS_HOST_DEVICE + void update(Arguments const& args, + void* workspace = nullptr, + int tile_count = 0) { + problem_visitor = typename ProblemVisitor::Params( + args.problem_sizes, args.problem_count, workspace, tile_count); + threadblock_count = args.threadblock_count; + output_op = args.output_op; + ptr_A = args.ptr_A; + ptr_B = args.ptr_B; + ptr_C = args.ptr_C; + ptr_D = args.ptr_D; + ptr_C_split = workspace; + ptr_D_split = workspace; + + lda = args.lda; + ldb = args.ldb; + ldc = args.ldc; + ldd = args.ldd; } + }; - /// Executes one GEMM - CUTLASS_DEVICE - void operator()(Params const& params, SharedStorage& shared_storage) - { + /// Shared memory storage structure + struct SharedStorage { + union { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + } kernel; - // - // These types shadow the type-level definitions and support the ability to implement - // a 'transposed' GEMM that computes the transposed problems. - // - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename Epilogue::OutputTileIterator::Layout; + // ProblemVisitor shared storage can't be overlapped with others + typename ProblemVisitor::SharedStorage problem_visitor; + }; - // - // Problem visitor. - // - ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); + public: + // + // Methods + // - // Outer 'persistent' loop to iterate over tiles - while (problem_visitor.next_tile()) - { + CUTLASS_DEVICE + SplitkGemmGrouped() {} - GemmCoord problem_size = problem_visitor.problem_size(); - int32_t problem_idx = problem_visitor.problem_index(); - int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) { + return Status::kSuccess; + } - GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + static Status can_implement(Arguments const& args) { + return Status::kSuccess; + } - // Load element pointers. Exchange pointers and strides if working on the transpose - ElementA* ptr_A - = reinterpret_cast((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx])); - typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]); + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) { + // + // These types shadow the type-level definitions and support the ability to + // implement a 'transposed' GEMM that computes the transposed problems. + // + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; - ElementB* ptr_B - = reinterpret_cast((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx])); - typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]); + // + // Problem visitor. + // + ProblemVisitor problem_visitor( + params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); - // Compute threadblock location - ThreadblockSwizzle threadblock_swizzle; - GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + // Outer 'persistent' loop to iterate over tiles + while (problem_visitor.next_tile()) { + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); - cutlass::gemm::GemmCoord threadblock_offset(int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM, - int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN, 0); + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{ - threadblock_offset.m(), - threadblock_tile_offset.k() * params.gemm_k_size, - }; + // Load element pointers. Exchange pointers and strides if working on the + // transpose + ElementA* ptr_A = reinterpret_cast(( + kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx])); + typename LayoutA::LongIndex ldm_A = + (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]); - cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size, threadblock_offset.n()}; + ElementB* ptr_B = reinterpret_cast(( + kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx])); + typename LayoutB::LongIndex ldm_B = + (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]); - // Problem size is a function of threadblock index in the K dimension - int problem_size_k; - if (threadblock_tile_offset.k() + 1 == params.grid_tiled_shape.k()) - { - problem_size_k = problem_size.k(); - } - else - { - problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; - } + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; + cutlass::gemm::GemmCoord threadblock_offset( + int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM, + int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN, + 0); - // Compute position within threadblock - int thread_idx = threadIdx.x; + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_offset.m(), + threadblock_tile_offset.k() * params.gemm_k_size, + }; - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A( - LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size_k}, thread_idx, tb_offset_A); + cutlass::MatrixCoord tb_offset_B{ + threadblock_tile_offset.k() * params.gemm_k_size, + threadblock_offset.n()}; - typename Mma::IteratorB iterator_B( - LayoutB(ldm_B), ptr_B, {problem_size_k, problem_size.n()}, thread_idx, tb_offset_B); + // Problem size is a function of threadblock index in the K dimension + int problem_size_k; + if (threadblock_tile_offset.k() + 1 == params.grid_tiled_shape.k()) { + problem_size_k = problem_size.k(); + } else { + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } - typename Mma::FragmentC accumulators; + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = + (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / + Mma::Shape::kK; - accumulators.clear(); + // Compute position within threadblock + int thread_idx = threadIdx.x; - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx_sync(); + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(LayoutA(ldm_A), + ptr_A, + {problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A); - int lane_idx = threadIdx.x % 32; + typename Mma::IteratorB iterator_B(LayoutB(ldm_B), + ptr_B, + {problem_size_k, problem_size.n()}, + thread_idx, + tb_offset_B); - // - // Matrix multiply phase - // + typename Mma::FragmentC accumulators; - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx); + accumulators.clear(); - // Wait for all threads to finish their epilogue phases from the previous tile. - __syncthreads(); + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = canonical_warp_idx_sync(); - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + int lane_idx = threadIdx.x % 32; - // - // Epilogue - // + // + // Matrix multiply phase + // - EpilogueOutputOp output_op(params.output_op); + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx); - ElementC* ptr_C = params.ptr_C_split; - ElementC* ptr_D = params.ptr_D_split; + // Wait for all threads to finish their epilogue phases from the previous + // tile. + __syncthreads(); - LayoutC layout_C(params.ldc[problem_idx]); - LayoutC layout_D(params.ldd[problem_idx]); + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + accumulators); - typename Epilogue::OutputTileIterator::Params params_C(layout_C); - typename Epilogue::OutputTileIterator::Params params_D(layout_D); + // + // Epilogue + // - // assume identity swizzle - MatrixCoord threadblock_offset_C(threadblock_offset.m(), threadblock_offset.n()); + EpilogueOutputOp output_op(params.output_op); - // Tile iterator loading from source tensor. - typename Epilogue::OutputTileIterator iterator_C( - params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset_C); + ElementC* ptr_C = params.ptr_C_split; + ElementC* ptr_D = params.ptr_D_split; - iterator_C.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k() - + gridDim.z * params.splitk_buffer_offsets[problem_idx]); + LayoutC layout_C(params.ldc[problem_idx]); + LayoutC layout_D(params.ldd[problem_idx]); - // Tile iterator writing to destination tensor. - typename Epilogue::OutputTileIterator iterator_D( - params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset_C); - iterator_D.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k() - + gridDim.z * params.splitk_buffer_offsets[problem_idx]); + typename Epilogue::OutputTileIterator::Params params_C(layout_C); + typename Epilogue::OutputTileIterator::Params params_D(layout_D); - Epilogue epilogue(shared_storage.kernel.epilogue, thread_idx, warp_idx, lane_idx); + // assume identity swizzle + MatrixCoord threadblock_offset_C(threadblock_offset.m(), + threadblock_offset.n()); - // Execute the epilogue operator to update the destination tensor. - epilogue(output_op, iterator_D, accumulators, iterator_C); + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset_C); - // Next tile - problem_visitor.advance(gridDim.x); - } + iterator_C.add_pointer_offset( + problem_size.m() * problem_size.n() * threadblock_tile_offset.k() + + gridDim.z * params.splitk_buffer_offsets[problem_idx]); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset_C); + iterator_D.add_pointer_offset( + problem_size.m() * problem_size.n() * threadblock_tile_offset.k() + + gridDim.z * params.splitk_buffer_offsets[problem_idx]); + + Epilogue epilogue( + shared_storage.kernel.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // Next tile + problem_visitor.advance(gridDim.x); } + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace kernel -} // namespace gemm -} // namespace cutlass +} // namespace kernel +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_dq_mma.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_dq_mma.h index ed5e3e4daf..a268861a0d 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_dq_mma.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_dq_mma.h @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,18 +19,15 @@ #include "cutlass_extensions/arch/mma.h" #include "cutlass_extensions/interleaved_numeric_conversion.h" -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ +namespace cutlass { +namespace gemm { +namespace threadblock { //////////////////////////////////////////////////////////////////////////////// -// We need to distinguish here, since we want volta support. It is too much effort -// to write shared memory iterators that are probably needed for volta to function -// properly. As a result, we allow converters both after the LDG (for volta) and after -// the LDS for Turing+. +// We need to distinguish here, since we want volta support. It is too much +// effort to write shared memory iterators that are probably needed for volta to +// function properly. As a result, we allow converters both after the LDG (for +// volta) and after the LDS for Turing+. template < /// Iterator for B matrix in global memory typename IteratorB, @@ -38,9 +35,7 @@ template < typename MmaOperator, /// Math operation perform by warp level operator typename MathOperator> -struct SetConverters -{ -}; +struct SetConverters {}; // Dequantize after LDG, so set transforms accordingly template < @@ -48,14 +43,16 @@ template < typename IteratorB, /// Mma Policy typename MmaOperator> -struct SetConverters -{ - using TransformAfterLDG - = FastInterleavedAndBiasedNumericArrayConverter; +struct SetConverters { + using TransformAfterLDG = FastInterleavedAndBiasedNumericArrayConverter< + typename MmaOperator::ArchMmaOperator::ElementB, + typename IteratorB::Element, + IteratorB::Fragment::kElements>; - using TransformAfterLDS = NumericArrayConverter; + using TransformAfterLDS = + NumericArrayConverter; }; // Dequantize after LDS, so set transforms accordingly @@ -65,14 +62,18 @@ template < typename IteratorB, /// Mma Policy typename MmaOperator> -struct SetConverters -{ - using TransformAfterLDG = NumericArrayConverter; +struct SetConverters { + using TransformAfterLDG = + NumericArrayConverter; - using TransformAfterLDS - = FastInterleavedAndBiasedNumericArrayConverter; + using TransformAfterLDS = FastInterleavedAndBiasedNumericArrayConverter< + typename MmaOperator::ArchMmaOperator::ElementB, + typename TransformAfterLDG::result_type::Element, + MmaOperator::FragmentB::kElements>; }; //////////////////////////////////////////////////////////////////////////////// @@ -120,6 +121,6 @@ template < typename Enable = void> struct DqMma; -} // namespace threadblock -} // namespace gemm -} // namespace cutlass +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h index 17c6346553..566cd379b2 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,49 +27,77 @@ #include "cutlass_extensions/gemm/threadblock/default_dq_mma.h" #include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h" -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ +namespace cutlass { +namespace gemm { +namespace threadblock { //////////////////////////////////////////////////////////////////////////////// -template +template struct DefaultScaleIteratorsMultistage; // Fine grained iterators -template -struct DefaultScaleIteratorsMultistage> -{ - using IteratorScale - = cutlass::transform::threadblock::FineGrainedScaleZeroIterator, Element, - Layout, 0, Alignment>; +template +struct DefaultScaleIteratorsMultistage< + MmaShape, + Element, + Layout, + QuantOp, + Alignment, + std::enable_if_t> { + using IteratorScale = + cutlass::transform::threadblock::FineGrainedScaleZeroIterator< + cutlass::MatrixShape<1, MmaShape::kN>, + Element, + Layout, + 0, + Alignment>; - using SmemIteratorScale = IteratorScale; + using SmemIteratorScale = IteratorScale; }; // Per column iterators -template -struct DefaultScaleIteratorsMultistage> -{ - // ThreadMap for scale iterator - static_assert((MmaShape::kN % Alignment) == 0, ""); +template +struct DefaultScaleIteratorsMultistage< + MmaShape, + Element, + Layout, + QuantOp, + Alignment, + std::enable_if_t> { + // ThreadMap for scale iterator + static_assert((MmaShape::kN % Alignment) == 0, ""); -private: - using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap, - MmaShape::kN / Alignment, Alignment>; + private: + using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + MmaShape::kN / Alignment, + Alignment>; -public: - // Define iterators over tiles from the scale operand - using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator, - Element, Layout, 0, IteratorScaleThreadMap, Alignment>; + public: + // Define iterators over tiles from the scale operand + using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape<1, MmaShape::kN>, + Element, + Layout, + 0, + IteratorScaleThreadMap, + Alignment>; - using SmemIteratorScale = IteratorScale; + using SmemIteratorScale = IteratorScale; }; //////////////////////////////////////////////////////////////////////////////// @@ -111,69 +139,133 @@ template < typename Operator_, /// Use zfill or predicate for out-of-bound cp.async SharedMemoryClearOption SharedMemoryClear> -struct DqMma= 80 && !layout::IsColumnMajorTileInterleave::value)>::type> -{ +struct DqMma= 80 && + !layout::IsColumnMajorTileInterleave::value)>::type> { + static_assert(platform::is_same::value || + platform::is_same::value || + platform::is_same::value, + "Element A must be fp16, fp8 or bf16"); - static_assert(platform::is_same::value || platform::is_same::value - || platform::is_same::value, - "Element A must be fp16, fp8 or bf16"); + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert( + platform::is_same::value, + "Mma multistage must dequantize after ldsm"); - using OperatorInfo = arch::DetagOperator; - using Operator = typename OperatorInfo::Operator; - static_assert(platform::is_same::value, - "Mma multistage must dequantize after ldsm"); + static_assert(platform::is_same::value || + platform::is_same::value, + "Element B must be uint8 or uint4"); - static_assert(platform::is_same::value || platform::is_same::value, - "Element B must be uint8 or uint4"); + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma + // multistage pieces are created + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; - // Define the MmaCore components - // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, + LayoutA, + 1, + ThreadMapA, + AccessTypeA>; - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, - AccessTypeA>; + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, + LayoutB, + 0, + ThreadMapB, + AccessTypeB>; - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, ElementB, LayoutB, 0, ThreadMapB, - AccessTypeB>; + using ScaleIterators = + DefaultScaleIteratorsMultistage; - using ScaleIterators = DefaultScaleIteratorsMultistage; + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; - // Define iterators over tiles from the scale operand - using IteratorScale = typename ScaleIterators::IteratorScale; + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; - using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + using Converter = FastInterleavedAndBiasedNumericArrayConverter< + ElementScale, + ElementB, + MmaCore::MmaPolicy::Operator::FragmentB::kElements>; - using Converter = FastInterleavedAndBiasedNumericArrayConverter; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage; + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage< + typename MmaCore::Shape, + IteratorA, + typename MmaCore::SmemIteratorA, + MmaCore::kCacheOpA, + IteratorB, + typename MmaCore::SmemIteratorB, + MmaCore::kCacheOpB, + IteratorScale, + SmemIteratorScale, + ElementAccumulator, + layout::RowMajor, + typename MmaCore::MmaPolicy, + kStages, + Converter, + OperatorInfo::QuantOp, + SharedMemoryClear>; }; // Specialization to handle column major interleave B @@ -214,89 +306,159 @@ template < typename Operator_, /// Use zfill or predicate for out-of-bound cp.async SharedMemoryClearOption SharedMemoryClear> -struct DqMma= 80 && layout::IsColumnMajorTileInterleave::value)>::type> -{ +struct DqMma= 80 && + layout::IsColumnMajorTileInterleave::value)>::type> { + static_assert(platform::is_same::value || + platform::is_same::value || + platform::is_same::value, + "Element A must be fp16, fp8 or bf16"); - static_assert(platform::is_same::value || platform::is_same::value - || platform::is_same::value, - "Element A must be fp16, fp8 or bf16"); + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert( + platform::is_same::value, + "Mma multistage must dequantize after ldsm"); - using OperatorInfo = arch::DetagOperator; - using Operator = typename OperatorInfo::Operator; - static_assert(platform::is_same::value, - "Mma multistage must dequantize after ldsm"); + static_assert(platform::is_same::value || + platform::is_same::value, + "Element B must be uint8 or uint4"); - static_assert(platform::is_same::value || platform::is_same::value, - "Element B must be uint8 or uint4"); + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma + // multistage pieces are created + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; - // Define the MmaCore components - // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, + LayoutA, + 1, + ThreadMapA, + AccessTypeA>; - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, - AccessTypeA>; + private: + static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved; + static constexpr int RowsPerTile = LayoutB::kRowsPerTile; + static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); + static_assert(RowsPerTile == MmaCore::Shape::kK, ""); -private: - static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved; - static constexpr int RowsPerTile = LayoutB::kRowsPerTile; - static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); - static_assert(RowsPerTile == MmaCore::Shape::kK, ""); + using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; + using OriginalWarpArrangement = + typename OriginalThreadMap::Detail::WarpThreadArrangement; + static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); - using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; - using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; - static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); + using GmemIteratorShape = + MatrixShape; + using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, + OriginalThreadMap::kThreads, + layout::PitchLinearShape< + OriginalWarpArrangement::kContiguous * ColumnsInterleaved, + OriginalWarpArrangement::kStrided / ColumnsInterleaved>, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; - using GmemIteratorShape - = MatrixShape; - using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, OriginalThreadMap::kThreads, - layout::PitchLinearShape, - MmaCore::kAccessSizeInBits / sizeof_bits::value>; + public: + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + GmemIteratorShape, + ElementB, + layout::ColumnMajor, + 0, + GmemThreadMapB, + AccessTypeB>; -public: - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator; + using ScaleIterators = + DefaultScaleIteratorsMultistage; - using ScaleIterators = DefaultScaleIteratorsMultistage; + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; - // Define iterators over tiles from the scale operand - using IteratorScale = typename ScaleIterators::IteratorScale; + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; - using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + using Converter = FastInterleavedAndBiasedNumericArrayConverter< + ElementScale, + ElementB, + MmaCore::MmaPolicy::Operator::FragmentB::kElements>; - using Converter = FastInterleavedAndBiasedNumericArrayConverter; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage; + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage< + typename MmaCore::Shape, + IteratorA, + typename MmaCore::SmemIteratorA, + MmaCore::kCacheOpA, + IteratorB, + typename MmaCore::SmemIteratorB, + MmaCore::kCacheOpB, + IteratorScale, + SmemIteratorScale, + ElementAccumulator, + layout::RowMajor, + typename MmaCore::MmaPolicy, + kStages, + Converter, + OperatorInfo::QuantOp, + SharedMemoryClear>; }; -} // namespace threadblock -} // namespace gemm -} // namespace cutlass +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h index 345cd2eec9..ba7f2863e0 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,58 +27,95 @@ #include "cutlass_extensions/gemm/threadblock/default_dq_mma.h" #include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h" -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ +namespace cutlass { +namespace gemm { +namespace threadblock { //////////////////////////////////////////////////////////////////////////////// -template +template struct DefaultScaleIteratorsPipelined; // Fine grained iterators -template -struct DefaultScaleIteratorsPipelined> -{ -private: - using SmemScaleType = half_t; +template +struct DefaultScaleIteratorsPipelined< + MmaShape, + Element, + Layout, + QuantOp, + Alignment, + std::enable_if_t> { + private: + using SmemScaleType = half_t; -public: - using IteratorScale - = cutlass::transform::threadblock::FineGrainedScaleZeroIterator, Element, - Layout, 0, Alignment>; + public: + using IteratorScale = + cutlass::transform::threadblock::FineGrainedScaleZeroIterator< + cutlass::MatrixShape<1, MmaShape::kN>, + Element, + Layout, + 0, + Alignment>; - using SmemIteratorScale - = cutlass::transform::threadblock::FineGrainedScaleZeroIterator, - SmemScaleType, Layout, 0, Alignment>; + using SmemIteratorScale = + cutlass::transform::threadblock::FineGrainedScaleZeroIterator< + cutlass::MatrixShape<1, MmaShape::kN>, + SmemScaleType, + Layout, + 0, + Alignment>; }; // Per column iterators -template -struct DefaultScaleIteratorsPipelined> -{ - static_assert((MmaShape::kN % Alignment) == 0, ""); +template +struct DefaultScaleIteratorsPipelined< + MmaShape, + Element, + Layout, + QuantOp, + Alignment, + std::enable_if_t> { + static_assert((MmaShape::kN % Alignment) == 0, ""); -private: - // ThreadMap for scale iterator - using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap, - MmaShape::kN / Alignment, Alignment>; - using SmemScaleType = half_t; + private: + // ThreadMap for scale iterator + using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + MmaShape::kN / Alignment, + Alignment>; + using SmemScaleType = half_t; -public: - // Define iterators over tiles from the scale operand - using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator, - Element, Layout, 0, IteratorScaleThreadMap, Alignment>; + public: + // Define iterators over tiles from the scale operand + using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape<1, MmaShape::kN>, + Element, + Layout, + 0, + IteratorScaleThreadMap, + Alignment>; - using SmemIteratorScale - = cutlass::transform::threadblock::PredicatedTileIterator, SmemScaleType, - Layout, 0, IteratorScaleThreadMap, Alignment>; + using SmemIteratorScale = + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape<1, MmaShape::kN>, + SmemScaleType, + Layout, + 0, + IteratorScaleThreadMap, + Alignment>; }; //////////////////////////////////////////////////////////////////////////////// @@ -116,57 +153,110 @@ template < typename InstructionShape, /// Operation performed by GEMM typename Operator_> -struct DqMma::value)>::type> -{ +struct DqMma::value)>::type> { + static_assert(platform::is_same::value || + platform::is_same::value, + "Element A must be fp16 or bf16"); - static_assert(platform::is_same::value || platform::is_same::value, - "Element A must be fp16 or bf16"); + static_assert(platform::is_same::value || + platform::is_same::value, + "Element B must be uint8 or uint4"); - static_assert(platform::is_same::value || platform::is_same::value, - "Element B must be uint8 or uint4"); + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(OperatorInfo::QuantOp == + WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, + ""); - using OperatorInfo = arch::DetagOperator; - using Operator = typename OperatorInfo::Operator; - static_assert(OperatorInfo::QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, ""); + static constexpr bool DqAfterLDG = + platform::is_same::value; + using MmaCoreElementA = half_t; + using MmaCoreElementB = typename platform:: + conditional::type; - static constexpr bool DqAfterLDG = platform::is_same::value; - using MmaCoreElementA = half_t; - using MmaCoreElementB = typename platform::conditional::type; + // Define the MmaCore components + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + // Define iterators over tiles from the A operand + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementA, + LayoutA, + 1, + typename MmaCore::IteratorThreadMapA, + kAlignmentA>; - // Define iterators over tiles from the A operand - using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, ElementA, LayoutA, 1, - typename MmaCore::IteratorThreadMapA, kAlignmentA>; + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementB, + LayoutB, + 0, + typename MmaCore::IteratorThreadMapB, + kAlignmentB>; - // Define iterators over tiles from the B operand - using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, ElementB, LayoutB, 0, - typename MmaCore::IteratorThreadMapB, kAlignmentB>; + using ScaleIterators = DefaultScaleIteratorsPipelined; - using ScaleIterators = DefaultScaleIteratorsPipelined; + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; - // Define iterators over tiles from the scale operand - using IteratorScale = typename ScaleIterators::IteratorScale; + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; - using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + using Converters = + SetConverters; - using Converters = SetConverters; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined; + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined< + typename MmaCore::Shape, + IteratorA, + typename MmaCore::SmemIteratorA, + IteratorB, + typename MmaCore::SmemIteratorB, + IteratorScale, + SmemIteratorScale, + ElementAccumulator, + layout::RowMajor, + typename MmaCore::MmaPolicy, + typename Converters::TransformAfterLDG, + typename Converters::TransformAfterLDS, + OperatorInfo::QuantOp>; }; // Specialization to handle column major interleave B @@ -203,82 +293,140 @@ template < typename InstructionShape, /// Operation performed by GEMM typename Operator_> -struct DqMma::value)>::type> -{ +struct DqMma::value)>::type> { + static_assert(platform::is_same::value || + platform::is_same::value, + "Element A must be fp16 or bf16"); - static_assert(platform::is_same::value || platform::is_same::value, - "Element A must be fp16 or bf16"); + static_assert(platform::is_same::value || + platform::is_same::value, + "Element B must be uint8 or uint4"); - static_assert(platform::is_same::value || platform::is_same::value, - "Element B must be uint8 or uint4"); + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; - using OperatorInfo = arch::DetagOperator; - using Operator = typename OperatorInfo::Operator; + static constexpr bool DqAfterLDG = + platform::is_same::value; + using MmaCoreElementA = half_t; + using MmaCoreElementB = typename platform:: + conditional::type; - static constexpr bool DqAfterLDG = platform::is_same::value; - using MmaCoreElementA = half_t; - using MmaCoreElementB = typename platform::conditional::type; + // Define the MmaCore components + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + // Define iterators over tiles from the A operand + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementA, + LayoutA, + 1, + typename MmaCore::IteratorThreadMapA, + kAlignmentA>; - // Define iterators over tiles from the A operand - using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, ElementA, LayoutA, 1, - typename MmaCore::IteratorThreadMapA, kAlignmentA>; + private: + static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved; + static constexpr int RowsPerTile = LayoutB::kRowsPerTile; + static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); + static_assert(RowsPerTile == MmaCore::Shape::kK, ""); -private: - static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved; - static constexpr int RowsPerTile = LayoutB::kRowsPerTile; - static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); - static_assert(RowsPerTile == MmaCore::Shape::kK, ""); + using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; + using OriginalWarpArrangement = + typename OriginalThreadMap::Detail::WarpThreadArrangement; + static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); - using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; - using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; - static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); + using GmemIteratorShape = + MatrixShape; + using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, + OriginalThreadMap::kThreads, + layout::PitchLinearShape< + OriginalWarpArrangement::kContiguous * ColumnsInterleaved, + OriginalWarpArrangement::kStrided / ColumnsInterleaved>, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; - using GmemIteratorShape - = MatrixShape; - using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, OriginalThreadMap::kThreads, - layout::PitchLinearShape, - MmaCore::kAccessSizeInBits / sizeof_bits::value>; + public: + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< + GmemIteratorShape, + ElementB, + layout::ColumnMajor, + 0, + GmemThreadMapB, + kAlignmentB>; -public: - // Define iterators over tiles from the B operand - using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator; + // ThreadMap for scale iterator + static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); + using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + MmaCore::Shape::kN / kAlignmentScale, + kAlignmentScale>; - // ThreadMap for scale iterator - static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); - using IteratorScaleThreadMap - = transform::PitchLinearStripminedThreadMap, - MmaCore::Shape::kN / kAlignmentScale, kAlignmentScale>; + using ScaleIterators = DefaultScaleIteratorsPipelined; - using ScaleIterators = DefaultScaleIteratorsPipelined; + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; - // Define iterators over tiles from the scale operand - using IteratorScale = typename ScaleIterators::IteratorScale; + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; - using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + using Converters = + SetConverters; - using Converters = SetConverters; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined; + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined< + typename MmaCore::Shape, + IteratorA, + typename MmaCore::SmemIteratorA, + IteratorB, + typename MmaCore::SmemIteratorB, + IteratorScale, + SmemIteratorScale, + ElementAccumulator, + layout::RowMajor, + typename MmaCore::MmaPolicy, + typename Converters::TransformAfterLDG, + typename Converters::TransformAfterLDS, + OperatorInfo::QuantOp>; }; -} // namespace threadblock -} // namespace gemm -} // namespace cutlass +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h index b50d66380e..31915cf389 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,7 +27,8 @@ namespace threadblock { //////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma pipelined (stage=2) +/// Specialization for row-major output (OperatorClass TensorOp), fp16 +/// activation & int8 weight, mma pipelined (stage=2) template < /// Layout type for A matrix operand typename LayoutA, @@ -49,34 +50,61 @@ template < typename InstructionShape, /// Operation performed by GEMM typename Operator> -struct DefaultMma -{ +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + using Mma = DqMma; - using Mma = DqMma; + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; //////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma pipelined (stage=2) +/// Specialization for row-major output (OperatorClass TensorOp), fp16 +/// activation & int4 weight, mma pipelined (stage=2) template < /// Layout type for A matrix operand typename LayoutA, @@ -98,35 +126,61 @@ template < typename InstructionShape, /// Operation performed by GEMM typename Operator> -struct DefaultMma -{ +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + using Mma = DqMma; - using Mma = DqMma; + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; //////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma multistage -/// (stage>=3) +/// Specialization for row-major output (OperatorClass TensorOp), fp16 +/// activation & int8 weight, mma multistage (stage>=3) template < /// Layout type for A matrix operand typename LayoutA, @@ -152,36 +206,64 @@ template < int kStages, /// Shared memory clear option SharedMemoryClearOption SharedMemoryClear> -struct DefaultMma -{ +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + using Mma = DqMma; - using Mma = DqMma; + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; //////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma multistage -/// (stage>=3) +/// Specialization for row-major output (OperatorClass TensorOp), fp16 +/// activation & int4 weight, mma multistage (stage>=3) template < /// Layout type for A matrix operand typename LayoutA, @@ -207,37 +289,65 @@ template < int kStages, /// Shared memory clear option SharedMemoryClearOption SharedMemoryClear> -struct DefaultMma -{ +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + using Mma = DqMma; - using Mma = DqMma; + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; #ifdef ENABLE_FP8 //////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), fp8 activation & int4 weight, mma multistage -/// (stage>=3) +/// Specialization for row-major output (OperatorClass TensorOp), fp8 activation +/// & int4 weight, mma multistage (stage>=3) template < /// Layout type for A matrix operand typename LayoutA, @@ -263,36 +373,65 @@ template < int kStages, /// Shared memory clear option SharedMemoryClearOption SharedMemoryClear> -struct DefaultMma -{ +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + using Mma = DqMma; - using Mma = DqMma; + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; #endif -// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on -// large tile when not enough shared mem is present to do 3+ stage +// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps +// avoid reg spills on large tile when not enough shared mem is present to do 3+ +// stage template < /// Layout type for A matrix operand typename LayoutA, @@ -318,39 +457,86 @@ template < bool GatherA, /// Gather operand B by using an index array bool GatherB> -struct DefaultMma -{ +struct DefaultMma { + // Define the MmaCore components + // 3 is used on purpose here to trigger components for mma multistage + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; - // Define the MmaCore components - // 3 is used on purpose here to trigger components for mma multistage - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + half_t, + LayoutA, + 1, + ThreadMapA, + AccessTypeA, + GatherA>; - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, half_t, LayoutA, 1, ThreadMapA, AccessTypeA, - GatherA>; + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + half_t, + LayoutB, + 0, + ThreadMapB, + AccessTypeB, + GatherB>; - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, half_t, LayoutB, 0, ThreadMapB, AccessTypeB, - GatherB>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage; + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = + cutlass::gemm::threadblock::MmaMultistage; }; //////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), fbf16 activation & int2 weight, mma multistage +/// Specialization for row-major output (OperatorClass TensorOp), fbf16 +/// activation & int2 weight, mma multistage template < /// Layout type for A matrix operand @@ -373,26 +559,50 @@ template < typename InstructionShape, /// Operation performed by GEMM typename Operator> -struct DefaultMma -{ -private: - using Mma = DefaultWint2xMma; +struct DefaultMma { + private: + using Mma = DefaultWint2xMma; -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; template < @@ -420,29 +630,55 @@ template < int kStages, /// Shared memory clear option SharedMemoryClearOption SharedMemoryClear> -struct DefaultMma -{ -private: - using Mma = DefaultWint2xMma; +struct DefaultMma { + private: + using Mma = DefaultWint2xMma; -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; -} // namespace threadblock -} // namespace gemm -} // namespace cutlass +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h index 300261c3f0..1ff648e0bf 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,7 +27,8 @@ namespace threadblock { //////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & bf16 weight +/// Specialization for row-major output (OperatorClass TensorOp), bf16 +/// activation & bf16 weight template < /// Layout type for A matrix operand typename LayoutA, @@ -55,40 +56,85 @@ template < bool GatherA, /// Gather operand B by using an index array bool GatherB> -struct DefaultMma -{ +struct DefaultMma { + private: + // Conversions only needed pre-ampere. This will trigger mma pipeline, so we + // convert before STS. + static constexpr bool arch_has_bf16_mma = + ArchTag::kMinComputeCapability >= 80; + using MmaElementA = typename platform:: + conditional::type; + using MmaElementB = typename platform:: + conditional::type; -private: - // Conversions only needed pre-ampere. This will trigger mma pipeline, so we convert before STS. - static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80; - using MmaElementA = typename platform::conditional::type; - using MmaElementB = typename platform::conditional::type; + public: + // Define the MmaCore components + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; -public: - // Define the MmaCore components - using MmaCore = - typename cutlass::gemm::threadblock::DefaultMmaCore; + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + bfloat16_t, + LayoutA, + 1, + typename MmaCore::IteratorThreadMapA, + kAlignmentA, + GatherA>; - using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, bfloat16_t, LayoutA, 1, - typename MmaCore::IteratorThreadMapA, kAlignmentA, GatherA>; + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + bfloat16_t, + LayoutB, + 0, + typename MmaCore::IteratorThreadMapB, + kAlignmentB, + GatherB>; - // Define iterators over tiles from the B operand - using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, bfloat16_t, LayoutB, 0, - typename MmaCore::IteratorThreadMapB, kAlignmentB, GatherB>; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined; + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = + cutlass::gemm::threadblock::MmaPipelined; }; -// bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on -// large tile when not enough shared mem is present to do 3+ stage +// bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps +// avoid reg spills on large tile when not enough shared mem is present to do 3+ +// stage template < /// Layout type for A matrix operand typename LayoutA, @@ -114,40 +160,86 @@ template < bool GatherA, /// Gather operand B by using an index array bool GatherB> -struct DefaultMma -{ +struct DefaultMma { + // Define the MmaCore components + // 3 is used on purpose here to trigger components for mma multistage + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; - // Define the MmaCore components - // 3 is used on purpose here to trigger components for mma multistage - using MmaCore = - typename cutlass::gemm::threadblock::DefaultMmaCore; + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + bfloat16_t, + LayoutA, + 1, + ThreadMapA, + AccessTypeA, + GatherA>; - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, bfloat16_t, LayoutA, 1, ThreadMapA, - AccessTypeA, GatherA>; + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + bfloat16_t, + LayoutB, + 0, + ThreadMapB, + AccessTypeB, + GatherB>; - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, bfloat16_t, LayoutB, 0, ThreadMapB, - AccessTypeB, GatherB>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage; + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = + cutlass::gemm::threadblock::MmaMultistage; }; //////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight +/// Specialization for row-major output (OperatorClass TensorOp), bf16 +/// activation & int8 weight template < /// Layout type for A matrix operand typename LayoutA, @@ -169,34 +261,61 @@ template < typename InstructionShape, /// Operation performed by GEMM typename Operator> -struct DefaultMma -{ +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + using Mma = DqMma; - using Mma = DqMma; + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; //////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight +/// Specialization for row-major output (OperatorClass TensorOp), bf16 +/// activation & int4 weight template < /// Layout type for A matrix operand typename LayoutA, @@ -218,34 +337,61 @@ template < typename InstructionShape, /// Operation performed by GEMM typename Operator> -struct DefaultMma -{ +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + using Mma = DqMma; - using Mma = DqMma; + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; //////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight +/// Specialization for row-major output (OperatorClass TensorOp), bf16 +/// activation & int8 weight template < /// Layout type for A matrix operand typename LayoutA, @@ -271,35 +417,64 @@ template < int kStages, /// Shared memory clear option SharedMemoryClearOption SharedMemoryClear> -struct DefaultMma -{ +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + using Mma = DqMma; - using Mma = DqMma; + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; //////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight +/// Specialization for row-major output (OperatorClass TensorOp), fp16 +/// activation & int4 weight template < /// Layout type for A matrix operand typename LayoutA, @@ -325,35 +500,64 @@ template < int kStages, /// Shared memory clear option SharedMemoryClearOption SharedMemoryClear> -struct DefaultMma -{ +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + using Mma = DqMma; - using Mma = DqMma; + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; //////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), fbf16 activation & int2 weight, mma multistage +/// Specialization for row-major output (OperatorClass TensorOp), fbf16 +/// activation & int2 weight, mma multistage template < /// Layout type for A matrix operand @@ -376,26 +580,50 @@ template < typename InstructionShape, /// Operation performed by GEMM typename Operator> -struct DefaultMma -{ -private: - using Mma = DefaultWint2xMma; +struct DefaultMma { + private: + using Mma = DefaultWint2xMma; -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; template < @@ -423,29 +651,55 @@ template < int kStages, /// Shared memory clear option SharedMemoryClearOption SharedMemoryClear> -struct DefaultMma -{ -private: - using Mma = DefaultWint2xMma; +struct DefaultMma { + private: + using Mma = DefaultWint2xMma; -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; -} // namespace threadblock -} // namespace gemm -} // namespace cutlass +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_core.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_core.h index e2bc640bac..58fd564416 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_core.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_core.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ @@ -66,10 +67,21 @@ template < cutlass::arch::CacheOperation::Kind CacheOpA, /// Cache operation of operand B cutlass::arch::CacheOperation::Kind CacheOpB> -struct DefaultMmaCore { +struct DefaultMmaCore { using Shape = Shape_; using WarpShape = WarpShape_; using InstructionShape = InstructionShape_; @@ -104,7 +116,8 @@ struct DefaultMmaCore::value) / kAccessSizeInBits; + (Shape::kK * Shape::kN * sizeof_bits::value) / + kAccessSizeInBits; static constexpr int kThreadsForB = kMaxThreadsForB > kThreads ? kThreads : kMaxThreadsForB; @@ -129,11 +142,13 @@ struct DefaultMmaCore::value, Shape::kK>; + sizeof_bits::value, + Shape::kK>; // Shared memory layout using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< - sizeof_bits::value, Shape::kK>; + sizeof_bits::value, + Shape::kK>; // // Iterators to write to shared memory @@ -141,26 +156,34 @@ struct DefaultMmaCore, kThreads, + layout::PitchLinearShape, + kThreads, layout::PitchLinearShape, kAccessSizeInBits / sizeof_bits::value>; /// Shared memory iterator to A operand using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementA, SmemLayoutA, 0, + MatrixShape, + ElementA, + SmemLayoutA, + 0, IteratorThreadMapA>; /// ThreadMap of iterator B using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, kThreadsForB, + layout::PitchLinearShape, + kThreadsForB, layout::PitchLinearShape, kAccessSizeInBits / sizeof_bits::value>; /// Shared memory iterator to B operand using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< - MatrixShape, ElementB, SmemLayoutB, 1, + MatrixShape, + ElementB, + SmemLayoutB, + 1, IteratorThreadMapB>; // @@ -168,13 +191,23 @@ struct DefaultMmaCore::Type; + using MmaTensorOp = + typename cutlass::gemm::warp::DefaultMmaTensorOp::Type; /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy, - MatrixShape<0, 0>, WarpCount::kK>; + using MmaPolicy = MmaPolicy, + MatrixShape<0, 0>, + WarpCount::kK>; }; } // namespace threadblock diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h index 1782330de8..e4684b610b 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,47 +31,61 @@ namespace threadblock { template struct DefaultQuantParamsIterators { -private: - static constexpr int kAlignment = 128 / sizeof_bits::value; - static_assert((ThreadblockShape::kN % kAlignment) == 0, ""); + private: + static constexpr int kAlignment = 128 / sizeof_bits::value; + static_assert((ThreadblockShape::kN % kAlignment) == 0, ""); - static constexpr int kRows = - (GroupSize == -1) ? 1 : (ThreadblockShape::kK + GroupSize - 1) / GroupSize; - static constexpr int kColumns = ThreadblockShape::kN; + static constexpr int kRows = + (GroupSize == -1) ? 1 + : (ThreadblockShape::kK + GroupSize - 1) / GroupSize; + static constexpr int kColumns = ThreadblockShape::kN; - using IteratorThreadMap = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kColumns / kAlignment, kAlignment>; + using IteratorThreadMap = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kColumns / kAlignment, + kAlignment>; -public: - using Iterator = cutlass::transform::threadblock::PredicatedTileIterator< - MatrixShape, ElementT, layout::RowMajor, 0, - IteratorThreadMap, kAlignment>; - using SmemIterator = Iterator; + public: + using Iterator = cutlass::transform::threadblock::PredicatedTileIterator< + MatrixShape, + ElementT, + layout::RowMajor, + 0, + IteratorThreadMap, + kAlignment>; + using SmemIterator = Iterator; }; template struct DefaultQuantParamsIterators { -private: - static constexpr int kAlignment = 32 / sizeof_bits::value; - static_assert((ThreadblockShape::kN % kAlignment) == 0, ""); + private: + static constexpr int kAlignment = 32 / sizeof_bits::value; + static_assert((ThreadblockShape::kN % kAlignment) == 0, ""); - static constexpr int kRows = - (GroupSize == -1) ? 1 : (ThreadblockShape::kK + 2 * GroupSize - 1) / (2 * GroupSize); - static constexpr int kColumns = - (GroupSize == -1) ? ThreadblockShape::kN : ThreadblockShape::kN * 2; + static constexpr int kRows = + (GroupSize == -1) + ? 1 + : (ThreadblockShape::kK + 2 * GroupSize - 1) / (2 * GroupSize); + static constexpr int kColumns = + (GroupSize == -1) ? ThreadblockShape::kN : ThreadblockShape::kN * 2; - using IteratorThreadMap = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kColumns / kAlignment, kAlignment>; + using IteratorThreadMap = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kColumns / kAlignment, + kAlignment>; -public: - using AccessType = cutlass::Array; - using Iterator = cutlass::transform::threadblock::PredicatedTileAccessIterator< - MatrixShape, uint4b_t, layout::RowMajor, - 0, IteratorThreadMap, AccessType>; + public: + using AccessType = cutlass::Array; + using Iterator = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + MatrixShape, + uint4b_t, + layout::RowMajor, + 0, + IteratorThreadMap, + AccessType>; - using SmemIterator = Iterator; + using SmemIterator = Iterator; }; template < @@ -142,105 +156,174 @@ template < typename Operator, /// Use zfill or predicate for out-of-bound cp.async SharedMemoryClearOption SharedMemoryClear> -struct DefaultWint2xMma -{ -public: - static_assert(platform::is_same::value || platform::is_same::value, - "Element A must be fp16 or bf16"); +struct DefaultWint2xMma { + public: + static_assert(platform::is_same::value || + platform::is_same::value, + "Element A must be fp16 or bf16"); - static_assert(platform::is_same::value, - "Element B must be uint2b_t"); + static_assert(platform::is_same::value, + "Element B must be uint2b_t"); - static_assert(platform::is_same::value, - "Mma multistage must dequantize after ldsm"); + static_assert( + platform::is_same::value, + "Mma multistage must dequantize after ldsm"); - using ElementSuperScale = ElementA; - using ElementLocalScale = uint4b_t; - using ElementCodeScaleZp = float; + using ElementSuperScale = ElementA; + using ElementLocalScale = uint4b_t; + using ElementCodeScaleZp = float; - static constexpr int kGroupSize = 64; + static constexpr int kGroupSize = 64; - static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; - // Define the MmaCore components - // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma + // multistage pieces are created + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, - AccessTypeA>; + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, + LayoutA, + 1, + ThreadMapA, + AccessTypeA>; -private: - static constexpr int kColumnsInterleaved = LayoutB::kColumnsInterleaved; - static constexpr int kRowsPerTile = LayoutB::kRowsPerTile; - static_assert(!(MmaCore::Shape::kN % kColumnsInterleaved), "ThreadblockShape must be disivle by kColumnsInterleaved"); - static_assert(kRowsPerTile == MmaCore::Shape::kK, ""); + private: + static constexpr int kColumnsInterleaved = LayoutB::kColumnsInterleaved; + static constexpr int kRowsPerTile = LayoutB::kRowsPerTile; + static_assert(!(MmaCore::Shape::kN % kColumnsInterleaved), + "ThreadblockShape must be disivle by kColumnsInterleaved"); + static_assert(kRowsPerTile == MmaCore::Shape::kK, ""); - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using WarpArrangement = typename ThreadMapB::Detail::WarpThreadArrangement; - static_assert(!(WarpArrangement::kStrided % kColumnsInterleaved), ""); + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using WarpArrangement = typename ThreadMapB::Detail::WarpThreadArrangement; + static_assert(!(WarpArrangement::kStrided % kColumnsInterleaved), ""); - using IteratorShapeB = MatrixShape< - MmaCore::Shape::kK * kColumnsInterleaved, MmaCore::Shape::kN / kColumnsInterleaved>; - using InterleavedThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, - ThreadMapB::kThreads, - layout::PitchLinearShape, - MmaCore::kAccessSizeInBits / sizeof_bits::value>; + using IteratorShapeB = MatrixShape; + using InterleavedThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, + ThreadMapB::kThreads, + layout::PitchLinearShape, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; -public: - // Define iterators over tiles from the B operand - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - IteratorShapeB, ElementB, layout::ColumnMajor, 0, InterleavedThreadMapB, - AccessTypeB>; + public: + // Define iterators over tiles from the B operand + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + IteratorShapeB, + ElementB, + layout::ColumnMajor, + 0, + InterleavedThreadMapB, + AccessTypeB>; -private: - // Define iterators over tiles from extra quant params for B operand - using IteratorSuperScale = typename DefaultQuantParamsIterators< - ThreadblockShape, ElementSuperScale, -1>::Iterator; - using SmemIteratorSuperScale = typename DefaultQuantParamsIterators< - ThreadblockShape, ElementSuperScale, -1>::SmemIterator; + private: + // Define iterators over tiles from extra quant params for B operand + using IteratorSuperScale = + typename DefaultQuantParamsIterators::Iterator; + using SmemIteratorSuperScale = + typename DefaultQuantParamsIterators::SmemIterator; - using IteratorLocalScale = typename DefaultQuantParamsIterators< - ThreadblockShape, ElementLocalScale, kGroupSize>::Iterator; - using SmemIteratorLocalScale = typename DefaultQuantParamsIterators< - ThreadblockShape, ElementLocalScale, kGroupSize>::SmemIterator; + using IteratorLocalScale = + typename DefaultQuantParamsIterators::Iterator; + using SmemIteratorLocalScale = + typename DefaultQuantParamsIterators::SmemIterator; - using IteratorCodeScaleZp = typename DefaultQuantParamsIterators< - ThreadblockShape, ElementCodeScaleZp, -1>::Iterator; - using SmemIteratorCodeScaleZp = typename DefaultQuantParamsIterators< - ThreadblockShape, ElementCodeScaleZp, -1>::Iterator; + using IteratorCodeScaleZp = + typename DefaultQuantParamsIterators::Iterator; + using SmemIteratorCodeScaleZp = + typename DefaultQuantParamsIterators::Iterator; -public: - using QuantParamsAccessor = Wint2ParamsAccessor< - ElementA, ThreadblockShape, IteratorSuperScale, SmemIteratorSuperScale, - IteratorLocalScale, SmemIteratorLocalScale, - IteratorCodeScaleZp, SmemIteratorCodeScaleZp, kStages, kGroupSize>; + public: + using QuantParamsAccessor = Wint2ParamsAccessor; - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage< - typename MmaCore::Shape, - IteratorA, typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, - IteratorB, typename MmaCore::SmemIteratorB, MmaCore::kCacheOpB, - ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, - kStages, QuantParamsAccessor, SharedMemoryClear>; + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage< + typename MmaCore::Shape, + IteratorA, + typename MmaCore::SmemIteratorA, + MmaCore::kCacheOpA, + IteratorB, + typename MmaCore::SmemIteratorB, + MmaCore::kCacheOpB, + ElementAccumulator, + layout::RowMajor, + typename MmaCore::MmaPolicy, + kStages, + QuantParamsAccessor, + SharedMemoryClear>; }; -} // namespace threadblock -} // namespace gemm -} // namespace cutlass +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_base.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_base.h index 1fb7f7eb28..51e410aad3 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_base.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_base.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file @@ -47,30 +48,33 @@ //////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ +namespace cutlass { +namespace gemm { +namespace threadblock { //////////////////////////////////////////////////////////////////////////////// // SFINAE trick so I can keep the same loop code for Volta and dispatch to the -// correct warp level mma. On volta, all data is stored to shared memory as FP16. +// correct warp level mma. On volta, all data is stored to shared memory as +// FP16. template -CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D, - typename WarpMma::FragmentA const& A, typename WarpMma::FragmentB const& B, typename WarpMma::FragmentC const& C, - int const warp_tileB_k_offset) -{ - warp_mma(D, A, B, C); +CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, + typename WarpMma::FragmentC& D, + typename WarpMma::FragmentA const& A, + typename WarpMma::FragmentB const& B, + typename WarpMma::FragmentC const& C, + int const warp_tileB_k_offset) { + warp_mma(D, A, B, C); } template -CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D, - typename WarpMma::TransformedFragmentA const& A, typename WarpMma::TransformedFragmentB const& B, - typename WarpMma::FragmentC const& C, int const warp_tileB_k_offset) -{ - warp_mma(D, A, B, C, warp_tileB_k_offset); +CUTLASS_DEVICE void run_warp_mma( + WarpMma& warp_mma, + typename WarpMma::FragmentC& D, + typename WarpMma::TransformedFragmentA const& A, + typename WarpMma::TransformedFragmentB const& B, + typename WarpMma::FragmentC const& C, + int const warp_tileB_k_offset) { + warp_mma(D, A, B, C, warp_tileB_k_offset); } //////////////////////////////////////////////////////////////////////////////// @@ -90,168 +94,169 @@ template < WeightOnlyQuantOp DequantOp, /// Used for partial specialization, typename Enable = bool> -class DqMmaBase -{ -public: - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; +class DqMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; - ///< Policy describing tuning details - using Policy = Policy_; + ///< Policy describing tuning details + using Policy = Policy_; - ///< Type of the scale to be loaded - using ElementScale = ElementScale_; + ///< Type of the scale to be loaded + using ElementScale = ElementScale_; - static_assert(DequantOp != WeightOnlyQuantOp::UNDEFINED, ""); + static_assert(DequantOp != WeightOnlyQuantOp::UNDEFINED, ""); - // Finegrained scales get streamed in via cp.async - static constexpr int ScalebiasStages = isFinegrained(DequantOp) ? Stages : 1; - // We always have scales. - static constexpr int ScaleElementsPerStage = Shape::kN; - // We sometimes have a bias - static constexpr int BiasElementsPerStage = hasZero(DequantOp) ? Shape::kN : 0; + // Finegrained scales get streamed in via cp.async + static constexpr int ScalebiasStages = isFinegrained(DequantOp) ? Stages : 1; + // We always have scales. + static constexpr int ScaleElementsPerStage = Shape::kN; + // We sometimes have a bias + static constexpr int BiasElementsPerStage = + hasZero(DequantOp) ? Shape::kN : 0; + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; + + /// Number of warp-level GEMM operations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + static constexpr int kNumKIterationsPerWarpBLoad = + Operator::IteratorB::InstructionShape::kRow / + Operator::InstructionShape::kK; + + static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), ""); + static constexpr int kWarpGemmIterationsForB = + kWarpGemmIterations / kNumKIterationsPerWarpBLoad; + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = + TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = + TensorRef; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: // - // Dependent types + // Type definitions // - /// Warp-level Mma - using Operator = typename Policy::Operator; + /// Shape of the A matrix operand in shared memory + using ShapeA = + MatrixShape; - /// Shape describing the overall GEMM computed from shared memory - /// by each warp. - using WarpGemm = typename Policy::Operator::Shape; + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape; - /// Shape describing the number of warps filling the CTA - using WarpCount = GemmShape; + /// Shape of the shared memory buffer for the scales for the B matrix. + using ShapeScale = MatrixShape; + /// Shape of the shared memory buffer for the biases of the B matrix. + using ShapeZero = MatrixShape; - /// Number of warp-level GEMM operations - static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); - - static constexpr int kNumKIterationsPerWarpBLoad - = Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK; - - static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), ""); - static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad; - - /// Number of stages - static int const kStages = Stages; - - /// Tensor reference to the A operand - using TensorRefA = TensorRef; - - /// Tensor reference to the B operand - using TensorRefB = TensorRef; - - // - // Nested structs - // - - /// Shared storage object needed by threadblock-scoped GEMM - class SharedStorage - { - public: - // - // Type definitions - // - - /// Shape of the A matrix operand in shared memory - using ShapeA - = MatrixShape; - - /// Shape of the B matrix operand in shared memory - using ShapeB - = MatrixShape; - - /// Shape of the shared memory buffer for the scales for the B matrix. - using ShapeScale = MatrixShape; - /// Shape of the shared memory buffer for the biases of the B matrix. - using ShapeZero = MatrixShape; - - public: - // - // Data members - // - - /// Buffer for A operand - AlignedBuffer operand_A; - - /// Buffer for B operand - AlignedBuffer operand_B; - - /// Buffer to hold scales for threadblock - AlignedBuffer operand_scale; - - /// Buffer to hold scales for threadblock - AlignedBuffer operand_zero; - - public: - // - // Methods - // - - /// Returns a layout object for the A matrix - CUTLASS_DEVICE - static typename Operator::LayoutA LayoutA() - { - return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); - } - - /// Returns a layout object for the B matrix - CUTLASS_HOST_DEVICE - static typename Operator::LayoutB LayoutB() - { - return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); - } - - /// Returns a TensorRef to the A operand - CUTLASS_HOST_DEVICE - TensorRefA operand_A_ref() - { - return TensorRefA{operand_A.data(), LayoutA()}; - } - - /// Returns a TensorRef to the B operand - CUTLASS_HOST_DEVICE - TensorRefB operand_B_ref() - { - return TensorRefB{operand_B.data(), LayoutB()}; - } - }; - -protected: + public: // // Data members // - /// Iterator to load a warp-scoped tile of A operand from shared memory - typename Operator::IteratorA warp_tile_iterator_A_; + /// Buffer for A operand + AlignedBuffer operand_A; - /// Iterator to load a warp-scoped tile of B operand from shared memory - typename Operator::IteratorB warp_tile_iterator_B_; + /// Buffer for B operand + AlignedBuffer operand_B; -public: - /// Construct from tensor references + /// Buffer to hold scales for threadblock + AlignedBuffer operand_scale; + + /// Buffer to hold scales for threadblock + AlignedBuffer operand_zero; + + public: + // + // Methods + // + + /// Returns a layout object for the A matrix CUTLASS_DEVICE - DqMmaBase( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - SharedStorage& shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx) - : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx) - , warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) - { + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + }; + + protected: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage& shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace threadblock -} // namespace gemm -} // namespace cutlass +} // namespace threadblock +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h index 3c4036dd8c..4a4a3137be 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file @@ -48,12 +49,9 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ +namespace cutlass { +namespace gemm { +namespace threadblock { ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -102,9 +100,9 @@ template < typename Enable = void> class DqMmaMultistage; -} // namespace threadblock -} // namespace gemm -} // namespace cutlass +} // namespace threadblock +} // namespace gemm +} // namespace cutlass #include "cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h" #include "cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h" diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h index e87a51b22c..02b3b11840 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file @@ -48,12 +49,9 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ +namespace cutlass { +namespace gemm { +namespace threadblock { ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -98,460 +96,587 @@ template < WeightOnlyQuantOp QuantOp_, /// Use zfill or predicate for out-of-bound cp.async SharedMemoryClearOption SharedMemoryClear> -class DqMmaMultistage> - : public DqMmaBase -{ -public: - ///< Base class - using Base = DqMmaBase; - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - ///< Iterates over tiles of A operand in global memory - using IteratorA = IteratorA_; - ///< Iterates over tiles of B operand in global memory - using IteratorB = IteratorB_; - ///< Data type of accumulator matrix - using ElementC = ElementC_; - ///< Layout of accumulator matrix - using LayoutC = LayoutC_; - ///< Policy describing tuning details - using Policy = Policy_; +class DqMmaMultistage> + : public DqMmaBase { + public: + ///< Base class + using Base = DqMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; - using IteratorScale = IteratorScale_; - using ElementScale = typename IteratorScale::Element; - using LayoutScale = typename IteratorScale::Layout; + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - using SmemIteratorScale = SmemIteratorScale_; + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - using TransformBAfterLDS = TransformBAfterLDS_; + using TransformBAfterLDS = TransformBAfterLDS_; - static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; - // - // Dependent types - // + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + // + // Dependent types + // - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; - /// Warp-level Mma - using Operator = typename Policy::Operator; + /// Warp-level Mma + using Operator = typename Policy::Operator; - /// Minimum architecture is Sm80 to support cp.async - using ArchTag = arch::Sm80; + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; - using Dequantizer = warp::MmaTensorOpDequantizer; + using Dequantizer = warp::MmaTensorOpDequantizer; - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; - static_assert(Base::SharedStorage::ShapeScale::kRow == Stages, ""); - static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, ""); + static_assert(Base::SharedStorage::ShapeScale::kRow == Stages, ""); + static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, ""); - /// Internal structure exposed for introspection. - struct Detail - { + /// Internal structure exposed for introspection. + struct Detail { + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); - static_assert(Base::kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; - /// Number of cp.async instructions to load one stage of operand A - static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; - /// Number of cp.async instructions to load one stage of operand B - static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount; + /// Number of stages + static int const kStages = Stages; - /// Number of stages - static int const kStages = Stages; + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; - /// Number of cp.async instructions to load on group of operand A - static int const kAccessesPerGroupA - = (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + }; - /// Number of cp.async instructions to load on group of operand B - static int const kAccessesPerGroupB - = (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - }; + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; -private: - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - Dequantizer warp_dequantizer_; + using ElementA = typename IteratorA::Element; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; - using ElementA = typename IteratorA::Element; - using ElementB = typename IteratorB::Element; - using LayoutDetailsForB = kernel::LayoutDetailsB; + static constexpr bool RequiresTileInterleave = + layout::IsColumnMajorTileInterleave< + typename LayoutDetailsForB::Layout>::value; + static_assert(!RequiresTileInterleave || + (RequiresTileInterleave && + (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); - static constexpr bool RequiresTileInterleave - = layout::IsColumnMajorTileInterleave::value; - static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), - "Layout K must match threadblockK"); + private: + // + // Data members + // -private: - // - // Data members - // + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; + /// Iterator to write threadblock-scoped tile of scale and zero operand to + /// shared memory + SmemIteratorScale smem_iterator_scale_; - /// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory - SmemIteratorScale smem_iterator_scale_; + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + /// The group size for quantization + int const group_size, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_dequantizer_( + {shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + {shared_storage.operand_zero.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / + Base::WarpCount::kM, + lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_iterator_scale_(LayoutScale(Shape::kN), + shared_storage.operand_scale.data(), + shared_storage.operand_zero.data(), + {Base::kStages, Shape::kN}, + thread_idx, + group_size) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension -public: - /// Construct from tensor references - CUTLASS_DEVICE - DqMmaMultistage( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage& shared_storage, - /// The group size for quantization - int const group_size, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx) - : Base(shared_storage, thread_idx, warp_idx, lane_idx) - , warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, - {shared_storage.operand_zero.data(), LayoutScale(Shape::kN)}, - (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx) - , smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx) - , smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) - , smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), - shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size) - { - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + CUTLASS_DEVICE + void copy_scales_and_advance(IteratorScale& iterator_scale, + int stage = -1, + int k_iter = -1) { + static_assert(IteratorScale::Shape::kRow == 1, "Scale stride must be 1."); + + typename IteratorScale::AccessType* gmem_scale_ptr = + iterator_scale.get_scale(); + typename IteratorScale::AccessType* gmem_zero_ptr = + iterator_scale.get_zero(); + + typename IteratorScale::AccessType* smem_scale_ptr = + reinterpret_cast( + this->smem_iterator_scale_.get_scale()); + typename IteratorScale::AccessType* smem_zero_ptr = + reinterpret_cast( + this->smem_iterator_scale_.get_zero()); + + int const kSrcBytes = sizeof_bits::value * + IteratorScale::kAlignment / 8; + + cutlass::arch::cp_async( + smem_scale_ptr, gmem_scale_ptr, iterator_scale.valid()); + + if (gmem_zero_ptr != nullptr) { + cutlass::arch::cp_async( + smem_zero_ptr, gmem_zero_ptr, iterator_scale.valid()); } - CUTLASS_DEVICE - void copy_scales_and_advance(IteratorScale& iterator_scale, int stage = -1, int k_iter = -1) - { - static_assert(IteratorScale::Shape::kRow == 1, "Scale stride must be 1."); - - typename IteratorScale::AccessType* gmem_scale_ptr = iterator_scale.get_scale(); - typename IteratorScale::AccessType* gmem_zero_ptr = iterator_scale.get_zero(); - - typename IteratorScale::AccessType* smem_scale_ptr - = reinterpret_cast(this->smem_iterator_scale_.get_scale()); - typename IteratorScale::AccessType* smem_zero_ptr - = reinterpret_cast(this->smem_iterator_scale_.get_zero()); - - int const kSrcBytes = sizeof_bits::value * IteratorScale::kAlignment / 8; - - cutlass::arch::cp_async(smem_scale_ptr, gmem_scale_ptr, iterator_scale.valid()); - - if (gmem_zero_ptr != nullptr) - { - cutlass::arch::cp_async(smem_zero_ptr, gmem_zero_ptr, iterator_scale.valid()); + if (iterator_scale.group_size_ == 64) { + iterator_scale.add_tile_offset({1, 0}); + } else if (iterator_scale.group_size_ == 128) { + if constexpr (Shape::kK == 128) { + iterator_scale.add_tile_offset({1, 0}); + } else if constexpr (Shape::kK == 64) { + if (iterator_scale.row_groupsize64_ & 0x1) { + iterator_scale.add_tile_offset({1, 0}); } - - if (iterator_scale.group_size_ == 64) - { - iterator_scale.add_tile_offset({1, 0}); - } - else if (iterator_scale.group_size_ == 128) - { - if constexpr (Shape::kK == 128) - { - iterator_scale.add_tile_offset({1, 0}); - } - else if constexpr (Shape::kK == 64) - { - if (iterator_scale.row_groupsize64_ & 0x1) - { - iterator_scale.add_tile_offset({1, 0}); - } - } - else - { - static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128"); - } - } - - iterator_scale.row_groupsize64_++; - - this->smem_iterator_scale_.add_tile_offset({1, 0}); + } else { + static_assert(Shape::kK == 0, + "Unsupported k tile shape, can only be 64 or 128"); + } } - CUTLASS_DEVICE - void copy_tiles_and_advance( - IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0) - { - iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); - this->smem_iterator_A_.set_iteration_index(group_start_A); + iterator_scale.row_groupsize64_++; + + this->smem_iterator_scale_.add_tile_offset({1, 0}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA& iterator_A, + IteratorB& iterator_B, + int group_start_A = 0, + int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; - // Async Copy for operand A CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) - { - if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) - { - typename IteratorA::AccessType* dst_ptr - = reinterpret_cast(this->smem_iterator_A_.get()); + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); - int const kSrcBytes = sizeof_bits::value - * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) - { - auto gmem_ptr = iterator_A.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) - { - cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_A.valid()); - } - else - { - cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_A.valid()); - } - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } + ++iterator_A; } - iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); - this->smem_iterator_B_.set_iteration_index(group_start_B); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) - { - if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) - { - typename IteratorB::AccessType* dst_ptr - = reinterpret_cast(this->smem_iterator_B_.get()); - - int const kSrcBytes = sizeof_bits::value - * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) - { - auto gmem_ptr = iterator_B.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) - { - cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - else - { - cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - - ++iterator_B; - } - ++this->smem_iterator_B_; - } - } + ++this->smem_iterator_A_; + } } - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()( - ///< problem size of GEMM - int gemm_k_iterations, - ///< destination accumulator tile - FragmentC& accum, - ///< iterator over A operand in global memory - IteratorA iterator_A, - ///< iterator over B operand in global memory - IteratorB iterator_B, - ///< iterator over scale operand in global memory - IteratorScale iterator_scale, - ///< initial value of accumulator - FragmentC const& src_accum) - { + iterator_B.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); - // - // Prologue - // + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); - TransformBAfterLDS lds_converter; + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; - // Issue several complete stages CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) - { + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - iterator_scale.clear_mask(gemm_k_iterations == 0); + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } - iterator_A.set_iteration_index(0); - this->smem_iterator_A_.set_iteration_index(0); + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) - { - typename IteratorA::AccessType* dst_ptr - = reinterpret_cast(this->smem_iterator_A_.get()); + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over scale operand in global memory + IteratorScale iterator_scale, + ///< initial value of accumulator + FragmentC const& src_accum) { + // + // Prologue + // - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) - { - int const kSrcBytes = sizeof_bits::value - * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + TransformBAfterLDS lds_converter; - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_A.get(), iterator_A.valid()); + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); - ++iterator_A; - } + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); - ++this->smem_iterator_A_; - } + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); - iterator_B.set_iteration_index(0); - this->smem_iterator_B_.set_iteration_index(0); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) - { - typename IteratorB::AccessType* dst_ptr - = reinterpret_cast(this->smem_iterator_B_.get()); + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) - { - int const kSrcBytes = sizeof_bits::value - * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_B.get(), iterator_B.valid()); - - ++iterator_B; - } - - ++this->smem_iterator_B_; - } - - copy_scales_and_advance(iterator_scale, stage, gemm_k_iterations); - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Defines the boundary of a stage of cp.async. - cutlass::arch::cp_async_fence(); + ++iterator_A; } - // Perform accumulation in the 'd' output operand - accum = src_accum; + ++this->smem_iterator_A_; + } - // - // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels - // so that all accumulator elements outside the GEMM footprint are zero. - // + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); - if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) - { + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; - typename IteratorA::AccessType zero_A; - zero_A.clear(); + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); - last_smem_iterator_A.set_iteration_index(0); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) - { - - typename IteratorA::AccessType* dst_ptr - = reinterpret_cast(last_smem_iterator_A.get()); - - *dst_ptr = zero_A; - - ++last_smem_iterator_A; - } - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); - typename IteratorB::AccessType zero_B; - - zero_B.clear(); - last_smem_iterator_B.set_iteration_index(0); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) - { - - typename IteratorB::AccessType* dst_ptr - = reinterpret_cast(last_smem_iterator_B.get()); - - *dst_ptr = zero_B; - - ++last_smem_iterator_B; - } + ++iterator_B; } - // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) - cutlass::arch::cp_async_wait(); - __syncthreads(); + ++this->smem_iterator_B_; + } - // Pair of fragments used to overlap shared memory loads and math - // instructions - WarpFragmentA warp_frag_A[2]; - WarpFragmentB warp_frag_B[2]; - typename Dequantizer::FragmentScale warp_frag_scales; - typename Dequantizer::FragmentZero warp_frag_zeros; + copy_scales_and_advance(iterator_scale, stage, gemm_k_iterations); - Operator warp_mma; + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); - this->warp_tile_iterator_A_.load(warp_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_frag_B[0]); + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } - warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros); + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for + // some kernels so that all accumulator elements outside the GEMM footprint + // are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + /// Iterator to write threadblock-scoped tile of A operand to shared + /// memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared + /// memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + + // Wait until we have at least one committed global fetch stage. + // (#uncommitted = Base::kStages - 1 - #committed) + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + typename Dequantizer::FragmentScale warp_frag_scales; + typename Dequantizer::FragmentZero warp_frag_zeros; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros); + + // if((threadIdx.x||threadIdx.y||threadIdx.z)==0){ + // uint32_t* frag_b_reg_ptr = + // reinterpret_cast(&warp_frag_B[0]); printf("#### + // warp_frag_b_load [0] bid:%d-%d-%d," + // " frag_b_reg_ptr:%x-%x-%x-%x-%x-%x-%x-%x \n", + // blockIdx.x,blockIdx.y,blockIdx.z, + // frag_b_reg_ptr[0], + // frag_b_reg_ptr[1], + // frag_b_reg_ptr[2], + // frag_b_reg_ptr[3], + // frag_b_reg_ptr[4], + // frag_b_reg_ptr[5], + // frag_b_reg_ptr[6], + // frag_b_reg_ptr[7] + // ); + // } + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + warp_dequantizer_.add_pointer_offset(Shape::kN); + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + int const warp_tileB_k_compute_offset = + warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = + warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + if (warp_tileB_k_compute_offset == + Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load( + warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + typename TransformBAfterLDS::result_type converted_frag_B = + lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize( + converted_frag_B, warp_frag_scales, warp_frag_zeros); + + using FragmentOperandB = + cutlass::Array; + constexpr cutlass::FloatRoundStyle RoundStyle = + cutlass::FloatRoundStyle::round_to_nearest; + constexpr int ConversionVectorWidth = + TransformBAfterLDS::result_type::kElements; + static_assert(ConversionVectorWidth == FragmentOperandB::kElements); + + using Converter = cutlass::NumericArrayConverter; + + FragmentOperandB converted_frag_B_operand = + Converter::convert(converted_frag_B); // if((threadIdx.x||threadIdx.y||threadIdx.z)==0){ - // uint32_t* frag_b_reg_ptr = reinterpret_cast(&warp_frag_B[0]); - // printf("#### warp_frag_b_load [0] bid:%d-%d-%d," - // " frag_b_reg_ptr:%x-%x-%x-%x-%x-%x-%x-%x \n", + // uint32_t* frag_b_reg_ptr = + // reinterpret_cast(&warp_frag_B[(warp_tileB_k_load_offset) + // % 2]); uint32_t* converted_frag_B_reg_ptr = + // reinterpret_cast(&converted_frag_B); printf("#### + // after lds_converter bid:%d-%d-%d" + // " frag_b_reg_ptr[%d]:%x-%x-%x-%x-%x-%x-%x-%x" + // " converted_frag_b_reg_ptr:%x-%x-%x-%x-%x-%x-%x-%x \n", // blockIdx.x,blockIdx.y,blockIdx.z, + // ((warp_tileB_k_load_offset) % 2), // frag_b_reg_ptr[0], // frag_b_reg_ptr[1], // frag_b_reg_ptr[2], @@ -559,195 +684,124 @@ public: // frag_b_reg_ptr[4], // frag_b_reg_ptr[5], // frag_b_reg_ptr[6], - // frag_b_reg_ptr[7] - // ); + // frag_b_reg_ptr[7], + // converted_frag_B_reg_ptr[0], + // converted_frag_B_reg_ptr[1], + // converted_frag_B_reg_ptr[2], + // converted_frag_B_reg_ptr[3], + // converted_frag_B_reg_ptr[4], + // converted_frag_B_reg_ptr[5], + // converted_frag_B_reg_ptr[6], + // converted_frag_B_reg_ptr[7] + // ); // } - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - warp_dequantizer_.add_pointer_offset(Shape::kN); + run_warp_mma(warp_mma, + accum, + warp_frag_A[warp_mma_k % 2], + converted_frag_B_operand, + accum, + warp_tileB_k_compute_offset); - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - iterator_scale.clear_mask(gemm_k_iterations == 0); + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; - int smem_write_stage_idx = Base::kStages - 1; - int smem_read_stage_idx = 0; + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; - // - // Mainloop - // + copy_tiles_and_advance(iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > (-Base::kStages + 1);) - { - // - // Loop over GEMM K dimension - // - - // Computes a warp-level GEMM on data held in shared memory - // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) - { - - // Load warp-level tiles from shared memory, wrapping to k offset if - // this is the last group as the case may be. - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_A_; - - int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; - int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; - if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) - { - this->warp_tile_iterator_B_.set_kgroup_index( - (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); - ++this->warp_tile_iterator_B_; - } - - typename TransformBAfterLDS::result_type converted_frag_B - = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); - warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zeros); - - using FragmentOperandB = cutlass::Array; - constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; - constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements; - static_assert(ConversionVectorWidth == FragmentOperandB::kElements); - - using Converter - = cutlass::NumericArrayConverter; - - FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B); - - - // if((threadIdx.x||threadIdx.y||threadIdx.z)==0){ - // uint32_t* frag_b_reg_ptr = reinterpret_cast(&warp_frag_B[(warp_tileB_k_load_offset) % 2]); - // uint32_t* converted_frag_B_reg_ptr = reinterpret_cast(&converted_frag_B); - // printf("#### after lds_converter bid:%d-%d-%d" - // " frag_b_reg_ptr[%d]:%x-%x-%x-%x-%x-%x-%x-%x" - // " converted_frag_b_reg_ptr:%x-%x-%x-%x-%x-%x-%x-%x \n", - // blockIdx.x,blockIdx.y,blockIdx.z, - // ((warp_tileB_k_load_offset) % 2), - // frag_b_reg_ptr[0], - // frag_b_reg_ptr[1], - // frag_b_reg_ptr[2], - // frag_b_reg_ptr[3], - // frag_b_reg_ptr[4], - // frag_b_reg_ptr[5], - // frag_b_reg_ptr[6], - // frag_b_reg_ptr[7], - // converted_frag_B_reg_ptr[0], - // converted_frag_B_reg_ptr[1], - // converted_frag_B_reg_ptr[2], - // converted_frag_B_reg_ptr[3], - // converted_frag_B_reg_ptr[4], - // converted_frag_B_reg_ptr[5], - // converted_frag_B_reg_ptr[6], - // converted_frag_B_reg_ptr[7] - // ); - // } - - run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum, - warp_tileB_k_compute_offset); - - // Issue global->shared copies for the this stage - if (warp_mma_k < Base::kWarpGemmIterations - 1) - { - int group_start_iteration_A, group_start_iteration_B; - - group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; - group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); - - // This is the first group of a given stage, so we issue the loads for the B scales immediately. - if (group_start_iteration_B == 0) - { - copy_scales_and_advance(iterator_scale); - } - } - - if (warp_mma_k + 2 == Base::kWarpGemmIterations) - { - int group_start_iteration_A, group_start_iteration_B; - group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; - group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); - - // Inserts a memory fence between stages of cp.async instructions. - cutlass::arch::cp_async_fence(); - - // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - - // #committed) - arch::cp_async_wait(); - __syncthreads(); - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Add negative offsets to return iterators to the 'start' of the - // circular buffer in shared memory - if (smem_write_stage_idx == (Base::kStages - 1)) - { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0}); - smem_write_stage_idx = 0; - } - else - { - ++smem_write_stage_idx; - } - - if (smem_read_stage_idx == (Base::kStages - 1)) - { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); - warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN); - smem_read_stage_idx = 0; - } - else - { - ++smem_read_stage_idx; - } - - --gemm_k_iterations; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - iterator_scale.clear_mask(gemm_k_iterations == 0); - } - } - - // Load the scale needed for the next tile iteration. - warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros); - // Update internal pointer to set of scales in shared memory. - warp_dequantizer_.add_pointer_offset(Shape::kN); + // This is the first group of a given stage, so we issue the loads for + // the B scales immediately. + if (group_start_iteration_B == 0) { + copy_scales_and_advance(iterator_scale); + } } - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) - { - // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop - cutlass::arch::cp_async_fence(); - cutlass::arch::cp_async_wait<0>(); - __syncthreads(); + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Wait until we have at least one committed global fetch stage. + // (#uncommitted = Base::kStages - 1 - #committed) + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterationsForB, + 0}); + warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); } + } + + // Load the scale needed for the next tile iteration. + warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros); + // Update internal pointer to set of scales in shared memory. + warp_dequantizer_.add_pointer_offset(Shape::kN); } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM + // mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace threadblock -} // namespace gemm -} // namespace cutlass +} // namespace threadblock +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h index 83efdc5cb0..6371da633c 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file @@ -48,12 +49,9 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ +namespace cutlass { +namespace gemm { +namespace threadblock { ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -98,550 +96,605 @@ template < WeightOnlyQuantOp QuantOp_, /// Use zfill or predicate for out-of-bound cp.async SharedMemoryClearOption SharedMemoryClear> -class DqMmaMultistage> - : public DqMmaBase -{ -public: - ///< Base class - using Base = DqMmaBase; - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - ///< Iterates over tiles of A operand in global memory - using IteratorA = IteratorA_; - ///< Iterates over tiles of B operand in global memory - using IteratorB = IteratorB_; - ///< Data type of accumulator matrix - using ElementC = ElementC_; - ///< Layout of accumulator matrix - using LayoutC = LayoutC_; - ///< Policy describing tuning details - using Policy = Policy_; +class DqMmaMultistage> + : public DqMmaBase { + public: + ///< Base class + using Base = DqMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; - using IteratorScale = IteratorScale_; - using ElementScale = typename IteratorScale::Element; - using LayoutScale = typename IteratorScale::Layout; + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - using SmemIteratorScale = SmemIteratorScale_; + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - using TransformBAfterLDS = TransformBAfterLDS_; + using TransformBAfterLDS = TransformBAfterLDS_; - static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; - // - // Dependent types - // + // + // Dependent types + // - /// Fragment of operand Scale loaded from global memory; - using FragmentScale = typename IteratorScale::Fragment; + /// Fragment of operand Scale loaded from global memory; + using FragmentScale = typename IteratorScale::Fragment; - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; - /// Warp-level Mma - using Operator = typename Policy::Operator; + /// Warp-level Mma + using Operator = typename Policy::Operator; - /// Minimum architecture is Sm80 to support cp.async - using ArchTag = arch::Sm80; + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; - using Dequantizer = warp::MmaTensorOpDequantizer; + using Dequantizer = warp::MmaTensorOpDequantizer; - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; - /// Internal structure exposed for introspection. - struct Detail - { + /// Internal structure exposed for introspection. + struct Detail { + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); - static_assert(Base::kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; - /// Number of cp.async instructions to load one stage of operand A - static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; - /// Number of cp.async instructions to load one stage of operand B - static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount; + /// Number of stages + static int const kStages = Stages; - /// Number of stages - static int const kStages = Stages; + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; - /// Number of cp.async instructions to load on group of operand A - static int const kAccessesPerGroupA - = (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + }; - /// Number of cp.async instructions to load on group of operand B - static int const kAccessesPerGroupB - = (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - }; + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; -private: - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - Dequantizer warp_dequantizer_; + using ElementA = typename IteratorA::Element; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; - using ElementA = typename IteratorA::Element; - using ElementB = typename IteratorB::Element; - using LayoutDetailsForB = kernel::LayoutDetailsB; + static constexpr bool RequiresTileInterleave = + layout::IsColumnMajorTileInterleave< + typename LayoutDetailsForB::Layout>::value; + static_assert(!RequiresTileInterleave || + (RequiresTileInterleave && + (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); - static constexpr bool RequiresTileInterleave - = layout::IsColumnMajorTileInterleave::value; - static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), - "Layout K must match threadblockK"); + private: + // + // Data members + // -private: - // - // Data members - // + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; + /// Iterator to write threadblock-scoped tile of scale operand to shared + /// memory + SmemIteratorScale smem_iterator_scale_; - /// Iterator to write threadblock-scoped tile of scale operand to shared memory - SmemIteratorScale smem_iterator_scale_; + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + ///< Group size for quantization. Not used by this main loop since it + ///< assumes per-column + int const group_size, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_dequantizer_( + {shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / + Base::WarpCount::kM, + lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_iterator_scale_(LayoutScale(Shape::kN), + shared_storage.operand_scale.data(), + {1, Shape::kN}, + thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension -public: - /// Construct from tensor references - CUTLASS_DEVICE - DqMmaMultistage( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage& shared_storage, - ///< Group size for quantization. Not used by this main loop since it assumes per-column - int const group_size, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx) - : Base(shared_storage, thread_idx, warp_idx, lane_idx) - , warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, - (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx) - , smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx) - , smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) - , smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx) - { - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA& iterator_A, + IteratorB& iterator_B, + int group_start_A = 0, + int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } } - CUTLASS_DEVICE - void copy_tiles_and_advance( - IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0) - { - iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); - this->smem_iterator_A_.set_iteration_index(group_start_A); + iterator_B.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; - // Async Copy for operand A CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) - { - if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) - { - typename IteratorA::AccessType* dst_ptr - = reinterpret_cast(this->smem_iterator_A_.get()); + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); - int const kSrcBytes = sizeof_bits::value - * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) - { - auto gmem_ptr = iterator_A.get(); + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) - { - cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_A.valid()); - } - else - { - cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_A.valid()); - } + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over scale operand in global memory + IteratorScale iterator_scale, + ///< initial value of accumulator + FragmentC const& src_accum) { + // + // Prologue + // - ++iterator_A; - } + TransformBAfterLDS lds_converter; - ++this->smem_iterator_A_; - } + // NOTE - switch to ldg.sts + // Issue this first, so cp.async.commit_group will commit this load as well. + // Note: we do not commit here and this load will commit in the same group + // as + // the first load of A. + FragmentScale tb_frag_scales; + tb_frag_scales.clear(); + iterator_scale.load(tb_frag_scales); + this->smem_iterator_scale_.store(tb_frag_scales); + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; } - iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); - this->smem_iterator_B_.set_iteration_index(group_start_B); + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); - // Async Copy for operand B CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) - { - if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) - { - typename IteratorB::AccessType* dst_ptr - = reinterpret_cast(this->smem_iterator_B_.get()); + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; - int const kSrcBytes = sizeof_bits::value - * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) - { - auto gmem_ptr = iterator_B.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) - { - cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - else - { - cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - - ++iterator_B; - } - ++this->smem_iterator_B_; - } + ++iterator_B; } + + ++this->smem_iterator_B_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); } - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()( - ///< problem size of GEMM - int gemm_k_iterations, - ///< destination accumulator tile - FragmentC& accum, - ///< iterator over A operand in global memory - IteratorA iterator_A, - ///< iterator over B operand in global memory - IteratorB iterator_B, - ///< iterator over scale operand in global memory - IteratorScale iterator_scale, - ///< initial value of accumulator - FragmentC const& src_accum) - { + // Perform accumulation in the 'd' output operand + accum = src_accum; - // - // Prologue - // + // + // Clear the remaining tiles of SMEM. This is a functional requirement for + // some kernels so that all accumulator elements outside the GEMM footprint + // are zero. + // - TransformBAfterLDS lds_converter; + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + /// Iterator to write threadblock-scoped tile of A operand to shared + /// memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); - // NOTE - switch to ldg.sts - // Issue this first, so cp.async.commit_group will commit this load as well. - // Note: we do not commit here and this load will commit in the same group as - // the first load of A. - FragmentScale tb_frag_scales; - tb_frag_scales.clear(); - iterator_scale.load(tb_frag_scales); - this->smem_iterator_scale_.store(tb_frag_scales); + typename IteratorA::AccessType zero_A; + zero_A.clear(); - // Issue several complete stages - CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) - { + last_smem_iterator_A.set_iteration_index(0); - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_A.get()); - iterator_A.set_iteration_index(0); - this->smem_iterator_A_.set_iteration_index(0); + *dst_ptr = zero_A; - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) - { - typename IteratorA::AccessType* dst_ptr - = reinterpret_cast(this->smem_iterator_A_.get()); + ++last_smem_iterator_A; + } - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) - { - int const kSrcBytes = sizeof_bits::value - * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + /// Iterator to write threadblock-scoped tile of B operand to shared + /// memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; - int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_A.get(), iterator_A.valid()); + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_B.get()); - ++iterator_A; - } + *dst_ptr = zero_B; - ++this->smem_iterator_A_; - } + ++last_smem_iterator_B; + } + } - iterator_B.set_iteration_index(0); - this->smem_iterator_B_.set_iteration_index(0); + // Wait until we have at least one committed global fetch stage. + // (#uncommitted = Base::kStages - 1 - #committed) + cutlass::arch::cp_async_wait(); + __syncthreads(); - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) - { - typename IteratorB::AccessType* dst_ptr - = reinterpret_cast(this->smem_iterator_B_.get()); + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + typename Dequantizer::FragmentScale warp_frag_scales; - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) - { - int const kSrcBytes = sizeof_bits::value - * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + Operator warp_mma; - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_B.get(), iterator_B.valid()); + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); - ++iterator_B; - } + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + warp_dequantizer_.load(warp_frag_scales); - ++this->smem_iterator_B_; - } + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; - // Defines the boundary of a stage of cp.async. - cutlass::arch::cp_async_fence(); - } + // + // Mainloop + // - // Perform accumulation in the 'd' output operand - accum = src_accum; + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // - // - // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels - // so that all accumulator elements outside the GEMM footprint are zero. - // - - if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) - { - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); - - typename IteratorA::AccessType zero_A; - zero_A.clear(); - - last_smem_iterator_A.set_iteration_index(0); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) - { - - typename IteratorA::AccessType* dst_ptr - = reinterpret_cast(last_smem_iterator_A.get()); - - *dst_ptr = zero_A; - - ++last_smem_iterator_A; - } - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); - typename IteratorB::AccessType zero_B; - - zero_B.clear(); - last_smem_iterator_B.set_iteration_index(0); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) - { - - typename IteratorB::AccessType* dst_ptr - = reinterpret_cast(last_smem_iterator_B.get()); - - *dst_ptr = zero_B; - - ++last_smem_iterator_B; - } - } - - // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) - cutlass::arch::cp_async_wait(); - __syncthreads(); - - // Pair of fragments used to overlap shared memory loads and math - // instructions - WarpFragmentA warp_frag_A[2]; - WarpFragmentB warp_frag_B[2]; - typename Dequantizer::FragmentScale warp_frag_scales; - - Operator warp_mma; - - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - - this->warp_tile_iterator_A_.load(warp_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_frag_B[0]); - warp_dequantizer_.load(warp_frag_scales); + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - - int smem_write_stage_idx = Base::kStages - 1; - int smem_read_stage_idx = 0; - - // - // Mainloop - // - - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > (-Base::kStages + 1);) - { - // - // Loop over GEMM K dimension - // - - // Computes a warp-level GEMM on data held in shared memory - // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) - { - - // Load warp-level tiles from shared memory, wrapping to k offset if - // this is the last group as the case may be. - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_A_; - - int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; - int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; - if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) - { - this->warp_tile_iterator_B_.set_kgroup_index( - (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); - ++this->warp_tile_iterator_B_; - } - - typename TransformBAfterLDS::result_type converted_frag_B - = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); - warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); - - using FragmentOperandB = cutlass::Array; - constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; - constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements; - static_assert(ConversionVectorWidth == FragmentOperandB::kElements); - - using Converter - = cutlass::NumericArrayConverter; - - FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B); - run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum, - warp_tileB_k_compute_offset); - - // Issue global->shared copies for the this stage - if (warp_mma_k < Base::kWarpGemmIterations - 1) - { - int group_start_iteration_A, group_start_iteration_B; - - group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; - group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); - } - - if (warp_mma_k + 2 == Base::kWarpGemmIterations) - { - int group_start_iteration_A, group_start_iteration_B; - group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; - group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); - - // Inserts a memory fence between stages of cp.async instructions. - cutlass::arch::cp_async_fence(); - - // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - - // #committed) - arch::cp_async_wait(); - __syncthreads(); - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Add negative offsets to return iterators to the 'start' of the - // circular buffer in shared memory - if (smem_write_stage_idx == (Base::kStages - 1)) - { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - smem_write_stage_idx = 0; - } - else - { - ++smem_write_stage_idx; - } - - if (smem_read_stage_idx == (Base::kStages - 1)) - { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); - smem_read_stage_idx = 0; - } - else - { - ++smem_read_stage_idx; - } - - --gemm_k_iterations; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - } - } + int const warp_tileB_k_compute_offset = + warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = + warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + if (warp_tileB_k_compute_offset == + Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load( + warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; } - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) - { - // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop - cutlass::arch::cp_async_fence(); - cutlass::arch::cp_async_wait<0>(); - __syncthreads(); + typename TransformBAfterLDS::result_type converted_frag_B = + lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); + + using FragmentOperandB = + cutlass::Array; + constexpr cutlass::FloatRoundStyle RoundStyle = + cutlass::FloatRoundStyle::round_to_nearest; + constexpr int ConversionVectorWidth = + TransformBAfterLDS::result_type::kElements; + static_assert(ConversionVectorWidth == FragmentOperandB::kElements); + + using Converter = cutlass::NumericArrayConverter; + + FragmentOperandB converted_frag_B_operand = + Converter::convert(converted_frag_B); + run_warp_mma(warp_mma, + accum, + warp_frag_A[warp_mma_k % 2], + converted_frag_B_operand, + accum, + warp_tileB_k_compute_offset); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Wait until we have at least one committed global fetch stage. + // (#uncommitted = Base::kStages - 1 - #committed) + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterationsForB, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + } + } } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM + // mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace threadblock -} // namespace gemm -} // namespace cutlass +} // namespace threadblock +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h index bd3e38971b..dd7e8ae4b7 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file @@ -53,27 +54,27 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ +namespace cutlass { +namespace gemm { +namespace threadblock { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. template < /// Size of the Gemm problem - concept: gemm::GemmShape<> typename Shape_, /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) typename IteratorA_, /// Iterates over tiles of A operand in shared memory /// (concept: WriteableTileIterator | RandomAccessTileIterator) typename SmemIteratorA_, /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) typename IteratorB_, /// Iterates over tiles of B operand in shared memory /// (concept: WriteableTileIterator | RandomAccessTileIterator) @@ -98,9 +99,9 @@ template < typename Enable = void> class DqMmaPipelined; -} // namespace threadblock -} // namespace gemm -} // namespace cutlass +} // namespace threadblock +} // namespace gemm +} // namespace cutlass #include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h" #include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h" diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h index 50bdd0d85b..01fe06329b 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file @@ -53,27 +54,27 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ +namespace cutlass { +namespace gemm { +namespace threadblock { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. template < /// Size of the Gemm problem - concept: gemm::GemmShape<> typename Shape_, /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) typename IteratorA_, /// Iterates over tiles of A operand in shared memory /// (concept: WriteableTileIterator | RandomAccessTileIterator) typename SmemIteratorA_, /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) typename IteratorB_, /// Iterates over tiles of B operand in shared memory /// (concept: WriteableTileIterator | RandomAccessTileIterator) @@ -94,393 +95,442 @@ template < typename TransformBAfterLDS_, /// The quantization operator being used WeightOnlyQuantOp QuantOp_> -class DqMmaPipelined> - : public DqMmaBase -{ -public: - ///< Base class - using Base = DqMmaBase; +class DqMmaPipelined> + : public DqMmaBase { + public: + ///< Base class + using Base = DqMmaBase; - using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory - using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory - using ElementC = ElementC_; ///< Data type of accumulator matrix - using LayoutC = LayoutC_; ///< Layout of accumulator matrix - using Policy = Policy_; ///< Policy describing tuning details + using Shape = + Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = + IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = + IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details - using IteratorScale = IteratorScale_; - using ElementScale = typename IteratorScale::Element; - using LayoutScale = typename IteratorScale::Layout; + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - using SmemIteratorScale = SmemIteratorScale_; + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; - using TransformBAfterLDG = TransformBAfterLDG_; - using TransformBAfterLDS = TransformBAfterLDS_; + using TransformBAfterLDG = TransformBAfterLDG_; + using TransformBAfterLDS = TransformBAfterLDS_; - static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; - // - // Dependent types - // + // + // Dependent types + // - /// Fragment of operand A loaded from global memory - using FragmentA = typename IteratorA::Fragment; + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; - /// Fragment of operand B loaded from global memory - using FragmentB = typename IteratorB::Fragment; + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; - /// Fragment of operand Scale loaded from global memory; - using FragmentScale = typename IteratorScale::Fragment; + /// Fragment of operand Scale loaded from global memory; + using FragmentScale = typename IteratorScale::Fragment; - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; - /// Warp-level Mma - using Operator = typename Policy::Operator; + /// Warp-level Mma + using Operator = typename Policy::Operator; - /// Obtain the arch tag from the warp-level operator - using ArchTag = typename Policy::Operator::ArchTag; + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; - using Dequantizer = warp::MmaTensorOpDequantizer; + using Dequantizer = + warp::MmaTensorOpDequantizer; - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; - // staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline) - static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2"); + // staticaly assert kStages for DqMmaPipelined is two (Double-buffered + // pipeline) + static_assert((Base::kStages == 2), + "DqMmaPipelined requires kStages set to value 2"); - static_assert(Base::SharedStorage::ShapeScale::kRow == Base::kStages, ""); - static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, ""); + static_assert(Base::SharedStorage::ShapeScale::kRow == Base::kStages, ""); + static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, ""); -private: - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - Dequantizer warp_dequantizer_; + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; - using WarpFragmentScale = typename Dequantizer::FragmentScale; - using WarpFragmentZero = typename Dequantizer::FragmentZero; + using WarpFragmentScale = typename Dequantizer::FragmentScale; + using WarpFragmentZero = typename Dequantizer::FragmentZero; - using ElementA = typename IteratorA::Element; - using ElementB = typename IteratorB::Element; - using LayoutDetailsForB = kernel::LayoutDetailsB; + using ElementA = typename IteratorA::Element; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; - static constexpr bool RequiresTileInterleave - = layout::IsColumnMajorTileInterleave::value; - static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), - "Layout K must match threadblockK"); + static constexpr bool RequiresTileInterleave = + layout::IsColumnMajorTileInterleave< + typename LayoutDetailsForB::Layout>::value; + static_assert(!RequiresTileInterleave || + (RequiresTileInterleave && + (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); -protected: - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; + protected: + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; - /// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory - SmemIteratorScale smem_iterator_scale_; + /// Iterator to write threadblock-scoped tile of scale and zero operand to + /// shared memory + SmemIteratorScale smem_iterator_scale_; -public: - /// Construct from tensor references - CUTLASS_DEVICE - DqMmaPipelined(typename Base::SharedStorage& - shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM - int const group_size, ///< The group size for quantization - int thread_idx, ///< ID within the threadblock - int warp_idx, ///< ID of warp - int lane_idx ///< ID of each thread within a warp - ) - : Base(shared_storage, thread_idx, warp_idx, lane_idx) - , warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, - {shared_storage.operand_zero.data(), LayoutScale(Shape::kN)}, - (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx) - , smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx) - , smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) - , smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), - shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size) - { + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaPipelined(typename Base::SharedStorage& + shared_storage, ///< Shared storage needed for internal + ///< use by threadblock-scoped GEMM + int const group_size, ///< The group size for quantization + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_dequantizer_( + {shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + {shared_storage.operand_zero.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / + Base::WarpCount::kM, + lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_iterator_scale_(LayoutScale(Shape::kN), + shared_storage.operand_scale.data(), + shared_storage.operand_zero.data(), + {Base::kStages, Shape::kN}, + thread_idx, + group_size) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + CUTLASS_DEVICE + void copy_scales_and_advance(IteratorScale& iterator_scale) { + using TransformScale = + NumericArrayConverter; + + FragmentScale tb_frag_scales; + FragmentScale tb_frag_zeros; + tb_frag_scales.clear(); + tb_frag_zeros.clear(); + + TransformScale transformScale; + + using FragmentElement = typename FragmentScale::Element; + + auto gmem_scale_ptr = iterator_scale.get_scale(); + auto gmem_zero_ptr = iterator_scale.get_zero(); + + arch::global_load( + tb_frag_scales, gmem_scale_ptr, iterator_scale.valid()); + + if (gmem_zero_ptr != nullptr) { + arch::global_load( + tb_frag_zeros, gmem_zero_ptr, iterator_scale.valid()); } - CUTLASS_DEVICE - void copy_scales_and_advance(IteratorScale& iterator_scale) - { - using TransformScale = NumericArrayConverter; + typename TransformScale::result_type tb_frag_scales_fp16 = + transformScale(tb_frag_scales); + typename TransformScale::result_type tb_frag_zeros_fp16; + if (gmem_zero_ptr != nullptr) + tb_frag_zeros_fp16 = transformScale(tb_frag_zeros); - FragmentScale tb_frag_scales; - FragmentScale tb_frag_zeros; - tb_frag_scales.clear(); - tb_frag_zeros.clear(); + auto frag_scale_ptr_fp16 = + reinterpret_cast( + &tb_frag_scales_fp16); + auto frag_zero_ptr_fp16 = + reinterpret_cast( + &tb_frag_zeros_fp16); + auto smem_scale_ptr = this->smem_iterator_scale_.get_scale(); + auto smem_zero_ptr = this->smem_iterator_scale_.get_zero(); - TransformScale transformScale; + if (iterator_scale.valid()) { + auto smem_offset = cast_smem_ptr_to_uint(smem_scale_ptr); + arch::shared_store(smem_offset, + frag_scale_ptr_fp16); - using FragmentElement = typename FragmentScale::Element; - - auto gmem_scale_ptr = iterator_scale.get_scale(); - auto gmem_zero_ptr = iterator_scale.get_zero(); - - arch::global_load(tb_frag_scales, gmem_scale_ptr, iterator_scale.valid()); - - if (gmem_zero_ptr != nullptr) - { - arch::global_load( - tb_frag_zeros, gmem_zero_ptr, iterator_scale.valid()); - } - - typename TransformScale::result_type tb_frag_scales_fp16 = transformScale(tb_frag_scales); - typename TransformScale::result_type tb_frag_zeros_fp16; - if (gmem_zero_ptr != nullptr) - tb_frag_zeros_fp16 = transformScale(tb_frag_zeros); - - auto frag_scale_ptr_fp16 = reinterpret_cast(&tb_frag_scales_fp16); - auto frag_zero_ptr_fp16 = reinterpret_cast(&tb_frag_zeros_fp16); - auto smem_scale_ptr = this->smem_iterator_scale_.get_scale(); - auto smem_zero_ptr = this->smem_iterator_scale_.get_zero(); - - if (iterator_scale.valid()) - { - auto smem_offset = cast_smem_ptr_to_uint(smem_scale_ptr); - arch::shared_store(smem_offset, frag_scale_ptr_fp16); - - if (gmem_zero_ptr != nullptr) - { - smem_offset = cast_smem_ptr_to_uint(smem_zero_ptr); - arch::shared_store(smem_offset, frag_zero_ptr_fp16); - } - } - - if (iterator_scale.group_size_ == 64) - { - iterator_scale.add_tile_offset({1, 0}); - } - else if (iterator_scale.group_size_ == 128) - { - if constexpr (Shape::kK == 128) - { - iterator_scale.add_tile_offset({1, 0}); - } - else if constexpr (Shape::kK == 64) - { - if (iterator_scale.row_groupsize64_ & 0x1) - { - iterator_scale.add_tile_offset({1, 0}); - } - } - else - { - static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128"); - } - } - - iterator_scale.row_groupsize64_++; - - this->smem_iterator_scale_.add_tile_offset({1, 0}); + if (gmem_zero_ptr != nullptr) { + smem_offset = cast_smem_ptr_to_uint(smem_zero_ptr); + arch::shared_store(smem_offset, + frag_zero_ptr_fp16); + } } - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop - FragmentC& accum, ///< destination accumulator tile - IteratorA iterator_A, ///< iterator over A operand in global memory - IteratorB iterator_B, ///< iterator over B operand in global memory - IteratorScale iterator_scale, ///< iterator over scale operand in global memory - FragmentC const& src_accum) - { ///< source accumulator tile + if (iterator_scale.group_size_ == 64) { + iterator_scale.add_tile_offset({1, 0}); + } else if (iterator_scale.group_size_ == 128) { + if constexpr (Shape::kK == 128) { + iterator_scale.add_tile_offset({1, 0}); + } else if constexpr (Shape::kK == 64) { + if (iterator_scale.row_groupsize64_ & 0x1) { + iterator_scale.add_tile_offset({1, 0}); + } + } else { + static_assert(Shape::kK == 0, + "Unsupported k tile shape, can only be 64 or 128"); + } + } - // - // Prologue - // - TransformBAfterLDG ldg_converter; - TransformBAfterLDS lds_converter; + iterator_scale.row_groupsize64_++; - using TransformA - = NumericArrayConverter; + this->smem_iterator_scale_.add_tile_offset({1, 0}); + } - // These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want - // to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS. - TransformA transformA; + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + IteratorScale + iterator_scale, ///< iterator over scale operand in global memory + FragmentC const& src_accum) { ///< source accumulator tile - // Perform accumulation in the 'd' output operand - accum = src_accum; + // + // Prologue + // + TransformBAfterLDG ldg_converter; + TransformBAfterLDS lds_converter; - FragmentA tb_frag_A; - FragmentB tb_frag_B; + using TransformA = NumericArrayConverter; - tb_frag_A.clear(); - tb_frag_B.clear(); + // These transforms are mainly to handle when we have bfloat activations and + // weights in GMEM and want to issue HMMA on architectures older than + // Ampere. We will convert to FP16 before STS. + TransformA transformA; - // The last kblock is loaded in the prolog - iterator_A.load(tb_frag_A); - iterator_B.load(tb_frag_B); + // Perform accumulation in the 'd' output operand + accum = src_accum; - ++iterator_A; - ++iterator_B; + FragmentA tb_frag_A; + FragmentB tb_frag_B; - this->smem_iterator_A_.store(transformA(tb_frag_A)); - this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + tb_frag_A.clear(); + tb_frag_B.clear(); - ++this->smem_iterator_A_; - ++this->smem_iterator_B_; + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); - copy_scales_and_advance(iterator_scale); + ++iterator_A; + ++iterator_B; - __syncthreads(); + this->smem_iterator_A_.store(transformA(tb_frag_A)); + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); - // Pair of fragments used to overlap shared memory loads and math instructions - WarpFragmentA warp_frag_A[2]; - WarpFragmentB warp_frag_B[2]; - WarpFragmentScale warp_frag_scales; - WarpFragmentZero warp_frag_zero; + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); + copy_scales_and_advance(iterator_scale); - this->warp_tile_iterator_A_.load(warp_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_frag_B[0]); + __syncthreads(); - warp_dequantizer_.load(warp_frag_scales, warp_frag_zero); + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + WarpFragmentScale warp_frag_scales; + WarpFragmentZero warp_frag_zero; + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + warp_dequantizer_.load(warp_frag_scales, warp_frag_zero); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + warp_dequantizer_.add_pointer_offset(Shape::kN); + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations <= 1); + iterator_B.clear_mask(gemm_k_iterations <= 1); + iterator_scale.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* + // issuing shared memory loads (which have the tighest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + // Write fragments to shared memory + this->smem_iterator_A_.store(transformA(tb_frag_A)); + + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterationsForB, + 0}); + warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - warp_dequantizer_.add_pointer_offset(Shape::kN); - Operator warp_mma; - - int smem_write_stage_idx = 1; - - // Avoid reading out of bounds - iterator_A.clear_mask(gemm_k_iterations <= 1); - iterator_B.clear_mask(gemm_k_iterations <= 1); - iterator_scale.clear_mask(gemm_k_iterations <= 1); - - // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing - // shared memory loads (which have the tighest latency requirement). - - // - // Mainloop - // - - // Note: The main loop does not support Base::kWarpGemmIterations == 2. - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > 0; --gemm_k_iterations) - { - // - // Loop over GEMM K dimension - // - - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) - { - - // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group - // as the case may be. - - if (warp_mma_k == Base::kWarpGemmIterations - 1) - { - - // Write fragments to shared memory - this->smem_iterator_A_.store(transformA(tb_frag_A)); - - this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); - - __syncthreads(); - - ++this->smem_iterator_A_; - ++this->smem_iterator_B_; - - // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory - if (smem_write_stage_idx == 1) - { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0}); - } - else - { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); - warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN); - } - - smem_write_stage_idx ^= 1; - } - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_A_; - - int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; - int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; - // We are just about to finish computing on a fragment of B, so initiate the load for the next fragment. - if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) - { - this->warp_tile_iterator_B_.set_kgroup_index( - (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); - ++this->warp_tile_iterator_B_; - } - - if (warp_mma_k == 0) - { - - iterator_A.load(tb_frag_A); - iterator_B.load(tb_frag_B); - - ++iterator_A; - ++iterator_B; - - copy_scales_and_advance(iterator_scale); - - // Avoid reading out of bounds if this was the last loop iteration - iterator_A.clear_mask(gemm_k_iterations <= 2); - iterator_B.clear_mask(gemm_k_iterations <= 2); - iterator_scale.clear_mask(gemm_k_iterations <= 2); - } - - typename TransformBAfterLDS::result_type converted_frag_B - = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); - warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zero); - run_warp_mma( - warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset); - } - - // Load the scales needed for the next tile iteration - warp_dequantizer_.load(warp_frag_scales, warp_frag_zero); - // Update internal pointer to the set of scales in shared memory - warp_dequantizer_.add_pointer_offset(Shape::kN); + int const warp_tileB_k_compute_offset = + warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = + warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + // We are just about to finish computing on a fragment of B, so initiate + // the load for the next fragment. + if (warp_tileB_k_compute_offset == + Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load( + warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; } + + if (warp_mma_k == 0) { + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + copy_scales_and_advance(iterator_scale); + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations <= 2); + iterator_B.clear_mask(gemm_k_iterations <= 2); + iterator_scale.clear_mask(gemm_k_iterations <= 2); + } + + typename TransformBAfterLDS::result_type converted_frag_B = + lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize( + converted_frag_B, warp_frag_scales, warp_frag_zero); + run_warp_mma(warp_mma, + accum, + warp_frag_A[warp_mma_k % 2], + converted_frag_B, + accum, + warp_tileB_k_compute_offset); + } + + // Load the scales needed for the next tile iteration + warp_dequantizer_.load(warp_frag_scales, warp_frag_zero); + // Update internal pointer to the set of scales in shared memory + warp_dequantizer_.add_pointer_offset(Shape::kN); } + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace threadblock -} // namespace gemm -} // namespace cutlass +} // namespace threadblock +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h index 316ea9f80a..e6f512edfe 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file @@ -53,27 +54,27 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ +namespace cutlass { +namespace gemm { +namespace threadblock { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. template < /// Size of the Gemm problem - concept: gemm::GemmShape<> typename Shape_, /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) typename IteratorA_, /// Iterates over tiles of A operand in shared memory /// (concept: WriteableTileIterator | RandomAccessTileIterator) typename SmemIteratorA_, /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) typename IteratorB_, /// Iterates over tiles of B operand in shared memory /// (concept: WriteableTileIterator | RandomAccessTileIterator) @@ -94,306 +95,361 @@ template < typename TransformBAfterLDS_, /// The quantization operator being used WeightOnlyQuantOp QuantOp_> -class DqMmaPipelined> - : public DqMmaBase -{ -public: - ///< Base class - using Base = DqMmaBase; +class DqMmaPipelined> + : public DqMmaBase { + public: + ///< Base class + using Base = DqMmaBase; - using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory - using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory - using ElementC = ElementC_; ///< Data type of accumulator matrix - using LayoutC = LayoutC_; ///< Layout of accumulator matrix - using Policy = Policy_; ///< Policy describing tuning details + using Shape = + Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = + IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = + IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details - using IteratorScale = IteratorScale_; - using ElementScale = typename IteratorScale::Element; - using LayoutScale = typename IteratorScale::Layout; + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - using SmemIteratorScale = SmemIteratorScale_; + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; - using TransformBAfterLDG = TransformBAfterLDG_; - using TransformBAfterLDS = TransformBAfterLDS_; + using TransformBAfterLDG = TransformBAfterLDG_; + using TransformBAfterLDS = TransformBAfterLDS_; - static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of operand Scale loaded from global memory; + using FragmentScale = typename IteratorScale::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + using Dequantizer = warp::MmaTensorOpDequantizer< + Operator, + typename Base::WarpGemm, + Operand::kB, + typename SmemIteratorScale::Fragment::Element, + LayoutScale, + 32, + QuantOp>; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for DqMmaPipelined is two (Double-buffered + // pipeline) + static_assert((Base::kStages == 2), + "DqMmaPipelined requires kStages set to value 2"); + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementA = typename IteratorA::Element; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave = + layout::IsColumnMajorTileInterleave< + typename LayoutDetailsForB::Layout>::value; + static_assert(!RequiresTileInterleave || + (RequiresTileInterleave && + (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + + protected: + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale operand to shared + /// memory + SmemIteratorScale smem_iterator_scale_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaPipelined( + typename Base::SharedStorage& + shared_storage, ///< Shared storage needed for internal use by + ///< threadblock-scoped GEMM + int const + group_size, ///< Will not be used, just to adapt to finegrained + ///< modifications and make the compilation successful. + ///< Because DqMmaPipelined is only enabled for sm<80, so + ///< even if this argument is not added, it does not + ///< affect compilation for sm>=80. + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_dequantizer_( + {shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / + Base::WarpCount::kM, + lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_iterator_scale_(LayoutScale(Shape::kN), + shared_storage.operand_scale.data(), + {1, Shape::kN}, + thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + IteratorScale + iterator_scale, ///< iterator over scale operand in global memory + FragmentC const& src_accum) { ///< source accumulator tile // - // Dependent types + // Prologue + // + TransformBAfterLDG ldg_converter; + TransformBAfterLDS lds_converter; + + using TransformA = NumericArrayConverter; + + using TransformScale = + NumericArrayConverter; + + // These transforms are mainly to handle when we have bfloat activations and + // weights in GMEM and want to issue HMMA on architectures older than + // Ampere. We will convert to FP16 before STS. + TransformA transformA; + TransformScale transformScale; + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A; + FragmentB tb_frag_B; + FragmentScale tb_frag_scales; + + using WarpFragmentScale = typename Dequantizer::FragmentScale; + WarpFragmentScale warp_frag_scales; + + tb_frag_A.clear(); + tb_frag_B.clear(); + tb_frag_scales.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + iterator_scale.load(tb_frag_scales); + + ++iterator_A; + ++iterator_B; + + this->smem_iterator_A_.store(transformA(tb_frag_A)); + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + this->smem_iterator_scale_.store(transformScale(tb_frag_scales)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + __syncthreads(); + + warp_dequantizer_.load(warp_frag_scales); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations <= 1); + iterator_B.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* + // issuing shared memory loads (which have the tighest latency requirement). + + // + // Mainloop // - /// Fragment of operand A loaded from global memory - using FragmentA = typename IteratorA::Fragment; + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // - /// Fragment of operand B loaded from global memory - using FragmentB = typename IteratorB::Fragment; + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. - /// Fragment of operand Scale loaded from global memory; - using FragmentScale = typename IteratorScale::Fragment; + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + // Write fragments to shared memory + this->smem_iterator_A_.store(transformA(tb_frag_A)); - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); - /// Warp-level Mma - using Operator = typename Policy::Operator; + __syncthreads(); - /// Obtain the arch tag from the warp-level operator - using ArchTag = typename Policy::Operator::ArchTag; + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; - using Dequantizer = warp::MmaTensorOpDequantizer; + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterationsForB, + 0}); + } - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - // staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline) - static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2"); - -private: - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - Dequantizer warp_dequantizer_; - - using ElementA = typename IteratorA::Element; - using ElementB = typename IteratorB::Element; - using LayoutDetailsForB = kernel::LayoutDetailsB; - - static constexpr bool RequiresTileInterleave - = layout::IsColumnMajorTileInterleave::value; - static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), - "Layout K must match threadblockK"); - -protected: - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - - /// Iterator to write threadblock-scoped tile of scale operand to shared memory - SmemIteratorScale smem_iterator_scale_; - -public: - /// Construct from tensor references - CUTLASS_DEVICE - DqMmaPipelined(typename Base::SharedStorage& - shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM - int const group_size, ///< Will not be used, just to adapt to finegrained modifications and make the compilation - ///< successful. Because DqMmaPipelined is only enabled for sm<80, so even if this - ///< argument is not added, it does not affect compilation for sm>=80. - int thread_idx, ///< ID within the threadblock - int warp_idx, ///< ID of warp - int lane_idx ///< ID of each thread within a warp - ) - : Base(shared_storage, thread_idx, warp_idx, lane_idx) - , warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, - (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx) - , smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx) - , smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) - , smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx) - { - - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); - } - - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop - FragmentC& accum, ///< destination accumulator tile - IteratorA iterator_A, ///< iterator over A operand in global memory - IteratorB iterator_B, ///< iterator over B operand in global memory - IteratorScale iterator_scale, ///< iterator over scale operand in global memory - FragmentC const& src_accum) - { ///< source accumulator tile - - // - // Prologue - // - TransformBAfterLDG ldg_converter; - TransformBAfterLDS lds_converter; - - using TransformA - = NumericArrayConverter; - - using TransformScale = NumericArrayConverter; - - // These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want - // to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS. - TransformA transformA; - TransformScale transformScale; - - // Perform accumulation in the 'd' output operand - accum = src_accum; - - FragmentA tb_frag_A; - FragmentB tb_frag_B; - FragmentScale tb_frag_scales; - - using WarpFragmentScale = typename Dequantizer::FragmentScale; - WarpFragmentScale warp_frag_scales; - - tb_frag_A.clear(); - tb_frag_B.clear(); - tb_frag_scales.clear(); - - // The last kblock is loaded in the prolog - iterator_A.load(tb_frag_A); - iterator_B.load(tb_frag_B); - iterator_scale.load(tb_frag_scales); - - ++iterator_A; - ++iterator_B; - - this->smem_iterator_A_.store(transformA(tb_frag_A)); - this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); - this->smem_iterator_scale_.store(transformScale(tb_frag_scales)); - - ++this->smem_iterator_A_; - ++this->smem_iterator_B_; - - __syncthreads(); - - warp_dequantizer_.load(warp_frag_scales); - - // Pair of fragments used to overlap shared memory loads and math instructions - WarpFragmentA warp_frag_A[2]; - WarpFragmentB warp_frag_B[2]; - - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - - this->warp_tile_iterator_A_.load(warp_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_frag_B[0]); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - Operator warp_mma; - - int smem_write_stage_idx = 1; - - // Avoid reading out of bounds - iterator_A.clear_mask(gemm_k_iterations <= 1); - iterator_B.clear_mask(gemm_k_iterations <= 1); - - // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing - // shared memory loads (which have the tighest latency requirement). - - // - // Mainloop - // - - // Note: The main loop does not support Base::kWarpGemmIterations == 2. - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > 0; --gemm_k_iterations) - { - // - // Loop over GEMM K dimension - // - - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) - { - - // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group - // as the case may be. - - if (warp_mma_k == Base::kWarpGemmIterations - 1) - { - - // Write fragments to shared memory - this->smem_iterator_A_.store(transformA(tb_frag_A)); - - this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); - - __syncthreads(); - - ++this->smem_iterator_A_; - ++this->smem_iterator_B_; - - // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory - if (smem_write_stage_idx == 1) - { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - } - else - { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); - } - - smem_write_stage_idx ^= 1; - } - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_A_; - - int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; - int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; - // We are just about to finish computing on a fragment of B, so initiate the load for the next fragment. - if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) - { - this->warp_tile_iterator_B_.set_kgroup_index( - (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); - ++this->warp_tile_iterator_B_; - } - - if (warp_mma_k == 0) - { - - iterator_A.load(tb_frag_A); - iterator_B.load(tb_frag_B); - - ++iterator_A; - ++iterator_B; - - // Avoid reading out of bounds if this was the last loop iteration - iterator_A.clear_mask(gemm_k_iterations <= 2); - iterator_B.clear_mask(gemm_k_iterations <= 2); - } - - typename TransformBAfterLDS::result_type converted_frag_B - = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); - warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); - run_warp_mma( - warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset); - } + smem_write_stage_idx ^= 1; } + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + int const warp_tileB_k_compute_offset = + warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = + warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + // We are just about to finish computing on a fragment of B, so initiate + // the load for the next fragment. + if (warp_tileB_k_compute_offset == + Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load( + warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + if (warp_mma_k == 0) { + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations <= 2); + iterator_B.clear_mask(gemm_k_iterations <= 2); + } + + typename TransformBAfterLDS::result_type converted_frag_B = + lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); + run_warp_mma(warp_mma, + accum, + warp_frag_A[warp_mma_k % 2], + converted_frag_B, + accum, + warp_tileB_k_compute_offset); + } } + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace threadblock -} // namespace gemm -} // namespace cutlass +} // namespace threadblock +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h index bc91e724bb..9b118add9a 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h @@ -66,7 +66,7 @@ template < /// Size of extra quantized params typename QuantParamsShape> class Wint2xMmaBase { -public: + public: ///< Size of the Gemm problem - concept: gemm::GemmShape<> using Shape = Shape_; @@ -85,9 +85,9 @@ public: using WarpGemm = typename Policy::Operator::Shape; /// Shape describing the number of warps filling the CTA - using WarpCount = - GemmShape; + using WarpCount = GemmShape; /// Number of warp-level GEMM operations static int const kWarpGemmIterations = @@ -95,7 +95,8 @@ public: /// Number of warp-level GEMM operations per load for B static constexpr int kWarpGemmIterationsPerLoadForB = - Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK; + Operator::IteratorB::InstructionShape::kRow / + Operator::InstructionShape::kK; static_assert(!(kWarpGemmIterations % kWarpGemmIterationsPerLoadForB), ""); static constexpr int kWarpLoadIterationsForB = @@ -125,7 +126,7 @@ public: /// Shared storage object needed by threadblock-scoped GEMM class SharedStorage { - public: + public: // // Type definitions // @@ -142,7 +143,7 @@ public: /// Shape of all quant params in shared memory using QuantParamsShapeB = QuantParamsShape; - public: + public: // // Data members // @@ -156,7 +157,7 @@ public: /// Buffer for extra quant params of B operand AlignedBuffer operand_quant_params_B; - public: + public: // // Methods // @@ -186,7 +187,7 @@ public: } }; -protected: + protected: // // Data members // @@ -197,7 +198,7 @@ protected: /// Iterator to load a warp-scoped tile of B operand from shared memory typename Operator::IteratorB warp_tile_iterator_B_; -public: + public: /// Construct from tensor references CUTLASS_DEVICE Wint2xMmaBase( @@ -215,8 +216,8 @@ public: ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace threadblock -} // namespace gemm -} // namespace cutlass +} // namespace threadblock +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h index dd26cf68ea..245a89c152 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file @@ -91,11 +92,17 @@ template < typename QuantParamsAccessor_, /// Use zfill or predicate for out-of-bound cp.async SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone> -class Wint2xMmaMultistage : - public Wint2xMmaBase { -public: +class Wint2xMmaMultistage + : public Wint2xMmaBase { + public: ///< Base class - using Base = Wint2xMmaBase; + using Base = Wint2xMmaBase; ///< Size of the Gemm problem - concept: gemm::GemmShape<> using Shape = Shape_; ///< Iterates over tiles of A operand in global memory @@ -133,17 +140,19 @@ public: /// Minimum architecture is Sm80 to support cp.async using ArchTag = arch::Sm80; - //using LayoutScale = typename QuantParamsAccessor::IteratorSuperScale::Layout; + // using LayoutScale = typename + // QuantParamsAccessor::IteratorSuperScale::Layout; using LayoutScale = layout::RowMajor; using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; - using WarpDequantizer = - warp::MmaTensorOpWin2xDequantizer; - static_assert(sizeof(WarpDequantizer) > 0, "WarpDequantizer template instantiation failed"); + using WarpDequantizer = warp::MmaTensorOpWin2xDequantizer< + Operator, + typename Base::WarpGemm, + Operand::kB, + typename WarpTransformedFragmentB::Element, + LayoutScale, + QuantParamsAccessor::kGroupSize>; + static_assert(sizeof(WarpDequantizer) > 0, + "WarpDequantizer template instantiation failed"); /// Complex transform on A operand static ComplexTransform const kTransformA = Operator::kTransformA; @@ -153,7 +162,6 @@ public: /// Internal structure exposed for introspection. struct Detail { - /// Number of cp.async instructions to load one stage of operand A static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; @@ -167,24 +175,25 @@ public: /// Number of cp.async instructions to load on group of operand A static int const kAccessesPerGroupA = - (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; /// Number of cp.async instructions to load on group of operand B static int const kAccessesPerGroupB = - (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; - // Optional staged-accumulation (e.g., tf32x3 kernels) for improved numerical - // accuracy, where each mainloop iteration first accumulates into a temporary - // set of freshly-cleared accumulators, which are subsequently added to the - // final accumulator set. - static bool const kStagedAccumulation = arch::detail::UseStagedAccumulation::value; + // Optional staged-accumulation (e.g., tf32x3 kernels) for improved + // numerical accuracy, where each mainloop iteration first accumulates into + // a temporary set of freshly-cleared accumulators, which are subsequently + // added to the final accumulator set. + static bool const kStagedAccumulation = + arch::detail::UseStagedAccumulation::value; }; private: - // Structure encapsulating pipeline state live from one iteration to the next struct PipeState { - using WarpLoadedFragmentA = typename Operator::FragmentA; using WarpLoadedFragmentB = typename Operator::FragmentB; using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; @@ -197,10 +206,12 @@ public: /// Temporary accumulator to facilitate staged-accumulation FragmentC tmp_accum_; - /// Pair of A fragments used to overlap shared memory loads and math instructions + /// Pair of A fragments used to overlap shared memory loads and math + /// instructions WarpTransformedFragmentA warp_frag_A_[2]; - /// Pair of B fragments used to overlap shared memory loads and math instructions + /// Pair of B fragments used to overlap shared memory loads and math + /// instructions WarpLoadedFragmentB warp_loaded_frag_B_; WarpTransformedFragmentB warp_frag_B_[2]; @@ -218,12 +229,14 @@ public: using LayoutDetailsForB = kernel::LayoutDetailsB; static constexpr bool IsTileInterleaveLayout = - layout::IsColumnMajorTileInterleave::value; - static_assert(!IsTileInterleaveLayout || (IsTileInterleaveLayout && (Shape::kK == LayoutDetailsForB::ThreadblockK)), - "Layout K must match threadblockK"); + layout::IsColumnMajorTileInterleave< + typename LayoutDetailsForB::Layout>::value; + static_assert(!IsTileInterleaveLayout || + (IsTileInterleaveLayout && + (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); private: - // // Data members // @@ -249,8 +262,7 @@ public: /// Shared memory read stage index int smem_read_stage_idx_; -public: - + public: /// Construct from tensor references CUTLASS_DEVICE Wint2xMmaMultistage( @@ -261,19 +273,24 @@ public: ///< ID of warp int warp_idx, ///< ID of each thread within a warp - int lane_idx - ) : Base(shared_storage, thread_idx, warp_idx, lane_idx), - smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), - quant_params_accessor_B_(shared_storage.operand_quant_params_B.data(), thread_idx, warp_idx, lane_idx), - warp_dequantizer_(quant_params_accessor_B_.super_scale_ref(), - quant_params_accessor_B_.local_scale_ref(), - quant_params_accessor_B_.code_scale_ref(), - quant_params_accessor_B_.code_zp_ref(), - (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx), - smem_write_stage_idx_(0), - smem_read_stage_idx_(0) - { + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + quant_params_accessor_B_(shared_storage.operand_quant_params_B.data(), + thread_idx, + warp_idx, + lane_idx), + warp_dequantizer_( + quant_params_accessor_B_.super_scale_ref(), + quant_params_accessor_B_.local_scale_ref(), + quant_params_accessor_B_.code_scale_ref(), + quant_params_accessor_B_.code_zp_ref(), + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / + Base::WarpCount::kM, + lane_idx), + smem_write_stage_idx_(0), + smem_read_stage_idx_(0) { // Compute warp location within threadblock tile by mapping the warp_id to // three coordinates: // _m: the warp's position within the threadblock along the M dimension @@ -295,22 +312,26 @@ public: /// Advance shared memory read-iterators to the next stage CUTLASS_DEVICE - void advance_smem_read_stage() - { + void advance_smem_read_stage() { ++smem_read_stage_idx_; if (smem_read_stage_idx_ == Base::kStages) { // Wrap back around to the 'start' of the circular buffer in shared memory - this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpLoadIterationsForB, 0}); + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpLoadIterationsForB, + 0}); smem_read_stage_idx_ = 0; } } - /// Advance global memory read-iterators and shared memory write-iterators to the stage + /// Advance global memory read-iterators and shared memory write-iterators to + /// the stage CUTLASS_DEVICE - void advance_smem_write_stage(IteratorA &iterator_A, IteratorB &iterator_B) - { + void advance_smem_write_stage(IteratorA &iterator_A, IteratorB &iterator_B) { // Advance global iterators iterator_A.add_tile_offset({0, 1}); iterator_B.add_tile_offset({1, 0}); @@ -395,7 +416,9 @@ public: CUTLASS_PRAGMA_UNROLL for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { auto gmem_ptr = iterator_B.get(); - bool is_valid = (threadIdx.x < IteratorB::ThreadMap::kThreads) ? iterator_B.valid() : false; + bool is_valid = (threadIdx.x < IteratorB::ThreadMap::kThreads) + ? iterator_B.valid() + : false; if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { cutlass::arch::cp_async_zfill( @@ -429,10 +452,9 @@ public: for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { auto gmem_ptr = iterator_A.get(); - int const kSrcBytes = - sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / - IteratorA::kAccessesPerVector / 8; + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; cutlass::arch::cp_async_zfill( dst_ptr + v, iterator_A.get(), iterator_A.valid()); @@ -464,10 +486,9 @@ public: for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { auto gmem_ptr = iterator_B.get(); - int const kSrcBytes = - sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; cutlass::arch::cp_async_zfill( dst_ptr + v, iterator_B.get(), iterator_B.valid()); @@ -480,18 +501,22 @@ public: } /// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching - /// the global fragments needed by the first kStages-1 threadblock mainloop iterations + /// the global fragments needed by the first kStages-1 threadblock mainloop + /// iterations CUTLASS_DEVICE - void prologue( - IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory - IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory - QuantArguments &mma_quant_args, ///< iterators for extra quant params for B - int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining + void prologue(IteratorA &iterator_A, ///< [in|out] iterator over A operand in + ///< global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in + ///< global memory + QuantArguments & + mma_quant_args, ///< iterators for extra quant params for B + int &gemm_k_iterations) ///< [in|out] number of threadblock + ///< mainloop iterations remaining { // Issue several complete stages CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { - + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { // Disable global fetching if done with global fetch iterations iterator_A.clear_mask(gemm_k_iterations == 0); iterator_B.clear_mask(gemm_k_iterations == 0); @@ -502,11 +527,14 @@ public: // Async copy zipped B to shared memory. copy_tiles_and_advance_per_stage_B(iterator_B); - // Async copy other quantized params to shared memory, local_scale, code_scale, code_zp, super_scale. + // Async copy other quantized params to shared memory, local_scale, + // code_scale, code_zp, super_scale. if (stage == 0) { - quant_params_accessor_B_.copy_tiles_and_advance_per_stage(mma_quant_args, stage); + quant_params_accessor_B_.copy_tiles_and_advance_per_stage( + mma_quant_args, stage); } else { - quant_params_accessor_B_.copy_tiles_and_advance_per_stage(mma_quant_args, stage); + quant_params_accessor_B_.copy_tiles_and_advance_per_stage( + mma_quant_args, stage); } // Move to the next write stage @@ -517,11 +545,12 @@ public: cutlass::arch::cp_async_fence(); } - // Optionally clear the remaining stages of SMEM. This is a functional requirement for - // some kernels so that all accumulator elements outside the GEMM footprint are zero. + // Optionally clear the remaining stages of SMEM. This is a functional + // requirement for some kernels so that all accumulator elements outside the + // GEMM footprint are zero. if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { - - /// Iterator to write threadblock-scoped tile of A operand to shared memory + /// Iterator to write threadblock-scoped tile of A operand to shared + /// memory SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); typename IteratorA::AccessType zero_A; @@ -531,7 +560,6 @@ public: // Async Copy for operand A CUTLASS_PRAGMA_UNROLL for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { - typename IteratorA::AccessType *dst_ptr = reinterpret_cast( last_smem_iterator_A.get()); @@ -545,7 +573,8 @@ public: return; } - /// Iterator to write threadblock-scoped tile of B operand to shared memory + /// Iterator to write threadblock-scoped tile of B operand to shared + /// memory SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); typename IteratorB::AccessType zero_B; @@ -555,7 +584,6 @@ public: // Async Copy for operand B CUTLASS_PRAGMA_UNROLL for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - typename IteratorB::AccessType *dst_ptr = reinterpret_cast( last_smem_iterator_B.get()); @@ -569,9 +597,9 @@ public: /// Wait until we have at least one completed global fetch stage CUTLASS_DEVICE - void gmem_wait() - { - // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) + void gmem_wait() { + // Wait until we have at least one committed global fetch stage. + // (#uncommitted = Base::kStages - 1 - #committed) cutlass::arch::cp_async_wait(); __syncthreads(); } @@ -579,25 +607,31 @@ public: /// Perform a threadblock mainloop iteration of matrix multiply-accumulate CUTLASS_DEVICE void mac_loop_iter( - PipeState &pipe_state, ///< [in|out] loop-carried pipeline state - FragmentC &accum, ///< [in|out] destination accumulator tile - IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory - IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory - QuantArguments &mma_quant_args, ///< iterators for extra quant params for B - int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop iterations remaining - int stage) - { + PipeState &pipe_state, ///< [in|out] loop-carried pipeline state + FragmentC &accum, ///< [in|out] destination accumulator tile + IteratorA + &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB + &iterator_B, ///< [in|out] iterator over B operand in global memory + QuantArguments + &mma_quant_args, ///< iterators for extra quant params for B + int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop + ///< iterations remaining + int stage) { const int mma_stage = stage - Base::kStages + 1; // Unroll the warp-level MMA tiles of a threadblock's mainloop iteration CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { - - int warp_k_compute_offset_B = warp_mma_k % Base::kWarpGemmIterationsPerLoadForB; + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + int warp_k_compute_offset_B = + warp_mma_k % Base::kWarpGemmIterationsPerLoadForB; if (warp_k_compute_offset_B == Base::kWarpGemmIterationsPerLoadForB - 1) { // Load the next warp-tile's B fragment from shared memory - this->warp_tile_iterator_B_.set_kgroup_index(((warp_mma_k + 1) % Base::kWarpGemmIterations) / Base::kWarpLoadIterationsForB); + this->warp_tile_iterator_B_.set_kgroup_index( + ((warp_mma_k + 1) % Base::kWarpGemmIterations) / + Base::kWarpLoadIterationsForB); this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_); ++this->warp_tile_iterator_B_; } @@ -608,28 +642,31 @@ public: } // Load the next warp-tile's A fragment from shared memory - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(pipe_state.warp_frag_A_[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load( + pipe_state.warp_frag_A_[(warp_mma_k + 1) % 2]); ++this->warp_tile_iterator_A_; // dequantizes next warp-tile - warp_dequantizer_.dequantize(pipe_state.warp_frag_local_scale_, - pipe_state.warp_frag_code_scale_, - pipe_state.warp_frag_code_zp_, - pipe_state.warp_frag_super_scale_, - pipe_state.warp_loaded_frag_B_, - pipe_state.warp_frag_B_[(warp_mma_k + 1) % 2], - ((warp_mma_k == Base::kWarpGemmIterations - 1) ? (mma_stage + 1) : mma_stage) * Shape::kK, - (warp_mma_k + 1) % Base::kWarpGemmIterationsPerLoadForB); + warp_dequantizer_.dequantize( + pipe_state.warp_frag_local_scale_, + pipe_state.warp_frag_code_scale_, + pipe_state.warp_frag_code_zp_, + pipe_state.warp_frag_super_scale_, + pipe_state.warp_loaded_frag_B_, + pipe_state.warp_frag_B_[(warp_mma_k + 1) % 2], + ((warp_mma_k == Base::kWarpGemmIterations - 1) ? (mma_stage + 1) + : mma_stage) * + Shape::kK, + (warp_mma_k + 1) % Base::kWarpGemmIterationsPerLoadForB); // Execute the current warp-tile of MMA operations if constexpr (Detail::kStagedAccumulation) { - warp_mma_( - pipe_state.tmp_accum_, - pipe_state.warp_frag_A_[warp_mma_k % 2], - pipe_state.warp_frag_B_[warp_mma_k % 2], - pipe_state.tmp_accum_ - ); + warp_mma_(pipe_state.tmp_accum_, + pipe_state.warp_frag_A_[warp_mma_k % 2], + pipe_state.warp_frag_B_[warp_mma_k % 2], + pipe_state.tmp_accum_); if (warp_mma_k == 0) { plus plus_accum; @@ -637,11 +674,10 @@ public: pipe_state.tmp_accum_.clear(); } } else { - warp_mma_( - accum, - pipe_state.warp_frag_A_[warp_mma_k % 2], - pipe_state.warp_frag_B_[warp_mma_k % 2], - accum); + warp_mma_(accum, + pipe_state.warp_frag_A_[warp_mma_k % 2], + pipe_state.warp_frag_B_[warp_mma_k % 2], + accum); } // Except for the last warp-tile, all warp-tiles issue their share of @@ -654,22 +690,28 @@ public: copy_tiles_and_advance_B(iterator_B, group_start_iteration_B); if (warp_mma_k == 0) { - quant_params_accessor_B_.copy_tiles_and_advance_per_stage(mma_quant_args, stage); + quant_params_accessor_B_.copy_tiles_and_advance_per_stage( + mma_quant_args, stage); } } // The second-to-last warp-tile also: - // - performs the last warp-tile's share of global->shared fragment copies + // - performs the last warp-tile's share of global->shared fragment + // copies // - moves to the next global fetch stage if (warp_mma_k + 2 == Base::kWarpGemmIterations) { // Performs the last warp-tile's share of global->shared fragment copies - if constexpr (Detail::AsyncCopyIterationsPerStageA >= Base::kWarpGemmIterations) { - int group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + if constexpr (Detail::AsyncCopyIterationsPerStageA >= + Base::kWarpGemmIterations) { + int group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; copy_tiles_and_advance_A(iterator_A, group_start_iteration_A); } - if constexpr (Detail::AsyncCopyIterationsPerStageB >= Base::kWarpGemmIterations) { - int group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + if constexpr (Detail::AsyncCopyIterationsPerStageB >= + Base::kWarpGemmIterations) { + int group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; copy_tiles_and_advance_B(iterator_B, group_start_iteration_B); } @@ -691,7 +733,8 @@ public: --gemm_k_iterations; iterator_A.clear_mask(gemm_k_iterations == 0); iterator_B.clear_mask(gemm_k_iterations == 0); - quant_params_accessor_B_.clear_mask(mma_quant_args, gemm_k_iterations == 0); + quant_params_accessor_B_.clear_mask(mma_quant_args, + gemm_k_iterations == 0); } } } @@ -700,12 +743,13 @@ public: /// multiply-accumulate. Assumes prologue has been initiated. CUTLASS_DEVICE void gemm_iters( - int gemm_k_iterations, ///< number of threadblock mainloop iterations - FragmentC &accum, ///< [in|out] accumulator tile - IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory - IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory - QuantArguments &mma_quant_args) - { + int gemm_k_iterations, ///< number of threadblock mainloop iterations + FragmentC &accum, ///< [in|out] accumulator tile + IteratorA + &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB + &iterator_B, ///< [in|out] iterator over B operand in global memory + QuantArguments &mma_quant_args) { PipeState pipe_state; // Disable global fetching if done with global fetch iterations @@ -748,14 +792,13 @@ public: // Mainloop CUTLASS_GEMM_LOOP for (; gemm_k_iterations > (-Base::kStages + 1);) { - mac_loop_iter( - pipe_state, - accum, - iterator_A, - iterator_B, - mma_quant_args, - gemm_k_iterations, - stage); + mac_loop_iter(pipe_state, + accum, + iterator_A, + iterator_B, + mma_quant_args, + gemm_k_iterations, + stage); stage += 1; } @@ -764,7 +807,8 @@ public: accum = plus_accum(accum, pipe_state.tmp_accum_); } - // Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop + // Commit and drain all pending and predicated cp.async pnz from the GEMM + // mainloop cutlass::arch::cp_async_fence(); cutlass::arch::cp_async_wait<0>(); __syncthreads(); @@ -772,15 +816,16 @@ public: /// Prepares the class for another prologue. CUTLASS_DEVICE - void wind_down() - { - // Catch-up the smem-read iterator to the smem-write iterator (so this class can be reused for another tile's prologue) + void wind_down() { +// Catch-up the smem-read iterator to the smem-write iterator (so this class can +// be reused for another tile's prologue) - // First, increment remaining warp tiles to get to the next full stage. (Ideally we would - // just decrement one tile, but not all iterators implement --() decrement.) - #pragma unroll - for (int warp_mma_k = 1; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) - { +// First, increment remaining warp tiles to get to the next full stage. (Ideally +// we would just decrement one tile, but not all iterators implement --() +// decrement.) +#pragma unroll + for (int warp_mma_k = 1; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { this->warp_tile_iterator_A_.set_kgroup_index(warp_mma_k); this->warp_tile_iterator_B_.set_kgroup_index(warp_mma_k); @@ -789,22 +834,24 @@ public: } smem_read_stage_idx_++; - // Then wrap back two full stages (one for the tile advancing we just did, and one to catch the write iterators) - static const int kStageIters = Policy::kPartitionsK * Base::kWarpGemmIterations; - if (smem_read_stage_idx_ > 1) - { + // Then wrap back two full stages (one for the tile advancing we just did, + // and one to catch the write iterators) + static const int kStageIters = + Policy::kPartitionsK * Base::kWarpGemmIterations; + if (smem_read_stage_idx_ > 1) { this->warp_tile_iterator_A_.add_tile_offset({0, (-2 * kStageIters)}); this->warp_tile_iterator_B_.add_tile_offset({(-2 * kStageIters), 0}); - } - else - { - this->warp_tile_iterator_A_.add_tile_offset({0, ((Base::kStages - 2) * kStageIters)}); - this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0}); + } else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, ((Base::kStages - 2) * kStageIters)}); + this->warp_tile_iterator_B_.add_tile_offset( + {((Base::kStages - 2) * kStageIters), 0}); } smem_read_stage_idx_ = smem_write_stage_idx_; } - /// Perform a threadblock-scoped matrix multiply-accumulate, pre-load B to shared memory. + /// Perform a threadblock-scoped matrix multiply-accumulate, pre-load B to + /// shared memory. CUTLASS_DEVICE void operator()( ///< problem size of GEMM @@ -819,8 +866,8 @@ public: QuantArguments mma_quant_args, ///< initial value of accumulator FragmentC const &src_accum) { - - // Prologue (start fetching iterations of global fragments into shared memory) + // Prologue (start fetching iterations of global fragments into shared + // memory) prologue(iterator_A, iterator_B, mma_quant_args, gemm_k_iterations); // Wait until we have at least one completed global fetch stage @@ -830,7 +877,8 @@ public: accum = src_accum; // Perform the MAC-iterations - gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, mma_quant_args); + gemm_iters( + gemm_k_iterations, accum, iterator_A, iterator_B, mma_quant_args); } }; diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h index c6eb2750c8..2409087cbc 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h @@ -46,9 +46,10 @@ template < /// Group size for quantization int GroupSize_> class Wint2ParamsAccessor { -public: - static_assert(platform::is_same::value || platform::is_same::value, - "T must be fp16 or bf16"); + public: + static_assert(platform::is_same::value || + platform::is_same::value, + "T must be fp16 or bf16"); using ElementType = T; using Shape = Shape_; @@ -72,7 +73,7 @@ public: using ElementLocalScale = typename IteratorLocalScale::Element; using LayoutLocalScale = typename IteratorLocalScale::Layout; static_assert(platform::is_same::value, - "local_scale's type must be uint4b_t."); + "local_scale's type must be uint4b_t."); using ElementCodeScaleZp = typename IteratorCodeScaleZp::Element; using LayoutCodeScaleZp = typename IteratorCodeScaleZp::Layout; @@ -80,24 +81,32 @@ public: /// 2 uint4b_t values are stored in a single uint8_t constexpr static int kStagesPerLocalScaleLoad = 2 * kGroupSize / Shape::kK; constexpr static int kLocalScaleRows = - IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn * sizeof_bits::value / 8 / Shape::kN; + IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn * + sizeof_bits::value / 8 / Shape::kN; using SmemElement = uint8_t; - constexpr static int kSmemRows = - kLocalScaleRows * kStages + sizeof(ElementSuperScale) + sizeof(ElementCodeScaleZp) * 2; + constexpr static int kSmemRows = kLocalScaleRows * kStages + + sizeof(ElementSuperScale) + + sizeof(ElementCodeScaleZp) * 2; constexpr static int kSmemColumns = Shape::kN; using QuantParamsShape = MatrixShape; constexpr static int kSuperScaleSmemOffset = 0; - constexpr static int kCodeScaleSmemOffset = kSmemColumns * sizeof(ElementSuperScale); - constexpr static int kCodeZpSmemOffset = kCodeScaleSmemOffset + kSmemColumns * sizeof(ElementCodeScaleZp); - constexpr static int kLocalScaleSmemOffset = kCodeZpSmemOffset + kSmemColumns * sizeof(ElementCodeScaleZp); + constexpr static int kCodeScaleSmemOffset = + kSmemColumns * sizeof(ElementSuperScale); + constexpr static int kCodeZpSmemOffset = + kCodeScaleSmemOffset + kSmemColumns * sizeof(ElementCodeScaleZp); + constexpr static int kLocalScaleSmemOffset = + kCodeZpSmemOffset + kSmemColumns * sizeof(ElementCodeScaleZp); /// TensorRef type for loading element from a tensor - using SuperTensorRef = cutlass::TensorRef; - using LocalTensorRef = cutlass::TensorRef; - using CodeTensorRef = cutlass::TensorRef; + using SuperTensorRef = + cutlass::TensorRef; + using LocalTensorRef = + cutlass::TensorRef; + using CodeTensorRef = + cutlass::TensorRef; struct Arguments { IteratorSuperScale iterator_super_scale; @@ -113,14 +122,14 @@ public: IteratorCodeScaleZp iterator_code_scale, IteratorCodeScaleZp iterator_code_zp, int local_scale_pointer_offset) - : iterator_super_scale(iterator_super_scale), - iterator_local_scale(iterator_local_scale), - iterator_code_scale(iterator_code_scale), - iterator_code_zp(iterator_code_zp), - local_scale_pointer_offset(local_scale_pointer_offset) {} + : iterator_super_scale(iterator_super_scale), + iterator_local_scale(iterator_local_scale), + iterator_code_scale(iterator_code_scale), + iterator_code_zp(iterator_code_zp), + local_scale_pointer_offset(local_scale_pointer_offset) {} }; -private: + private: // // Data members // @@ -128,13 +137,17 @@ private: /// Begin address of shared memory uint8_t* smem_pointer_; - /// Iterator to write threadblock-scoped tile of super scale operand to shared memory + /// Iterator to write threadblock-scoped tile of super scale operand to shared + /// memory SmemIteratorSuperScale smem_iterator_super_scale_; - /// Iterator to write threadblock-scoped tile of local scale operand to shared memory + /// Iterator to write threadblock-scoped tile of local scale operand to shared + /// memory SmemIteratorLocalScale smem_iterator_local_scale_; - /// Iterator to write threadblock-scoped tile of code scale operand to shared memory + /// Iterator to write threadblock-scoped tile of code scale operand to shared + /// memory SmemIteratorCodeScaleZp smem_iterator_code_scale_; - /// Iterator to write threadblock-scoped tile of code zp operand to shared memory + /// Iterator to write threadblock-scoped tile of code zp operand to shared + /// memory SmemIteratorCodeScaleZp smem_iterator_code_zp_; /// Shared memory write stage index @@ -145,25 +158,29 @@ private: CUTLASS_DEVICE ElementSuperScale* get_super_scale_smem_ptr() { - return reinterpret_cast(smem_pointer_ + kSuperScaleSmemOffset); + return reinterpret_cast(smem_pointer_ + + kSuperScaleSmemOffset); } CUTLASS_DEVICE ElementLocalScale* get_local_scale_smem_ptr() { - return reinterpret_cast(smem_pointer_ + kLocalScaleSmemOffset); + return reinterpret_cast(smem_pointer_ + + kLocalScaleSmemOffset); } CUTLASS_DEVICE ElementCodeScaleZp* get_code_scale_smem_ptr() { - return reinterpret_cast(smem_pointer_ + kCodeScaleSmemOffset); + return reinterpret_cast(smem_pointer_ + + kCodeScaleSmemOffset); } CUTLASS_DEVICE ElementCodeScaleZp* get_code_zp_smem_ptr() { - return reinterpret_cast(smem_pointer_ + kCodeZpSmemOffset); + return reinterpret_cast(smem_pointer_ + + kCodeZpSmemOffset); } -public: + public: /// Construct from tensor references CUTLASS_DEVICE Wint2ParamsAccessor( @@ -175,55 +192,74 @@ public: int warp_idx, ///< ID of each thread within a warp int lane_idx) - : smem_pointer_(smem_pointer), - smem_iterator_super_scale_(LayoutSuperScale(IteratorSuperScale::Shape::kColumn), - get_super_scale_smem_ptr(), {1, IteratorSuperScale::Shape::kColumn}, thread_idx), - smem_iterator_local_scale_(LayoutLocalScale(IteratorLocalScale::Shape::kColumn), - get_local_scale_smem_ptr(), {1, IteratorLocalScale::Shape::kColumn}, thread_idx), - smem_iterator_code_scale_(LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn), - get_code_scale_smem_ptr(), {1, IteratorCodeScaleZp::Shape::kColumn}, thread_idx), - smem_iterator_code_zp_(LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn), - get_code_zp_smem_ptr(), {1, IteratorCodeScaleZp::Shape::kColumn}, thread_idx), - smem_write_stage_idx_(0), - smem_read_stage_idx_(0) {} + : smem_pointer_(smem_pointer), + smem_iterator_super_scale_( + LayoutSuperScale(IteratorSuperScale::Shape::kColumn), + get_super_scale_smem_ptr(), + {1, IteratorSuperScale::Shape::kColumn}, + thread_idx), + smem_iterator_local_scale_( + LayoutLocalScale(IteratorLocalScale::Shape::kColumn), + get_local_scale_smem_ptr(), + {1, IteratorLocalScale::Shape::kColumn}, + thread_idx), + smem_iterator_code_scale_( + LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn), + get_code_scale_smem_ptr(), + {1, IteratorCodeScaleZp::Shape::kColumn}, + thread_idx), + smem_iterator_code_zp_( + LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn), + get_code_zp_smem_ptr(), + {1, IteratorCodeScaleZp::Shape::kColumn}, + thread_idx), + smem_write_stage_idx_(0), + smem_read_stage_idx_(0) {} CUTLASS_DEVICE SuperTensorRef super_scale_ref() { - return {get_super_scale_smem_ptr(), LayoutSuperScale(IteratorSuperScale::Shape::kColumn)}; + return {get_super_scale_smem_ptr(), + LayoutSuperScale(IteratorSuperScale::Shape::kColumn)}; } CUTLASS_DEVICE LocalTensorRef local_scale_ref() { - return {get_local_scale_smem_ptr(), LayoutLocalScale(IteratorLocalScale::Shape::kColumn)}; + return {get_local_scale_smem_ptr(), + LayoutLocalScale(IteratorLocalScale::Shape::kColumn)}; } CUTLASS_DEVICE CodeTensorRef code_scale_ref() { - return {get_code_scale_smem_ptr(), LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn)}; + return {get_code_scale_smem_ptr(), + LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn)}; } CUTLASS_DEVICE CodeTensorRef code_zp_ref() { - return {get_code_zp_smem_ptr(), LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn)}; + return {get_code_zp_smem_ptr(), + LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn)}; } template - CUTLASS_DEVICE - void copy_tiles_and_advance_per_stage(Arguments &quant_args, int stage) { + CUTLASS_DEVICE void copy_tiles_and_advance_per_stage(Arguments& quant_args, + int stage) { if constexpr (IsFirstStage) { - // Load channel-wise super_scale to shared memory, which only needs to be done once. + // Load channel-wise super_scale to shared memory, which only needs to be + // done once. typename IteratorSuperScale::Fragment tb_frag_super_scale; tb_frag_super_scale.clear(); quant_args.iterator_super_scale.load(tb_frag_super_scale); this->smem_iterator_super_scale_.store(tb_frag_super_scale); - // Load channel-wise code_scale to shared memory, which only needs to be done once. + // Load channel-wise code_scale to shared memory, which only needs to be + // done once. typename IteratorCodeScaleZp::Fragment tb_frag_code_scale; tb_frag_code_scale.clear(); quant_args.iterator_code_scale.load(tb_frag_code_scale); this->smem_iterator_code_scale_.store(tb_frag_code_scale); - // Load channel-wise code_zp to shared memory, which only needs to be done once. + // Load channel-wise code_zp to shared memory, which only needs to be done + // once. typename IteratorCodeScaleZp::Fragment tb_frag_code_zp; tb_frag_code_zp.clear(); quant_args.iterator_code_zp.load(tb_frag_code_zp); @@ -231,20 +267,24 @@ public: } if ((stage % kStagesPerLocalScaleLoad) == 0) { - // Load group-wise local_scale to shared memory, which only needs to be done at each stage. - // Since 2 uint4b_t values of local_scale are saved in a single uint8_t, local_scale needs to be loaded once every two stages. + // Load group-wise local_scale to shared memory, which only needs to be + // done at each stage. Since 2 uint4b_t values of local_scale are saved in + // a single uint8_t, local_scale needs to be loaded once every two stages. using AccessType = typename IteratorLocalScale::AccessType; - cutlass::arch::CacheOperation::Kind const kCacheOp = (sizeof_bits::value == 128) - ? cutlass::arch::CacheOperation::Global : cutlass::arch::CacheOperation::Always; + cutlass::arch::CacheOperation::Kind const kCacheOp = + (sizeof_bits::value == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; quant_args.iterator_local_scale.set_iteration_index(0); this->smem_iterator_local_scale_.set_iteration_index(0); // Async Copy for local_scale CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < IteratorLocalScale::ThreadMap::Iterations::kCount; ++j) { - AccessType *dst_ptr = - reinterpret_cast(this->smem_iterator_local_scale_.get()); + for (int j = 0; j < IteratorLocalScale::ThreadMap::Iterations::kCount; + ++j) { + AccessType* dst_ptr = reinterpret_cast( + this->smem_iterator_local_scale_.get()); CUTLASS_PRAGMA_UNROLL for (int v = 0; v < IteratorLocalScale::kAccessesPerVector; ++v) { @@ -255,8 +295,8 @@ public: IteratorLocalScale::ThreadMap::kElementsPerAccess / IteratorLocalScale::kAccessesPerVector / 8; - cutlass::arch::cp_async( - dst_ptr + v, gmem_ptr, quant_args.iterator_local_scale.valid()); + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, quant_args.iterator_local_scale.valid()); } ++quant_args.iterator_local_scale; } @@ -265,13 +305,15 @@ public: } CUTLASS_DEVICE - void advance_smem_write_stage(Arguments &quant_args) { + void advance_smem_write_stage(Arguments& quant_args) { if (smem_write_stage_idx_ % kStagesPerLocalScaleLoad == 0) { // Advance global iterators - quant_args.iterator_local_scale.add_pointer_offset(quant_args.local_scale_pointer_offset); + quant_args.iterator_local_scale.add_pointer_offset( + quant_args.local_scale_pointer_offset); // Advance shared iterators - int smem_pointer_offset = IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn; + int smem_pointer_offset = + IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn; smem_iterator_local_scale_.add_pointer_offset(smem_pointer_offset); } @@ -280,7 +322,8 @@ public: if (smem_write_stage_idx_ == kStagesPerLocalScaleLoad * kStages) { // Wrap back around to the 'start' of the circular buffer in shared memory - int pointer_offset = - kStages * IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn; + int pointer_offset = -kStages * IteratorLocalScale::Shape::kRow * + IteratorLocalScale::Shape::kColumn; smem_iterator_local_scale_.add_pointer_offset(pointer_offset); smem_write_stage_idx_ = 0; } @@ -298,14 +341,14 @@ public: if (smem_read_stage_idx_ == kStagesPerLocalScaleLoad * kStages) { smem_read_stage_idx_ = 0; - byte_offset = - (kStages - 1) * kLocalScaleRows * kSmemColumns; + byte_offset = -(kStages - 1) * kLocalScaleRows * kSmemColumns; } return byte_offset; } CUTLASS_DEVICE - int clear_mask(Arguments &quant_args, bool cond) { + int clear_mask(Arguments& quant_args, bool cond) { quant_args.iterator_local_scale.clear_mask(cond); } }; diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_unzip.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_unzip.h index 9d49d5eb53..a1bc5a0ecf 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_unzip.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_unzip.h @@ -29,18 +29,27 @@ namespace gemm { namespace threadblock { template -using UnzipArray = cutlass::AlignedArray::value / 8)>; +using UnzipArray = + cutlass::AlignedArray::value / 8)>; -template +template struct UnzipAndDequantFunctor { - __device__ void operator()(const T *in_ptr, const T *supper_scale_ptr, - T *out_ptr, const int64_t in_stride) {} + __device__ void operator()(const T *in_ptr, + const T *supper_scale_ptr, + T *out_ptr, + const int64_t in_stride) {} }; template -struct UnzipAndDequantFunctor { +struct UnzipAndDequantFunctor { using ZippedT = uint16_t; using ScaleComputeT = float; @@ -52,7 +61,8 @@ struct UnzipAndDequantFunctor> shift_bit) & kWeightMask; int32_t value = shifted_value - kBZP; @@ -61,8 +71,10 @@ struct UnzipAndDequantFunctor(scaled_value); } - __device__ void operator()(const uint16_t *in_ptr, const T *super_scale_ptr, - T *out_ptr, const int64_t in_stride) { + __device__ void operator()(const uint16_t *in_ptr, + const T *super_scale_ptr, + T *out_ptr, + const int64_t in_stride) { int32_t shift_bits[7] = {13, 11, 9, 6, 4, 2, 0}; int tid = threadIdx.x; @@ -111,8 +123,11 @@ struct UnzipAndDequantFunctor -struct UnzipAndDequantFunctor { +struct UnzipAndDequantFunctor { using ZippedT = uint8_t; using ScaleComputeT = float; @@ -129,9 +144,11 @@ struct UnzipAndDequantFunctor(column_wise_smem_ptr); - code_zp_ptr = reinterpret_cast(column_wise_smem_ptr + sizeof(float) * TileColumns); - super_scale_ptr = reinterpret_cast(column_wise_smem_ptr + 2 * sizeof(float) * TileColumns); + code_zp_ptr = reinterpret_cast(column_wise_smem_ptr + + sizeof(float) * TileColumns); + super_scale_ptr = reinterpret_cast(column_wise_smem_ptr + + 2 * sizeof(float) * TileColumns); } }; - __device__ void Load(const uint8_t *g_weight_ptr, const uint8_t *g_local_scale_ptr, - const float *g_code_scale_ptr, const float *g_code_zp_ptr, + __device__ void Load(const uint8_t *g_weight_ptr, + const uint8_t *g_local_scale_ptr, + const float *g_code_scale_ptr, + const float *g_code_zp_ptr, const T *g_super_scale_ptr, - Arguments *args, const int64_t in_stride, bool need_preload) { + Arguments *args, + const int64_t in_stride, + bool need_preload) { int tid = threadIdx.x; #pragma unroll @@ -186,7 +215,8 @@ struct UnzipAndDequantFunctorlocal_scale_ptr[ls_row_id * TileColumns + col] = g_local_scale_ptr[local_scale_offset]; + args->local_scale_ptr[ls_row_id * TileColumns + col] = + g_local_scale_ptr[local_scale_offset]; } #pragma unroll @@ -205,10 +235,12 @@ struct UnzipAndDequantFunctor( - args->weight_ptr + z_offset, g_weight_ptr + g_offset, true); + int z_offset = (tid * weight_per_thread_size + i * kBytesPerThread); + int g_offset = + z_offset / TileColumns * in_stride + z_offset % TileColumns; + cutlass::arch::cp_async( + args->weight_ptr + z_offset, g_weight_ptr + g_offset, true); } } else if (tid < weight_threads + local_scale_threads) { constexpr int start_thread_id = weight_threads; - constexpr int local_scale_per_thread_size = local_scale_size / local_scale_threads; - constexpr int kIterations = (local_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread; + constexpr int local_scale_per_thread_size = + local_scale_size / local_scale_threads; + constexpr int kIterations = + (local_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kIterations; ++i) { - int z_offset = (tid - start_thread_id) * local_scale_per_thread_size + i * kBytesPerThread; - int g_offset = z_offset / TileColumns * in_stride + z_offset % TileColumns; - cutlass::arch::cp_async( - args->local_scale_ptr + z_offset, g_local_scale_ptr + g_offset, true); + int z_offset = (tid - start_thread_id) * local_scale_per_thread_size + + i * kBytesPerThread; + int g_offset = + z_offset / TileColumns * in_stride + z_offset % TileColumns; + cutlass::arch::cp_async( + args->local_scale_ptr + z_offset, + g_local_scale_ptr + g_offset, + true); } } else if (need_preload) { if (tid < weight_threads + local_scale_threads + code_scale_threads) { constexpr int start_thread_id = weight_threads + local_scale_threads; - constexpr int code_scale_per_thread_size = code_scale_size / code_scale_threads; - constexpr int kIterations = (code_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread; + constexpr int code_scale_per_thread_size = + code_scale_size / code_scale_threads; + constexpr int kIterations = + (code_scale_per_thread_size + kBytesPerThread - 1) / + kBytesPerThread; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kIterations; ++i) { - int offset = ((tid - start_thread_id) * code_scale_per_thread_size + i * kBytesPerThread) / sizeof(float); - cutlass::arch::cp_async( + int offset = ((tid - start_thread_id) * code_scale_per_thread_size + + i * kBytesPerThread) / + sizeof(float); + cutlass::arch::cp_async( args->code_scale_ptr + offset, g_code_scale_ptr + offset, true); } - } else if (tid < weight_threads + local_scale_threads + code_scale_threads + code_zp_threads) { - constexpr int start_thread_id = weight_threads + local_scale_threads + code_scale_threads; + } else if (tid < weight_threads + local_scale_threads + + code_scale_threads + code_zp_threads) { + constexpr int start_thread_id = + weight_threads + local_scale_threads + code_scale_threads; constexpr int code_zp_per_thread_size = code_zp_size / code_zp_threads; - constexpr int kIterations = (code_zp_per_thread_size + kBytesPerThread - 1) / kBytesPerThread; + constexpr int kIterations = + (code_zp_per_thread_size + kBytesPerThread - 1) / kBytesPerThread; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kIterations; ++i) { - int offset = ((tid - start_thread_id) * code_zp_per_thread_size + i * kBytesPerThread) / sizeof(float); - cutlass::arch::cp_async( + int offset = ((tid - start_thread_id) * code_zp_per_thread_size + + i * kBytesPerThread) / + sizeof(float); + cutlass::arch::cp_async( args->code_zp_ptr + offset, g_code_zp_ptr + offset, true); } - } else if (tid < weight_threads + local_scale_threads + code_scale_threads + code_zp_threads + super_scale_threads) { + } else if (tid < weight_threads + local_scale_threads + + code_scale_threads + code_zp_threads + + super_scale_threads) { if (g_super_scale_ptr) { - constexpr int start_thread_id = weight_threads + local_scale_threads + code_scale_threads + code_zp_threads; - constexpr int super_scale_per_thread_size = super_scale_size / super_scale_threads; - constexpr int kIterations = (super_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread; + constexpr int start_thread_id = weight_threads + local_scale_threads + + code_scale_threads + code_zp_threads; + constexpr int super_scale_per_thread_size = + super_scale_size / super_scale_threads; + constexpr int kIterations = + (super_scale_per_thread_size + kBytesPerThread - 1) / + kBytesPerThread; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kIterations; ++i) { - int offset = ((tid - start_thread_id) * super_scale_per_thread_size + i * kBytesPerThread) / sizeof(T); - cutlass::arch::cp_async( - args->super_scale_ptr + offset, g_super_scale_ptr + offset, true); + int offset = + ((tid - start_thread_id) * super_scale_per_thread_size + + i * kBytesPerThread) / + sizeof(T); + cutlass::arch::cp_async( + args->super_scale_ptr + offset, + g_super_scale_ptr + offset, + true); } } } } } - __device__ void Compute(const Arguments &args, T *out_ptr, + __device__ void Compute(const Arguments &args, + T *out_ptr, const int64_t block_start_row) { int32_t shift_bits[4] = {9, 6, 3, 0}; @@ -333,9 +408,9 @@ struct UnzipAndDequantFunctor(floor(zipped_value[zipped_row] * code_scale + code_zp + - static_cast(0.5))); + int32_t decode_value = static_cast( + floor(zipped_value[zipped_row] * code_scale + code_zp + + static_cast(0.5))); int row = group_id * 64 + zipped_row * 4; @@ -355,14 +430,17 @@ struct UnzipAndDequantFunctor= 32) ? 4 : 2; constexpr int RowStride = NumThreads * N / TileColumns; constexpr int kNumIters = kNumWeightsPerThread / N; - static_assert(N * NumThreads >= TileColumns, "N * NumThreads should be no less than TileColumns."); + static_assert(N * NumThreads >= TileColumns, + "N * NumThreads should be no less than TileColumns."); constexpr ScaleComputeT decode_value_zp = static_cast(0.5); @@ -373,19 +451,22 @@ struct UnzipAndDequantFunctor local_scales = - *reinterpret_cast *>(args.local_scale_ptr + begin_col_id); + *reinterpret_cast *>(args.local_scale_ptr + + begin_col_id); UnzipArray zipped_values[2]; int zipped_offset = begin_row_id * TileColumns + begin_col_id; - zipped_values[0] = - *reinterpret_cast *>(args.weight_ptr + zipped_offset); + zipped_values[0] = *reinterpret_cast *>( + args.weight_ptr + zipped_offset); - UnzipArray super_scales = - *reinterpret_cast *>(args.super_scale_ptr + begin_col_id); + UnzipArray super_scales = *reinterpret_cast *>( + args.super_scale_ptr + begin_col_id); UnzipArray code_scales = - *reinterpret_cast *>(args.code_scale_ptr + begin_col_id); + *reinterpret_cast *>(args.code_scale_ptr + + begin_col_id); UnzipArray code_zps = - *reinterpret_cast *>(args.code_zp_ptr + begin_col_id); + *reinterpret_cast *>(args.code_zp_ptr + + begin_col_id); // special for TileRows = 64 int local_scale_shift = (((block_start_row / 64) + 1) & 1) * 4; @@ -394,9 +475,10 @@ struct UnzipAndDequantFunctor(local_scales[i]) >> local_scale_shift) & kLocalScaleMask; - scales[i] = - static_cast(shifted_local_scale) * static_cast(super_scales[i]); + (static_cast(local_scales[i]) >> local_scale_shift) & + kLocalScaleMask; + scales[i] = static_cast(shifted_local_scale) * + static_cast(super_scales[i]); } #pragma unroll @@ -405,26 +487,33 @@ struct UnzipAndDequantFunctor *>(args.weight_ptr + zipped_offset); + *reinterpret_cast *>(args.weight_ptr + + zipped_offset); } UnzipArray outs[4]; #pragma unroll for (int i = 0; i < N; ++i) { - int32_t decode_value = - static_cast(floor(static_cast(zipped_values[iter_id & 1][i]) * code_scales[i] - + code_zps[i] + decode_value_zp)); + int32_t decode_value = static_cast( + floor(static_cast(zipped_values[iter_id & 1][i]) * + code_scales[i] + + code_zps[i] + decode_value_zp)); - ScaleComputeT value_3 = static_cast((decode_value & kWeightMask) - kBZP); + ScaleComputeT value_3 = + static_cast((decode_value & kWeightMask) - kBZP); decode_value >>= 3; - ScaleComputeT value_2 = static_cast((decode_value & kWeightMask) - kBZP); + ScaleComputeT value_2 = + static_cast((decode_value & kWeightMask) - kBZP); decode_value >>= 3; - ScaleComputeT value_1 = static_cast((decode_value & kWeightMask) - kBZP); + ScaleComputeT value_1 = + static_cast((decode_value & kWeightMask) - kBZP); decode_value >>= 3; - ScaleComputeT value_0 = static_cast((decode_value & kWeightMask) - kBZP); + ScaleComputeT value_0 = + static_cast((decode_value & kWeightMask) - kBZP); outs[0][i] = static_cast(scales[i] * value_0); outs[1][i] = static_cast(scales[i] * value_1); outs[2][i] = static_cast(scales[i] * value_2); diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/default_mma_tensor_op.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/default_mma_tensor_op.h index af4298df5e..8245aff71c 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/default_mma_tensor_op.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/default_mma_tensor_op.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,18 +18,20 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file - \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. + \brief Default warp-level GEMM operators selected by data type, size, and + layouts of operands. */ #pragma once @@ -70,35 +72,60 @@ template < /// Store the accumulators in row major or column major. Row major is used /// when output layout is interleaved. bool AccumulatorsInRowMajor> -struct DefaultMmaTensorOp -{ +struct DefaultMmaTensorOp { + private: + // Shape for computing the FP16s + using ComputeInstructionShape = InstructionShape_; -private: - // Shape for computing the FP16s - using ComputeInstructionShape = InstructionShape_; + // Chosen so we get K=16 for int8, K=32 for int4, K=64 for int2. + static constexpr int LoadInstructionK = 128 / sizeof_bits::value; - // Chosen so we get K=16 for int8, K=32 for int4, K=64 for int2. - static constexpr int LoadInstructionK = 128 / sizeof_bits::value; + // Shape for loading the narrow data type from shared memory + using LoadInstructionShape = + GemmShape; - // Shape for loading the narrow data type from shared memory - using LoadInstructionShape = GemmShape; + public: + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma, + cutlass::MatrixShape<1, 1>>; -public: - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::Mma, - cutlass::MatrixShape<1, 1>>; - - // Define the warp-level tensor op - using Type = cutlass::gemm::warp::MmaTensorOpComputeBWithF16; + // Define the warp-level tensor op + using Type = + cutlass::gemm::warp::MmaTensorOpComputeBWithF16; }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace warp -} // namespace gemm -} // namespace cutlass +} // namespace warp +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h index 64136a9758..edc37d72a1 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,19 +18,20 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file - \brief Templates implementing warp-level matrix multiply-accumulate operations targeting - Tensor Cores. + \brief Templates implementing warp-level matrix multiply-accumulate + operations targeting Tensor Cores. */ #pragma once @@ -63,7 +64,8 @@ namespace gemm { namespace warp { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Structure to compute the matrix product targeting Tensor Cores, for the case when A is floating point and B is quantized integer. +/// Structure to compute the matrix product targeting Tensor Cores, for the case +/// when A is floating point and B is quantized integer. template < /// Size of the Gemm problem - concept: gemm::GemmShape<> typename Shape_, @@ -90,213 +92,241 @@ template < bool AccumulatorsInRowMajor = false, /// Used for partial specialization typename Enable = bool> -class MmaTensorOpComputeBWithF16 -{ -public: - /// Shape of warp-level matrix operation (concept: GemmShape) - using Shape = Shape_; +class MmaTensorOpComputeBWithF16 { + public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; - /// Data type of multiplicand A - using ElementA = ElementA_; + /// Data type of multiplicand A + using ElementA = ElementA_; - /// Layout of multiplicand A - using LayoutA = LayoutA_; + /// Layout of multiplicand A + using LayoutA = LayoutA_; - /// Data type of multiplicand B - using ElementB = ElementB_; + /// Data type of multiplicand B + using ElementB = ElementB_; - /// Layout of multiplicand B - using LayoutB = LayoutB_; + /// Layout of multiplicand B + using LayoutB = LayoutB_; - /// Data type of accumulator matrix C - using ElementC = ElementC_; + /// Data type of accumulator matrix C + using ElementC = ElementC_; - /// Layout of accumulator matrix C - using LayoutC = LayoutC_; + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; - /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) - using Policy = Policy_; + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; - /// Underlying matrix multiply operator (concept: arch::Mma) - using ArchMmaOperator = typename Policy::Operator; + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; - /// Indicates math operator - using MathOperator = typename ArchMmaOperator::Operator; + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; - /// Architecture tag from underlying instruction - using ArchTag = typename ArchMmaOperator::ArchTag; - static_assert((platform::is_same::value - && platform::is_same::value) - || (platform::is_same::value - && platform::is_same::value - && ArchTag::kMinComputeCapability >= 80) - || (platform::is_same::value - && platform::is_same::value - && ArchTag::kMinComputeCapability >= 89), - "MmaTensorOpCvtBToA only supports underlying HMMA/QMMA"); + /// Architecture tag from underlying instruction + using ArchTag = typename ArchMmaOperator::ArchTag; + static_assert( + (platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value && + ArchTag::kMinComputeCapability >= 80) || + (platform::is_same::value && + platform::is_same::value && + ArchTag::kMinComputeCapability >= 89), + "MmaTensorOpCvtBToA only supports underlying HMMA/QMMA"); - static_assert(platform::is_same::value - || (platform::is_same::value && ArchTag::kMinComputeCapability >= 80) - || (platform::is_same::value && ArchTag::kMinComputeCapability >= 89), - "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+, or FP8 on Ada"); + static_assert(platform::is_same::value || + (platform::is_same::value && + ArchTag::kMinComputeCapability >= 80) || + (platform::is_same::value && + ArchTag::kMinComputeCapability >= 89), + "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+, " + "or FP8 on Ada"); - /// Indicates class of matrix operator - using OperatorClass = arch::OpClassTensorOp; + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; - /// Shape of underlying instruction - using InstructionShape = typename ArchMmaOperator::Shape; + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; - /// Instruction shape to override shared memory iterators with - using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; + /// Instruction shape to override shared memory iterators with + using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; + + static_assert(SharedMemoryInstructionShape::kM == InstructionShape::kM, + "M dimension of compute instruction must match load"); + static_assert(SharedMemoryInstructionShape::kN == InstructionShape::kN, + "N dimension of compute instruction must match load"); + + static constexpr int kExpansionFactor = + SharedMemoryInstructionShape::kK / InstructionShape::kK; + + static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + + public: + /// Iterates over the A operand in memory + using IteratorA = MmaTensorOpMultiplicandTileIterator< + MatrixShape, + Operand::kA, + ElementA, + LayoutA, + MatrixShape, + Policy::OpDelta::kRow, + kThreadCount, + kPartitionsK>; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = + Array; + + /// Iterates over the B operand in memory + using IteratorB = MmaTensorOpMultiplicandTileIterator< + MatrixShape, + Operand::kB, + ElementB, + LayoutB, + MatrixShape, + Policy::OpDelta::kRow, + kThreadCount, + kPartitionsK>; + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed B tile + using TransformedFragmentB = + Array; + + /// Iterates over the C operand in memory + using IteratorC = + MmaTensorOpAccumulatorTileIterator, + ElementC, + LayoutC, + typename ArchMmaOperator::Shape, + typename Policy::OpDelta>; + + /// Storage for C tile + using FragmentC = typename IteratorC::Fragment; + + /// Number of mma operations performed + using MmaIterations = + MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / + ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / + ArchMmaOperator::Shape::kN>; + + public: + /// Underlying matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; + + public: + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + MmaTensorOpComputeBWithF16() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()(FragmentC& D, + TransformedFragmentA const& A, + TransformedFragmentB const& B, + FragmentC const& C, + int const warp_tileB_k_offset) const { + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; static_assert( - SharedMemoryInstructionShape::kM == InstructionShape::kM, "M dimension of compute instruction must match load"); - static_assert( - SharedMemoryInstructionShape::kN == InstructionShape::kN, "N dimension of compute instruction must match load"); + TransformedFragmentB::kElements == + MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, + "Each thread should have a pack of mma registers for each column " + "iteration AND for the expanded K dim of " + "B"); - static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK; + D = C; - static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); - - /// Complex transform on A operand - static ComplexTransform const kTransformA = ComplexTransform::kNone; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = ComplexTransform::kNone; - - /// Number of threads participating in warp-level matrix product - static int const kThreadCount = 32; - - /// Number of partitions along K dimension - static int const kPartitionsK = PartitionsK_; - -public: - /// Iterates over the A operand in memory - using IteratorA - = MmaTensorOpMultiplicandTileIterator, Operand::kA, ElementA, LayoutA, - MatrixShape, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; - - /// Storage for A tile - using FragmentA = typename IteratorA::Fragment; - - /// Storage for transformed A tile - using TransformedFragmentA = Array; - - /// Iterates over the B operand in memory - using IteratorB = MmaTensorOpMultiplicandTileIterator, Operand::kB, ElementB, - LayoutB, MatrixShape, Policy::OpDelta::kRow, - kThreadCount, kPartitionsK>; - - /// Storage for B tile - using FragmentB = typename IteratorB::Fragment; - - /// Storage for transformed B tile - using TransformedFragmentB = Array; - - /// Iterates over the C operand in memory - using IteratorC = MmaTensorOpAccumulatorTileIterator, ElementC, LayoutC, - typename ArchMmaOperator::Shape, typename Policy::OpDelta>; - - /// Storage for C tile - using FragmentC = typename IteratorC::Fragment; - - /// Number of mma operations performed - using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, - (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>; - -public: - /// Underlying matrix multiply operator (concept: arch::Mma) - ArchMmaOperator mma; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_DEVICE - MmaTensorOpComputeBWithF16() {} - - /// Performs a warp-level matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C, - int const warp_tileB_k_offset) const - { - - using MmaOperandA = typename ArchMmaOperator::FragmentA; - using MmaOperandB = typename ArchMmaOperator::FragmentB; - using MmaOperandC = typename ArchMmaOperator::FragmentC; - - static_assert( - TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, - "Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of " - "B"); - - D = C; - - MmaOperandA const* ptr_A = reinterpret_cast(&A); - MmaOperandB const* ptr_B = reinterpret_cast(&B); - MmaOperandC* ptr_D = reinterpret_cast(&D); + MmaOperandA const* ptr_A = reinterpret_cast(&A); + MmaOperandB const* ptr_B = reinterpret_cast(&B); + MmaOperandC* ptr_D = reinterpret_cast(&D); #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) - // Serpentine visitation order maximizing reuse of Rb - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) - { + // Serpentine visitation order maximizing reuse of Rb + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) - { - - int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); - - int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; - if (AccumulatorsInRowMajor) - { // matrix B is reordered - mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], ptr_A[m_serpentine], ptr_B[n_offsetB], - ptr_D[n + m_serpentine * MmaIterations::kColumn]); - } - else - { - mma(ptr_D[m_serpentine + n * MmaIterations::kRow], ptr_A[m_serpentine], ptr_B[n_offsetB], - ptr_D[m_serpentine + n * MmaIterations::kRow]); - } - } + int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], + ptr_A[m_serpentine], + ptr_B[n_offsetB], + ptr_D[n + m_serpentine * MmaIterations::kColumn]); + } else { + mma(ptr_D[m_serpentine + n * MmaIterations::kRow], + ptr_A[m_serpentine], + ptr_B[n_offsetB], + ptr_D[m_serpentine + n * MmaIterations::kRow]); } -#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - // Serpentine visitation order maximizing reuse of Ra - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) - { - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) - { - - int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); - - int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine; - if (AccumulatorsInRowMajor) - { // matrix B is reordered - mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], ptr_A[m], ptr_B[n_serpentine_offsetB], - ptr_D[n_serpentine + m * MmaIterations::kColumn]); - } - else - { - mma(ptr_D[m + n_serpentine * MmaIterations::kRow], ptr_A[m], ptr_B[n_serpentine_offsetB], - ptr_D[m + n_serpentine * MmaIterations::kRow]); - } - } - } -#else - assert(0); -#endif + } } +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + // Serpentine visitation order maximizing reuse of Ra + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + int n_serpentine_offsetB = + warp_tileB_k_offset + kExpansionFactor * n_serpentine; + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], + ptr_A[m], + ptr_B[n_serpentine_offsetB], + ptr_D[n_serpentine + m * MmaIterations::kColumn]); + } else { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], + ptr_A[m], + ptr_B[n_serpentine_offsetB], + ptr_D[m + n_serpentine * MmaIterations::kRow]); + } + } + } +#else + assert(0); +#endif + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Structure to compute the matrix product targeting Tensor Cores, for the case when A is floating point and B is quantized integer. -/// Specialization for B of uint2b_t. +/// Structure to compute the matrix product targeting Tensor Cores, for the case +/// when A is floating point and B is quantized integer. Specialization for B of +/// uint2b_t. template < /// Size of the Gemm problem - concept: gemm::GemmShape<> typename Shape_, @@ -319,214 +349,232 @@ template < /// Store the accumulators in row major or column major. Row major is used /// when output layout is interleaved. bool AccumulatorsInRowMajor> -class MmaTensorOpComputeBWithF16< - Shape_, - ElementA_, - LayoutA_, - uint2b_t, - LayoutB_, - ElementC_, - LayoutC_, - Policy_, - SharedMemoryInstructionShape_, - PartitionsK_, - AccumulatorsInRowMajor> -{ -public: - /// Shape of warp-level matrix operation (concept: GemmShape) - using Shape = Shape_; +class MmaTensorOpComputeBWithF16 { + public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; - /// Data type of multiplicand A - using ElementA = ElementA_; + /// Data type of multiplicand A + using ElementA = ElementA_; - /// Layout of multiplicand A - using LayoutA = LayoutA_; + /// Layout of multiplicand A + using LayoutA = LayoutA_; - /// Data type of multiplicand B - using ElementB = uint2b_t; + /// Data type of multiplicand B + using ElementB = uint2b_t; - /// Layout of multiplicand B - using LayoutB = LayoutB_; + /// Layout of multiplicand B + using LayoutB = LayoutB_; - /// Data type of accumulator matrix C - using ElementC = ElementC_; + /// Data type of accumulator matrix C + using ElementC = ElementC_; - /// Layout of accumulator matrix C - using LayoutC = LayoutC_; + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; - /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) - using Policy = Policy_; + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; - /// Underlying matrix multiply operator (concept: arch::Mma) - using ArchMmaOperator = typename Policy::Operator; + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; - /// Indicates math operator - using MathOperator = typename ArchMmaOperator::Operator; + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; - /// Architecture tag from underlying instruction - using ArchTag = typename ArchMmaOperator::ArchTag; - static_assert((platform::is_same::value - && platform::is_same::value) - || (platform::is_same::value - && platform::is_same::value - && ArchTag::kMinComputeCapability >= 80), - "MmaTensorOpCvtBToA only supports underlying HMMA/QMMA"); + /// Architecture tag from underlying instruction + using ArchTag = typename ArchMmaOperator::ArchTag; + static_assert( + (platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value && + ArchTag::kMinComputeCapability >= 80), + "MmaTensorOpCvtBToA only supports underlying HMMA/QMMA"); - static_assert(platform::is_same::value - || (platform::is_same::value && ArchTag::kMinComputeCapability >= 80), - "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+"); + static_assert(platform::is_same::value || + (platform::is_same::value && + ArchTag::kMinComputeCapability >= 80), + "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+"); - /// Indicates class of matrix operator - using OperatorClass = arch::OpClassTensorOp; + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; - /// Shape of underlying instruction - using InstructionShape = typename ArchMmaOperator::Shape; + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; - /// Instruction shape to override shared memory iterators with - using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; + /// Instruction shape to override shared memory iterators with + using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; - static_assert( - SharedMemoryInstructionShape::kM == InstructionShape::kM, "M dimension of compute instruction must match load"); - static_assert( - SharedMemoryInstructionShape::kN == InstructionShape::kN, "N dimension of compute instruction must match load"); + static_assert(SharedMemoryInstructionShape::kM == InstructionShape::kM, + "M dimension of compute instruction must match load"); + static_assert(SharedMemoryInstructionShape::kN == InstructionShape::kN, + "N dimension of compute instruction must match load"); - static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK; + static constexpr int kExpansionFactor = + SharedMemoryInstructionShape::kK / InstructionShape::kK; - static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); + static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); - /// Complex transform on A operand - static ComplexTransform const kTransformA = ComplexTransform::kNone; + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; - /// Complex transform on B operand - static ComplexTransform const kTransformB = ComplexTransform::kNone; + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; - /// Number of threads participating in warp-level matrix product - static int const kThreadCount = 32; + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; - /// Number of partitions along K dimension - static int const kPartitionsK = PartitionsK_; + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; -public: - /// Iterates over the A operand in memory - using IteratorA - = MmaTensorOpMultiplicandTileIterator, Operand::kA, ElementA, LayoutA, - MatrixShape, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + public: + /// Iterates over the A operand in memory + using IteratorA = MmaTensorOpMultiplicandTileIterator< + MatrixShape, + Operand::kA, + ElementA, + LayoutA, + MatrixShape, + Policy::OpDelta::kRow, + kThreadCount, + kPartitionsK>; - /// Storage for A tile - using FragmentA = typename IteratorA::Fragment; + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; - /// Storage for transformed A tile - using TransformedFragmentA = Array; + /// Storage for transformed A tile + using TransformedFragmentA = + Array; - /// Iterates over the B operand in memory - using IteratorB = MmaTensorOpMultiplicandTileIterator, Operand::kB, ElementB, - LayoutB, MatrixShape, Policy::OpDelta::kRow, - kThreadCount, kPartitionsK>; + /// Iterates over the B operand in memory + using IteratorB = MmaTensorOpMultiplicandTileIterator< + MatrixShape, + Operand::kB, + ElementB, + LayoutB, + MatrixShape, + Policy::OpDelta::kRow, + kThreadCount, + kPartitionsK>; - /// Storage for B tile - using FragmentB = typename IteratorB::Fragment; + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; - /// Storage for transformed B tile - using TransformedFragmentB = - Array; + /// Storage for transformed B tile + using TransformedFragmentB = Array; - /// Iterates over the C operand in memory - using IteratorC = MmaTensorOpAccumulatorTileIterator, ElementC, LayoutC, - typename ArchMmaOperator::Shape, typename Policy::OpDelta>; + /// Iterates over the C operand in memory + using IteratorC = + MmaTensorOpAccumulatorTileIterator, + ElementC, + LayoutC, + typename ArchMmaOperator::Shape, + typename Policy::OpDelta>; - /// Storage for C tile - using FragmentC = typename IteratorC::Fragment; + /// Storage for C tile + using FragmentC = typename IteratorC::Fragment; - /// Number of mma operations performed - using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, - (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>; + /// Number of mma operations performed + using MmaIterations = + MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / + ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / + ArchMmaOperator::Shape::kN>; -public: - /// Underlying matrix multiply operator (concept: arch::Mma) - ArchMmaOperator mma; + public: + /// Underlying matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; -public: - // - // Methods - // + public: + // + // Methods + // - /// Ctor - CUTLASS_DEVICE - MmaTensorOpComputeBWithF16() {} + /// Ctor + CUTLASS_DEVICE + MmaTensorOpComputeBWithF16() {} - /// Performs a warp-level matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C) const - { + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()(FragmentC& D, + TransformedFragmentA const& A, + TransformedFragmentB const& B, + FragmentC const& C) const { + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; - using MmaOperandA = typename ArchMmaOperator::FragmentA; - using MmaOperandB = typename ArchMmaOperator::FragmentB; - using MmaOperandC = typename ArchMmaOperator::FragmentC; + D = C; - D = C; - - MmaOperandA const* ptr_A = reinterpret_cast(&A); - MmaOperandB const* ptr_B = reinterpret_cast(&B); - MmaOperandC* ptr_D = reinterpret_cast(&D); + MmaOperandA const* ptr_A = reinterpret_cast(&A); + MmaOperandB const* ptr_B = reinterpret_cast(&B); + MmaOperandC* ptr_D = reinterpret_cast(&D); #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) - // Serpentine visitation order maximizing reuse of Rb - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) - { + // Serpentine visitation order maximizing reuse of Rb + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) - { - - int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); - - if (AccumulatorsInRowMajor) - { // matrix B is reordered - mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], ptr_A[m_serpentine], ptr_B[n], - ptr_D[n + m_serpentine * MmaIterations::kColumn]); - } - else - { - mma(ptr_D[m_serpentine + n * MmaIterations::kRow], ptr_A[m_serpentine], ptr_B[n], - ptr_D[m_serpentine + n * MmaIterations::kRow]); - } - } + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], + ptr_A[m_serpentine], + ptr_B[n], + ptr_D[n + m_serpentine * MmaIterations::kColumn]); + } else { + mma(ptr_D[m_serpentine + n * MmaIterations::kRow], + ptr_A[m_serpentine], + ptr_B[n], + ptr_D[m_serpentine + n * MmaIterations::kRow]); } -#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - // Serpentine visitation order maximizing reuse of Ra - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) - { - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) - { - - int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); - - if (AccumulatorsInRowMajor) - { // matrix B is reordered - mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], ptr_A[m], ptr_B[n_serpentine], - ptr_D[n_serpentine + m * MmaIterations::kColumn]); - } - else - { - mma(ptr_D[m + n_serpentine * MmaIterations::kRow], ptr_A[m], ptr_B[n_serpentine], - ptr_D[m + n_serpentine * MmaIterations::kRow]); - } - } - } -#else - assert(0); -#endif + } } +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + // Serpentine visitation order maximizing reuse of Ra + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], + ptr_A[m], + ptr_B[n_serpentine], + ptr_D[n_serpentine + m * MmaIterations::kColumn]); + } else { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], + ptr_A[m], + ptr_B[n_serpentine], + ptr_D[m + n_serpentine * MmaIterations::kRow]); + } + } + } +#else + assert(0); +#endif + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace warp -} // namespace gemm -} // namespace cutlass +} // namespace warp +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h index 24e844abca..5bb06a5414 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,18 +18,20 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file - \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. + \brief Defines iterators used by warp-level matrix multiply operations + targeting Tensor Cores. */ #pragma once @@ -57,12 +59,9 @@ //////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace warp -{ +namespace cutlass { +namespace gemm { +namespace warp { //////////////////////////////////////////////////////////////////////////////// @@ -94,193 +93,216 @@ template < typename Shape_, /// WeightOnlyQuantOp QuantOp_> -class MmaTensorOpDequantizer= 80 - && platform::is_same::value>::type> -{ +class MmaTensorOpDequantizer< + MmaOperator_, + Shape_, + Operand::kB, + bfloat16_t, + layout::RowMajor, + 32, + QuantOp_, + typename platform::enable_if< + MmaOperator_::ArchTag::kMinComputeCapability >= 80 && + platform::is_same::value>::type> { + public: + /// Mma Operator + using MmaOperator = MmaOperator_; -public: - /// Mma Operator - using MmaOperator = MmaOperator_; + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; - // The architecture specific mma ooperator being used - using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; - // Mma Instruction Shape - using InstructionShape = typename ArchMmaOperator::Shape; + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = + MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; - // This is the ratio of the load instruction vs the compute instruction. - static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + /// Type of the scales + using ElementScale = bfloat16_t; - /// Type of the scales - using ElementScale = bfloat16_t; + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = + Array; - /// Fragment to hold B data before Mma - using FragmentDequantizedOperand = Array; + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + using FragmentScale = + Array; + using FragmentZero = + Array; - // Fragment to hold scale data to apply to B before mma - // We need 1 fp16 per matrix iteration in the N dimension - static constexpr int kColsPerMmaPerThread = 1; - using FragmentScale = Array; - using FragmentZero = Array; + /// Warp mma shape + using Shape = Shape_; - /// Warp mma shape - using Shape = Shape_; + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; - /// Layout of the scales in shared memory - using Layout = layout::RowMajor; + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; - static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; - - CUTLASS_DEVICE - MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx) - { - int const warp_offset = warp_idx_n * Shape::kN; - int const quad = lane_idx / 4; - int const thread_offset = warp_offset + quad; - pointer_scale_ = smem_scales.data() + thread_offset; - if constexpr (hasZero(QuantOp)) - { - pointer_zero_ = smem_zeros.data() + thread_offset; - } + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, + TensorRef smem_zeros, + int const warp_idx_n, + int const lane_idx) { + int const warp_offset = warp_idx_n * Shape::kN; + int const quad = lane_idx / 4; + int const thread_offset = warp_offset + quad; + pointer_scale_ = smem_scales.data() + thread_offset; + if constexpr (hasZero(QuantOp)) { + pointer_zero_ = smem_zeros.data() + thread_offset; } + } - CUTLASS_DEVICE - MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx) - : MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx) - { + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, + int const warp_idx_n, + int const lane_idx) + : MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx) { + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; + ++mma_n_iter) { + scale_frag[mma_n_iter] = + pointer_scale_[mma_n_iter * InstructionShape::kN]; } + } - CUTLASS_DEVICE - void load(FragmentScale& scale_frag) - { - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; - } - } - - CUTLASS_DEVICE - void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag) - { + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, + FragmentScale const& scale_frag) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) - using _MmaOperandB = typename ArchMmaOperator::FragmentB; - using ExpandedMmaOperandB = Array; - static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn - == FragmentDequantizedOperand::kElements, - ""); + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = + Array; + static_assert( + ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn == + FragmentDequantizedOperand::kElements, + ""); - __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag); - ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); + __nv_bfloat16 const* scale_ptr = + reinterpret_cast<__nv_bfloat16 const*>(&scale_frag); + ExpandedMmaOperandB* operand_frag_ptr = + reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; + ++mma_n_iter) { + static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); - __nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]); - __nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); + __nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]); + __nv_bfloat162* operand_bf16x2_ptr = + reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) - { - operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2); - } - } + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) { + operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2); + } + } #else - // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should - // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid - // numerous conversion instructions in GEMM main loop. - arch::device_breakpoint(); + // Slow path not implemented here on purpose. If we need to do HMMA on older + // arch, scale conversion should happen before scales are stored to shared + // memory and we should use the fp16 dequantizer. This will avoid numerous + // conversion instructions in GEMM main loop. + arch::device_breakpoint(); #endif - } + } - CUTLASS_DEVICE - void load(FragmentScale& scale_frag, FragmentScale& zero_frag) - { - if constexpr (hasZero(QuantOp)) - { - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; - zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN]; - } - } - else - { - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; - } - } + CUTLASS_DEVICE + void load(FragmentScale& scale_frag, FragmentScale& zero_frag) { + if constexpr (hasZero(QuantOp)) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; + ++mma_n_iter) { + scale_frag[mma_n_iter] = + pointer_scale_[mma_n_iter * InstructionShape::kN]; + zero_frag[mma_n_iter] = + pointer_zero_[mma_n_iter * InstructionShape::kN]; + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; + ++mma_n_iter) { + scale_frag[mma_n_iter] = + pointer_scale_[mma_n_iter * InstructionShape::kN]; + } } + } - CUTLASS_DEVICE - void dequantize( - FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag) - { + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, + FragmentScale const& scale_frag, + FragmentScale const& zero_frag) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) - using _MmaOperandB = typename ArchMmaOperator::FragmentB; - using ExpandedMmaOperandB = Array; - static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn - == FragmentDequantizedOperand::kElements, - ""); + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = + Array; + static_assert( + ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn == + FragmentDequantizedOperand::kElements, + ""); - __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag); - __nv_bfloat16 const* zero_ptr = reinterpret_cast<__nv_bfloat16 const*>(&zero_frag); + __nv_bfloat16 const* scale_ptr = + reinterpret_cast<__nv_bfloat16 const*>(&scale_frag); + __nv_bfloat16 const* zero_ptr = + reinterpret_cast<__nv_bfloat16 const*>(&zero_frag); - ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + ExpandedMmaOperandB* operand_frag_ptr = + reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; + ++mma_n_iter) { + static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); + + __nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]); + __nv_bfloat162 zerox2 = __bfloat162bfloat162(zero_ptr[mma_n_iter]); + __nv_bfloat162* operand_bf16x2_ptr = + reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); + + if constexpr (hasZero(QuantOp)) { CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); - - __nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]); - __nv_bfloat162 zerox2 = __bfloat162bfloat162(zero_ptr[mma_n_iter]); - __nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); - - if constexpr (hasZero(QuantOp)) - { - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) - { - operand_bf16x2_ptr[ii] = __hfma2(operand_bf16x2_ptr[ii], scalex2, zerox2); - } - } - else - { - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) - { - operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2); - } - } + for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) { + operand_bf16x2_ptr[ii] = + __hfma2(operand_bf16x2_ptr[ii], scalex2, zerox2); } + } else { + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) { + operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2); + } + } + } #else - // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should - // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid - // numerous conversion instructions in GEMM main loop. - arch::device_breakpoint(); + // Slow path not implemented here on purpose. If we need to do HMMA on older + // arch, scale conversion should happen before scales are stored to shared + // memory and we should use the fp16 dequantizer. This will avoid numerous + // conversion instructions in GEMM main loop. + arch::device_breakpoint(); #endif - } + } - // Adds a pointer offset in units of elements. - CUTLASS_DEVICE - void add_pointer_offset(int64_t const& offset) - { - static_assert(sizeof(ElementScale) > 1, ""); - pointer_scale_ += offset; - pointer_zero_ += offset; - } + // Adds a pointer offset in units of elements. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_scale_ += offset; + pointer_zero_ += offset; + } -private: - ElementScale const* pointer_scale_; - ElementScale const* pointer_zero_; + private: + ElementScale const* pointer_scale_; + ElementScale const* pointer_zero_; }; //////////////////////////////////////////////////////////////////////////////// @@ -293,170 +315,190 @@ template < typename Shape_, /// WeightOnlyQuantOp QuantOp_> -class MmaTensorOpDequantizer= 75 - && platform::is_same::value>::type> -{ +class MmaTensorOpDequantizer< + MmaOperator_, + Shape_, + Operand::kB, + half_t, + layout::RowMajor, + 32, + QuantOp_, + typename platform::enable_if< + MmaOperator_::ArchTag::kMinComputeCapability >= 75 && + platform::is_same::value>::type> { + public: + /// Mma Operator + using MmaOperator = MmaOperator_; -public: - /// Mma Operator - using MmaOperator = MmaOperator_; + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; - // The architecture specific mma ooperator being used - using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; - // Mma Instruction Shape - using InstructionShape = typename ArchMmaOperator::Shape; + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = + MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; - // This is the ratio of the load instruction vs the compute instruction. - static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + /// Type of the scales + using ElementScale = half_t; - /// Type of the scales - using ElementScale = half_t; + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = + Array; - /// Fragment to hold B data before Mma - using FragmentDequantizedOperand = Array; + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + using FragmentScale = + Array; + using FragmentZero = + Array; - // Fragment to hold scale data to apply to B before mma - // We need 1 fp16 per matrix iteration in the N dimension - static constexpr int kColsPerMmaPerThread = 1; - using FragmentScale = Array; - using FragmentZero = Array; + /// Warp mma shape + using Shape = Shape_; - /// Warp mma shape - using Shape = Shape_; + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; - /// Layout of the scales in shared memory - using Layout = layout::RowMajor; + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; - static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; - - CUTLASS_DEVICE - MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx) - { - int const warp_offset = warp_idx_n * Shape::kN; - int const quad = lane_idx / 4; - int const thread_offset = warp_offset + quad; - pointer_scale_ = smem_scales.data() + thread_offset; - if constexpr (hasZero(QuantOp)) - { - pointer_zero_ = smem_zeros.data() + thread_offset; - } + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, + TensorRef smem_zeros, + int const warp_idx_n, + int const lane_idx) { + int const warp_offset = warp_idx_n * Shape::kN; + int const quad = lane_idx / 4; + int const thread_offset = warp_offset + quad; + pointer_scale_ = smem_scales.data() + thread_offset; + if constexpr (hasZero(QuantOp)) { + pointer_zero_ = smem_zeros.data() + thread_offset; } + } - CUTLASS_DEVICE - MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx) - : MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx) - { + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, + int const warp_idx_n, + int const lane_idx) + : MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx) { + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; + ++mma_n_iter) { + scale_frag[mma_n_iter] = + pointer_scale_[mma_n_iter * InstructionShape::kN]; } + } - CUTLASS_DEVICE - void load(FragmentScale& scale_frag) - { - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; - } + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, + FragmentScale const& scale_frag) { + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = + Array; + static_assert( + ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn == + FragmentDequantizedOperand::kElements, + ""); + + multiplies mul_op; + + ExpandedMmaOperandB* operand_frag_ptr = + reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; + ++mma_n_iter) { + operand_frag_ptr[mma_n_iter] = + mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); } + } - CUTLASS_DEVICE - void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag) - { - using _MmaOperandB = typename ArchMmaOperator::FragmentB; - using ExpandedMmaOperandB - = Array; - static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn - == FragmentDequantizedOperand::kElements, - ""); - - multiplies mul_op; - - ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); - } + CUTLASS_DEVICE + void load(FragmentScale& scale_frag, FragmentScale& zero_frag) { + if constexpr (hasZero(QuantOp)) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; + ++mma_n_iter) { + scale_frag[mma_n_iter] = + pointer_scale_[mma_n_iter * InstructionShape::kN]; + zero_frag[mma_n_iter] = + pointer_zero_[mma_n_iter * InstructionShape::kN]; + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; + ++mma_n_iter) { + scale_frag[mma_n_iter] = + pointer_scale_[mma_n_iter * InstructionShape::kN]; + } } + } - CUTLASS_DEVICE - void load(FragmentScale& scale_frag, FragmentScale& zero_frag) - { - if constexpr (hasZero(QuantOp)) - { - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; - zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN]; - } - } - else - { - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; - } - } + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, + FragmentScale const& scale_frag, + FragmentScale const& zero_frag) { + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = + Array; + static_assert( + ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn == + FragmentDequantizedOperand::kElements, + ""); + + multiplies mul_op; + ExpandedMmaOperandB* operand_frag_ptr = + reinterpret_cast(&operand_frag); + + if constexpr (hasZero(QuantOp)) { + plus plus_op; + + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; + ++mma_n_iter) { + operand_frag_ptr[mma_n_iter] = plus_op( + mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]), + zero_frag[mma_n_iter]); + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; + ++mma_n_iter) { + operand_frag_ptr[mma_n_iter] = + mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); + } } + } - CUTLASS_DEVICE - void dequantize( - FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag) - { - using _MmaOperandB = typename ArchMmaOperator::FragmentB; - using ExpandedMmaOperandB - = Array; - static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn - == FragmentDequantizedOperand::kElements, - ""); + // Adds a pointer offset in units of elements. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_scale_ += offset; + pointer_zero_ += offset; + } - multiplies mul_op; - ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); - - if constexpr (hasZero(QuantOp)) - { - plus plus_op; - - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - operand_frag_ptr[mma_n_iter] - = plus_op(mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]), zero_frag[mma_n_iter]); - } - } - else - { - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); - } - } - } - - // Adds a pointer offset in units of elements. - CUTLASS_DEVICE - void add_pointer_offset(int64_t const& offset) - { - static_assert(sizeof(ElementScale) > 1, ""); - pointer_scale_ += offset; - pointer_zero_ += offset; - } - -private: - ElementScale const* pointer_scale_; - ElementScale const* pointer_zero_; + private: + ElementScale const* pointer_scale_; + ElementScale const* pointer_zero_; }; //////////////////////////////////////////////////////////////////////////////// -// Specialization for Volta A x RowMajor B tensorOp, for 32x32x4 interleaved gemm +// Specialization for Volta A x RowMajor B tensorOp, for 32x32x4 interleaved +// gemm template < /// Underlying matrix multiply operator (concept: MmaTensorOp) typename MmaOperator_, @@ -464,86 +506,98 @@ template < typename Shape_, /// WeightOnlyQuantOp QuantOp_> -class MmaTensorOpDequantizer::value - && platform::is_same::value>::type> -{ +class MmaTensorOpDequantizer< + MmaOperator_, + Shape_, + Operand::kB, + half_t, + layout::RowMajor, + 32, + QuantOp_, + typename platform::enable_if< + platform::is_same::value && + platform::is_same::value>::type> { + public: + static_assert(platform::is_same>::value, + ""); -public: - static_assert(platform::is_same>::value, ""); + /// Mma Operator + using MmaOperator = MmaOperator_; - /// Mma Operator - using MmaOperator = MmaOperator_; + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; - // The architecture specific mma ooperator being used - using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; - // Mma Instruction Shape - using InstructionShape = typename ArchMmaOperator::Shape; + /// Type of the scales + using ElementScale = half_t; - /// Type of the scales - using ElementScale = half_t; + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = + Array; - /// Fragment to hold B data before Mma - using FragmentDequantizedOperand = Array; + /// Warp mma shape + using Shape = Shape_; - /// Warp mma shape - using Shape = Shape_; + // Fragment to hold scale data to apply to B before mma + // Each 32x32x4 matmul uses 8 elements from B. + static constexpr int ColsPerMmaTile = 32; + static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile; + using FragmentScale = Array; + using AccessType = Array; - // Fragment to hold scale data to apply to B before mma - // Each 32x32x4 matmul uses 8 elements from B. - static constexpr int ColsPerMmaTile = 32; - static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile; - using FragmentScale = Array; - using AccessType = Array; + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; - /// Layout of the scales in shared memory - using Layout = layout::RowMajor; + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + static_assert(QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, ""); - static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; - static_assert(QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, ""); + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, + int const warp_idx_n, + int const lane_idx) { + int const warp_offset = warp_idx_n * Shape::kN; + int const base_col = lane_idx & 0xF8; + int const thread_offset = warp_offset + base_col; + pointer_ = smem_scales.data() + thread_offset; + } - CUTLASS_DEVICE - MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx) - { - int const warp_offset = warp_idx_n * Shape::kN; - int const base_col = lane_idx & 0xF8; - int const thread_offset = warp_offset + base_col; - pointer_ = smem_scales.data() + thread_offset; + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) { + AccessType* scale_frag_ptr = reinterpret_cast(&scale_frag); + + CUTLASS_PRAGMA_UNROLL + for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) { + // We jump by 32 here since volta does <32x32x4> super mmas inside a warp. + scale_frag_ptr[tile_iter] = *reinterpret_cast( + pointer_ + ColsPerMmaTile * tile_iter); } + } - CUTLASS_DEVICE - void load(FragmentScale& scale_frag) - { - AccessType* scale_frag_ptr = reinterpret_cast(&scale_frag); + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, + FragmentScale const& scale_frag) { + static_assert( + FragmentScale::kElements == FragmentDequantizedOperand::kElements, ""); - CUTLASS_PRAGMA_UNROLL - for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) - { - // We jump by 32 here since volta does <32x32x4> super mmas inside a warp. - scale_frag_ptr[tile_iter] = *reinterpret_cast(pointer_ + ColsPerMmaTile * tile_iter); - } - } + multiplies mul_op; + operand_frag = mul_op(operand_frag, scale_frag); + } - CUTLASS_DEVICE - void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag) - { - static_assert(FragmentScale::kElements == FragmentDequantizedOperand::kElements, ""); - - multiplies mul_op; - operand_frag = mul_op(operand_frag, scale_frag); - } - -private: - ElementScale const* pointer_; + private: + ElementScale const* pointer_; }; //////////////////////////////////////////////////////////////////////////////// -// Specialization for Volta A x ColumnMajor B tensorOp, for 32x32x4 interleaved gemm +// Specialization for Volta A x ColumnMajor B tensorOp, for 32x32x4 interleaved +// gemm template < /// Underlying matrix multiply operator (concept: MmaTensorOp) typename MmaOperator_, @@ -551,98 +605,110 @@ template < typename Shape_, /// WeightOnlyQuantOp QuantOp_> -class MmaTensorOpDequantizer::value - && platform::is_same::value>::type> -{ +class MmaTensorOpDequantizer< + MmaOperator_, + Shape_, + Operand::kB, + half_t, + layout::RowMajor, + 32, + QuantOp_, + typename platform::enable_if< + platform::is_same::value && + platform::is_same::value>::type> { + public: + static_assert(platform::is_same>::value, + ""); -public: - static_assert(platform::is_same>::value, ""); + /// Mma Operator + using MmaOperator = MmaOperator_; - /// Mma Operator - using MmaOperator = MmaOperator_; + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; - // The architecture specific mma ooperator being used - using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; - // Mma Instruction Shape - using InstructionShape = typename ArchMmaOperator::Shape; + /// Type of the scales + using ElementScale = half_t; - /// Type of the scales - using ElementScale = half_t; + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = + Array; - /// Fragment to hold B data before Mma - using FragmentDequantizedOperand = Array; + /// Warp mma shape + using Shape = Shape_; - /// Warp mma shape - using Shape = Shape_; + // Fragment to hold scale data to apply to B before mma + // Each 32x32x4 matmul uses 8 elements from B. + static constexpr int ColsPerMmaTile = 32; + static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile; + using FragmentScale = Array; - // Fragment to hold scale data to apply to B before mma - // Each 32x32x4 matmul uses 8 elements from B. - static constexpr int ColsPerMmaTile = 32; - static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile; - using FragmentScale = Array; + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; - /// Layout of the scales in shared memory - using Layout = layout::RowMajor; + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + static_assert(QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, ""); - static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; - static_assert(QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, ""); + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, + int const warp_idx_n, + int const lane_idx) { + int const warp_offset = warp_idx_n * Shape::kN; + int const base_col = lane_idx & 0xF8 + lane_idx % 4; + int const thread_offset = warp_offset + base_col; + pointer_ = smem_scales.data() + thread_offset; + } - CUTLASS_DEVICE - MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx) - { - int const warp_offset = warp_idx_n * Shape::kN; - int const base_col = lane_idx & 0xF8 + lane_idx % 4; - int const thread_offset = warp_offset + base_col; - pointer_ = smem_scales.data() + thread_offset; + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) { + CUTLASS_PRAGMA_UNROLL + for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) { + // We jump by 32 here since volta does <32x32x4> super mmas inside a warp. + // For col major B, each thread will jump 4 cols to get its next value + // inside of the super mma. + CUTLASS_PRAGMA_UNROLL + for (int mma_iter = 0; mma_iter < 2; ++mma_iter) { + scale_frag[tile_iter * 2 + mma_iter] = + pointer_[ColsPerMmaTile * tile_iter + 4 * mma_iter]; + } } + } - CUTLASS_DEVICE - void load(FragmentScale& scale_frag) - { - CUTLASS_PRAGMA_UNROLL - for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) - { - // We jump by 32 here since volta does <32x32x4> super mmas inside a warp. - // For col major B, each thread will jump 4 cols to get its next value inside - // of the super mma. - CUTLASS_PRAGMA_UNROLL - for (int mma_iter = 0; mma_iter < 2; ++mma_iter) - { - scale_frag[tile_iter * 2 + mma_iter] = pointer_[ColsPerMmaTile * tile_iter + 4 * mma_iter]; - } - } + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, + FragmentScale const& scale_frag) { + using MmaOperandB = typename ArchMmaOperator::FragmentB; + static constexpr int total_n_mmas = 2 * TileNIterations; + static_assert(MmaOperandB::kElements * total_n_mmas == + FragmentDequantizedOperand::kElements, + ""); + + multiplies mul_op; + + MmaOperandB* operand_frag_ptr = + reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < total_n_mmas; ++mma_n_iter) { + operand_frag_ptr[mma_n_iter] = + mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); } + } - CUTLASS_DEVICE - void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag) - { - using MmaOperandB = typename ArchMmaOperator::FragmentB; - static constexpr int total_n_mmas = 2 * TileNIterations; - static_assert(MmaOperandB::kElements * total_n_mmas == FragmentDequantizedOperand::kElements, ""); - - multiplies mul_op; - - MmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < total_n_mmas; ++mma_n_iter) - { - operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); - } - } - -private: - ElementScale const* pointer_; + private: + ElementScale const* pointer_; }; //////////////////////////////////////////////////////////////////////////////// -} // namespace warp -} // namespace gemm -} // namespace cutlass +} // namespace warp +} // namespace gemm +} // namespace cutlass //////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h index 4678b58e48..ac17315008 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h @@ -64,139 +64,152 @@ struct DataTypeTraits; template <> struct DataTypeTraits { - using Type = __nv_bfloat16; - using DualType = __nv_bfloat162; + using Type = __nv_bfloat16; + using DualType = __nv_bfloat162; }; template <> struct DataTypeTraits { - using Type = __half; - using DualType = __half2; + using Type = __half; + using DualType = __half2; }; template struct LocalScaleConverter { - using FragmentSource = Array; - using FragmentResult = Array; + using FragmentSource = Array; + using FragmentResult = Array; - CUTLASS_DEVICE - static void Apply(FragmentSource const& local_scale_frag, - FragmentResult const& super_scale_frag, - FragmentResult& scale_frag, - int shift_bit) { - constexpr uint32_t kLocalScaleMask = 0xf; + CUTLASS_DEVICE + static void Apply(FragmentSource const& local_scale_frag, + FragmentResult const& super_scale_frag, + FragmentResult& scale_frag, + int shift_bit) { + constexpr uint32_t kLocalScaleMask = 0xf; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - int32_t shifted_value = (static_cast(local_scale_frag[i]) >> shift_bit) & kLocalScaleMask; - scale_frag[i] = static_cast(shifted_value) * super_scale_frag[i]; - } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + int32_t shifted_value = + (static_cast(local_scale_frag[i]) >> shift_bit) & + kLocalScaleMask; + scale_frag[i] = static_cast(shifted_value) * super_scale_frag[i]; } + } }; template -struct LocalScaleConverter::type> { - using FragmentSource = Array; - using FragmentResult = Array; +struct LocalScaleConverter::type> { + using FragmentSource = Array; + using FragmentResult = Array; - CUTLASS_DEVICE - static void Apply(FragmentSource const& local_scale_frag, - FragmentResult const& super_scale_frag, - FragmentResult& scale_frag, - int shift_bit) { - constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; - constexpr uint32_t MASK = 0x000f000f; - // 2^10 = 1024 - constexpr uint32_t I4s_TO_FP16s_MAGIC_NUM = 0x64006400; + CUTLASS_DEVICE + static void Apply(FragmentSource const& local_scale_frag, + FragmentResult const& super_scale_frag, + FragmentResult& scale_frag, + int shift_bit) { + constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + constexpr uint32_t MASK = 0x000f000f; + // 2^10 = 1024 + constexpr uint32_t I4s_TO_FP16s_MAGIC_NUM = 0x64006400; - // -2^10 = -1024 - constexpr uint32_t FP16_BIAS = 0xE400E400; - // 1.0 - constexpr uint32_t FP16_ONE = 0x3C003C00; + // -2^10 = -1024 + constexpr uint32_t FP16_BIAS = 0xE400E400; + // 1.0 + constexpr uint32_t FP16_ONE = 0x3C003C00; - __half2* scale_ptr = reinterpret_cast<__half2 *>(&scale_frag); - __half2 const* super_scale_ptr = reinterpret_cast<__half2 const*>(&super_scale_frag); + __half2* scale_ptr = reinterpret_cast<__half2*>(&scale_frag); + __half2 const* super_scale_ptr = + reinterpret_cast<__half2 const*>(&super_scale_frag); - uint32_t const* local_scale_ptr = reinterpret_cast(&local_scale_frag); + uint32_t const* local_scale_ptr = + reinterpret_cast(&local_scale_frag); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 4; ++i) { - int i4s = local_scale_ptr[i] >> shift_bit; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 4; ++i) { + int i4s = local_scale_ptr[i] >> shift_bit; - // unpack: 0, 1 - int32_t low = __byte_perm(i4s, i4s, 0xF1F0); - int32_t unpack0 = lop3(low, MASK, I4s_TO_FP16s_MAGIC_NUM); - // unpack: 2, 3 - int32_t high = __byte_perm(i4s, i4s, 0xF3F2); - int32_t unpack1 = lop3(high, MASK, I4s_TO_FP16s_MAGIC_NUM); + // unpack: 0, 1 + int32_t low = __byte_perm(i4s, i4s, 0xF1F0); + int32_t unpack0 = lop3(low, MASK, I4s_TO_FP16s_MAGIC_NUM); + // unpack: 2, 3 + int32_t high = __byte_perm(i4s, i4s, 0xF3F2); + int32_t unpack1 = lop3(high, MASK, I4s_TO_FP16s_MAGIC_NUM); - __half2 scale0 = __hfma2(*reinterpret_cast<__half2*>(&unpack0), - *reinterpret_cast(&FP16_ONE), - *reinterpret_cast(&FP16_BIAS)); - __half2 scale1 = __hfma2(*reinterpret_cast<__half2*>(&unpack1), - *reinterpret_cast(&FP16_ONE), - *reinterpret_cast(&FP16_BIAS)); + __half2 scale0 = __hfma2(*reinterpret_cast<__half2*>(&unpack0), + *reinterpret_cast(&FP16_ONE), + *reinterpret_cast(&FP16_BIAS)); + __half2 scale1 = __hfma2(*reinterpret_cast<__half2*>(&unpack1), + *reinterpret_cast(&FP16_ONE), + *reinterpret_cast(&FP16_BIAS)); - scale_ptr[2 * i] = __hmul2(scale0, super_scale_ptr[2 * i]); - scale_ptr[2 * i + 1] = __hmul2(scale1, super_scale_ptr[2 * i + 1]); - } + scale_ptr[2 * i] = __hmul2(scale0, super_scale_ptr[2 * i]); + scale_ptr[2 * i + 1] = __hmul2(scale1, super_scale_ptr[2 * i + 1]); } + } }; template -struct LocalScaleConverter::type> { - using FragmentSource = Array; - using FragmentResult = Array; +struct LocalScaleConverter::type> { + using FragmentSource = Array; + using FragmentResult = Array; - CUTLASS_DEVICE - static void Apply(FragmentSource const& local_scale_frag, - FragmentResult const& super_scale_frag, - FragmentResult& scale_frag, - int shift_bit) { + CUTLASS_DEVICE + static void Apply(FragmentSource const& local_scale_frag, + FragmentResult const& super_scale_frag, + FragmentResult& scale_frag, + int shift_bit) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) - constexpr uint32_t immLut = (0xF0 & 0xCC) | 0xAA; - constexpr uint32_t MASK = 0x000F000F; - constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; + constexpr uint32_t immLut = (0xF0 & 0xCC) | 0xAA; + constexpr uint32_t MASK = 0x000F000F; + constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; - constexpr uint32_t BF16_BIAS = 0xC300C300; - constexpr uint32_t BF16_ONE = 0x3F803F80; + constexpr uint32_t BF16_BIAS = 0xC300C300; + constexpr uint32_t BF16_ONE = 0x3F803F80; - __nv_bfloat162* scale_ptr = reinterpret_cast<__nv_bfloat162 *>(&scale_frag); - __nv_bfloat162 const* super_scale_ptr = reinterpret_cast<__nv_bfloat162 const*>(&super_scale_frag); + __nv_bfloat162* scale_ptr = reinterpret_cast<__nv_bfloat162*>(&scale_frag); + __nv_bfloat162 const* super_scale_ptr = + reinterpret_cast<__nv_bfloat162 const*>(&super_scale_frag); - uint32_t const* local_scale_ptr = reinterpret_cast(&local_scale_frag); + uint32_t const* local_scale_ptr = + reinterpret_cast(&local_scale_frag); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 4; ++i) { - int i4s = local_scale_ptr[i] >> shift_bit; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 4; ++i) { + int i4s = local_scale_ptr[i] >> shift_bit; - // unpack: 0, 1 - int32_t low = __byte_perm(i4s, i4s, 0xF1F0); - int32_t unpack0 = lop3(low, MASK, I4s_TO_BF16s_MAGIC_NUM); - // unpack: 2, 3 - int32_t high = __byte_perm(i4s, i4s, 0xF3F2); - int32_t unpack1 = lop3(high, MASK, I4s_TO_BF16s_MAGIC_NUM); + // unpack: 0, 1 + int32_t low = __byte_perm(i4s, i4s, 0xF1F0); + int32_t unpack0 = lop3(low, MASK, I4s_TO_BF16s_MAGIC_NUM); + // unpack: 2, 3 + int32_t high = __byte_perm(i4s, i4s, 0xF3F2); + int32_t unpack1 = lop3(high, MASK, I4s_TO_BF16s_MAGIC_NUM); - nv_bfloat162 scale0 = __hfma2(*reinterpret_cast(&unpack0), - *reinterpret_cast(&BF16_ONE), - *reinterpret_cast(&BF16_BIAS)); - nv_bfloat162 scale1 = __hfma2(*reinterpret_cast(&unpack1), - *reinterpret_cast(&BF16_ONE), - *reinterpret_cast(&BF16_BIAS)); + nv_bfloat162 scale0 = + __hfma2(*reinterpret_cast(&unpack0), + *reinterpret_cast(&BF16_ONE), + *reinterpret_cast(&BF16_BIAS)); + nv_bfloat162 scale1 = + __hfma2(*reinterpret_cast(&unpack1), + *reinterpret_cast(&BF16_ONE), + *reinterpret_cast(&BF16_BIAS)); - scale_ptr[2 * i] = __hmul2(scale0, super_scale_ptr[2 * i]); - scale_ptr[2 * i + 1] = __hmul2(scale1, super_scale_ptr[2 * i + 1]); - } -#else - // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should - // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid - // numerous conversion instructions in GEMM main loop. - arch::device_breakpoint(); -#endif + scale_ptr[2 * i] = __hmul2(scale0, super_scale_ptr[2 * i]); + scale_ptr[2 * i + 1] = __hmul2(scale1, super_scale_ptr[2 * i + 1]); } +#else + // Slow path not implemented here on purpose. If we need to do HMMA on older + // arch, scale conversion should happen before scales are stored to shared + // memory and we should use the fp16 dequantizer. This will avoid numerous + // conversion instructions in GEMM main loop. + arch::device_breakpoint(); +#endif + } }; -} // namespace detail +} // namespace detail //////////////////////////////////////////////////////////////////////////////// @@ -216,7 +229,7 @@ template < /// typename Enable = void> class MmaTensorOpWin2xDequantizer { - //static_assert(false, "Not Supported!"); + // static_assert(false, "Not Supported!"); }; //////////////////////////////////////////////////////////////////////////////// @@ -230,207 +243,237 @@ template < typename ElementOperand_, /// Group size for quantization int GroupSize_> -class MmaTensorOpWin2xDequantizer< - MmaOperator_, - Shape_, - Operand::kB, - ElementOperand_, - layout::RowMajor, - GroupSize_> - //typename platform::enable_if= 80 - // && platform::is_same::value>::type> +class MmaTensorOpWin2xDequantizer +// typename platform::enable_if= +// 80 +// && platform::is_same::value>::type> { -public: - static_assert(platform::is_same::value || platform::is_same::value, - "T must be fp16 or bf16"); + public: + static_assert(platform::is_same::value || + platform::is_same::value, + "T must be fp16 or bf16"); - /// Mma Operator - using MmaOperator = MmaOperator_; + /// Mma Operator + using MmaOperator = MmaOperator_; - // The architecture specific mma ooperator being used - using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; - // Mma Instruction Shape - using InstructionShape = typename ArchMmaOperator::Shape; + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; - /// Warp mma shape - using Shape = Shape_; + /// Warp mma shape + using Shape = Shape_; - /// Type of mma operand - using ElementOperand = ElementOperand_; + /// Type of mma operand + using ElementOperand = ElementOperand_; - /// Layout of the scales in shared memory - using Layout = layout::RowMajor; + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; - /// Group size for quantization - static constexpr int kGroupSize = GroupSize_; + /// Group size for quantization + static constexpr int kGroupSize = GroupSize_; - /// Type of input - using ElementB = typename MmaOperator::FragmentB::Element; - static_assert(platform::is_same::value, "ElementB must be uint2b_t"); + /// Type of input + using ElementB = typename MmaOperator::FragmentB::Element; + static_assert(platform::is_same::value, + "ElementB must be uint2b_t"); - /// Type of the scales - using ElementLocalScale = uint4b_t; - using ElementSuperScale = ElementOperand; - using ElementCodeScaleZp = float; + /// Type of the scales + using ElementLocalScale = uint4b_t; + using ElementSuperScale = ElementOperand; + using ElementCodeScaleZp = float; - // Fragment to hold scale data to apply to B before mma - // We need 1 fp16 per matrix iteration in the N dimension - static constexpr int kWarpIterationsAlongN = MmaOperator::MmaIterations::kColumn; + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kWarpIterationsAlongN = + MmaOperator::MmaIterations::kColumn; - // use uint8_t to save 2 4-bits local scales - using FragmentLocalScale = Array; - using FragmentSuperScale = Array; - using FragmentCodeScaleZp = Array; + // use uint8_t to save 2 4-bits local scales + using FragmentLocalScale = Array; + using FragmentSuperScale = Array; + using FragmentCodeScaleZp = Array; - /// Fragment to hold B data before Mma - using FragmentInput = Array; + /// Fragment to hold B data before Mma + using FragmentInput = Array; - // This is the ratio of the load instruction vs the compute instruction. - static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = + MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; - static constexpr int kNumPacks = sizeof_bits::value / sizeof_bits::value; - static constexpr int kUnpackFactor = MmaOperator::FragmentB::kElements / (kWarpIterationsAlongN * kNumPacks); - static constexpr int kUnpackInterval = kExpansionFactor / kUnpackFactor; + static constexpr int kNumPacks = + sizeof_bits::value / sizeof_bits::value; + static constexpr int kUnpackFactor = + MmaOperator::FragmentB::kElements / (kWarpIterationsAlongN * kNumPacks); + static constexpr int kUnpackInterval = kExpansionFactor / kUnpackFactor; - /// Unpack 4 uint2b_t values compreseed in a uint8_t to floating points. - using Uint2Converter = FastInterleavedAndBiasedNumericArrayConverter< - ElementOperand, ElementB, MmaOperator::FragmentB::kElements / kUnpackFactor>; - using FragmentInputUnpack = typename Uint2Converter::result_type; + /// Unpack 4 uint2b_t values compreseed in a uint8_t to floating points. + using Uint2Converter = FastInterleavedAndBiasedNumericArrayConverter< + ElementOperand, + ElementB, + MmaOperator::FragmentB::kElements / kUnpackFactor>; + using FragmentInputUnpack = typename Uint2Converter::result_type; - /// Fragment to hold internal scales before Mma - using FragmentScale = Array; + /// Fragment to hold internal scales before Mma + using FragmentScale = Array; - /// Fragment of dequantized B - using FragmentOutput = Array; + /// Fragment of dequantized B + using FragmentOutput = + Array; - /// TensorRef type for loading element from a tensor - using SuperTensorRef = cutlass::TensorRef; - using LocalTensorRef = cutlass::TensorRef; - using CodeTensorRef = cutlass::TensorRef; + /// TensorRef type for loading element from a tensor + using SuperTensorRef = cutlass::TensorRef; + using LocalTensorRef = cutlass::TensorRef; + using CodeTensorRef = cutlass::TensorRef; -private: - // - // Data members - // + private: + // + // Data members + // - uint8_t* pointer_local_scale_; - ElementCodeScaleZp* pointer_code_scale_; - ElementCodeScaleZp* pointer_code_zp_; - ElementSuperScale* pointer_super_scale_; + uint8_t* pointer_local_scale_; + ElementCodeScaleZp* pointer_code_scale_; + ElementCodeScaleZp* pointer_code_zp_; + ElementSuperScale* pointer_super_scale_; - //FragmentInputUnpack unpacked_frag_; - FragmentScale scale_frag_; + // FragmentInputUnpack unpacked_frag_; + FragmentScale scale_frag_; -public: - CUTLASS_DEVICE - MmaTensorOpWin2xDequantizer(SuperTensorRef smem_super_scale, - LocalTensorRef smem_local_scale, - CodeTensorRef smem_code_scale, - CodeTensorRef smem_code_zp, - int warp_idx_n, - int lane_idx) { - int warp_offset = warp_idx_n * Shape::kN; - int quad = lane_idx / 4; - int thread_offset = warp_offset + quad; - pointer_super_scale_ = smem_super_scale.data() + thread_offset; - pointer_code_scale_ = smem_code_scale.data() + thread_offset; - pointer_code_zp_ = smem_code_zp.data() + thread_offset; - pointer_local_scale_ = reinterpret_cast(smem_local_scale.data()) + thread_offset; + public: + CUTLASS_DEVICE + MmaTensorOpWin2xDequantizer(SuperTensorRef smem_super_scale, + LocalTensorRef smem_local_scale, + CodeTensorRef smem_code_scale, + CodeTensorRef smem_code_zp, + int warp_idx_n, + int lane_idx) { + int warp_offset = warp_idx_n * Shape::kN; + int quad = lane_idx / 4; + int thread_offset = warp_offset + quad; + pointer_super_scale_ = smem_super_scale.data() + thread_offset; + pointer_code_scale_ = smem_code_scale.data() + thread_offset; + pointer_code_zp_ = smem_code_zp.data() + thread_offset; + pointer_local_scale_ = + reinterpret_cast(smem_local_scale.data()) + thread_offset; + } + + /// Channel-wise params, need to load just once + CUTLASS_DEVICE + void load(FragmentCodeScaleZp& code_scale_frag, + FragmentCodeScaleZp& code_zp_frag, + FragmentSuperScale& super_scale_frag) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; ++mma_n_iter) { + super_scale_frag[mma_n_iter] = + pointer_super_scale_[mma_n_iter * + InstructionShape::kN]; // bank conflict + code_scale_frag[mma_n_iter] = + pointer_code_scale_[mma_n_iter * InstructionShape::kN]; + code_zp_frag[mma_n_iter] = + pointer_code_zp_[mma_n_iter * InstructionShape::kN]; + } + } + + /// Group-wise params, need to load multiple times + CUTLASS_DEVICE + void load(FragmentLocalScale& local_scale_frag) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; ++mma_n_iter) { + local_scale_frag[mma_n_iter] = + pointer_local_scale_[mma_n_iter * + InstructionShape::kN]; // bank conflict + } + } + + CUTLASS_DEVICE + void dequantize(const FragmentLocalScale& local_scale_frag, + const FragmentCodeScaleZp& code_scale_frag, + const FragmentCodeScaleZp& code_zp_frag, + const FragmentSuperScale& super_scale_frag, + const FragmentInput& input_frag, + FragmentOutput& output_frag, + int tb_offset_k, + int warp_k_compute_offset) { + if constexpr (kUnpackInterval != 1) { + // unsupport now + arch::device_breakpoint(); } - /// Channel-wise params, need to load just once - CUTLASS_DEVICE - void load(FragmentCodeScaleZp& code_scale_frag, - FragmentCodeScaleZp& code_zp_frag, - FragmentSuperScale& super_scale_frag) { - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; ++mma_n_iter) { - super_scale_frag[mma_n_iter] = pointer_super_scale_[mma_n_iter * InstructionShape::kN]; // bank conflict - code_scale_frag[mma_n_iter] = pointer_code_scale_[mma_n_iter * InstructionShape::kN]; - code_zp_frag[mma_n_iter] = pointer_code_zp_[mma_n_iter * InstructionShape::kN]; - } + typename Uint2Converter::source_type source_frag; + + int in_offset = warp_k_compute_offset * kUnpackInterval; + + uint8_t const* ptr_input = reinterpret_cast(&input_frag); + uint8_t* ptr_source = reinterpret_cast(&source_frag); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; ++mma_n_iter) { + ptr_source[mma_n_iter] = + ptr_input[mma_n_iter * kUnpackFactor + in_offset]; + } + FragmentInputUnpack unpacked_frag = + Uint2Converter::convert(source_frag, code_scale_frag, code_zp_frag); + + // dequantize local_scale + if (warp_k_compute_offset == 0) { + using LocalScaleConverter = + detail::LocalScaleConverter; + + // special for TileRows = 64 + int local_scale_shift = (((tb_offset_k / kGroupSize) + 1) & 1) * 4; + LocalScaleConverter::Apply( + local_scale_frag, super_scale_frag, scale_frag_, local_scale_shift); } - /// Group-wise params, need to load multiple times - CUTLASS_DEVICE - void load(FragmentLocalScale& local_scale_frag) { - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; ++mma_n_iter) { - local_scale_frag[mma_n_iter] = pointer_local_scale_[mma_n_iter * InstructionShape::kN]; // bank conflict - } + // unscale + // After applying LOP3 optimizations for performance, the B operand requires + // data rearrangement. reorder: [0, 4, 1, 5, 2, 6, 3, 7, 8, 12, 9, 13, 10, + // 14, 11, 15] + const int kWarpIterationsAlongK = + FragmentOutput::kElements / kWarpIterationsAlongN; + + using Type = typename detail::DataTypeTraits::Type; + using DualType = typename detail::DataTypeTraits::DualType; + + Type* output_ptr = reinterpret_cast(&output_frag); + DualType const* unpacked_ptr = + reinterpret_cast(&unpacked_frag); + DualType const* scale_ptr = reinterpret_cast(&scale_frag_); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; + mma_n_iter += 2) { + int mapped_idx_base = (mma_n_iter / 2) * kWarpIterationsAlongK; + + DualType scalex2 = scale_ptr[mma_n_iter / 2]; + + CUTLASS_PRAGMA_UNROLL + for (int mma_k_iter = 0; mma_k_iter < kWarpIterationsAlongK; + ++mma_k_iter) { + DualType unpacked_valuex2 = unpacked_ptr[mapped_idx_base + mma_k_iter]; + DualType scaled_value = __hmul2(unpacked_valuex2, scalex2); + output_ptr[mma_n_iter * kWarpIterationsAlongK + mma_k_iter] = + scaled_value.x; + output_ptr[(mma_n_iter + 1) * kWarpIterationsAlongK + mma_k_iter] = + scaled_value.y; + } } + } - CUTLASS_DEVICE - void dequantize(const FragmentLocalScale& local_scale_frag, - const FragmentCodeScaleZp& code_scale_frag, - const FragmentCodeScaleZp& code_zp_frag, - const FragmentSuperScale& super_scale_frag, - const FragmentInput& input_frag, - FragmentOutput& output_frag, - int tb_offset_k, - int warp_k_compute_offset) { - if constexpr (kUnpackInterval != 1) { - // unsupport now - arch::device_breakpoint(); - } - - typename Uint2Converter::source_type source_frag; - - int in_offset = warp_k_compute_offset * kUnpackInterval; - - uint8_t const* ptr_input = reinterpret_cast(&input_frag); - uint8_t* ptr_source = reinterpret_cast(&source_frag); - - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; ++mma_n_iter) { - ptr_source[mma_n_iter] = ptr_input[mma_n_iter * kUnpackFactor + in_offset]; - } - FragmentInputUnpack unpacked_frag = Uint2Converter::convert(source_frag, code_scale_frag, code_zp_frag); - - // dequantize local_scale - if (warp_k_compute_offset == 0) { - using LocalScaleConverter = detail::LocalScaleConverter; - - // special for TileRows = 64 - int local_scale_shift = (((tb_offset_k / kGroupSize) + 1) & 1) * 4; - LocalScaleConverter::Apply(local_scale_frag, super_scale_frag, scale_frag_, local_scale_shift); - } - - // unscale - // After applying LOP3 optimizations for performance, the B operand requires data rearrangement. - // reorder: [0, 4, 1, 5, 2, 6, 3, 7, 8, 12, 9, 13, 10, 14, 11, 15] - const int kWarpIterationsAlongK = FragmentOutput::kElements / kWarpIterationsAlongN; - - using Type = typename detail::DataTypeTraits::Type; - using DualType = typename detail::DataTypeTraits::DualType; - - Type* output_ptr = reinterpret_cast(&output_frag); - DualType const* unpacked_ptr = reinterpret_cast(&unpacked_frag); - DualType const* scale_ptr = reinterpret_cast(&scale_frag_); - - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; mma_n_iter += 2) { - int mapped_idx_base = (mma_n_iter / 2) * kWarpIterationsAlongK; - - DualType scalex2 = scale_ptr[mma_n_iter / 2]; - - CUTLASS_PRAGMA_UNROLL - for (int mma_k_iter = 0; mma_k_iter < kWarpIterationsAlongK; ++mma_k_iter) { - DualType unpacked_valuex2 = unpacked_ptr[mapped_idx_base + mma_k_iter]; - DualType scaled_value = __hmul2(unpacked_valuex2, scalex2); - output_ptr[mma_n_iter * kWarpIterationsAlongK + mma_k_iter] = scaled_value.x; - output_ptr[(mma_n_iter + 1) * kWarpIterationsAlongK + mma_k_iter] = scaled_value.y; - } - } - } - - /// Add an offset to pointer in units of elements. - /// Only group-wise params needs. - CUTLASS_DEVICE - void add_pointer_offset(int64_t const& offset) { - pointer_local_scale_ += offset; - } + /// Add an offset to pointer in units of elements. + /// Only group-wise params needs. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) { + pointer_local_scale_ += offset; + } }; //////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm_configs.h b/custom_ops/gpu_ops/cutlass_extensions/gemm_configs.h index 81e58f20ef..02becf2388 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm_configs.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm_configs.h @@ -21,301 +21,322 @@ #include #include -namespace cutlass_extensions -{ -// Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape +namespace cutlass_extensions { +// Note: The shapes are in the format MxNxK. The K shape of the runtime config +// MUST match the K shape // in the kernel layout details when doing weight only quantization. -enum class CutlassTileConfig -{ - // Signals that we should run heuristics do choose a config - Undefined, +enum class CutlassTileConfig { + // Signals that we should run heuristics do choose a config + Undefined, - // Signals that we should run heuristics do choose a config - ChooseWithHeuristic, + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, - // SiMT config - CtaShape128x128x8_WarpShape64x64x8, + // SiMT config + CtaShape128x128x8_WarpShape64x64x8, - // TensorCore configs CTA_N = 128, CTA_K = 64 - // Warp configs for M=16 - CtaShape16x128x64_WarpShape16x32x64, - // Warp configs for M=32 - CtaShape32x128x64_WarpShape32x32x64, + // TensorCore configs CTA_N = 128, CTA_K = 64 + // Warp configs for M=16 + CtaShape16x128x64_WarpShape16x32x64, + // Warp configs for M=32 + CtaShape32x128x64_WarpShape32x32x64, - // Warp configs for M=64 - CtaShape64x128x64_WarpShape32x64x64, - CtaShape64x64x128_WarpShape32x64x64, - CtaShape64x128x64_WarpShape64x32x64, + // Warp configs for M=64 + CtaShape64x128x64_WarpShape32x64x64, + CtaShape64x64x128_WarpShape32x64x64, + CtaShape64x128x64_WarpShape64x32x64, - // Warp configs for M=128 - CtaShape128x64x64_WarpShape64x32x64, - CtaShape128x128x64_WarpShape64x32x64, - CtaShape128x128x64_WarpShape64x64x64, - CtaShape128x128x64_WarpShape128x32x64, - CtaShape128x256x64_WarpShape64x64x64, + // Warp configs for M=128 + CtaShape128x64x64_WarpShape64x32x64, + CtaShape128x128x64_WarpShape64x32x64, + CtaShape128x128x64_WarpShape64x64x64, + CtaShape128x128x64_WarpShape128x32x64, + CtaShape128x256x64_WarpShape64x64x64, - // Warp configs for M=256 - CtaShape256x128x64_WarpShape64x64x64, + // Warp configs for M=256 + CtaShape256x128x64_WarpShape64x64x64, - // TensorCore config CTA_N = 64, CTA_K = 128 - CtaShape128x64x128_WarpShape64x32x128, + // TensorCore config CTA_N = 64, CTA_K = 128 + CtaShape128x64x128_WarpShape64x32x128, - // TensorCore config CTA_N = 256, CTA_K = 64 - CtaShape16x256x64_WarpShape16x64x64, + // TensorCore config CTA_N = 256, CTA_K = 64 + CtaShape16x256x64_WarpShape16x64x64, - // TensorCore config CTA_N = 256, CTA_K = 128 - CtaShape16x256x128_WarpShape16x64x128 + // TensorCore config CTA_N = 256, CTA_K = 128 + CtaShape16x256x128_WarpShape16x64x128 }; -enum class SplitKStyle -{ - NO_SPLIT_K, - SPLIT_K_SERIAL, - STREAM_K, // Sm80+ - // SPLIT_K_PARALLEL // Not supported yet +enum class SplitKStyle { + NO_SPLIT_K, + SPLIT_K_SERIAL, + STREAM_K, // Sm80+ + // SPLIT_K_PARALLEL // Not supported yet }; // New enum for SM100 (Blackwell) Tile Configs // Placeholder values - actual optimal values need research -enum class CutlassTileConfigSM100 -{ - // Signals that we should run heuristics do choose a config - Undefined, +enum class CutlassTileConfigSM100 { + // Signals that we should run heuristics do choose a config + Undefined, - // Signals that we should run heuristics do choose a config - ChooseWithHeuristic, + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, - // Actual SM100 tile configs based on user input (K-tile is 128B) - CtaShape64x64x128B, - CtaShape64x128x128B, - CtaShape64x256x128B, - CtaShape128x64x128B, - CtaShape128x128x128B, - CtaShape128x256x128B, - CtaShape256x64x128B, - CtaShape256x128x128B, - CtaShape256x256x128B - // Note: The user-provided list for get_candidate_tiles_sm100 also includes - // CtaShape128x64x128B and CtaShape256x64x128B for specific FP4 grouped gemm cases. - // These are already covered by the list above if general suffices. - // If they need distinct enum values, they should be added. - // For now, keeping the enum concise with unique shapes mentioned for general use. + // Actual SM100 tile configs based on user input (K-tile is 128B) + CtaShape64x64x128B, + CtaShape64x128x128B, + CtaShape64x256x128B, + CtaShape128x64x128B, + CtaShape128x128x128B, + CtaShape128x256x128B, + CtaShape256x64x128B, + CtaShape256x128x128B, + CtaShape256x256x128B + // Note: The user-provided list for get_candidate_tiles_sm100 also includes + // CtaShape128x64x128B and CtaShape256x64x128B for specific FP4 grouped gemm + // cases. These are already covered by the list above if general suffices. If + // they need distinct enum values, they should be added. For now, keeping the + // enum concise with unique shapes mentioned for general use. }; +enum class CutlassTileConfigSM90 { + // Signals that we should run heuristics do choose a config + Undefined, -enum class CutlassTileConfigSM90 -{ - // Signals that we should run heuristics do choose a config - Undefined, + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, - // Signals that we should run heuristics do choose a config - ChooseWithHeuristic, + // CTA configs for M=64 + CtaShape64x16x128B, + CtaShape64x32x128B, + CtaShape64x64x128B, + CtaShape64x128x128B, + CtaShape64x256x128B, - // CTA configs for M=64 - CtaShape64x16x128B, - CtaShape64x32x128B, - CtaShape64x64x128B, - CtaShape64x128x128B, - CtaShape64x256x128B, + // CTA configs for M=128 + CtaShape128x16x128B, + CtaShape128x32x128B, + CtaShape128x64x128B, + CtaShape128x128x128B, + CtaShape128x256x128B, - // CTA configs for M=128 - CtaShape128x16x128B, - CtaShape128x32x128B, - CtaShape128x64x128B, - CtaShape128x128x128B, - CtaShape128x256x128B, - - // CTA configs for M=128 - CtaShape256x128x128B, + // CTA configs for M=128 + CtaShape256x128x128B, }; -enum class MainloopScheduleType -{ - AUTO // Automatically selects between pingpong and cooperative schedules on Hopper. On older architectures, this - // defaults to the "legacy" main loop schedule. +enum class MainloopScheduleType { + AUTO // Automatically selects between pingpong and cooperative schedules on + // Hopper. On older architectures, this defaults to the "legacy" main + // loop schedule. }; -enum class EpilogueScheduleType -{ - AUTO // Automatically chooses an epilogue schedule compatible with the selected main loop schedule for Hopper. For - // architectures older than hopper, the epilogue is always performed by the same thread block as the main loop. +enum class EpilogueScheduleType { + AUTO // Automatically chooses an epilogue schedule compatible with the + // selected main loop schedule for Hopper. For architectures older than + // hopper, the epilogue is always performed by the same thread block as + // the main loop. }; -enum class ClusterShape -{ - ClusterShape_1x1x1, - ClusterShape_2x1x1, - ClusterShape_1x2x1, - ClusterShape_2x2x1, - ClusterShape_1x8x1, - ClusterShape_8x1x1 +enum class ClusterShape { + ClusterShape_1x1x1, + ClusterShape_2x1x1, + ClusterShape_1x2x1, + ClusterShape_2x2x1, + ClusterShape_1x8x1, + ClusterShape_8x1x1 }; -struct CutlassGemmConfig -{ - enum CandidateConfigTypeParam : int - { - NONE = 0, - WEIGHT_ONLY = 1u << 0, - SIMT_ONLY = 1u << 1, - INT8_ONLY = 1u << 2, - HOPPER = 1u << 3, // SM90 - GROUPED_GEMM = 1u << 4, - FP8_ONLY = 1u << 5, - BLACKWELL = 1u << 6, // SM100 - FP4_ONLY = 1u << 7, // For Blackwell FP4/MXFP4 paths - }; +struct CutlassGemmConfig { + enum CandidateConfigTypeParam : int { + NONE = 0, + WEIGHT_ONLY = 1u << 0, + SIMT_ONLY = 1u << 1, + INT8_ONLY = 1u << 2, + HOPPER = 1u << 3, // SM90 + GROUPED_GEMM = 1u << 4, + FP8_ONLY = 1u << 5, + BLACKWELL = 1u << 6, // SM100 + FP4_ONLY = 1u << 7, // For Blackwell FP4/MXFP4 paths + }; - CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic; - SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K; - int split_k_factor = -1; - int stages = -1; + CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic; + SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K; + int split_k_factor = -1; + int stages = -1; - // config options for sm90 - CutlassTileConfigSM90 tile_config_sm90 = CutlassTileConfigSM90::ChooseWithHeuristic; - MainloopScheduleType mainloop_schedule = MainloopScheduleType::AUTO; - EpilogueScheduleType epilogue_schedule = EpilogueScheduleType::AUTO; - ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1; - bool is_sm90 = false; + // config options for sm90 + CutlassTileConfigSM90 tile_config_sm90 = + CutlassTileConfigSM90::ChooseWithHeuristic; + MainloopScheduleType mainloop_schedule = MainloopScheduleType::AUTO; + EpilogueScheduleType epilogue_schedule = EpilogueScheduleType::AUTO; + ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1; + bool is_sm90 = false; - // config options for sm100 (Blackwell) - // Assuming SM100 might use similar schedule/cluster types as SM90 for now. - // These might need to become SM100-specific if Blackwell introduces new concepts. - CutlassTileConfigSM100 tile_config_sm100 = CutlassTileConfigSM100::ChooseWithHeuristic; - // MainloopScheduleType mainloop_schedule_sm100 = MainloopScheduleType::AUTO; // Example if SM100 has different types - // EpilogueScheduleType epilogue_schedule_sm100 = EpilogueScheduleType::AUTO; // Example - // ClusterShape cluster_shape_sm100 = ClusterShape::ClusterShape_1x1x1; // Example - bool is_sm100 = false; + // config options for sm100 (Blackwell) + // Assuming SM100 might use similar schedule/cluster types as SM90 for now. + // These might need to become SM100-specific if Blackwell introduces new + // concepts. + CutlassTileConfigSM100 tile_config_sm100 = + CutlassTileConfigSM100::ChooseWithHeuristic; + // MainloopScheduleType mainloop_schedule_sm100 = MainloopScheduleType::AUTO; + // // Example if SM100 has different types EpilogueScheduleType + // epilogue_schedule_sm100 = EpilogueScheduleType::AUTO; // Example + // ClusterShape cluster_shape_sm100 = ClusterShape::ClusterShape_1x1x1; // + // Example + bool is_sm100 = false; + CutlassGemmConfig() : is_sm90(false), is_sm100(false) {} - CutlassGemmConfig() : is_sm90(false), is_sm100(false) {} + CutlassGemmConfig(CutlassTileConfig tile_config, + SplitKStyle split_k_style, + int split_k_factor, + int stages) + : tile_config(tile_config), + split_k_style(split_k_style), + split_k_factor(split_k_factor), + stages(stages), + is_sm90(false), + is_sm100(false) {} - CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages) - : tile_config(tile_config) - , split_k_style(split_k_style) - , split_k_factor(split_k_factor) - , stages(stages) - , is_sm90(false) - , is_sm100(false) - { + // Constructor for SM90 + CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90_in, + MainloopScheduleType mainloop_schedule_in, + EpilogueScheduleType epilogue_schedule_in, + ClusterShape cluster_shape_in) + : tile_config_sm90(tile_config_sm90_in), + mainloop_schedule(mainloop_schedule_in), + epilogue_schedule(epilogue_schedule_in), + cluster_shape(cluster_shape_in), + is_sm90(true), + is_sm100(false) {} + + // Constructor for SM100 (Blackwell) + // Using existing MainloopScheduleType, EpilogueScheduleType, ClusterShape for + // now. These might need to be new SM100-specific types if Blackwell's TMA + // differs significantly. + CutlassGemmConfig(CutlassTileConfigSM100 tile_config_sm100_in, + MainloopScheduleType mainloop_schedule_in, + EpilogueScheduleType epilogue_schedule_in, + ClusterShape cluster_shape_in) + : tile_config_sm100(tile_config_sm100_in), + mainloop_schedule( + mainloop_schedule_in) // Potentially use mainloop_schedule_sm100 if + // types diverge + , + epilogue_schedule( + epilogue_schedule_in) // Potentially use epilogue_schedule_sm100 + , + cluster_shape(cluster_shape_in) // Potentially use cluster_shape_sm100 + , + is_sm90(false) // Explicitly false + , + is_sm100(true) {} + + std::string toString() const { + std::stringstream tactic; + tactic << "Cutlass GEMM Tactic"; + if (is_sm100 && + tile_config_sm100 != + cutlass_extensions::CutlassTileConfigSM100::ChooseWithHeuristic) { + assert(is_sm100 && !is_sm90 && "Invalid cutlass GEMM config: SM100"); + tactic + << "\n\tstyle=TMA_SM100" // Indicate SM100 specific TMA if applicable + << "\n\ttile shape ID: " << (int)tile_config_sm100 + << "\n\tcluster shape ID: " << (int)cluster_shape + << "\n\tmainloop sched: " << (int)mainloop_schedule + << "\n\tepi sched: " << (int)epilogue_schedule; + } else if (is_sm90 && tile_config_sm90 != + cutlass_extensions::CutlassTileConfigSM90:: + ChooseWithHeuristic) { + assert(is_sm90 && !is_sm100 && "Invalid cutlass GEMM config: SM90"); + tactic << "\n\tstyle=TMA_SM90" + << "\n\ttile shape ID: " << (int)tile_config_sm90 + << "\n\tcluster shape ID: " << (int)cluster_shape + << "\n\tmainloop sched: " << (int)mainloop_schedule + << "\n\tepi sched: " << (int)epilogue_schedule; + } else if (tile_config != + cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic) { + assert(!is_sm90 && !is_sm100 && + "Invalid cutlass GEMM config: Compatible"); + tactic << "\n\tstyle=compatible" + << "\n\ttile shape ID: " << (int)tile_config + << "\n\tstages: " << (int)stages + << "\n\tsplit_k_style: " << (int)split_k_style + << "\n\tsplit k: " << (int)split_k_factor; + } else { + tactic << "\n\tundefined"; } + tactic << "\n"; + return tactic.str(); + } - // Constructor for SM90 - CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90_in, MainloopScheduleType mainloop_schedule_in, - EpilogueScheduleType epilogue_schedule_in, ClusterShape cluster_shape_in) - : tile_config_sm90(tile_config_sm90_in) - , mainloop_schedule(mainloop_schedule_in) - , epilogue_schedule(epilogue_schedule_in) - , cluster_shape(cluster_shape_in) - , is_sm90(true) - , is_sm100(false) - { - } + void fromString(const std::string& str) { + std::istringstream stream(str); + std::string line; - // Constructor for SM100 (Blackwell) - // Using existing MainloopScheduleType, EpilogueScheduleType, ClusterShape for now. - // These might need to be new SM100-specific types if Blackwell's TMA differs significantly. - CutlassGemmConfig(CutlassTileConfigSM100 tile_config_sm100_in, MainloopScheduleType mainloop_schedule_in, - EpilogueScheduleType epilogue_schedule_in, ClusterShape cluster_shape_in) - : tile_config_sm100(tile_config_sm100_in) - , mainloop_schedule(mainloop_schedule_in) // Potentially use mainloop_schedule_sm100 if types diverge - , epilogue_schedule(epilogue_schedule_in) // Potentially use epilogue_schedule_sm100 - , cluster_shape(cluster_shape_in) // Potentially use cluster_shape_sm100 - , is_sm90(false) // Explicitly false - , is_sm100(true) - { - } + is_sm90 = false; // Reset flags + is_sm100 = false; - - std::string toString() const - { - std::stringstream tactic; - tactic << "Cutlass GEMM Tactic"; - if (is_sm100 && tile_config_sm100 != cutlass_extensions::CutlassTileConfigSM100::ChooseWithHeuristic) - { - assert(is_sm100 && !is_sm90 && "Invalid cutlass GEMM config: SM100"); - tactic << "\n\tstyle=TMA_SM100" // Indicate SM100 specific TMA if applicable - << "\n\ttile shape ID: " << (int) tile_config_sm100 - << "\n\tcluster shape ID: " << (int) cluster_shape - << "\n\tmainloop sched: " << (int) mainloop_schedule - << "\n\tepi sched: " << (int) epilogue_schedule; - } - else if (is_sm90 && tile_config_sm90 != cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic) - { - assert(is_sm90 && !is_sm100 && "Invalid cutlass GEMM config: SM90"); - tactic << "\n\tstyle=TMA_SM90" - << "\n\ttile shape ID: " << (int) tile_config_sm90 - << "\n\tcluster shape ID: " << (int) cluster_shape - << "\n\tmainloop sched: " << (int) mainloop_schedule - << "\n\tepi sched: " << (int) epilogue_schedule; - } - else if (tile_config != cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic) - { - assert(!is_sm90 && !is_sm100 && "Invalid cutlass GEMM config: Compatible"); - tactic << "\n\tstyle=compatible" - << "\n\ttile shape ID: " << (int) tile_config - << "\n\tstages: " << (int) stages - << "\n\tsplit_k_style: " << (int) split_k_style - << "\n\tsplit k: " << (int) split_k_factor; - } - else - { - tactic << "\n\tundefined"; - } - tactic << "\n"; - return tactic.str(); - } - - void fromString(const std::string& str) { - std::istringstream stream(str); - std::string line; - - is_sm90 = false; // Reset flags + while (std::getline(stream, line)) { + if (line.find("style=TMA_SM100") != std::string::npos) { + is_sm100 = true; + is_sm90 = false; + std::getline(stream, line); + tile_config_sm100 = + static_cast( + std::stoi(line.substr(line.find(':') + 1))); + std::getline(stream, line); + cluster_shape = static_cast( + std::stoi(line.substr(line.find(':') + 1))); + std::getline(stream, line); + mainloop_schedule = + static_cast( + std::stoi(line.substr(line.find(':') + 1))); + std::getline(stream, line); + epilogue_schedule = + static_cast( + std::stoi(line.substr(line.find(':') + 1))); + } else if (line.find("style=TMA_SM90") != + std::string::npos) { // Check for SM90 specific first + is_sm90 = true; is_sm100 = false; - - while (std::getline(stream, line)) { - if (line.find("style=TMA_SM100") != std::string::npos) { - is_sm100 = true; - is_sm90 = false; - std::getline(stream, line); - tile_config_sm100 = static_cast(std::stoi(line.substr(line.find(':') + 1))); - std::getline(stream, line); - cluster_shape = static_cast(std::stoi(line.substr(line.find(':') + 1))); - std::getline(stream, line); - mainloop_schedule = static_cast(std::stoi(line.substr(line.find(':') + 1))); - std::getline(stream, line); - epilogue_schedule = static_cast(std::stoi(line.substr(line.find(':') + 1))); - } else if (line.find("style=TMA_SM90") != std::string::npos) { // Check for SM90 specific first - is_sm90 = true; - is_sm100 = false; - std::getline(stream, line); - tile_config_sm90 = static_cast(std::stoi(line.substr(line.find(':') + 1))); - std::getline(stream, line); - cluster_shape = static_cast(std::stoi(line.substr(line.find(':') + 1))); - std::getline(stream, line); - mainloop_schedule = static_cast(std::stoi(line.substr(line.find(':') + 1))); - std::getline(stream, line); - epilogue_schedule = static_cast(std::stoi(line.substr(line.find(':') + 1))); - } else if (line.find("style=compatible") != std::string::npos) { - is_sm90 = false; - is_sm100 = false; - std::getline(stream, line); - tile_config = static_cast(std::stoi(line.substr(line.find(':') + 1))); - std::getline(stream, line); - stages = std::stoi(line.substr(line.find(':') + 1)); - std::getline(stream, line); - split_k_style = static_cast(std::stoi(line.substr(line.find(':') + 1))); - std::getline(stream, line); - split_k_factor = std::stoi(line.substr(line.find(':') + 1)); - } - } + std::getline(stream, line); + tile_config_sm90 = + static_cast( + std::stoi(line.substr(line.find(':') + 1))); + std::getline(stream, line); + cluster_shape = static_cast( + std::stoi(line.substr(line.find(':') + 1))); + std::getline(stream, line); + mainloop_schedule = + static_cast( + std::stoi(line.substr(line.find(':') + 1))); + std::getline(stream, line); + epilogue_schedule = + static_cast( + std::stoi(line.substr(line.find(':') + 1))); + } else if (line.find("style=compatible") != std::string::npos) { + is_sm90 = false; + is_sm100 = false; + std::getline(stream, line); + tile_config = static_cast( + std::stoi(line.substr(line.find(':') + 1))); + std::getline(stream, line); + stages = std::stoi(line.substr(line.find(':') + 1)); + std::getline(stream, line); + split_k_style = static_cast( + std::stoi(line.substr(line.find(':') + 1))); + std::getline(stream, line); + split_k_factor = std::stoi(line.substr(line.find(':') + 1)); + } } + } }; -inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& config) -{ - // clang-format off +inline std::ostream& operator<<(std::ostream& out, + CutlassGemmConfig const& config) { + // clang-format off if (config.is_sm100) { out << "tile_config_sm100_enum: " << int(config.tile_config_sm100) @@ -337,8 +358,8 @@ inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& conf << ", split_k_factor: " << config.split_k_factor << ", stages: " << config.stages; } - // clang-format on - return out; + // clang-format on + return out; } -} // namespace cutlass_extensions +} // namespace cutlass_extensions diff --git a/custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h b/custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h index e7e17657be..9a9b35a324 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h +++ b/custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,19 +18,21 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file - \brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t interleaved in a register + \brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t + interleaved in a register */ #pragma once @@ -52,726 +54,801 @@ __device__ inline int lop3(int a, int b, int c) { return res; } -// This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low -// bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally -// signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned. -// This converter will uninterleave the data and subtract the bias while converting to the result type. +// This converter is meant to be used with data interleaved in a 32-bit register +// where the even elements are in the low bits and the odd elemeents are in the +// high bits of the register. In addition, it assumes elements were originally +// signed and had a bias of 2**(b-1) added (where b is the number of bits in the +// type) to make all numbers unsigned. This converter will uninterleave the data +// and subtract the bias while converting to the result type. template struct FastInterleavedAndBiasedNumericArrayConverter; template <> -struct FastInterleavedAndBiasedNumericArrayConverter -{ - using result_type = Array; - using source_type = Array; +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; - uint32_t* h = reinterpret_cast(&result); - uint32_t const i8s = reinterpret_cast(source); + uint32_t* h = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23)); + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[0]) + : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[1]) + : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23)); - // Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16. - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM)); + // Lastly, we subtract 1152 from our constructed number using fp16 math to + // get our signed integer as fp16. + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[0]) + : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[1]) + : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM)); - return result; - } + return result; + } - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } }; template -struct FastInterleavedAndBiasedNumericArrayConverter -{ - static constexpr int VEC_WIDTH = 4; - static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); - using result_type = Array; - using source_type = Array; + using result_type = Array; + using source_type = Array; - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - FastInterleavedAndBiasedNumericArrayConverter - convert_vector_; + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; - result_type result; - using vec_result = Array; - using vec_source = Array; + result_type result; + using vec_result = Array; + using vec_source = Array; - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / VEC_WIDTH; ++i) - { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); } - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } }; template <> -struct FastInterleavedAndBiasedNumericArrayConverter -{ - using result_type = Array; - using source_type = Array; +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) - uint32_t* bf16_result_ptr = reinterpret_cast(&result); - uint32_t const i8s = reinterpret_cast(source); + uint32_t* bf16_result_ptr = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); - static constexpr uint32_t fp32_base = 0x4B000000; - float fp32_intermediates[4]; + static constexpr uint32_t fp32_base = 0x4B000000; + float fp32_intermediates[4]; - // Construct FP32s, bfloat does not have enough mantissa for IADD trick - uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); - fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); - fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652); - fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651); - fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); + // Construct FP32s, bfloat does not have enough mantissa for IADD trick + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); - // Subtract out fp32_base + 128 to make the unsigned integer signed. - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < 4; ++ii) - { - fp32_intermediates[ii] -= 8388736.f; - } + // Subtract out fp32_base + 128 to make the unsigned integer signed. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 4; ++ii) { + fp32_intermediates[ii] -= 8388736.f; + } - // Truncate the fp32 representation and pack up as bfloat16s. - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < 2; ++ii) - { - bf16_result_ptr[ii] - = __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632); - } + // Truncate the fp32 representation and pack up as bfloat16s. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 2; ++ii) { + bf16_result_ptr[ii] = __byte_perm(fp32_intermediates_casted[2 * ii + 0], + fp32_intermediates_casted[2 * ii + 1], + 0x7632); + } #else - // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use - // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. - result.clear(); // Suppress compiler warning - arch::device_breakpoint(); + // Disable this on architectures older than Ampere since they lack hardware + // for bf16 mma. If one wishes to use HMMA on older hardware, they should + // Convert directly to FP16 using FP16 converters. + result.clear(); // Suppress compiler warning + arch::device_breakpoint(); #endif - return result; - } + return result; + } - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } }; template -struct FastInterleavedAndBiasedNumericArrayConverter -{ - static constexpr int VEC_WIDTH = 4; - static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); - using result_type = Array; - using source_type = Array; + using result_type = Array; + using source_type = Array; - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - FastInterleavedAndBiasedNumericArrayConverter - convert_vector_; + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; - result_type result; - using vec_result = Array; - using vec_source = Array; + result_type result; + using vec_result = Array; + using vec_source = Array; - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / VEC_WIDTH; ++i) - { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); } - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } }; template <> -struct FastInterleavedAndBiasedNumericArrayConverter -{ - using result_type = Array; - using source_type = Array; +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; - uint32_t* h = reinterpret_cast(&result); - uint32_t const i4s = reinterpret_cast(source); + uint32_t* h = reinterpret_cast(&result); + uint32_t const i4s = reinterpret_cast(source); - // First, we extract the i4s and construct an intermediate fp16 number. - static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint32_t BOTTOM_MASK = 0x000f000f; - static constexpr uint32_t TOP_MASK = 0x00f000f0; - static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t BOTTOM_MASK = 0x000f000f; + static constexpr uint32_t TOP_MASK = 0x00f000f0; + static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; - // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing - // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. - // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and - // elt_67 to fp16 without having to shift them to the bottom bits before hand. + // Note that the entire sequence only requires 1 shift instruction. This is + // thanks to the register packing format and the fact that we force our + // integers to be unsigned, and account for this in the fp16 subtractions. + // In addition, I exploit the fact that sub and fma have the same throughput + // in order to convert elt_23 and elt_67 to fp16 without having to shift + // them to the bottom bits before hand. - // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue - // immediately before required. - const uint32_t top_i4s = i4s >> 8; - // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[0]) - : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[1]) - : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[2]) - : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[3]) - : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide + // RAW dependency if we issue immediately before required. + const uint32_t top_i4s = i4s >> 8; + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(top_i4s), + "n"(BOTTOM_MASK), + "n"(I4s_TO_F16s_MAGIC_NUM), + "n"(immLut)); + // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the - // half2 ctor. In this case, I chose performance reliability over code readability. + // I use inline PTX below because I am not sure if the compiler will emit + // float2half instructions if I use the half2 ctor. In this case, I chose + // performance reliability over code readability. - // This is the half2 {1032, 1032} represented as an integer. - static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; - // This is the half2 {1 / 16, 1 / 16} represented as an integer. - static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; - // This is the half2 {-72, -72} represented as an integer. - static constexpr uint32_t NEG_72 = 0xd480d480; + // This is the half2 {1032, 1032} represented as an integer. + static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; + // This is the half2 {1 / 16, 1 / 16} represented as an integer. + static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; + // This is the half2 {-72, -72} represented as an integer. + static constexpr uint32_t NEG_72 = 0xd480d480; - // Finally, we construct the output numbers. - // Convert elt_01 - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); - // Convert elt_23 - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); - // Convert elt_45 - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); - // Convert elt_67 - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[0]) + : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(h[1]) + : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); + // Convert elt_45 + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[2]) + : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(h[3]) + : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); - return result; - } + return result; + } - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } }; template -struct FastInterleavedAndBiasedNumericArrayConverter -{ - static constexpr int VEC_WIDTH = 8; - static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 8; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); - using result_type = Array; - using source_type = Array; + using result_type = Array; + using source_type = Array; - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - FastInterleavedAndBiasedNumericArrayConverter - convert_vector_; + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; - result_type result; - using vec_result = Array; - using vec_source = Array; + result_type result; + using vec_result = Array; + using vec_source = Array; - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / VEC_WIDTH; ++i) - { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); } - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } }; template <> -struct FastInterleavedAndBiasedNumericArrayConverter -{ - using result_type = Array; - using source_type = Array; +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) - uint32_t* h = reinterpret_cast(&result); - uint32_t const source_i4s = reinterpret_cast(source); + uint32_t* h = reinterpret_cast(&result); + uint32_t const source_i4s = reinterpret_cast(source); - // First, we extract the i4s and construct an intermediate fp16 number. - static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint32_t MASK = 0x000f000f; - static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; - // We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop. - // No shift needed for first item. - uint32_t i4s = source_i4s; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[0]) - : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); - CUTLASS_PRAGMA_UNROLL - for (int ii = 1; ii < result_type::kElements / 2; ++ii) - { - i4s >>= sizeof_bits::value; - // (i4s & 0x000f000f) | 0x43004300 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[ii]) - : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); - } + // We don't have enough mantissa to remove as much shift overhead as FP16, + // so we must loop. No shift needed for first item. + uint32_t i4s = source_i4s; + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + CUTLASS_PRAGMA_UNROLL + for (int ii = 1; ii < result_type::kElements / 2; ++ii) { + i4s >>= sizeof_bits::value; + // (i4s & 0x000f000f) | 0x43004300 + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[ii]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + } - // This is the BF16 {-136, -136} represented as an integer. - static constexpr uint32_t BF16_BIAS = 0xC308C308; - static constexpr uint32_t BF16_ONE = 0x3F803F80; + // This is the BF16 {-136, -136} represented as an integer. + static constexpr uint32_t BF16_BIAS = 0xC308C308; + static constexpr uint32_t BF16_ONE = 0x3F803F80; - // Finally, we construct the output numbers. - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < result_type::kElements / 2; ++ii) - { - // Since this section is for Ampere+, we use bf16 fma to do the bias subtraction - asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); - } + // Finally, we construct the output numbers. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < result_type::kElements / 2; ++ii) { + // Since this section is for Ampere+, we use bf16 fma to do the bias + // subtraction + asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(h[ii]) + : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); + } #else - // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use - // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. - arch::device_breakpoint(); - result.clear(); // Suppress compiler warning. + // Disable this on architectures older than Ampere since they lack hardware + // for bf16 mma. If one wishes to use HMMA on older hardware, they should + // Convert directly to FP16 using FP16 converters. + arch::device_breakpoint(); + result.clear(); // Suppress compiler warning. #endif - return result; - } + return result; + } - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } }; template -struct FastInterleavedAndBiasedNumericArrayConverter -{ - static constexpr int VEC_WIDTH = 8; - static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 8; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); - using result_type = Array; - using source_type = Array; + using result_type = Array; + using source_type = Array; - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - FastInterleavedAndBiasedNumericArrayConverter - convert_vector_; + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; - result_type result; - using vec_result = Array; - using vec_source = Array; + result_type result; + using vec_result = Array; + using vec_source = Array; - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / VEC_WIDTH; ++i) - { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); } - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } }; template <> -struct FastInterleavedAndBiasedNumericArrayConverter -{ - using result_type = Array; - using source_type = Array; +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; - using ScaleComputeT = float; - using code_type = Array; + using ScaleComputeT = float; + using code_type = Array; - CUTLASS_DEVICE - static result_type convert(source_type const& source, ScaleComputeT code_scale, ScaleComputeT code_zp) - { - uint32_t const i8s = reinterpret_cast(source); + CUTLASS_DEVICE + static result_type convert(source_type const& source, + ScaleComputeT code_scale, + ScaleComputeT code_zp) { + uint32_t const i8s = reinterpret_cast(source); - // 2^23 = 8388608 - static constexpr uint32_t FP32_BASE = 0x4B000000; + // 2^23 = 8388608 + static constexpr uint32_t FP32_BASE = 0x4B000000; - float fp32_intermediates[4]; - uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); - fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650); - fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651); - fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652); - fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653); + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651); + fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652); + fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653); - asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE)); - asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE)); - asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE)); - asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" + : "=r"(fp32_intermediates_casted[0]) + : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" + : "=r"(fp32_intermediates_casted[1]) + : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" + : "=r"(fp32_intermediates_casted[2]) + : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" + : "=r"(fp32_intermediates_casted[3]) + : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE)); - int32_t decode_value[4]; - ScaleComputeT new_code_zp = code_zp + 0.5f; + int32_t decode_value[4]; + ScaleComputeT new_code_zp = code_zp + 0.5f; - decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale, new_code_zp)); - decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale, new_code_zp)); - decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale, new_code_zp)); - decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale, new_code_zp)); + decode_value[0] = + __float2int_rd(fmaf(fp32_intermediates[0], code_scale, new_code_zp)); + decode_value[1] = + __float2int_rd(fmaf(fp32_intermediates[1], code_scale, new_code_zp)); + decode_value[2] = + __float2int_rd(fmaf(fp32_intermediates[2], code_scale, new_code_zp)); + decode_value[3] = + __float2int_rd(fmaf(fp32_intermediates[3], code_scale, new_code_zp)); - return convert_impl(decode_value); - } + return convert_impl(decode_value); + } - CUTLASS_DEVICE - static result_type convert(source_type const& source, code_type const& code_scale, code_type const& code_zp) - { - uint32_t const i8s = reinterpret_cast(source); + CUTLASS_DEVICE + static result_type convert(source_type const& source, + code_type const& code_scale, + code_type const& code_zp) { + uint32_t const i8s = reinterpret_cast(source); - // 2^23 = 8388608 - static constexpr uint32_t FP32_BASE = 0x4B000000; + // 2^23 = 8388608 + static constexpr uint32_t FP32_BASE = 0x4B000000; - float fp32_intermediates[4]; - uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); - fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650); - fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651); - fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652); - fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653); + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651); + fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652); + fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653); - asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE)); - asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE)); - asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE)); - asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" + : "=r"(fp32_intermediates_casted[0]) + : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" + : "=r"(fp32_intermediates_casted[1]) + : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" + : "=r"(fp32_intermediates_casted[2]) + : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" + : "=r"(fp32_intermediates_casted[3]) + : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE)); - int32_t decode_value[4]; + int32_t decode_value[4]; - decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale[0], code_zp[0] + 0.5f)); - decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale[1], code_zp[1] + 0.5f)); - decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale[2], code_zp[2] + 0.5f)); - decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale[3], code_zp[3] + 0.5f)); + decode_value[0] = __float2int_rd( + fmaf(fp32_intermediates[0], code_scale[0], code_zp[0] + 0.5f)); + decode_value[1] = __float2int_rd( + fmaf(fp32_intermediates[1], code_scale[1], code_zp[1] + 0.5f)); + decode_value[2] = __float2int_rd( + fmaf(fp32_intermediates[2], code_scale[2], code_zp[2] + 0.5f)); + decode_value[3] = __float2int_rd( + fmaf(fp32_intermediates[3], code_scale[3], code_zp[3] + 0.5f)); - return convert_impl(decode_value); - } + return convert_impl(decode_value); + } - CUTLASS_DEVICE - static result_type convert_impl(int32_t* decode_value) - { - result_type result; - static constexpr uint32_t immLut = (0xF0 & 0xCC) | 0xAA; + CUTLASS_DEVICE + static result_type convert_impl(int32_t* decode_value) { + result_type result; + static constexpr uint32_t immLut = (0xF0 & 0xCC) | 0xAA; - static constexpr uint32_t MASK = 0x003F003F; - // 2^10 = 1024 - static constexpr uint32_t EX = 0x64006400; + static constexpr uint32_t MASK = 0x003F003F; + // 2^10 = 1024 + static constexpr uint32_t EX = 0x64006400; - uint32_t* h = reinterpret_cast(&result); + uint32_t* h = reinterpret_cast(&result); - int32_t q0 = __byte_perm(decode_value[0], decode_value[1], 0x5410); - int32_t q1 = __byte_perm(decode_value[2], decode_value[3], 0x5410); + int32_t q0 = __byte_perm(decode_value[0], decode_value[1], 0x5410); + int32_t q1 = __byte_perm(decode_value[2], decode_value[3], 0x5410); - h[0] = lop3(q0 >> 9, MASK, EX); - h[1] = lop3(q0 >> 6, MASK, EX); - h[2] = lop3(q0 >> 3, MASK, EX); - h[3] = lop3(q0, MASK, EX); + h[0] = lop3(q0 >> 9, MASK, EX); + h[1] = lop3(q0 >> 6, MASK, EX); + h[2] = lop3(q0 >> 3, MASK, EX); + h[3] = lop3(q0, MASK, EX); - h[4] = lop3(q1 >> 9, MASK, EX); - h[5] = lop3(q1 >> 6, MASK, EX); - h[6] = lop3(q1 >> 3, MASK, EX); - h[7] = lop3(q1, MASK, EX); + h[4] = lop3(q1 >> 9, MASK, EX); + h[5] = lop3(q1 >> 6, MASK, EX); + h[6] = lop3(q1 >> 3, MASK, EX); + h[7] = lop3(q1, MASK, EX); - // 1024 + 32 = 1056 - static constexpr uint32_t SUB = 0x64206420; + // 1024 + 32 = 1056 + static constexpr uint32_t SUB = 0x64206420; - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(SUB)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(SUB)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(SUB)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(SUB)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(SUB)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(SUB)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(SUB)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(SUB)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[4]) : "r"(h[4]), "r"(SUB)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[5]) : "r"(h[5]), "r"(SUB)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[6]) : "r"(h[6]), "r"(SUB)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[7]) : "r"(h[7]), "r"(SUB)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[4]) : "r"(h[4]), "r"(SUB)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[5]) : "r"(h[5]), "r"(SUB)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[6]) : "r"(h[6]), "r"(SUB)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[7]) : "r"(h[7]), "r"(SUB)); - return result; - } + return result; + } - CUTLASS_DEVICE - result_type operator()(source_type const& s, ScaleComputeT code_scale, ScaleComputeT code_zp) - { - return convert(s, code_scale, code_zp); - } + CUTLASS_DEVICE + result_type operator()(source_type const& s, + ScaleComputeT code_scale, + ScaleComputeT code_zp) { + return convert(s, code_scale, code_zp); + } }; template <> -struct FastInterleavedAndBiasedNumericArrayConverter -{ - using result_type = Array; - using source_type = Array; +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; - using ScaleComputeT = float; - using code_type = Array; + using ScaleComputeT = float; + using code_type = Array; - CUTLASS_DEVICE - static result_type convert(source_type const& source, ScaleComputeT code_scale, ScaleComputeT code_zp) - { - uint32_t const i8s = reinterpret_cast(source); + CUTLASS_DEVICE + static result_type convert(source_type const& source, + ScaleComputeT code_scale, + ScaleComputeT code_zp) { + uint32_t const i8s = reinterpret_cast(source); - // 2^23 = 8388608 - static constexpr uint32_t FP32_BASE = 0x4B000000; + // 2^23 = 8388608 + static constexpr uint32_t FP32_BASE = 0x4B000000; - float fp32_intermediates[4]; - uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); - fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650); - fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651); - fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652); - fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653); + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651); + fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652); + fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653); - asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE)); - asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE)); - asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE)); - asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" + : "=r"(fp32_intermediates_casted[0]) + : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" + : "=r"(fp32_intermediates_casted[1]) + : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" + : "=r"(fp32_intermediates_casted[2]) + : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" + : "=r"(fp32_intermediates_casted[3]) + : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE)); - int32_t decode_value[4]; - ScaleComputeT new_code_zp = code_zp + 0.5f; + int32_t decode_value[4]; + ScaleComputeT new_code_zp = code_zp + 0.5f; - decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale, new_code_zp)); - decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale, new_code_zp)); - decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale, new_code_zp)); - decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale, new_code_zp)); + decode_value[0] = + __float2int_rd(fmaf(fp32_intermediates[0], code_scale, new_code_zp)); + decode_value[1] = + __float2int_rd(fmaf(fp32_intermediates[1], code_scale, new_code_zp)); + decode_value[2] = + __float2int_rd(fmaf(fp32_intermediates[2], code_scale, new_code_zp)); + decode_value[3] = + __float2int_rd(fmaf(fp32_intermediates[3], code_scale, new_code_zp)); - return convert_impl(decode_value); - } + return convert_impl(decode_value); + } - CUTLASS_DEVICE - static result_type convert(source_type const& source, code_type const& code_scale, code_type const& code_zp) - { - uint32_t const i8s = reinterpret_cast(source); + CUTLASS_DEVICE + static result_type convert(source_type const& source, + code_type const& code_scale, + code_type const& code_zp) { + uint32_t const i8s = reinterpret_cast(source); - // 2^23 = 8388608 - static constexpr uint32_t FP32_BASE = 0x4B000000; + // 2^23 = 8388608 + static constexpr uint32_t FP32_BASE = 0x4B000000; - float fp32_intermediates[4]; - uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); - fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650); - fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651); - fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652); - fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653); + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651); + fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652); + fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653); - asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE)); - asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE)); - asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE)); - asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" + : "=r"(fp32_intermediates_casted[0]) + : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" + : "=r"(fp32_intermediates_casted[1]) + : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" + : "=r"(fp32_intermediates_casted[2]) + : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" + : "=r"(fp32_intermediates_casted[3]) + : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE)); - int32_t decode_value[4]; + int32_t decode_value[4]; - decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale[0], code_zp[0] + 0.5f)); - decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale[1], code_zp[1] + 0.5f)); - decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale[2], code_zp[2] + 0.5f)); - decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale[3], code_zp[3] + 0.5f)); + decode_value[0] = __float2int_rd( + fmaf(fp32_intermediates[0], code_scale[0], code_zp[0] + 0.5f)); + decode_value[1] = __float2int_rd( + fmaf(fp32_intermediates[1], code_scale[1], code_zp[1] + 0.5f)); + decode_value[2] = __float2int_rd( + fmaf(fp32_intermediates[2], code_scale[2], code_zp[2] + 0.5f)); + decode_value[3] = __float2int_rd( + fmaf(fp32_intermediates[3], code_scale[3], code_zp[3] + 0.5f)); - return convert_impl(decode_value); - } + return convert_impl(decode_value); + } - CUTLASS_DEVICE - static result_type convert_impl(int32_t* decode_value) - { - result_type result; + CUTLASS_DEVICE + static result_type convert_impl(int32_t* decode_value) { + result_type result; - static constexpr uint32_t immLut = (0xF0 & 0xCC) | 0xAA; - static constexpr uint32_t MASK = 0x003F003F; - // 2^7 = 128 - static constexpr uint32_t EX = 0x43004300; + static constexpr uint32_t immLut = (0xF0 & 0xCC) | 0xAA; + static constexpr uint32_t MASK = 0x003F003F; + // 2^7 = 128 + static constexpr uint32_t EX = 0x43004300; - uint32_t* h = reinterpret_cast(&result); + uint32_t* h = reinterpret_cast(&result); - int32_t q0 = __byte_perm(decode_value[0], decode_value[1], 0x5410); - int32_t q1 = __byte_perm(decode_value[2], decode_value[3], 0x5410); + int32_t q0 = __byte_perm(decode_value[0], decode_value[1], 0x5410); + int32_t q1 = __byte_perm(decode_value[2], decode_value[3], 0x5410); - h[0] = lop3(q0 >> 9, MASK, EX); - h[1] = lop3(q0 >> 6, MASK, EX); - h[2] = lop3(q0 >> 3, MASK, EX); - h[3] = lop3(q0, MASK, EX); + h[0] = lop3(q0 >> 9, MASK, EX); + h[1] = lop3(q0 >> 6, MASK, EX); + h[2] = lop3(q0 >> 3, MASK, EX); + h[3] = lop3(q0, MASK, EX); - h[4] = lop3(q1 >> 9, MASK, EX); - h[5] = lop3(q1 >> 6, MASK, EX); - h[6] = lop3(q1 >> 3, MASK, EX); - h[7] = lop3(q1, MASK, EX); + h[4] = lop3(q1 >> 9, MASK, EX); + h[5] = lop3(q1 >> 6, MASK, EX); + h[6] = lop3(q1 >> 3, MASK, EX); + h[7] = lop3(q1, MASK, EX); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(ENABLE_BF16)) - // 128 + 32 = 160 - static constexpr uint32_t SUB = 0x43204320; + // 128 + 32 = 160 + static constexpr uint32_t SUB = 0x43204320; - asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(SUB)); - asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(SUB)); - asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(SUB)); - asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(SUB)); + asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(SUB)); + asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(SUB)); + asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(SUB)); + asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(SUB)); - asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[4]) : "r"(h[4]), "r"(SUB)); - asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[5]) : "r"(h[5]), "r"(SUB)); - asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[6]) : "r"(h[6]), "r"(SUB)); - asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[7]) : "r"(h[7]), "r"(SUB)); + asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[4]) : "r"(h[4]), "r"(SUB)); + asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[5]) : "r"(h[5]), "r"(SUB)); + asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[6]) : "r"(h[6]), "r"(SUB)); + asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[7]) : "r"(h[7]), "r"(SUB)); #else - // 1.0 - static constexpr uint32_t MUL = 0x3F803F80; - // -160 - static constexpr uint32_t ADD = 0xC320C320; + // 1.0 + static constexpr uint32_t MUL = 0x3F803F80; + // -160 + static constexpr uint32_t ADD = 0xC320C320; - asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[0]) : "r"(h[0]), "r"(MUL), "r"(ADD)); - asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(MUL), "r"(ADD)); - asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[2]) : "r"(h[2]), "r"(MUL), "r"(ADD)); - asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(MUL), "r"(ADD)); + asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(h[0]) + : "r"(h[0]), "r"(MUL), "r"(ADD)); + asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(h[1]) + : "r"(h[1]), "r"(MUL), "r"(ADD)); + asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(h[2]) + : "r"(h[2]), "r"(MUL), "r"(ADD)); + asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(h[3]) + : "r"(h[3]), "r"(MUL), "r"(ADD)); - asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[4]) : "r"(h[4]), "r"(MUL), "r"(ADD)); - asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[5]) : "r"(h[5]), "r"(MUL), "r"(ADD)); - asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[6]) : "r"(h[6]), "r"(MUL), "r"(ADD)); - asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[7]) : "r"(h[7]), "r"(MUL), "r"(ADD)); + asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(h[4]) + : "r"(h[4]), "r"(MUL), "r"(ADD)); + asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(h[5]) + : "r"(h[5]), "r"(MUL), "r"(ADD)); + asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(h[6]) + : "r"(h[6]), "r"(MUL), "r"(ADD)); + asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(h[7]) + : "r"(h[7]), "r"(MUL), "r"(ADD)); #endif - return result; - } + return result; + } - CUTLASS_DEVICE - result_type operator()(source_type const& s, ScaleComputeT code_scale, ScaleComputeT code_zp) - { - return convert(s, code_scale, code_zp); - } + CUTLASS_DEVICE + result_type operator()(source_type const& s, + ScaleComputeT code_scale, + ScaleComputeT code_zp) { + return convert(s, code_scale, code_zp); + } }; template -struct FastInterleavedAndBiasedNumericArrayConverter -{ - static_assert(platform::is_same::value || platform::is_same::value, - "T must be fp16 or bf16"); +struct FastInterleavedAndBiasedNumericArrayConverter { + static_assert(platform::is_same::value || + platform::is_same::value, + "T must be fp16 or bf16"); - static constexpr int kVecWidth = 16; - static_assert(!(N % kVecWidth), "N must be multiple of 16."); + static constexpr int kVecWidth = 16; + static_assert(!(N % kVecWidth), "N must be multiple of 16."); - using result_type = Array; - using source_type = Array; - using code_type = Array; + using result_type = Array; + using source_type = Array; + using code_type = Array; - CUTLASS_DEVICE - static result_type convert(source_type const& source, code_type const& code_scale, code_type const& code_zp) - { - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - FastInterleavedAndBiasedNumericArrayConverter - convert_vector_; + CUTLASS_DEVICE + static result_type convert(source_type const& source, + code_type const& code_scale, + code_type const& code_zp) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; - result_type result; - using vec_result = Array; - using vec_source = Array; + result_type result; + using vec_result = Array; + using vec_source = Array; - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / kVecWidth; ++i) - { - result_ptr[i] = convert_vector_(source_ptr[i], code_scale[i], code_zp[i]); - } - - return result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / kVecWidth; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i], code_scale[i], code_zp[i]); } - CUTLASS_DEVICE - static result_type convert(source_type const& source, Array const& code_scale, Array const& code_zp) - { - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - using Converter = FastInterleavedAndBiasedNumericArrayConverter; + return result; + } - result_type result; - using vec_result = typename Converter::result_type; - using vec_source = typename Converter::source_type; - using vec_code = typename Converter::code_type; + CUTLASS_DEVICE + static result_type convert(source_type const& source, + Array const& code_scale, + Array const& code_zp) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + using Converter = + FastInterleavedAndBiasedNumericArrayConverter; - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); - vec_code const* code_scale_ptr = reinterpret_cast(&code_scale); - vec_code const* code_zp_ptr = reinterpret_cast(&code_zp); + result_type result; + using vec_result = typename Converter::result_type; + using vec_source = typename Converter::source_type; + using vec_code = typename Converter::code_type; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / kVecWidth; ++i) - { - result_ptr[i] = Converter::convert(source_ptr[i], code_scale_ptr[i], code_zp_ptr[i]); - } + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + vec_code const* code_scale_ptr = + reinterpret_cast(&code_scale); + vec_code const* code_zp_ptr = reinterpret_cast(&code_zp); - return result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / kVecWidth; ++i) { + result_ptr[i] = + Converter::convert(source_ptr[i], code_scale_ptr[i], code_zp_ptr[i]); } - CUTLASS_DEVICE - result_type operator()(source_type const& s, code_type const& code_scale, code_type const& code_zp) - { - return convert(s, code_scale, code_zp); - } + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s, + code_type const& code_scale, + code_type const& code_zp) { + return convert(s, code_scale, code_zp); + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace cutlass +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/tile_interleaved_layout.h b/custom_ops/gpu_ops/cutlass_extensions/tile_interleaved_layout.h index 5a0cd29570..928f2645a5 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/tile_interleaved_layout.h +++ b/custom_ops/gpu_ops/cutlass_extensions/tile_interleaved_layout.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file @@ -38,29 +39,24 @@ #include "cutlass/matrix_coord.h" #include "cutlass/pitch_linear_coord.h" -namespace cutlass -{ -namespace layout -{ +namespace cutlass { +namespace layout { template -struct ColumnMajorTileInterleave -{ - static constexpr int kRowsPerTile = RowsPerTile; - static constexpr int kColumnsInterleaved = ColumnsInterleaved; +struct ColumnMajorTileInterleave { + static constexpr int kRowsPerTile = RowsPerTile; + static constexpr int kColumnsInterleaved = ColumnsInterleaved; }; template -struct IsColumnMajorTileInterleave -{ - static constexpr bool value = false; +struct IsColumnMajorTileInterleave { + static constexpr bool value = false; }; template -struct IsColumnMajorTileInterleave> -{ - static constexpr bool value = true; +struct IsColumnMajorTileInterleave> { + static constexpr bool value = true; }; -} // namespace layout -} // namespace cutlass +} // namespace layout +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h b/custom_ops/gpu_ops/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h index 6095925e37..6d45e5cb02 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h +++ b/custom_ops/gpu_ops/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,19 +18,20 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file - \brief Templates for visiting scales to be used when dequantizing the weights for weight-only GEMM - quantization. + \brief Templates for visiting scales to be used when dequantizing the + weights for weight-only GEMM quantization. */ #pragma once @@ -50,201 +51,205 @@ //////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace transform -{ -namespace threadblock -{ +namespace cutlass { +namespace transform { +namespace threadblock { //////////////////////////////////////////////////////////////////////////////// -template +template class FineGrainedScaleZeroIterator; template -class FineGrainedScaleZeroIterator -{ -public: - using Shape = Shape_; - using Element = Element_; - using Layout = layout::RowMajor; - static int const kAdvanceRank = 0; - static int const kAlignment = Alignment_; +class FineGrainedScaleZeroIterator { + public: + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = 0; + static int const kAlignment = Alignment_; - static int const kAccessesPerVector = 1; + static int const kAccessesPerVector = 1; - /// Row index of scales corresponding to the groupsize of 64 - int row_groupsize64_; - int group_size_; + /// Row index of scales corresponding to the groupsize of 64 + int row_groupsize64_; + int group_size_; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - using Pointer = Element*; - using NonConstPointer = typename platform::remove_const::type*; + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; - using AccessType = AlignedArray; + using AccessType = AlignedArray; - using Fragment = cutlass::Array; + using Fragment = cutlass::Array; - // For compatibility with existing iterator interface - struct Params - { - LongIndex stride_ = 0; + // For compatibility with existing iterator interface + struct Params { + LongIndex stride_ = 0; - /// amount (in byte) to increment pointer from first access of current tile - /// to first access of next tile - LongIndex inc_advance_ = 0; + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_ = 0; - // Default ctor - CUTLASS_HOST_DEVICE - Params() {} - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const& layout) - : stride_(layout.stride(0)) - { - inc_advance_ = Shape::kRow * stride_ * sizeof_bits::value / 8; - } - }; - -private: - /// Internal pointer type permits fast address arithmetic - using BytePointer = char*; - -private: - // - // Data members - // - - /// Parameters object with precomputed internal state - Params const params_; - - /// Internal pointer to first access of tile - BytePointer pointer_scale_; - BytePointer pointer_zero_; - - bool is_valid_ = false; - -public: - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_DEVICE - FineGrainedScaleZeroIterator( - ///< Precomputed parameters object - Params const& params, - ///< Pointer to start of scale tensor - Pointer pointer_scale, - ///< Pointer to start of zero tensor - Pointer pointer_zero, - ///< Extent of the scale and bias - TensorCoord extent, - ///< ID of each participating thread - int thread_id, - ///< Initial offset of threadblock - TensorCoord const& threadblock_offset, - ///< Group size - int group_size) - : params_(params) - , pointer_scale_(reinterpret_cast(const_cast(pointer_scale))) - , pointer_zero_(reinterpret_cast(const_cast(pointer_zero))) - { - row_groupsize64_ = threadblock_offset.row(); - group_size_ = group_size; - - const LongIndex tb_row_byte_offset - = threadblock_offset.row() / (group_size / 64) * params_.stride_ * sizeof_bits::value / 8; - const LongIndex tb_col_byte_offset = threadblock_offset.column() * sizeof_bits::value / 8; - pointer_scale_ += (tb_row_byte_offset + tb_col_byte_offset); - - if (pointer_zero_ != nullptr) - { - pointer_zero_ += (tb_row_byte_offset + tb_col_byte_offset); - } - - static constexpr int THREADS_PER_ROW = Shape::kColumn / kAlignment; - - int const thread_row = thread_id / THREADS_PER_ROW; - int const thread_col = thread_id % THREADS_PER_ROW; - - const LongIndex thread_row_byte_offset = thread_row * params_.stride_ * sizeof_bits::value / 8; - const LongIndex thread_col_byte_offset = thread_col * kAlignment * sizeof_bits::value / 8; - pointer_scale_ += (thread_row_byte_offset + thread_col_byte_offset); - if (pointer_zero_ != nullptr) - { - pointer_zero_ += (thread_row_byte_offset + thread_col_byte_offset); - } - - // For the rows, we must check that we are within the extent AND the tile to avoid extra reads on - // a given iteration. The same threads will be responsible for issues reads since the number of scales - // read in a given iteration is a constant. Therefore, we should never have to update is_valid_ - // outside of the constructor. - int const global_row = threadblock_offset.row() + thread_row; - int const global_col = threadblock_offset.column() + thread_col * kAlignment; - - bool const row_in_bounds = global_row < extent.row() && thread_row < Shape::kRow; - bool const col_in_bounds = global_col < extent.column(); - - is_valid_ = row_in_bounds && col_in_bounds; - } - - /// Construct a PredicatedTileAccessIterator with zero threadblock offset - CUTLASS_HOST_DEVICE FineGrainedScaleZeroIterator(Params const& params, ///< Precomputed parameters object - Pointer pointer_scale, ///< Pointer to start of scale tensor - Pointer pointer_zero, ///< Pointer to start of zero tensor - TensorCoord extent, ///< Extent of tensor - int thread_id, ///< ID of each participating thread - int group_size) - : FineGrainedScaleZeroIterator( - params, pointer_scale, pointer_zero, extent, thread_id, make_Coord(0, 0), group_size) - { - } - - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const& tile_offset) - { - const LongIndex row_byte_offset = tile_offset.row() * params_.inc_advance_; - const LongIndex col_byte_offset = tile_offset.column() * Shape::kColumn * sizeof_bits::value / 8; - pointer_scale_ += row_byte_offset + col_byte_offset; - if (pointer_zero_ != nullptr) - { - pointer_zero_ += row_byte_offset + col_byte_offset; - } - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE void clear_mask(bool enable = true) - { - is_valid_ &= (!enable); - } - - /// Returns whether access is valid or not + // Default ctor CUTLASS_HOST_DEVICE - bool valid() const - { - return is_valid_; + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : stride_(layout.stride(0)) { + inc_advance_ = Shape::kRow * stride_ * sizeof_bits::value / 8; + } + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + /// Parameters object with precomputed internal state + Params const params_; + + /// Internal pointer to first access of tile + BytePointer pointer_scale_; + BytePointer pointer_zero_; + + bool is_valid_ = false; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_DEVICE + FineGrainedScaleZeroIterator( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of scale tensor + Pointer pointer_scale, + ///< Pointer to start of zero tensor + Pointer pointer_zero, + ///< Extent of the scale and bias + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + ///< Group size + int group_size) + : params_(params), + pointer_scale_(reinterpret_cast( + const_cast(pointer_scale))), + pointer_zero_(reinterpret_cast( + const_cast(pointer_zero))) { + row_groupsize64_ = threadblock_offset.row(); + group_size_ = group_size; + + const LongIndex tb_row_byte_offset = threadblock_offset.row() / + (group_size / 64) * params_.stride_ * + sizeof_bits::value / 8; + const LongIndex tb_col_byte_offset = + threadblock_offset.column() * sizeof_bits::value / 8; + pointer_scale_ += (tb_row_byte_offset + tb_col_byte_offset); + + if (pointer_zero_ != nullptr) { + pointer_zero_ += (tb_row_byte_offset + tb_col_byte_offset); } - /// Returns a scale pointer - CUTLASS_HOST_DEVICE - AccessType* get_scale() const - { - return reinterpret_cast(pointer_scale_); + static constexpr int THREADS_PER_ROW = Shape::kColumn / kAlignment; + + int const thread_row = thread_id / THREADS_PER_ROW; + int const thread_col = thread_id % THREADS_PER_ROW; + + const LongIndex thread_row_byte_offset = + thread_row * params_.stride_ * sizeof_bits::value / 8; + const LongIndex thread_col_byte_offset = + thread_col * kAlignment * sizeof_bits::value / 8; + pointer_scale_ += (thread_row_byte_offset + thread_col_byte_offset); + if (pointer_zero_ != nullptr) { + pointer_zero_ += (thread_row_byte_offset + thread_col_byte_offset); } - /// Returns a zero pointer - CUTLASS_HOST_DEVICE - AccessType* get_zero() const - { - return reinterpret_cast(pointer_zero_); + // For the rows, we must check that we are within the extent AND the tile to + // avoid extra reads on a given iteration. The same threads will be + // responsible for issues reads since the number of scales read in a given + // iteration is a constant. Therefore, we should never have to update + // is_valid_ outside of the constructor. + int const global_row = threadblock_offset.row() + thread_row; + int const global_col = + threadblock_offset.column() + thread_col * kAlignment; + + bool const row_in_bounds = + global_row < extent.row() && thread_row < Shape::kRow; + bool const col_in_bounds = global_col < extent.column(); + + is_valid_ = row_in_bounds && col_in_bounds; + } + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE FineGrainedScaleZeroIterator( + Params const& params, ///< Precomputed parameters object + Pointer pointer_scale, ///< Pointer to start of scale tensor + Pointer pointer_zero, ///< Pointer to start of zero tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + int group_size) + : FineGrainedScaleZeroIterator(params, + pointer_scale, + pointer_zero, + extent, + thread_id, + make_Coord(0, 0), + group_size) {} + + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + const LongIndex row_byte_offset = tile_offset.row() * params_.inc_advance_; + const LongIndex col_byte_offset = + tile_offset.column() * Shape::kColumn * sizeof_bits::value / 8; + pointer_scale_ += row_byte_offset + col_byte_offset; + if (pointer_zero_ != nullptr) { + pointer_zero_ += row_byte_offset + col_byte_offset; } + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE void clear_mask(bool enable = true) { + is_valid_ &= (!enable); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const { return is_valid_; } + + /// Returns a scale pointer + CUTLASS_HOST_DEVICE + AccessType* get_scale() const { + return reinterpret_cast(pointer_scale_); + } + + /// Returns a zero pointer + CUTLASS_HOST_DEVICE + AccessType* get_zero() const { + return reinterpret_cast(pointer_zero_); + } }; -} // namespace threadblock -} // namespace transform -} // namespace cutlass +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/util/gather_tensor.hpp b/custom_ops/gpu_ops/cutlass_extensions/util/gather_tensor.hpp index b430380b01..b29fc4db5f 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/util/gather_tensor.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/util/gather_tensor.hpp @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once @@ -38,144 +39,148 @@ using namespace cute; /// Function object that applies an index to its argument template -struct IndexedGather -{ - CUTE_HOST_DEVICE constexpr IndexedGather(Iter indices = {}) - : indices_(indices) - { - } +struct IndexedGather { + CUTE_HOST_DEVICE constexpr IndexedGather(Iter indices = {}) + : indices_(indices) {} - template - CUTE_HOST_DEVICE constexpr auto operator()(I i) const - { - return indices_[i]; - } + template + CUTE_HOST_DEVICE constexpr auto operator()(I i) const { + return indices_[i]; + } - CUTE_HOST_DEVICE friend void print(IndexedGather const& s) - { - cute::print("Indexed{"); - print(s.indices_); - print("}"); - } + CUTE_HOST_DEVICE friend void print(IndexedGather const& s) { + cute::print("Indexed{"); + print(s.indices_); + print("}"); + } - Iter indices_; + Iter indices_; }; /// Custom stride object that applies a function followed by a stride template -struct CustomStride -{ - CUTE_HOST_DEVICE constexpr CustomStride(Func const& func, Stride const& stride) - : func_(func) - , stride_(stride) - { - } +struct CustomStride { + CUTE_HOST_DEVICE constexpr CustomStride(Func const& func, + Stride const& stride) + : func_(func), stride_(stride) {} - template - CUTE_HOST_DEVICE constexpr friend auto operator*(I i, CustomStride const& s) - { - return s.func_(i) * s.stride_; - } + template + CUTE_HOST_DEVICE constexpr friend auto operator*(I i, CustomStride const& s) { + return s.func_(i) * s.stride_; + } - template - CUTE_HOST_DEVICE constexpr friend auto operator*(CustomStride const& s, I i) - { - return s.func_(i) * s.stride_; - } + template + CUTE_HOST_DEVICE constexpr friend auto operator*(CustomStride const& s, I i) { + return s.func_(i) * s.stride_; + } - CUTE_HOST_DEVICE friend void print(CustomStride const& s) - { - cute::print("Custom{"); - print(s.func_); - cute::print(","); - print(s.stride_); - cute::print("}"); - } + CUTE_HOST_DEVICE friend void print(CustomStride const& s) { + cute::print("Custom{"); + print(s.func_); + cute::print(","); + print(s.stride_); + cute::print("}"); + } - template - CUTE_HOST_DEVICE constexpr friend auto safe_div(CustomStride const& s, Div const& div) - { - return CustomStride(s.func_, safe_div(s.stride_, div)); - } + template + CUTE_HOST_DEVICE constexpr friend auto safe_div(CustomStride const& s, + Div const& div) { + return CustomStride( + s.func_, safe_div(s.stride_, div)); + } - // Circumvent the requirement on make_layout that shape and stride are integral - template - CUTE_HOST_DEVICE constexpr friend auto make_layout(Shape const& shape, CustomStride const& stride) - { - return Layout(shape, stride); - } + // Circumvent the requirement on make_layout that shape and stride are + // integral + template + CUTE_HOST_DEVICE constexpr friend auto make_layout( + Shape const& shape, CustomStride const& stride) { + return Layout(shape, stride); + } - Func func_; - Stride stride_; + Func func_; + Stride stride_; }; template -CUTLASS_HOST_DEVICE auto make_custom_stride_layout(Stride const& stride, Func&& func) -{ - // Use a dummy shape and replace the first non-unit and non-zero stride with a custom gather stride - auto idx = find_if(stride, [](auto x) { return !is_constant<1, decltype(x)>{} && !is_constant<0, decltype(x)>{}; }); - constexpr int I = decltype(idx)::value; - return make_layout( - repeat_like(stride, _1{}), replace(stride, CustomStride{static_cast(func), get(stride)})); +CUTLASS_HOST_DEVICE auto make_custom_stride_layout(Stride const& stride, + Func&& func) { + // Use a dummy shape and replace the first non-unit and non-zero stride with a + // custom gather stride + auto idx = find_if(stride, [](auto x) { + return !is_constant<1, decltype(x)>{} && !is_constant<0, decltype(x)>{}; + }); + constexpr int I = decltype(idx)::value; + return make_layout( + repeat_like(stride, _1{}), + replace(stride, + CustomStride{static_cast(func), get(stride)})); } /// Helper function to optionally create a gather tensor template -CUTLASS_HOST_DEVICE auto make_gather_tensor(Iterator iter, Shape const& shape, Stride const& stride, Func&& func) -{ - Layout matrix_layout = make_identity_layout(shape); - auto offset = as_arithmetic_tuple(repeat_like(shape, _0{})); - Layout gather_layout = make_custom_stride_layout(stride, static_cast(func)); - return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout}); +CUTLASS_HOST_DEVICE auto make_gather_tensor(Iterator iter, + Shape const& shape, + Stride const& stride, + Func&& func) { + Layout matrix_layout = make_identity_layout(shape); + auto offset = as_arithmetic_tuple(repeat_like(shape, _0{})); + Layout gather_layout = + make_custom_stride_layout(stride, static_cast(func)); + return make_tensor(iter, + ComposedLayout{gather_layout, offset, matrix_layout}); } -namespace cute -{ +namespace cute { template -CUTE_HOST_DEVICE constexpr auto upcast(Shape const& shape, Stride const& stride) -{ - if constexpr (is_tuple::value) - { - return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast(s, d); }); - } - else if constexpr (is_scaled_basis::value) - { - if constexpr (Stride::mode() == I) - { - return make_layout(shape_div(shape, Int{}), shape_div(stride, Int{})); - } - else - { - return make_layout(shape, stride); - } - } - else - { - return upcast(shape, stride); +CUTE_HOST_DEVICE constexpr auto upcast(Shape const& shape, + Stride const& stride) { + if constexpr (is_tuple::value) { + return transform_layout(shape, stride, [](auto const& s, auto const& d) { + return upcast(s, d); + }); + } else if constexpr (is_scaled_basis::value) { + if constexpr (Stride::mode() == I) { + return make_layout(shape_div(shape, Int{}), + shape_div(stride, Int{})); + } else { + return make_layout(shape, stride); } + } else { + return upcast(shape, stride); + } - CUTE_GCC_UNREACHABLE; + CUTE_GCC_UNREACHABLE; } -template +template CUTE_HOST_DEVICE constexpr auto upcast( - ComposedLayout, Offset, Layout> const& layout) -{ - // Find index of the stride-1 mode - that is the only one that requires updating inner shape and offset - auto idx = find_if(layout.layout_a().stride(), [](auto x) { return is_constant<1, decltype(x)>{}; }); - constexpr int I = decltype(idx)::value; + ComposedLayout, + Offset, + Layout> const& layout) { + // Find index of the stride-1 mode - that is the only one that requires + // updating inner shape and offset + auto idx = find_if(layout.layout_a().stride(), + [](auto x) { return is_constant<1, decltype(x)>{}; }); + constexpr int I = decltype(idx)::value; - // Upcast the outer layout (works as expected) - auto outer = upcast(layout.layout_a()); + // Upcast the outer layout (works as expected) + auto outer = upcast(layout.layout_a()); - // Upcast the accumulated offset along stride-1 mode - auto offset = as_arithmetic_tuple(replace(layout.offset(), upcast(get(layout.offset())))); + // Upcast the accumulated offset along stride-1 mode + auto offset = as_arithmetic_tuple( + replace(layout.offset(), upcast(get(layout.offset())))); - // Upcast the inner layout's shape along stride-1 mode - auto inner = upcast(layout.layout_b().shape(), layout.layout_b().stride()); + // Upcast the inner layout's shape along stride-1 mode + auto inner = + upcast(layout.layout_b().shape(), layout.layout_b().stride()); - return composition(outer, offset, inner); + return composition(outer, offset, inner); } -} // namespace cute +} // namespace cute diff --git a/custom_ops/gpu_ops/cutlass_extensions/weight_only_quant_op.h b/custom_ops/gpu_ops/cutlass_extensions/weight_only_quant_op.h index 64774428e9..9f2d255234 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/weight_only_quant_op.h +++ b/custom_ops/gpu_ops/cutlass_extensions/weight_only_quant_op.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,41 +18,40 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file - \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. + \brief Defines iterators used by warp-level matrix multiply operations + targeting Tensor Cores. */ #pragma once -namespace cutlass -{ +namespace cutlass { -enum class WeightOnlyQuantOp -{ - UNDEFINED, - PER_COLUMN_SCALE_ONLY, - FINEGRAINED_SCALE_ONLY, - FINEGRAINED_SCALE_AND_ZEROS +enum class WeightOnlyQuantOp { + UNDEFINED, + PER_COLUMN_SCALE_ONLY, + FINEGRAINED_SCALE_ONLY, + FINEGRAINED_SCALE_AND_ZEROS }; -constexpr bool isFinegrained(WeightOnlyQuantOp op) -{ - return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS || op == WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; +constexpr bool isFinegrained(WeightOnlyQuantOp op) { + return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS || + op == WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; } -constexpr bool hasZero(WeightOnlyQuantOp op) -{ - return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; +constexpr bool hasZero(WeightOnlyQuantOp op) { + return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; } -} // namespace cutlass +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/wint_type_traits.h b/custom_ops/gpu_ops/cutlass_extensions/wint_type_traits.h index fa28810697..da6fcf41a2 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/wint_type_traits.h +++ b/custom_ops/gpu_ops/cutlass_extensions/wint_type_traits.h @@ -33,19 +33,23 @@ enum WintQuantMethod { }; // Convert CUDA data type to cutlass data type -template struct CutlassDataType { +template +struct CutlassDataType { using Type = T; }; -template <> struct CutlassDataType { +template <> +struct CutlassDataType { using Type = cutlass::half_t; }; -template <> struct CutlassDataType<__nv_bfloat16> { +template <> +struct CutlassDataType<__nv_bfloat16> { using Type = cutlass::bfloat16_t; }; -template struct WintQuantTraits; +template +struct WintQuantTraits; template struct WintQuantTraits { @@ -129,7 +133,7 @@ struct WintQuantTraits { using CodeScaleZpType = float; struct Arguments { - uint8_t *local_scale_ptr; // quanted 4-bits + uint8_t *local_scale_ptr; // quanted 4-bits float *code_scale_ptr; float *code_zp_ptr; }; @@ -140,4 +144,4 @@ struct WintQuantTraits { } }; -} // namespace cutlass +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_kernels/cutlass_helper.h b/custom_ops/gpu_ops/cutlass_kernels/cutlass_helper.h index 3ac548c625..f9e4aad2ff 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/cutlass_helper.h +++ b/custom_ops/gpu_ops/cutlass_kernels/cutlass_helper.h @@ -24,11 +24,11 @@ /** * Helper function for checking CUTLASS errors */ -#define CUTLASS_CHECK(status) \ - { \ - cutlass::Status error = status; \ - PD_CHECK(error == cutlass::Status::kSuccess, \ - cutlassGetStatusString(error)); \ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + PD_CHECK(error == cutlass::Status::kSuccess, \ + cutlassGetStatusString(error)); \ } /** @@ -38,44 +38,50 @@ * __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef * into code that will be executed on the device where it is defined. */ -template struct enable_sm90_or_later : Kernel { - template CUTLASS_DEVICE void operator()(Args &&...args) { +template +struct enable_sm90_or_later : Kernel { + template + CUTLASS_DEVICE void operator()(Args &&...args) { #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 Kernel::operator()(std::forward(args)...); #endif } }; -template class CutlassDtypeTraits; +template +class CutlassDtypeTraits; -template <> class CutlassDtypeTraits { -public: +template <> +class CutlassDtypeTraits { + public: typedef float DataType; typedef float data_t; }; -template <> class CutlassDtypeTraits { -public: +template <> +class CutlassDtypeTraits { + public: typedef cutlass::half_t DataType; typedef paddle::float16 data_t; }; -template <> class CutlassDtypeTraits { -public: +template <> +class CutlassDtypeTraits { + public: typedef cutlass::bfloat16_t DataType; typedef paddle::bfloat16 data_t; }; class CutlassGemmConfigMannager { -public: + public: static CutlassGemmConfigMannager &getInstance() { static CutlassGemmConfigMannager instance; return instance; } CutlassGemmConfigMannager(const CutlassGemmConfigMannager &) = delete; - CutlassGemmConfigMannager & - operator=(const CutlassGemmConfigMannager &) = delete; + CutlassGemmConfigMannager &operator=(const CutlassGemmConfigMannager &) = + delete; void up_date_configs(const nlohmann::json &j) { std::lock_guard lock(mutex_); @@ -102,7 +108,7 @@ public: return &json_; } -private: + private: void save_gemm_best_configs_(const std::string &config_file_path) { std::ifstream file(config_file_path); if (!file.good()) { diff --git a/custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.cu b/custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.cu index 6db16981c6..6ea5a275ad 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.cu @@ -19,14 +19,14 @@ #ifndef _WIN32 #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" -#endif // #ifndef _WIN32 +#endif // #ifndef _WIN32 #include "cutlass/gemm/gemm.h" #include "cutlass/numeric_types.h" #ifndef _WIN32 #pragma GCC diagnostic pop -#endif // #ifndef _WIN32 +#endif // #ifndef _WIN32 #include #include @@ -35,491 +35,509 @@ using namespace cutlass_extensions; -namespace kernels -{ -namespace cutlass_kernels -{ +namespace kernels { +namespace cutlass_kernels { -struct TileShape -{ - int m; - int n; +struct TileShape { + int m; + int n; }; -TileShape get_cta_shape_for_config(CutlassTileConfig tile_config) -{ - switch (tile_config) - { - case CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: return TileShape{16, 128}; - case CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: return TileShape{16, 256}; - case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: return TileShape{32, 128}; - case CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64: return TileShape{64, 64}; +TileShape get_cta_shape_for_config(CutlassTileConfig tile_config) { + switch (tile_config) { + case CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: + return TileShape{16, 128}; + case CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: + return TileShape{16, 256}; + case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + return TileShape{32, 128}; + case CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64: + return TileShape{64, 64}; case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: - case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: return TileShape{64, 128}; - case CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64: return TileShape{128, 64}; + case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + return TileShape{64, 128}; + case CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64: + return TileShape{128, 64}; case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: case CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64: - case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: return TileShape{128, 128}; - case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: return TileShape{128, 256}; - case CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: return TileShape{256, 128}; - case CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128: return TileShape{16, 256}; - default: throw("[get_grid_shape_for_config] Invalid config"); - } + case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: + return TileShape{128, 128}; + case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: + return TileShape{128, 256}; + case CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: + return TileShape{256, 128}; + case CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128: + return TileShape{16, 256}; + default: + throw("[get_grid_shape_for_config] Invalid config"); + } } -bool is_valid_split_k_factor(int64_t const m, int64_t const n, int64_t const k, TileShape const tile_shape, - int const split_k_factor, size_t const workspace_bytes, bool const is_weight_only) -{ +bool is_valid_split_k_factor(int64_t const m, + int64_t const n, + int64_t const k, + TileShape const tile_shape, + int const split_k_factor, + size_t const workspace_bytes, + bool const is_weight_only) { + // All tile sizes have a k_tile of 64. + static constexpr int k_tile = 128; - // All tile sizes have a k_tile of 64. - static constexpr int k_tile = 128; - - // For weight-only quant, we need k and k_elements_per_split to be a multiple of cta_k - if (is_weight_only) - { - if ((k % k_tile) != 0) - { - return false; - } - - if ((k % split_k_factor) != 0) - { - return false; - } - - int const k_elements_per_split = k / split_k_factor; - if ((k_elements_per_split % k_tile) != 0) - { - return false; - } + // For weight-only quant, we need k and k_elements_per_split to be a multiple + // of cta_k + if (is_weight_only) { + if ((k % k_tile) != 0) { + return false; } - // Check that the workspace has sufficient space for this split-k factor - int const ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; - int const ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; - int const required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; - - if (required_ws_bytes > workspace_bytes) - { - return false; + if ((k % split_k_factor) != 0) { + return false; } - return true; + int const k_elements_per_split = k / split_k_factor; + if ((k_elements_per_split % k_tile) != 0) { + return false; + } + } + + // Check that the workspace has sufficient space for this split-k factor + int const ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; + int const ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; + int const required_ws_bytes = + split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; + + if (required_ws_bytes > workspace_bytes) { + return false; + } + + return true; } std::vector get_candidate_tiles( - int const sm, CutlassGemmConfig::CandidateConfigTypeParam const config_type_param) -{ - enum class CutlassGemmType : char - { - Default, - WeightOnly, - Simt, - Int8, - Fp8 - }; + int const sm, + CutlassGemmConfig::CandidateConfigTypeParam const config_type_param) { + enum class CutlassGemmType : char { Default, WeightOnly, Simt, Int8, Fp8 }; - CutlassGemmType gemm_type = CutlassGemmType::Default; - if (config_type_param & CutlassGemmConfig::SIMT_ONLY) - { - gemm_type = CutlassGemmType::Simt; - } - else if (config_type_param & CutlassGemmConfig::WEIGHT_ONLY) - { - gemm_type = CutlassGemmType::WeightOnly; - } - else if (config_type_param & CutlassGemmConfig::INT8_ONLY) - { - gemm_type = CutlassGemmType::Int8; - } - else if (config_type_param & CutlassGemmConfig::FP8_ONLY) - { - gemm_type = CutlassGemmType::Fp8; - } + CutlassGemmType gemm_type = CutlassGemmType::Default; + if (config_type_param & CutlassGemmConfig::SIMT_ONLY) { + gemm_type = CutlassGemmType::Simt; + } else if (config_type_param & CutlassGemmConfig::WEIGHT_ONLY) { + gemm_type = CutlassGemmType::WeightOnly; + } else if (config_type_param & CutlassGemmConfig::INT8_ONLY) { + gemm_type = CutlassGemmType::Int8; + } else if (config_type_param & CutlassGemmConfig::FP8_ONLY) { + gemm_type = CutlassGemmType::Fp8; + } - std::vector base_configs{ - CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64}; - if (sm >= 75) - { - base_configs.push_back(CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64); - } + std::vector base_configs{ + CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64}; + if (sm >= 75) { + base_configs.push_back( + CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64); + } - switch (gemm_type) - { - case CutlassGemmType::Simt: return {CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8}; + switch (gemm_type) { + case CutlassGemmType::Simt: + return {CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8}; case CutlassGemmType::WeightOnly: - if (sm >= 75) - { - return {CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64, + if (sm >= 75) { + return {CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64, CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64, CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64}; - } - else - { - return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, - CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64}; - } - case CutlassGemmType::Int8: + } else { return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, - CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, - CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, - CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, - CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, - CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64}; + } + case CutlassGemmType::Int8: + return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, + CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; case CutlassGemmType::Fp8: - if (config_type_param & CutlassGemmConfig::GROUPED_GEMM) - { - if (sm == 89) - { - return {CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128, - CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, - CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, - CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, - CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, - CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, - CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; - } - else - { - // no valid ampere style fp8 configs for sm90 - return {}; - } + if (config_type_param & CutlassGemmConfig::GROUPED_GEMM) { + if (sm == 89) { + return {CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128, + CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, + CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; + } else { + // no valid ampere style fp8 configs for sm90 + return {}; } - default: return base_configs; - } + } + default: + return base_configs; + } } std::vector get_candidate_tiles_sm90( - int const sm, CutlassGemmConfig::CandidateConfigTypeParam const config) -{ + int const sm, CutlassGemmConfig::CandidateConfigTypeParam const config) { #ifdef FAST_BUILD - // Fast build disables all configs except this one for SM90 - return {CutlassTileConfigSM90::CtaShape128x128x128B}; + // Fast build disables all configs except this one for SM90 + return {CutlassTileConfigSM90::CtaShape128x128x128B}; #else - if (config & CutlassGemmConfig::GROUPED_GEMM) - { - return {CutlassTileConfigSM90::CtaShape128x16x128B, CutlassTileConfigSM90::CtaShape128x32x128B, - CutlassTileConfigSM90::CtaShape128x64x128B, CutlassTileConfigSM90::CtaShape128x128x128B, - CutlassTileConfigSM90::CtaShape128x256x128B, CutlassTileConfigSM90::CtaShape256x128x128B}; - } - else - { - return {CutlassTileConfigSM90::CtaShape64x16x128B, CutlassTileConfigSM90::CtaShape64x32x128B, - CutlassTileConfigSM90::CtaShape64x64x128B, CutlassTileConfigSM90::CtaShape64x128x128B, - CutlassTileConfigSM90::CtaShape64x256x128B, CutlassTileConfigSM90::CtaShape128x16x128B, - CutlassTileConfigSM90::CtaShape128x32x128B, CutlassTileConfigSM90::CtaShape128x64x128B, - CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B}; - } + if (config & CutlassGemmConfig::GROUPED_GEMM) { + return {CutlassTileConfigSM90::CtaShape128x16x128B, + CutlassTileConfigSM90::CtaShape128x32x128B, + CutlassTileConfigSM90::CtaShape128x64x128B, + CutlassTileConfigSM90::CtaShape128x128x128B, + CutlassTileConfigSM90::CtaShape128x256x128B, + CutlassTileConfigSM90::CtaShape256x128x128B}; + } else { + return {CutlassTileConfigSM90::CtaShape64x16x128B, + CutlassTileConfigSM90::CtaShape64x32x128B, + CutlassTileConfigSM90::CtaShape64x64x128B, + CutlassTileConfigSM90::CtaShape64x128x128B, + CutlassTileConfigSM90::CtaShape64x256x128B, + CutlassTileConfigSM90::CtaShape128x16x128B, + CutlassTileConfigSM90::CtaShape128x32x128B, + CutlassTileConfigSM90::CtaShape128x64x128B, + CutlassTileConfigSM90::CtaShape128x128x128B, + CutlassTileConfigSM90::CtaShape128x256x128B}; + } #endif } -// We only compile CUTLASS kernels with multi-cast along M if the M tile is >= 128. This is purely to improve -// compilation speed. -bool supports_mcast_along_m(CutlassTileConfigSM90 const tile) -{ +// We only compile CUTLASS kernels with multi-cast along M if the M tile is >= +// 128. This is purely to improve compilation speed. +bool supports_mcast_along_m(CutlassTileConfigSM90 const tile) { #ifdef FAST_BUILD - return false; + return false; #else - std::set valid_tiles{CutlassTileConfigSM90::CtaShape128x16x128B, - CutlassTileConfigSM90::CtaShape128x32x128B, CutlassTileConfigSM90::CtaShape128x64x128B, - CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B, - CutlassTileConfigSM90::CtaShape256x128x128B}; - return valid_tiles.count(tile) == 1; + std::set valid_tiles{ + CutlassTileConfigSM90::CtaShape128x16x128B, + CutlassTileConfigSM90::CtaShape128x32x128B, + CutlassTileConfigSM90::CtaShape128x64x128B, + CutlassTileConfigSM90::CtaShape128x128x128B, + CutlassTileConfigSM90::CtaShape128x256x128B, + CutlassTileConfigSM90::CtaShape256x128x128B}; + return valid_tiles.count(tile) == 1; #endif } -// We only compile CUTLASS kernels with multi-cast along N if the N tile is >= 128. This is purely to improve -// compilation speed. -bool supports_mcast_along_n(CutlassTileConfigSM90 const tile) -{ +// We only compile CUTLASS kernels with multi-cast along N if the N tile is >= +// 128. This is purely to improve compilation speed. +bool supports_mcast_along_n(CutlassTileConfigSM90 const tile) { #ifdef FAST_BUILD - return false; + return false; #else - std::set valid_tiles{CutlassTileConfigSM90::CtaShape64x128x128B, - CutlassTileConfigSM90::CtaShape64x256x128B, CutlassTileConfigSM90::CtaShape128x128x128B, - CutlassTileConfigSM90::CtaShape128x256x128B, CutlassTileConfigSM90::CtaShape256x128x128B}; - return valid_tiles.count(tile) == 1; + std::set valid_tiles{ + CutlassTileConfigSM90::CtaShape64x128x128B, + CutlassTileConfigSM90::CtaShape64x256x128B, + CutlassTileConfigSM90::CtaShape128x128x128B, + CutlassTileConfigSM90::CtaShape128x256x128B, + CutlassTileConfigSM90::CtaShape256x128x128B}; + return valid_tiles.count(tile) == 1; #endif } // SM100 (Blackwell) candidate tile configurations std::vector get_candidate_tiles_sm100( - int /*sm*/, CutlassGemmConfig::CandidateConfigTypeParam const config) -{ + int /*sm*/, CutlassGemmConfig::CandidateConfigTypeParam const config) { #ifdef FAST_BUILD - return {CutlassTileConfigSM100::CtaShape128x128x128B}; + return {CutlassTileConfigSM100::CtaShape128x128x128B}; #else - /* Grouped-GEMM path first (Blackwell uses 1-SM and 2-SM “cluster” kernels) */ - if (config & CutlassGemmConfig::GROUPED_GEMM) + /* Grouped-GEMM path first (Blackwell uses 1-SM and 2-SM “cluster” kernels) + */ + if (config & CutlassGemmConfig::GROUPED_GEMM) { + if (config & CutlassGemmConfig::FP4_ONLY) // nvfp4 / mx_fp4 { - if (config & CutlassGemmConfig::FP4_ONLY) // nvfp4 / mx_fp4 - { - return { - /* 1 SM (M=128) */ - CutlassTileConfigSM100::CtaShape128x128x128B, - CutlassTileConfigSM100::CtaShape128x256x128B, - /* 2 SM (M=256) */ - CutlassTileConfigSM100::CtaShape256x128x128B, - CutlassTileConfigSM100::CtaShape256x256x128B, - /* slim tiles for very tall matrices */ - CutlassTileConfigSM100::CtaShape128x64x128B, - CutlassTileConfigSM100::CtaShape256x64x128B}; - } + return {/* 1 SM (M=128) */ + CutlassTileConfigSM100::CtaShape128x128x128B, + CutlassTileConfigSM100::CtaShape128x256x128B, + /* 2 SM (M=256) */ + CutlassTileConfigSM100::CtaShape256x128x128B, + CutlassTileConfigSM100::CtaShape256x256x128B, + /* slim tiles for very tall matrices */ + CutlassTileConfigSM100::CtaShape128x64x128B, + CutlassTileConfigSM100::CtaShape256x64x128B}; + } - /* Fp8 / Fp16 grouped-GEMM */ - return { - CutlassTileConfigSM100::CtaShape128x128x128B, + /* Fp8 / Fp16 grouped-GEMM */ + return {CutlassTileConfigSM100::CtaShape128x128x128B, CutlassTileConfigSM100::CtaShape128x256x128B, CutlassTileConfigSM100::CtaShape256x128x128B, CutlassTileConfigSM100::CtaShape256x256x128B}; - } + } - /* Non-grouped path (plain GEMM or weight-only) */ - return { - /* 1 SM tiles */ - CutlassTileConfigSM100::CtaShape64x64x128B, - CutlassTileConfigSM100::CtaShape64x128x128B, - CutlassTileConfigSM100::CtaShape64x256x128B, - CutlassTileConfigSM100::CtaShape128x64x128B, - CutlassTileConfigSM100::CtaShape128x128x128B, - CutlassTileConfigSM100::CtaShape128x256x128B, - /* 2 SM tiles */ - CutlassTileConfigSM100::CtaShape256x64x128B, - CutlassTileConfigSM100::CtaShape256x128x128B, - CutlassTileConfigSM100::CtaShape256x256x128B}; + /* Non-grouped path (plain GEMM or weight-only) */ + return {/* 1 SM tiles */ + CutlassTileConfigSM100::CtaShape64x64x128B, + CutlassTileConfigSM100::CtaShape64x128x128B, + CutlassTileConfigSM100::CtaShape64x256x128B, + CutlassTileConfigSM100::CtaShape128x64x128B, + CutlassTileConfigSM100::CtaShape128x128x128B, + CutlassTileConfigSM100::CtaShape128x256x128B, + /* 2 SM tiles */ + CutlassTileConfigSM100::CtaShape256x64x128B, + CutlassTileConfigSM100::CtaShape256x128x128B, + CutlassTileConfigSM100::CtaShape256x256x128B}; #endif } // M-multicast support for SM100. -bool supports_mcast_along_m_sm100(CutlassTileConfigSM100 tile) -{ +bool supports_mcast_along_m_sm100(CutlassTileConfigSM100 tile) { #ifdef FAST_BUILD - return false; + return false; #else - std::set m_tiles{ - CutlassTileConfigSM100::CtaShape128x64x128B, - CutlassTileConfigSM100::CtaShape128x128x128B, - CutlassTileConfigSM100::CtaShape128x256x128B, - CutlassTileConfigSM100::CtaShape256x64x128B, - CutlassTileConfigSM100::CtaShape256x128x128B, - CutlassTileConfigSM100::CtaShape256x256x128B}; - return m_tiles.count(tile) == 1; + std::set m_tiles{ + CutlassTileConfigSM100::CtaShape128x64x128B, + CutlassTileConfigSM100::CtaShape128x128x128B, + CutlassTileConfigSM100::CtaShape128x256x128B, + CutlassTileConfigSM100::CtaShape256x64x128B, + CutlassTileConfigSM100::CtaShape256x128x128B, + CutlassTileConfigSM100::CtaShape256x256x128B}; + return m_tiles.count(tile) == 1; #endif } // N-multicast support for SM100. -bool supports_mcast_along_n_sm100(CutlassTileConfigSM100 tile) -{ +bool supports_mcast_along_n_sm100(CutlassTileConfigSM100 tile) { #ifdef FAST_BUILD - return false; + return false; #else - std::set n_tiles{ - CutlassTileConfigSM100::CtaShape64x128x128B, - CutlassTileConfigSM100::CtaShape64x256x128B, - CutlassTileConfigSM100::CtaShape128x128x128B, - CutlassTileConfigSM100::CtaShape128x256x128B, - CutlassTileConfigSM100::CtaShape256x128x128B}; - return n_tiles.count(tile) == 1; + std::set n_tiles{ + CutlassTileConfigSM100::CtaShape64x128x128B, + CutlassTileConfigSM100::CtaShape64x256x128B, + CutlassTileConfigSM100::CtaShape128x128x128B, + CutlassTileConfigSM100::CtaShape128x256x128B, + CutlassTileConfigSM100::CtaShape256x128x128B}; + return n_tiles.count(tile) == 1; #endif } - std::vector get_candidate_configs( - int sm, int const max_split_k, CutlassGemmConfig::CandidateConfigTypeParam const config_type_param) -{ - if (sm == 90 && (config_type_param & CutlassGemmConfig::HOPPER)) - { - std::vector tiles = get_candidate_tiles_sm90(sm, config_type_param); + int sm, + int const max_split_k, + CutlassGemmConfig::CandidateConfigTypeParam const config_type_param) { + if (sm == 90 && (config_type_param & CutlassGemmConfig::HOPPER)) { + std::vector tiles = + get_candidate_tiles_sm90(sm, config_type_param); - std::vector candidate_configs; - for (auto const& tile_config : tiles) - { - CutlassGemmConfig config( - tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1); - candidate_configs.push_back(config); + std::vector candidate_configs; + for (auto const& tile_config : tiles) { + CutlassGemmConfig config(tile_config, + MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, + ClusterShape::ClusterShape_1x1x1); + candidate_configs.push_back(config); - bool const has_m_mcast = supports_mcast_along_m(tile_config); - bool const has_n_mcast = supports_mcast_along_n(tile_config); - if (has_m_mcast) - { - CutlassGemmConfig config(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, - ClusterShape::ClusterShape_2x1x1); - candidate_configs.push_back(config); - } + bool const has_m_mcast = supports_mcast_along_m(tile_config); + bool const has_n_mcast = supports_mcast_along_n(tile_config); + if (has_m_mcast) { + CutlassGemmConfig config(tile_config, + MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, + ClusterShape::ClusterShape_2x1x1); + candidate_configs.push_back(config); + } - if (has_n_mcast) - { - CutlassGemmConfig config(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, - ClusterShape::ClusterShape_1x2x1); - candidate_configs.push_back(config); - } + if (has_n_mcast) { + CutlassGemmConfig config(tile_config, + MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, + ClusterShape::ClusterShape_1x2x1); + candidate_configs.push_back(config); + } - if (has_m_mcast && has_n_mcast) - { - CutlassGemmConfig config(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, - ClusterShape::ClusterShape_2x2x1); - candidate_configs.push_back(config); - } - } - return candidate_configs; + if (has_m_mcast && has_n_mcast) { + CutlassGemmConfig config(tile_config, + MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, + ClusterShape::ClusterShape_2x2x1); + candidate_configs.push_back(config); + } } - else if (sm == 100 && (config_type_param & CutlassGemmConfig::BLACKWELL)) // Assuming SM100 for Blackwell - { - std::vector tiles = get_candidate_tiles_sm100(sm, config_type_param); - std::vector candidate_configs; - - for (auto const& tile_config_sm100 : tiles) - { - // SM100 uses MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO similar to SM90. - // Cluster shapes are also handled similarly. - CutlassGemmConfig config( - tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1); - candidate_configs.push_back(config); - - bool const has_m_mcast = supports_mcast_along_m_sm100(tile_config_sm100); - bool const has_n_mcast = supports_mcast_along_n_sm100(tile_config_sm100); - - if (has_m_mcast) - { - CutlassGemmConfig mcast_m_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, - ClusterShape::ClusterShape_2x1x1); - candidate_configs.push_back(mcast_m_config); - } - - if (has_n_mcast) - { - CutlassGemmConfig mcast_n_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, - ClusterShape::ClusterShape_1x2x1); - candidate_configs.push_back(mcast_n_config); - } - - if (has_m_mcast && has_n_mcast) - { - CutlassGemmConfig mcast_mn_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, - ClusterShape::ClusterShape_2x2x1); - candidate_configs.push_back(mcast_mn_config); - } - } - return candidate_configs; - } - - // Fallback to older architecture configurations - std::vector tiles = get_candidate_tiles(sm, config_type_param); - std::vector candidate_configs; //Already declared above for SM90 path, ensure scope is correct or redeclare if necessary. - // It's fine here as it's within an else if / else block. - bool const int8_configs_only = config_type_param & CutlassGemmConfig::INT8_ONLY; - int const min_stages = int8_configs_only ? 3 : 2; - int const max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2); - for (auto const& tile_config : tiles) - { - for (int stages = min_stages; stages <= max_stages; ++stages) - { - CutlassGemmConfig config(tile_config, SplitKStyle::NO_SPLIT_K, 1, stages); - candidate_configs.push_back(config); - if (sm >= 75) - { - for (int split_k_factor = 2; split_k_factor <= max_split_k; ++split_k_factor) - { - auto config = CutlassGemmConfig{tile_config, SplitKStyle::SPLIT_K_SERIAL, split_k_factor, stages}; - candidate_configs.push_back(config); - } - } - } - } - return candidate_configs; + } else if (sm == 100 && + (config_type_param & + CutlassGemmConfig::BLACKWELL)) // Assuming SM100 for Blackwell + { + std::vector tiles = + get_candidate_tiles_sm100(sm, config_type_param); + std::vector candidate_configs; + + for (auto const& tile_config_sm100 : tiles) { + // SM100 uses MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO + // similar to SM90. Cluster shapes are also handled similarly. + CutlassGemmConfig config(tile_config_sm100, + MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, + ClusterShape::ClusterShape_1x1x1); + candidate_configs.push_back(config); + + bool const has_m_mcast = supports_mcast_along_m_sm100(tile_config_sm100); + bool const has_n_mcast = supports_mcast_along_n_sm100(tile_config_sm100); + + if (has_m_mcast) { + CutlassGemmConfig mcast_m_config(tile_config_sm100, + MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, + ClusterShape::ClusterShape_2x1x1); + candidate_configs.push_back(mcast_m_config); + } + + if (has_n_mcast) { + CutlassGemmConfig mcast_n_config(tile_config_sm100, + MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, + ClusterShape::ClusterShape_1x2x1); + candidate_configs.push_back(mcast_n_config); + } + + if (has_m_mcast && has_n_mcast) { + CutlassGemmConfig mcast_mn_config(tile_config_sm100, + MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, + ClusterShape::ClusterShape_2x2x1); + candidate_configs.push_back(mcast_mn_config); + } + } + return candidate_configs; + } + + // Fallback to older architecture configurations + std::vector tiles = + get_candidate_tiles(sm, config_type_param); + std::vector + candidate_configs; // Already declared above for SM90 path, ensure scope + // is correct or redeclare if necessary. + // It's fine here as it's within an else if / else + // block. + bool const int8_configs_only = + config_type_param & CutlassGemmConfig::INT8_ONLY; + int const min_stages = int8_configs_only ? 3 : 2; + int const max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2); + for (auto const& tile_config : tiles) { + for (int stages = min_stages; stages <= max_stages; ++stages) { + CutlassGemmConfig config(tile_config, SplitKStyle::NO_SPLIT_K, 1, stages); + candidate_configs.push_back(config); + if (sm >= 75) { + for (int split_k_factor = 2; split_k_factor <= max_split_k; + ++split_k_factor) { + auto config = CutlassGemmConfig{ + tile_config, SplitKStyle::SPLIT_K_SERIAL, split_k_factor, stages}; + candidate_configs.push_back(config); + } + } + } + } + + return candidate_configs; } -CutlassGemmConfig estimate_best_config_from_occupancies(std::vector const& candidate_configs, - std::vector const& occupancies, int64_t const m, int64_t const n, int64_t const k, int64_t const num_experts, - int const split_k_limit, size_t const workspace_bytes, int const multi_processor_count, int const is_weight_only) -{ +CutlassGemmConfig estimate_best_config_from_occupancies( + std::vector const& candidate_configs, + std::vector const& occupancies, + int64_t const m, + int64_t const n, + int64_t const k, + int64_t const num_experts, + int const split_k_limit, + size_t const workspace_bytes, + int const multi_processor_count, + int const is_weight_only) { + if (occupancies.size() != candidate_configs.size()) { + throw( + "[estimate_best_config_from_occupancies] occpancies and " + "candidate configs vectors must have equal length."); + } - if (occupancies.size() != candidate_configs.size()) - { - throw( - "[estimate_best_config_from_occupancies] occpancies and " - "candidate configs vectors must have equal length."); + CutlassGemmConfig best_config; + // Score will be [0, 1]. The objective is to minimize this score. + // It represents the fraction of SM resources unused in the last wave. + float config_score = 1.0f; + int config_waves = INT_MAX; + int current_m_tile = 0; + + int const max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit; + for (int ii = 0; ii < candidate_configs.size(); ++ii) { + CutlassGemmConfig candidate_config = candidate_configs[ii]; + TileShape tile_shape = + get_cta_shape_for_config(candidate_config.tile_config); + int occupancy = occupancies[ii]; + + if (occupancy == 0) { + continue; } - CutlassGemmConfig best_config; - // Score will be [0, 1]. The objective is to minimize this score. - // It represents the fraction of SM resources unused in the last wave. - float config_score = 1.0f; - int config_waves = INT_MAX; - int current_m_tile = 0; - - int const max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit; - for (int ii = 0; ii < candidate_configs.size(); ++ii) - { - CutlassGemmConfig candidate_config = candidate_configs[ii]; - TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config); - int occupancy = occupancies[ii]; - - if (occupancy == 0) - { - continue; - } - - // Keep small tile sizes when possible. - if (best_config.tile_config != CutlassTileConfig::ChooseWithHeuristic && m < current_m_tile - && current_m_tile < tile_shape.m) - { - continue; - } - - int const ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; - int const ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; - - for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor) - { - if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, is_weight_only)) - { - int const ctas_per_wave = occupancy * multi_processor_count; - int const ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor; - - int const num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave; - float const num_waves_fractional = ctas_for_problem / float(ctas_per_wave); - float const current_score = float(num_waves_total) - num_waves_fractional; - - float const score_slack = 0.1f; - if (current_score < config_score - || ((config_waves > num_waves_total) && (current_score < config_score + score_slack))) - { - config_score = current_score; - config_waves = num_waves_total; - SplitKStyle split_style - = split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; - best_config = CutlassGemmConfig( - candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages); - current_m_tile = tile_shape.m; - } - else if (current_score == config_score - && (best_config.stages < candidate_config.stages || split_k_factor < best_config.split_k_factor - || current_m_tile < tile_shape.m)) - { - // Prefer deeper pipeline or smaller split-k - SplitKStyle split_style - = split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; - best_config = CutlassGemmConfig( - candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages); - current_m_tile = tile_shape.m; - config_waves = num_waves_total; - } - } - } + // Keep small tile sizes when possible. + if (best_config.tile_config != CutlassTileConfig::ChooseWithHeuristic && + m < current_m_tile && current_m_tile < tile_shape.m) { + continue; } - if (best_config.tile_config == CutlassTileConfig::ChooseWithHeuristic) - { - throw("Heurisitc failed to find a valid config."); - } + int const ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; + int const ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; - return best_config; + for (int split_k_factor = 1; split_k_factor <= max_split_k; + ++split_k_factor) { + if (is_valid_split_k_factor(m, + n, + k, + tile_shape, + split_k_factor, + workspace_bytes, + is_weight_only)) { + int const ctas_per_wave = occupancy * multi_processor_count; + int const ctas_for_problem = + ctas_in_m_dim * ctas_in_n_dim * split_k_factor; + + int const num_waves_total = + (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave; + float const num_waves_fractional = + ctas_for_problem / float(ctas_per_wave); + float const current_score = + float(num_waves_total) - num_waves_fractional; + + float const score_slack = 0.1f; + if (current_score < config_score || + ((config_waves > num_waves_total) && + (current_score < config_score + score_slack))) { + config_score = current_score; + config_waves = num_waves_total; + SplitKStyle split_style = split_k_factor > 1 + ? SplitKStyle::SPLIT_K_SERIAL + : SplitKStyle::NO_SPLIT_K; + best_config = CutlassGemmConfig(candidate_config.tile_config, + split_style, + split_k_factor, + candidate_config.stages); + current_m_tile = tile_shape.m; + } else if (current_score == config_score && + (best_config.stages < candidate_config.stages || + split_k_factor < best_config.split_k_factor || + current_m_tile < tile_shape.m)) { + // Prefer deeper pipeline or smaller split-k + SplitKStyle split_style = split_k_factor > 1 + ? SplitKStyle::SPLIT_K_SERIAL + : SplitKStyle::NO_SPLIT_K; + best_config = CutlassGemmConfig(candidate_config.tile_config, + split_style, + split_k_factor, + candidate_config.stages); + current_m_tile = tile_shape.m; + config_waves = num_waves_total; + } + } + } + } + + if (best_config.tile_config == CutlassTileConfig::ChooseWithHeuristic) { + throw("Heurisitc failed to find a valid config."); + } + + return best_config; } -} // namespace cutlass_kernels -} // namespace kernels +} // namespace cutlass_kernels +} // namespace kernels diff --git a/custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.h b/custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.h index 8165bc421c..b6839be7f3 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.h +++ b/custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.h @@ -20,36 +20,47 @@ #include "cutlass_extensions/gemm_configs.h" #include "common/cudaUtils.h" -namespace kernels -{ -namespace cutlass_kernels -{ +namespace kernels { +namespace cutlass_kernels { template -struct should_filter_sm90_gemm_problem_shape -{ +struct should_filter_sm90_gemm_problem_shape { #ifdef FAST_BUILD - constexpr static int TILE_K = 128 * 8 / cutlass::sizeof_bits::value; - using SupportedCtaShape = cute::Shape>; - using SupportedCgaShape = cute::Shape; + constexpr static int TILE_K = + 128 * 8 / cutlass::sizeof_bits::value; + using SupportedCtaShape = + cute::Shape>; + using SupportedCgaShape = cute::Shape; - constexpr static bool value - = !cute::is_same_v || !cute::is_same_v; + constexpr static bool value = + !cute::is_same_v || + !cute::is_same_v; #else - constexpr static bool value = false; + constexpr static bool value = false; #endif }; template -constexpr static bool should_filter_sm90_gemm_problem_shape_v - = should_filter_sm90_gemm_problem_shape::value; +constexpr static bool should_filter_sm90_gemm_problem_shape_v = + should_filter_sm90_gemm_problem_shape::value; std::vector get_candidate_configs( - int sm, int const max_split_k, cutlass_extensions::CutlassGemmConfig::CandidateConfigTypeParam const); + int sm, + int const max_split_k, + cutlass_extensions::CutlassGemmConfig::CandidateConfigTypeParam const); cutlass_extensions::CutlassGemmConfig estimate_best_config_from_occupancies( std::vector const& candidate_configs, - std::vector const& occupancies, int64_t const m, int64_t const n, int64_t const k, int64_t const num_experts, - int const split_k_limit, size_t const workspace_bytes, int const multi_processor_count, int const is_weight_only); + std::vector const& occupancies, + int64_t const m, + int64_t const n, + int64_t const k, + int64_t const num_experts, + int const split_k_limit, + size_t const workspace_bytes, + int const multi_processor_count, + int const is_weight_only); -} // namespace cutlass_kernels -} // namespace kernels +} // namespace cutlass_kernels +} // namespace kernels diff --git a/custom_ops/gpu_ops/cutlass_kernels/cutlass_preprocessors.cu b/custom_ops/gpu_ops/cutlass_kernels/cutlass_preprocessors.cu index 5c62b58085..41c0412cd9 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/cutlass_preprocessors.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/cutlass_preprocessors.cu @@ -19,752 +19,781 @@ #include "cutlass_kernels/cutlass_preprocessors.h" #include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" -namespace kernels -{ -namespace cutlass_kernels -{ +namespace kernels { +namespace cutlass_kernels { -struct LayoutDetails -{ - enum class Layout - { - UNKNOWN, - ROW_MAJOR, - COLUMN_MAJOR - }; +struct LayoutDetails { + enum class Layout { UNKNOWN, ROW_MAJOR, COLUMN_MAJOR }; - Layout layoutB = Layout::UNKNOWN; - int rows_per_column_tile = 1; - int columns_interleaved = 1; + Layout layoutB = Layout::UNKNOWN; + int rows_per_column_tile = 1; + int columns_interleaved = 1; - bool uses_imma_ldsm = false; + bool uses_imma_ldsm = false; }; template -struct getLayoutDetails -{ +struct getLayoutDetails {}; + +template <> +struct getLayoutDetails { + LayoutDetails operator()() { + LayoutDetails layout_details; + layout_details.layoutB = LayoutDetails::Layout::ROW_MAJOR; + return layout_details; + } }; template <> -struct getLayoutDetails -{ - LayoutDetails operator()() - { - LayoutDetails layout_details; - layout_details.layoutB = LayoutDetails::Layout::ROW_MAJOR; - return layout_details; - } -}; - -template <> -struct getLayoutDetails -{ - LayoutDetails operator()() - { - LayoutDetails layout_details; - layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR; - return layout_details; - } +struct getLayoutDetails { + LayoutDetails operator()() { + LayoutDetails layout_details; + layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR; + return layout_details; + } }; template -struct getLayoutDetails> -{ - LayoutDetails operator()() - { - LayoutDetails layout_details; - layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR; - layout_details.rows_per_column_tile = RowsPerTile; - layout_details.columns_interleaved = ColumnsInterleaved; - return layout_details; - } +struct getLayoutDetails< + cutlass::layout::ColumnMajorTileInterleave> { + LayoutDetails operator()() { + LayoutDetails layout_details; + layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR; + layout_details.rows_per_column_tile = RowsPerTile; + layout_details.columns_interleaved = ColumnsInterleaved; + return layout_details; + } }; template -LayoutDetails getLayoutDetailsForArchAndQuantType() -{ - - using CompileTraits = cutlass::gemm::kernel::LayoutDetailsB; - using LayoutB = typename CompileTraits::Layout; - using MmaOperator = typename CompileTraits::Operator; - LayoutDetails details = getLayoutDetails()(); - details.uses_imma_ldsm = std::is_same::value; - return details; +LayoutDetails getLayoutDetailsForArchAndQuantType() { + using CompileTraits = + cutlass::gemm::kernel::LayoutDetailsB; + using LayoutB = typename CompileTraits::Layout; + using MmaOperator = typename CompileTraits::Operator; + LayoutDetails details = getLayoutDetails()(); + details.uses_imma_ldsm = std::is_same< + MmaOperator, + cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA>::value; + return details; } template -LayoutDetails getLayoutDetailsForArch(QuantType quant_type) -{ - int const bits_per_weight_element = get_weight_quant_bits(quant_type); - LayoutDetails details; - switch (quant_type) - { +LayoutDetails getLayoutDetailsForArch(QuantType quant_type) { + int const bits_per_weight_element = get_weight_quant_bits(quant_type); + LayoutDetails details; + switch (quant_type) { case QuantType::W8_A16: - details = getLayoutDetailsForArchAndQuantType(); - break; + details = getLayoutDetailsForArchAndQuantType(); + break; case QuantType::W4_A16: - details = getLayoutDetailsForArchAndQuantType(); - break; + details = getLayoutDetailsForArchAndQuantType(); + break; case QuantType::W4_AFP8: - details = getLayoutDetailsForArchAndQuantType(); - break; - default: PADDLE_THROW("Unsupported quantization type"); - } - return details; + details = getLayoutDetailsForArchAndQuantType(); + break; + default: + PADDLE_THROW("Unsupported quantization type"); + } + return details; } -LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch) -{ - if (arch >= 70 && arch < 75) - { - return getLayoutDetailsForArch(quant_type); - } - else if (arch >= 75 && arch < 80) - { - return getLayoutDetailsForArch(quant_type); - } - else if (arch >= 80 && arch < 90) - { - return getLayoutDetailsForArch(quant_type); - } - else if (arch == 90) - { - return getLayoutDetailsForArch(quant_type); - } - else - { - PADDLE_ENFORCE(false, "Unsupported Arch"); - return LayoutDetails(); - } +LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch) { + if (arch >= 70 && arch < 75) { + return getLayoutDetailsForArch(quant_type); + } else if (arch >= 75 && arch < 80) { + return getLayoutDetailsForArch(quant_type); + } else if (arch >= 80 && arch < 90) { + return getLayoutDetailsForArch(quant_type); + } else if (arch == 90) { + return getLayoutDetailsForArch(quant_type); + } else { + PADDLE_ENFORCE(false, "Unsupported Arch"); + return LayoutDetails(); + } } -// Permutes the rows of B in a way that is compatible with Turing+ architectures. +// Permutes the rows of B in a way that is compatible with Turing+ +// architectures. // // Throws an error for other architectures. // The data is permuted such that: // For W8_A16, each group of 16 rows is permuted using the map below: // 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15 // For W4_A16, each group of 32 rows is permuted using the map below: -// 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 12 13 20 21 28 29 6 7 14 15 22 23 30 31 +// 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 12 13 20 21 28 29 6 7 14 15 22 +// 23 30 31 // For W4_A8, see the map in the code. The idea is similar to above. -// The goal of this permutation is to ensure data ends up in the correct threads after -// we execute LDSM. It counteracts the effect of the data being of different widths. -// For more information about the expected layouts, see the MMA section in the PTX docs. -std::vector get_permutation_map(QuantType quant_type) -{ - - if (quant_type == QuantType::W8_A16) - { - return {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15}; - } - else if (quant_type == QuantType::W4_A16) - { - return {0, 1, 8, 9, 16, 17, 24, 25, 2, 3, 10, 11, 18, 19, 26, 27, 4, 5, 12, 13, 20, 21, 28, 29, 6, 7, 14, 15, - 22, 23, 30, 31}; - } - else if (quant_type == QuantType::W4_AFP8) - { - return {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23, 8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, - 28, 29, 30, 31}; - } - else - { - PADDLE_THROW("Invalid quantization type for LDSM permutation"); - } +// The goal of this permutation is to ensure data ends up in the correct threads +// after we execute LDSM. It counteracts the effect of the data being of +// different widths. For more information about the expected layouts, see the +// MMA section in the PTX docs. +std::vector get_permutation_map(QuantType quant_type) { + if (quant_type == QuantType::W8_A16) { + return {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15}; + } else if (quant_type == QuantType::W4_A16) { + return {0, 1, 8, 9, 16, 17, 24, 25, 2, 3, 10, 11, 18, 19, 26, 27, + 4, 5, 12, 13, 20, 21, 28, 29, 6, 7, 14, 15, 22, 23, 30, 31}; + } else if (quant_type == QuantType::W4_AFP8) { + return {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23, + 8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31}; + } else { + PADDLE_THROW("Invalid quantization type for LDSM permutation"); + } } -void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, int8_t const* quantized_tensor, - std::vector const& shape, QuantType quant_type, int64_t const arch_version) -{ - // We only want to run this step for weight only quant. - std::vector row_permutation = get_permutation_map(quant_type); +void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, + int8_t const* quantized_tensor, + std::vector const& shape, + QuantType quant_type, + int64_t const arch_version) { + // We only want to run this step for weight only quant. + std::vector row_permutation = get_permutation_map(quant_type); - PADDLE_ENFORCE(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); - const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; - const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; - const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + PADDLE_ENFORCE(shape.size() == 2 || shape.size() == 3, + "Shape must be 2-D or 3-D"); + const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; - int const BITS_PER_ELT = get_weight_quant_bits(quant_type); - int const K = 16 / BITS_PER_ELT; - int const ELTS_PER_BYTE = 8 / BITS_PER_ELT; - int const ELTS_PER_REG = 32 / BITS_PER_ELT; + int const BITS_PER_ELT = get_weight_quant_bits(quant_type); + int const K = 16 / BITS_PER_ELT; + int const ELTS_PER_BYTE = 8 / BITS_PER_ELT; + int const ELTS_PER_REG = 32 / BITS_PER_ELT; - uint32_t const* input_byte_ptr = reinterpret_cast(quantized_tensor); - uint32_t* output_byte_ptr = reinterpret_cast(permuted_quantized_tensor); + uint32_t const* input_byte_ptr = + reinterpret_cast(quantized_tensor); + uint32_t* output_byte_ptr = + reinterpret_cast(permuted_quantized_tensor); - int MMA_SHAPE_N = 8; - int B_ROWS_PER_MMA = 8 * K; - int const elts_in_int32 = 32 / BITS_PER_ELT; + int MMA_SHAPE_N = 8; + int B_ROWS_PER_MMA = 8 * K; + int const elts_in_int32 = 32 / BITS_PER_ELT; - int const num_vec_cols = num_cols / elts_in_int32; + int const num_vec_cols = num_cols / elts_in_int32; - PADDLE_ENFORCE( - arch_version >= 75, "Unsupported Arch. Pre-volta not supported. Column interleave not needed on Volta."); + PADDLE_ENFORCE(arch_version >= 75, + "Unsupported Arch. Pre-volta not supported. Column interleave " + "not needed on Volta."); - PADDLE_ENFORCE(num_rows % B_ROWS_PER_MMA == 0, - "Invalid shape for quantized tensor. Number of rows of quantized matrix must be a multiple of %d", - B_ROWS_PER_MMA); - PADDLE_ENFORCE(num_cols % MMA_SHAPE_N == 0, - "Invalid shape for quantized tensor. On turing/Ampere, the number of cols must be a multiple of %d.", - MMA_SHAPE_N); + PADDLE_ENFORCE(num_rows % B_ROWS_PER_MMA == 0, + "Invalid shape for quantized tensor. Number of rows of " + "quantized matrix must be a multiple of %d", + B_ROWS_PER_MMA); + PADDLE_ENFORCE(num_cols % MMA_SHAPE_N == 0, + "Invalid shape for quantized tensor. On turing/Ampere, the " + "number of cols must be a multiple of %d.", + MMA_SHAPE_N); - PADDLE_ENFORCE(size_t(B_ROWS_PER_MMA) == row_permutation.size(), "Unexpected number of LDSM rows permuted."); + PADDLE_ENFORCE(size_t(B_ROWS_PER_MMA) == row_permutation.size(), + "Unexpected number of LDSM rows permuted."); - for (int expert = 0; expert < num_experts; ++expert) - { - const int64_t matrix_offset = expert * int64_t(num_rows) * int64_t(num_vec_cols); - for (int base_row = 0; base_row < num_rows; base_row += B_ROWS_PER_MMA) - { - for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) - { + for (int expert = 0; expert < num_experts; ++expert) { + const int64_t matrix_offset = + expert * int64_t(num_rows) * int64_t(num_vec_cols); + for (int base_row = 0; base_row < num_rows; base_row += B_ROWS_PER_MMA) { + for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) { + for (int write_col = 0; write_col < num_vec_cols; ++write_col) { + int const write_row = base_row + tile_row; + int const tile_read_row = row_permutation[tile_row]; + int const read_row = base_row + tile_read_row; + int const read_col = write_col; - for (int write_col = 0; write_col < num_vec_cols; ++write_col) - { - int const write_row = base_row + tile_row; - int const tile_read_row = row_permutation[tile_row]; - int const read_row = base_row + tile_read_row; - int const read_col = write_col; + const int64_t read_offset = + matrix_offset + int64_t(read_row) * num_vec_cols + read_col; + const int64_t write_offset = + matrix_offset + int64_t(write_row) * num_vec_cols + write_col; - const int64_t read_offset = matrix_offset + int64_t(read_row) * num_vec_cols + read_col; - const int64_t write_offset = matrix_offset + int64_t(write_row) * num_vec_cols + write_col; - - output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; - } - } + output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; } + } } + } } // We need to use this transpose to correctly handle packed int4 and int8 data -// The reason this code is relatively complex is that the "trivial" loops took a substantial -// amount of time to transpose leading to long preprocessing times. This seemed to be a big -// issue for relatively large models. +// The reason this code is relatively complex is that the "trivial" loops took a +// substantial amount of time to transpose leading to long preprocessing times. +// This seemed to be a big issue for relatively large models. template -void subbyte_transpose_impl( - int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor, std::vector const& shape) -{ - constexpr int bits_per_elt = get_weight_quant_bits(quant_type); +void subbyte_transpose_impl(int8_t* transposed_quantized_tensor, + int8_t const* quantized_tensor, + std::vector const& shape) { + constexpr int bits_per_elt = get_weight_quant_bits(quant_type); - PADDLE_ENFORCE(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); - const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; - const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; - const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + PADDLE_ENFORCE(shape.size() == 2 || shape.size() == 3, + "Shape must be 2-D or 3-D"); + const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; - const size_t col_bytes = num_cols * bits_per_elt / 8; - const size_t col_bytes_trans = num_rows * bits_per_elt / 8; - const size_t num_bytes = size_t(num_experts) * num_rows * col_bytes; + const size_t col_bytes = num_cols * bits_per_elt / 8; + const size_t col_bytes_trans = num_rows * bits_per_elt / 8; + const size_t num_bytes = size_t(num_experts) * num_rows * col_bytes; - uint8_t const* input_byte_ptr = reinterpret_cast(quantized_tensor); - uint8_t* output_byte_ptr = reinterpret_cast(transposed_quantized_tensor); + uint8_t const* input_byte_ptr = + reinterpret_cast(quantized_tensor); + uint8_t* output_byte_ptr = + reinterpret_cast(transposed_quantized_tensor); - static constexpr int ELTS_PER_BYTE = 8 / bits_per_elt; + static constexpr int ELTS_PER_BYTE = 8 / bits_per_elt; - static constexpr int M_TILE_L1 = 64; - static constexpr int N_TILE_L1 = M_TILE_L1 / ELTS_PER_BYTE; - uint8_t cache_buf[M_TILE_L1][N_TILE_L1]; + static constexpr int M_TILE_L1 = 64; + static constexpr int N_TILE_L1 = M_TILE_L1 / ELTS_PER_BYTE; + uint8_t cache_buf[M_TILE_L1][N_TILE_L1]; - static constexpr int VECTOR_WIDTH = std::min(32, N_TILE_L1); + static constexpr int VECTOR_WIDTH = std::min(32, N_TILE_L1); - // We assume the dims are a multiple of vector width. Our kernels only handle dims which are multiples - // of 64 for weight-only quantization. As a result, this seemed like a reasonable tradeoff because it - // allows GCC to emit vector instructions. - PADDLE_ENFORCE(!(col_bytes_trans % VECTOR_WIDTH) && !(col_bytes % VECTOR_WIDTH), - "Number of bytes for rows and cols must be a multiple of %d. However, num_rows_bytes = %ld and " - "num_col_bytes = %ld.", - VECTOR_WIDTH, col_bytes_trans, col_bytes); + // We assume the dims are a multiple of vector width. Our kernels only handle + // dims which are multiples of 64 for weight-only quantization. As a result, + // this seemed like a reasonable tradeoff because it allows GCC to emit vector + // instructions. + PADDLE_ENFORCE( + !(col_bytes_trans % VECTOR_WIDTH) && !(col_bytes % VECTOR_WIDTH), + "Number of bytes for rows and cols must be a multiple of %d. However, " + "num_rows_bytes = %ld and " + "num_col_bytes = %ld.", + VECTOR_WIDTH, + col_bytes_trans, + col_bytes); - int const num_m_tiles = (num_rows + M_TILE_L1 - 1) / M_TILE_L1; - int const num_n_tiles = (col_bytes + N_TILE_L1 - 1) / N_TILE_L1; + int const num_m_tiles = (num_rows + M_TILE_L1 - 1) / M_TILE_L1; + int const num_n_tiles = (col_bytes + N_TILE_L1 - 1) / N_TILE_L1; - for (size_t expert = 0; expert < num_experts; ++expert) - { - const size_t matrix_offset = expert * num_rows * col_bytes; - for (size_t row_tile_start = 0; row_tile_start < num_rows; row_tile_start += M_TILE_L1) - { - for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes; col_tile_start_byte += N_TILE_L1) - { + for (size_t expert = 0; expert < num_experts; ++expert) { + const size_t matrix_offset = expert * num_rows * col_bytes; + for (size_t row_tile_start = 0; row_tile_start < num_rows; + row_tile_start += M_TILE_L1) { + for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes; + col_tile_start_byte += N_TILE_L1) { + int const row_limit = std::min(row_tile_start + M_TILE_L1, num_rows); + int const col_limit = + std::min(col_tile_start_byte + N_TILE_L1, col_bytes); - int const row_limit = std::min(row_tile_start + M_TILE_L1, num_rows); - int const col_limit = std::min(col_tile_start_byte + N_TILE_L1, col_bytes); + for (int ii = 0; ii < M_TILE_L1; ++ii) { + int const row = row_tile_start + ii; - for (int ii = 0; ii < M_TILE_L1; ++ii) - { - int const row = row_tile_start + ii; + for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) { + int const col = col_tile_start_byte + jj; - for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) - { - int const col = col_tile_start_byte + jj; + const size_t logical_src_offset = + matrix_offset + row * col_bytes + col; - const size_t logical_src_offset = matrix_offset + row * col_bytes + col; - - if (row < row_limit && col < col_limit) - { - for (int v = 0; v < VECTOR_WIDTH; ++v) - { - cache_buf[ii][jj + v] = input_byte_ptr[logical_src_offset + v]; - } - } - } - } - - if constexpr (bits_per_elt == 8) - { - for (int ii = 0; ii < M_TILE_L1; ++ii) - { - for (int jj = ii + 1; jj < N_TILE_L1; ++jj) - { - std::swap(cache_buf[ii][jj], cache_buf[jj][ii]); - } - } - } - else if constexpr (bits_per_elt == 4) - { - - for (int ii = 0; ii < M_TILE_L1; ++ii) - { - // Using M_TILE_L1 here is deliberate since we assume that the cache tile - // is square in the number of elements (not necessarily the number of bytes). - for (int jj = ii + 1; jj < M_TILE_L1; ++jj) - { - int const ii_byte = ii / ELTS_PER_BYTE; - int const ii_bit_offset = ii % ELTS_PER_BYTE; - - int const jj_byte = jj / ELTS_PER_BYTE; - int const jj_bit_offset = jj % ELTS_PER_BYTE; - - uint8_t src_elt = 0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset)); - uint8_t tgt_elt = 0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset)); - - cache_buf[ii][jj_byte] &= (0xF0 >> (4 * jj_bit_offset)); - cache_buf[jj][ii_byte] &= (0xF0 >> (4 * ii_bit_offset)); - - cache_buf[ii][jj_byte] |= (tgt_elt << (4 * jj_bit_offset)); - cache_buf[jj][ii_byte] |= (src_elt << (4 * ii_bit_offset)); - } - } - } - else - { - PADDLE_ENFORCE(false, "Unsupported quantization type."); - } - - const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE; - const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE; - - int const row_limit_trans = std::min(row_tile_start_trans + M_TILE_L1, num_cols); - int const col_limit_trans = std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans); - - for (int ii = 0; ii < M_TILE_L1; ++ii) - { - int const row = row_tile_start_trans + ii; - for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) - { - int const col = col_tile_start_byte_trans + jj; - - const size_t logical_tgt_offset = matrix_offset + row * col_bytes_trans + col; - - if (row < row_limit_trans && col < col_limit_trans) - { - for (int v = 0; v < VECTOR_WIDTH; ++v) - { - output_byte_ptr[logical_tgt_offset + v] = cache_buf[ii][jj + v]; - } - } - } - } + if (row < row_limit && col < col_limit) { + for (int v = 0; v < VECTOR_WIDTH; ++v) { + cache_buf[ii][jj + v] = input_byte_ptr[logical_src_offset + v]; + } } + } } - } -} -void subbyte_transpose(int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor, - std::vector const& shape, QuantType quant_type) -{ - if (quant_type == QuantType::W8_A16) - { - subbyte_transpose_impl(transposed_quantized_tensor, quantized_tensor, shape); - } - else if (quant_type == QuantType::W4_A16) - { - subbyte_transpose_impl(transposed_quantized_tensor, quantized_tensor, shape); - } - else if (quant_type == QuantType::W4_AFP8) - { - subbyte_transpose_impl(transposed_quantized_tensor, quantized_tensor, shape); - } - else - { - PADDLE_ENFORCE(false, "Invalid quant_type"); - } -} - -void add_bias_and_interleave_int8s_inplace(int8_t* int8_tensor, const size_t num_elts) -{ - for (int ii = 0; ii < num_elts; ++ii) - { - int8_tensor[ii] = int8_t(int(int8_tensor[ii]) + 128); - } - - // Step 2 will transform the layout of a 32-bit register in CUDA in order to match the int4 layout. This has no - // performance benefit and is purely so that int4 and int8 have the same layout. - // Pictorially, this does the following: - // bit 32 0 - // [elt_3 elt_2 elt_1 elt_0] (each elt occupies 8 bits) - // - // And it will rearrange the output 32 bit register to be the following: - // bit 32 0 - // [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits) - - PADDLE_ENFORCE(num_elts % 4 == 0, "Dimensions of int8 tensor must be a multiple of 4 for register relayout"); - for (size_t base = 0; base < num_elts; base += 4) - { - std::swap(int8_tensor[base + 1], int8_tensor[base + 2]); - } -} - -void add_bias_and_interleave_int4s_inplace(int8_t* packed_int4_tensor, const size_t num_elts) -{ - int const num_bytes = num_elts / 2; - - // Step 1 will be to transform all the int4s to unsigned in order to make the dequantize take as little - // instructions as possible in the CUDA code. - for (size_t ii = 0; ii < num_bytes; ++ii) - { - int8_t transformed_packed_int4s = 0; - int8_t transformed_first_elt - = (int8_t(packed_int4_tensor[ii] << 4) >> 4) + 8; // The double shift here is to ensure sign extension - int8_t transformed_second_elt = (packed_int4_tensor[ii] >> 4) + 8; - - PADDLE_ENFORCE( - transformed_first_elt >= 0 && transformed_first_elt <= 15, "Illegal result for int4 transform (first elt)"); - PADDLE_ENFORCE(transformed_second_elt >= 0 && transformed_second_elt <= 15, - "Illegal result for int4 transform (second elt)"); - - // We don't need to mask in these ops since everything should be in the range 0-15 - transformed_packed_int4s |= transformed_first_elt; - transformed_packed_int4s |= (transformed_second_elt << 4); - packed_int4_tensor[ii] = transformed_packed_int4s; - } - - // Step 2 will transform the layout of a 32-bit register in CUDA in order to minimize the number of shift & logical - // instructions That are needed to extract the int4s in the GEMM main loop. Pictorially, the loop below will do the - // following: Take as input a 32 bit register with layout: bit 32 0 - // [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt occupies 4 bits) - // - // And it will rearrange the output 32 bit register to be the following: - // bit 32 0 - // [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt occupies 4 bits) - - PADDLE_ENFORCE(num_bytes % 4 == 0, "Dimensions of int4 tensor must be a multiple of 8 for register relayout"); - const size_t num_registers = num_bytes / 4; - - uint32_t* register_ptr = reinterpret_cast(packed_int4_tensor); - for (size_t ii = 0; ii < num_registers; ++ii) - { - const uint32_t current_register = register_ptr[ii]; - uint32_t transformed_register = 0; - - for (int dest_idx = 0; dest_idx < 8; ++dest_idx) - { - int const src_idx = dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1; - int const src_shift = 4 * src_idx; - int const dest_shift = 4 * dest_idx; - - const uint32_t src_bits = (current_register >> src_shift) & 0xF; - transformed_register |= (src_bits << dest_shift); - } - register_ptr[ii] = transformed_register; - } -} - -void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type) -{ - if (quant_type == QuantType::W8_A16) - { - add_bias_and_interleave_int8s_inplace(tensor, num_elts); - } - else if (quant_type == QuantType::W4_A16 || quant_type == QuantType::W4_AFP8) - { - // W4_AFP8 uses the same preprocessor as W4_A16 because the FP8 data must - // be converted to FP16 before the scales can be applied using CUDA cores. - // As a result, we still want permute the data so that it is well aligned - // for conversion to FP16. - add_bias_and_interleave_int4s_inplace(tensor, num_elts); - } - else - { - PADDLE_ENFORCE(false, "Invalid quantization type for interleaving."); - } -} - -void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, int8_t const* quantized_tensor, - std::vector const& shape, QuantType quant_type, LayoutDetails details) -{ - PADDLE_ENFORCE(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); - const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; - const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; - const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; - - int const BITS_PER_ELT = get_weight_quant_bits(quant_type); - int const elts_in_int32 = 32 / BITS_PER_ELT; - - int const rows_per_tile = details.rows_per_column_tile; - - PADDLE_ENFORCE(!(num_rows % elts_in_int32), - "The number of rows must be a multiple of %d but the number of rows is %ld.", elts_in_int32, num_rows); - - uint32_t const* input_byte_ptr = reinterpret_cast(quantized_tensor); - uint32_t* output_byte_ptr = reinterpret_cast(interleaved_quantized_tensor); - - PADDLE_ENFORCE(!(num_rows % rows_per_tile), - "The number of rows must be a multiple of %d but the number of rows is %ld.", rows_per_tile, num_rows); - - int const num_vec_rows = num_rows / elts_in_int32; - int const vec_rows_per_tile = rows_per_tile / elts_in_int32; - int const interleave = details.columns_interleaved; - - for (int expert = 0; expert < num_experts; ++expert) - { - const int64_t matrix_offset = expert * int64_t(num_vec_rows) * int64_t(num_cols); - for (int read_col = 0; read_col < num_cols; ++read_col) - { - const int64_t write_col = read_col / interleave; - for (int base_vec_row = 0; base_vec_row < num_vec_rows; base_vec_row += vec_rows_per_tile) - { - for (int vec_read_row = base_vec_row; - vec_read_row < std::min(num_vec_rows, base_vec_row + vec_rows_per_tile); ++vec_read_row) - { - const int64_t vec_write_row = interleave * base_vec_row - + vec_rows_per_tile * (read_col % interleave) + vec_read_row % vec_rows_per_tile; - - const int64_t read_offset = matrix_offset + int64_t(read_col) * num_vec_rows + vec_read_row; - const int64_t write_offset - = matrix_offset + int64_t(write_col) * num_vec_rows * interleave + vec_write_row; - output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; - } + if constexpr (bits_per_elt == 8) { + for (int ii = 0; ii < M_TILE_L1; ++ii) { + for (int jj = ii + 1; jj < N_TILE_L1; ++jj) { + std::swap(cache_buf[ii][jj], cache_buf[jj][ii]); } + } + } else if constexpr (bits_per_elt == 4) { + for (int ii = 0; ii < M_TILE_L1; ++ii) { + // Using M_TILE_L1 here is deliberate since we assume that the cache + // tile is square in the number of elements (not necessarily the + // number of bytes). + for (int jj = ii + 1; jj < M_TILE_L1; ++jj) { + int const ii_byte = ii / ELTS_PER_BYTE; + int const ii_bit_offset = ii % ELTS_PER_BYTE; + + int const jj_byte = jj / ELTS_PER_BYTE; + int const jj_bit_offset = jj % ELTS_PER_BYTE; + + uint8_t src_elt = + 0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset)); + uint8_t tgt_elt = + 0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset)); + + cache_buf[ii][jj_byte] &= (0xF0 >> (4 * jj_bit_offset)); + cache_buf[jj][ii_byte] &= (0xF0 >> (4 * ii_bit_offset)); + + cache_buf[ii][jj_byte] |= (tgt_elt << (4 * jj_bit_offset)); + cache_buf[jj][ii_byte] |= (src_elt << (4 * ii_bit_offset)); + } + } + } else { + PADDLE_ENFORCE(false, "Unsupported quantization type."); } + + const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE; + const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE; + + int const row_limit_trans = + std::min(row_tile_start_trans + M_TILE_L1, num_cols); + int const col_limit_trans = + std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans); + + for (int ii = 0; ii < M_TILE_L1; ++ii) { + int const row = row_tile_start_trans + ii; + for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) { + int const col = col_tile_start_byte_trans + jj; + + const size_t logical_tgt_offset = + matrix_offset + row * col_bytes_trans + col; + + if (row < row_limit_trans && col < col_limit_trans) { + for (int v = 0; v < VECTOR_WIDTH; ++v) { + output_byte_ptr[logical_tgt_offset + v] = cache_buf[ii][jj + v]; + } + } + } + } + } } + } } -void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, int8_t const* row_major_quantized_weight, - std::vector const& shape, QuantType quant_type, bool force_interleave) -{ - int arch = 89; - if (force_interleave && arch == 90) - { - // Workaround for MOE which doesn't have specialised Hopper kernels yet - arch = 80; +void subbyte_transpose(int8_t* transposed_quantized_tensor, + int8_t const* quantized_tensor, + std::vector const& shape, + QuantType quant_type) { + if (quant_type == QuantType::W8_A16) { + subbyte_transpose_impl( + transposed_quantized_tensor, quantized_tensor, shape); + } else if (quant_type == QuantType::W4_A16) { + subbyte_transpose_impl( + transposed_quantized_tensor, quantized_tensor, shape); + } else if (quant_type == QuantType::W4_AFP8) { + subbyte_transpose_impl( + transposed_quantized_tensor, quantized_tensor, shape); + } else { + PADDLE_ENFORCE(false, "Invalid quant_type"); + } +} + +void add_bias_and_interleave_int8s_inplace(int8_t* int8_tensor, + const size_t num_elts) { + for (int ii = 0; ii < num_elts; ++ii) { + int8_tensor[ii] = int8_t(int(int8_tensor[ii]) + 128); + } + + // Step 2 will transform the layout of a 32-bit register in CUDA in order to + // match the int4 layout. This has no performance benefit and is purely so + // that int4 and int8 have the same layout. Pictorially, this does the + // following: bit 32 0 + // [elt_3 elt_2 elt_1 elt_0] (each elt occupies 8 bits) + // + // And it will rearrange the output 32 bit register to be the following: + // bit 32 0 + // [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits) + + PADDLE_ENFORCE(num_elts % 4 == 0, + "Dimensions of int8 tensor must be a multiple of 4 for " + "register relayout"); + for (size_t base = 0; base < num_elts; base += 4) { + std::swap(int8_tensor[base + 1], int8_tensor[base + 2]); + } +} + +void add_bias_and_interleave_int4s_inplace(int8_t* packed_int4_tensor, + const size_t num_elts) { + int const num_bytes = num_elts / 2; + + // Step 1 will be to transform all the int4s to unsigned in order to make the + // dequantize take as little instructions as possible in the CUDA code. + for (size_t ii = 0; ii < num_bytes; ++ii) { + int8_t transformed_packed_int4s = 0; + int8_t transformed_first_elt = + (int8_t(packed_int4_tensor[ii] << 4) >> 4) + + 8; // The double shift here is to ensure sign extension + int8_t transformed_second_elt = (packed_int4_tensor[ii] >> 4) + 8; + + PADDLE_ENFORCE(transformed_first_elt >= 0 && transformed_first_elt <= 15, + "Illegal result for int4 transform (first elt)"); + PADDLE_ENFORCE(transformed_second_elt >= 0 && transformed_second_elt <= 15, + "Illegal result for int4 transform (second elt)"); + + // We don't need to mask in these ops since everything should be in the + // range 0-15 + transformed_packed_int4s |= transformed_first_elt; + transformed_packed_int4s |= (transformed_second_elt << 4); + packed_int4_tensor[ii] = transformed_packed_int4s; + } + + // Step 2 will transform the layout of a 32-bit register in CUDA in order to + // minimize the number of shift & logical instructions That are needed to + // extract the int4s in the GEMM main loop. Pictorially, the loop below will + // do the following: Take as input a 32 bit register with layout: bit 32 0 + // [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt + // occupies 4 bits) + // + // And it will rearrange the output 32 bit register to be the following: + // bit 32 0 + // [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt + // occupies 4 bits) + + PADDLE_ENFORCE(num_bytes % 4 == 0, + "Dimensions of int4 tensor must be a multiple of 8 for " + "register relayout"); + const size_t num_registers = num_bytes / 4; + + uint32_t* register_ptr = reinterpret_cast(packed_int4_tensor); + for (size_t ii = 0; ii < num_registers; ++ii) { + const uint32_t current_register = register_ptr[ii]; + uint32_t transformed_register = 0; + + for (int dest_idx = 0; dest_idx < 8; ++dest_idx) { + int const src_idx = dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1; + int const src_shift = 4 * src_idx; + int const dest_shift = 4 * dest_idx; + + const uint32_t src_bits = (current_register >> src_shift) & 0xF; + transformed_register |= (src_bits << dest_shift); } - LayoutDetails details = getLayoutDetailsForTransform(quant_type, arch); + register_ptr[ii] = transformed_register; + } +} - PADDLE_ENFORCE(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); +void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, + const size_t num_elts, + QuantType quant_type) { + if (quant_type == QuantType::W8_A16) { + add_bias_and_interleave_int8s_inplace(tensor, num_elts); + } else if (quant_type == QuantType::W4_A16 || + quant_type == QuantType::W4_AFP8) { + // W4_AFP8 uses the same preprocessor as W4_A16 because the FP8 data must + // be converted to FP16 before the scales can be applied using CUDA cores. + // As a result, we still want permute the data so that it is well aligned + // for conversion to FP16. + add_bias_and_interleave_int4s_inplace(tensor, num_elts); + } else { + PADDLE_ENFORCE(false, "Invalid quantization type for interleaving."); + } +} - size_t num_elts = 1; - for (auto const& dim : shape) - { - num_elts *= dim; +void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, + int8_t const* quantized_tensor, + std::vector const& shape, + QuantType quant_type, + LayoutDetails details) { + PADDLE_ENFORCE(shape.size() == 2 || shape.size() == 3, + "Shape must be 2-D or 3-D"); + const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + int const BITS_PER_ELT = get_weight_quant_bits(quant_type); + int const elts_in_int32 = 32 / BITS_PER_ELT; + + int const rows_per_tile = details.rows_per_column_tile; + + PADDLE_ENFORCE(!(num_rows % elts_in_int32), + "The number of rows must be a multiple of %d but the number " + "of rows is %ld.", + elts_in_int32, + num_rows); + + uint32_t const* input_byte_ptr = + reinterpret_cast(quantized_tensor); + uint32_t* output_byte_ptr = + reinterpret_cast(interleaved_quantized_tensor); + + PADDLE_ENFORCE(!(num_rows % rows_per_tile), + "The number of rows must be a multiple of %d but the number " + "of rows is %ld.", + rows_per_tile, + num_rows); + + int const num_vec_rows = num_rows / elts_in_int32; + int const vec_rows_per_tile = rows_per_tile / elts_in_int32; + int const interleave = details.columns_interleaved; + + for (int expert = 0; expert < num_experts; ++expert) { + const int64_t matrix_offset = + expert * int64_t(num_vec_rows) * int64_t(num_cols); + for (int read_col = 0; read_col < num_cols; ++read_col) { + const int64_t write_col = read_col / interleave; + for (int base_vec_row = 0; base_vec_row < num_vec_rows; + base_vec_row += vec_rows_per_tile) { + for (int vec_read_row = base_vec_row; + vec_read_row < + std::min(num_vec_rows, base_vec_row + vec_rows_per_tile); + ++vec_read_row) { + const int64_t vec_write_row = + interleave * base_vec_row + + vec_rows_per_tile * (read_col % interleave) + + vec_read_row % vec_rows_per_tile; + + const int64_t read_offset = + matrix_offset + int64_t(read_col) * num_vec_rows + vec_read_row; + const int64_t write_offset = + matrix_offset + int64_t(write_col) * num_vec_rows * interleave + + vec_write_row; + output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; + } + } } + } +} - const size_t num_bytes = num_elts * get_weight_quant_bits(quant_type) / 8; +void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, + int8_t const* row_major_quantized_weight, + std::vector const& shape, + QuantType quant_type, + bool force_interleave) { + int arch = 89; + if (force_interleave && arch == 90) { + // Workaround for MOE which doesn't have specialised Hopper kernels yet + arch = 80; + } + LayoutDetails details = getLayoutDetailsForTransform(quant_type, arch); - std::vector src_buf(num_bytes); - std::vector dst_buf(num_bytes); - std::copy(row_major_quantized_weight, row_major_quantized_weight + num_bytes, src_buf.begin()); + PADDLE_ENFORCE(shape.size() == 2 || shape.size() == 3, + "Shape must be 2-D or 3-D"); - // Works on row major data, so issue this permutation first. - if (details.uses_imma_ldsm) - { - permute_B_rows_for_mixed_gemm(dst_buf.data(), src_buf.data(), shape, quant_type, arch); - src_buf.swap(dst_buf); - } + size_t num_elts = 1; + for (auto const& dim : shape) { + num_elts *= dim; + } - if (details.layoutB == LayoutDetails::Layout::COLUMN_MAJOR) - { - subbyte_transpose(dst_buf.data(), src_buf.data(), shape, quant_type); - src_buf.swap(dst_buf); - } + const size_t num_bytes = num_elts * get_weight_quant_bits(quant_type) / 8; - if (details.columns_interleaved > 1) - { - interleave_column_major_tensor(dst_buf.data(), src_buf.data(), shape, quant_type, details); - src_buf.swap(dst_buf); - } + std::vector src_buf(num_bytes); + std::vector dst_buf(num_bytes); + std::copy(row_major_quantized_weight, + row_major_quantized_weight + num_bytes, + src_buf.begin()); - if (arch >= 70 && arch < 90) - { - add_bias_and_interleave_quantized_tensor_inplace(src_buf.data(), num_elts, quant_type); - } - std::copy(src_buf.begin(), src_buf.end(), preprocessed_quantized_weight); + // Works on row major data, so issue this permutation first. + if (details.uses_imma_ldsm) { + permute_B_rows_for_mixed_gemm( + dst_buf.data(), src_buf.data(), shape, quant_type, arch); + src_buf.swap(dst_buf); + } + + if (details.layoutB == LayoutDetails::Layout::COLUMN_MAJOR) { + subbyte_transpose(dst_buf.data(), src_buf.data(), shape, quant_type); + src_buf.swap(dst_buf); + } + + if (details.columns_interleaved > 1) { + interleave_column_major_tensor( + dst_buf.data(), src_buf.data(), shape, quant_type, details); + src_buf.swap(dst_buf); + } + + if (arch >= 70 && arch < 90) { + add_bias_and_interleave_quantized_tensor_inplace( + src_buf.data(), num_elts, quant_type); + } + std::copy(src_buf.begin(), src_buf.end(), preprocessed_quantized_weight); } /* Arguments: - input_weight_ptr - the weight tensor to be quantized. Must be 2-D or 3-D and of type FP16. + input_weight_ptr - the weight tensor to be quantized. Must be 2-D or 3-D +and of type FP16. quant_type - the type of the output quantization weight. - This function does symmetric quantization on 2-D or 3-D tensors. It uses the full int range and assumes the - zero-point is zero and will automatically construct the scales. + This function does symmetric quantization on 2-D or 3-D tensors. It uses the +full int range and assumes the zero-point is zero and will automatically +construct the scales. - It always quantizes the last axis of the tensor. For 3-D tensors, it operates in "batched" mode where the tensor is - viewed as a stack of matrices and a scale is produced for each column of every matrix. + It always quantizes the last axis of the tensor. For 3-D tensors, it +operates in "batched" mode where the tensor is viewed as a stack of matrices and +a scale is produced for each column of every matrix. Outputs - processed_quantized_weight - quantized AND processed weight for GEMM. This MUST be used with the CUTLASS GEMM - unprocessed_quantized_weight - quantized but unprocessed weights. Useful for reference checking. - scale_ptr - scales for the quantized weight. + processed_quantized_weight - quantized AND processed weight for GEMM. This +MUST be used with the CUTLASS GEMM unprocessed_quantized_weight - quantized but +unprocessed weights. Useful for reference checking. scale_ptr - scales for the +quantized weight. - Note that the returned quantized_weights will be preprocessed in a way to accelerate the mixed type GEMM. The data - layout may not make sense if printed. + Note that the returned quantized_weights will be preprocessed in a way to +accelerate the mixed type GEMM. The data layout may not make sense if printed. Shapes: quant_type == int8: - If weight is a [m,n] matrix, quantized_weights will have shape [m,n] and scales of shape [n] - If weight is a [b,m,n] tensor, unprocessed_quantized_weight will have shape [b,m,n] and scales of shape [b,n] - quant_type == int4: - If weight is a [m,n] matrix, quantized_weights will have shape [m, ceil(n/2)] and scales of shape [n] - If weight is a [b,m,n] tensor, unprocessed_quantized_weight will have shape [b,m, ceil(n/2)] and scales of shape - [b,n] + If weight is a [m,n] matrix, quantized_weights will have shape [m,n] and +scales of shape [n] If weight is a [b,m,n] tensor, unprocessed_quantized_weight +will have shape [b,m,n] and scales of shape [b,n] quant_type == int4: If weight +is a [m,n] matrix, quantized_weights will have shape [m, ceil(n/2)] and scales +of shape [n] If weight is a [b,m,n] tensor, unprocessed_quantized_weight will +have shape [b,m, ceil(n/2)] and scales of shape [b,n] - The quantized_weight will be of type torch.int8 and have two int4 values packed in a single byte. This is the - reason for halving the shape. At the time of writing this code, there was not an elegant way to handle this kind - of batched quantization using torch's quantized tensors (to the best of the author's knowledge). Scale tensors - must have a dimension of 1, which breaks the semantics we need for batched weights. + The quantized_weight will be of type torch.int8 and have two int4 values +packed in a single byte. This is the reason for halving the shape. At the time +of writing this code, there was not an elegant way to handle this kind of +batched quantization using torch's quantized tensors (to the best of the +author's knowledge). Scale tensors must have a dimension of 1, which breaks the +semantics we need for batched weights. */ template -void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight, - ComputeType* scale_ptr, WeightType const* input_weight_ptr, std::vector const& shape, QuantType quant_type, - bool force_interleave) -{ +void symmetric_quantize(int8_t* processed_quantized_weight, + int8_t* unprocessed_quantized_weight, + ComputeType* scale_ptr, + WeightType const* input_weight_ptr, + std::vector const& shape, + QuantType quant_type, + bool force_interleave) { + PADDLE_ENFORCE(processed_quantized_weight, + "Processed quantized tensor is NULL"); + PADDLE_ENFORCE(scale_ptr, "Scale output pointer is NULL"); + PADDLE_ENFORCE(input_weight_ptr, "Input weight pointer is NULL"); - PADDLE_ENFORCE(processed_quantized_weight, "Processed quantized tensor is NULL"); - PADDLE_ENFORCE(scale_ptr, "Scale output pointer is NULL"); - PADDLE_ENFORCE(input_weight_ptr, "Input weight pointer is NULL"); + PADDLE_ENFORCE(shape.size() == 2 || shape.size() == 3, + "Shape must be 2-D or 3-D"); + const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; - PADDLE_ENFORCE(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); - const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; - const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; - const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + int const bits_in_type = get_weight_quant_bits(quant_type); + int const bytes_per_out_col = num_cols * bits_in_type / 8; - int const bits_in_type = get_weight_quant_bits(quant_type); - int const bytes_per_out_col = num_cols * bits_in_type / 8; + int const bits_per_weigtht_element = get_weight_quant_bits(quant_type); - int const bits_per_weigtht_element = get_weight_quant_bits(quant_type); + std::vector weight_buf; + if (unprocessed_quantized_weight == nullptr) { + weight_buf.resize(num_experts * num_rows * num_cols); + unprocessed_quantized_weight = weight_buf.data(); + } - std::vector weight_buf; - if (unprocessed_quantized_weight == nullptr) - { - weight_buf.resize(num_experts * num_rows * num_cols); - unprocessed_quantized_weight = weight_buf.data(); + int const input_mat_size = num_rows * num_cols; + int const quantized_mat_size = num_rows * bytes_per_out_col; + float const quant_range_scale = 1.f / float(1 << (bits_in_type - 1)); + + std::vector per_col_max(num_cols); + + for (int expert = 0; expert < num_experts; ++expert) { + WeightType const* current_weight = + input_weight_ptr + expert * input_mat_size; + int8_t* current_quantized_weight = + unprocessed_quantized_weight + expert * quantized_mat_size; + + // First we find the per column max for this expert weight. + for (int jj = 0; jj < num_cols; ++jj) { + per_col_max[jj] = 0.f; } - int const input_mat_size = num_rows * num_cols; - int const quantized_mat_size = num_rows * bytes_per_out_col; - float const quant_range_scale = 1.f / float(1 << (bits_in_type - 1)); - - std::vector per_col_max(num_cols); - - for (int expert = 0; expert < num_experts; ++expert) - { - WeightType const* current_weight = input_weight_ptr + expert * input_mat_size; - int8_t* current_quantized_weight = unprocessed_quantized_weight + expert * quantized_mat_size; - - // First we find the per column max for this expert weight. - for (int jj = 0; jj < num_cols; ++jj) - { - per_col_max[jj] = 0.f; - } - - for (int ii = 0; ii < num_rows; ++ii) - { - WeightType const* current_weight_row = current_weight + ii * num_cols; - for (int jj = 0; jj < num_cols; ++jj) - { - per_col_max[jj] = std::max(per_col_max[jj], std::abs(float(current_weight_row[jj]))); - } - } - - // Then, we construct the scales - ComputeType* current_scales = scale_ptr + expert * num_cols; - for (int jj = 0; jj < num_cols; ++jj) - { - per_col_max[jj] *= quant_range_scale; - current_scales[jj] = ComputeType(per_col_max[jj]); - } - - // Finally, construct the weights. - for (int ii = 0; ii < num_rows; ++ii) - { - int8_t* current_quantized_weight_row = current_quantized_weight + ii * bytes_per_out_col; - WeightType const* current_weight_row = current_weight + ii * num_cols; - for (int jj = 0; jj < bytes_per_out_col; ++jj) - { - - if (bits_per_weigtht_element == 8) - { - float const col_scale = per_col_max[jj]; - float const weight_elt = float(current_weight_row[jj]); - float const scaled_weight = (col_scale != 0.0f) ? round(weight_elt / col_scale) : 0.0f; - const int8_t clipped_weight = int8_t(std::max(-128.f, std::min(127.f, scaled_weight))); - current_quantized_weight_row[jj] = clipped_weight; - } - else if (bits_per_weigtht_element == 4) - { - - // We will pack two int4 elements per iteration of the inner loop. - int8_t packed_int4s = 0; - for (int packed_idx = 0; packed_idx < 2; ++packed_idx) - { - int const input_idx = 2 * jj + packed_idx; - if (input_idx < num_cols) - { - float const col_scale = per_col_max[input_idx]; - float const weight_elt = float(current_weight_row[input_idx]); - float const scaled_weight = (col_scale != 0.0f) ? round(weight_elt / col_scale) : 0.0f; - int int_weight = int(scaled_weight); - const int8_t clipped_weight = std::max(-8, std::min(7, int_weight)); - - // Kill the sign extension bits (hence 0x0F mask) then shift to upper bits - // if packing the second int4 and or the bits into the final result. - packed_int4s |= ((clipped_weight & 0x0F) << (4 * packed_idx)); - } - } - current_quantized_weight_row[jj] = packed_int4s; - } - else - { - PADDLE_ENFORCE(false, "Unsupported quantization type"); - } - } - } + for (int ii = 0; ii < num_rows; ++ii) { + WeightType const* current_weight_row = current_weight + ii * num_cols; + for (int jj = 0; jj < num_cols; ++jj) { + per_col_max[jj] = + std::max(per_col_max[jj], std::abs(float(current_weight_row[jj]))); + } } - preprocess_weights_for_mixed_gemm( - processed_quantized_weight, unprocessed_quantized_weight, shape, quant_type, force_interleave); + // Then, we construct the scales + ComputeType* current_scales = scale_ptr + expert * num_cols; + for (int jj = 0; jj < num_cols; ++jj) { + per_col_max[jj] *= quant_range_scale; + current_scales[jj] = ComputeType(per_col_max[jj]); + } + + // Finally, construct the weights. + for (int ii = 0; ii < num_rows; ++ii) { + int8_t* current_quantized_weight_row = + current_quantized_weight + ii * bytes_per_out_col; + WeightType const* current_weight_row = current_weight + ii * num_cols; + for (int jj = 0; jj < bytes_per_out_col; ++jj) { + if (bits_per_weigtht_element == 8) { + float const col_scale = per_col_max[jj]; + float const weight_elt = float(current_weight_row[jj]); + float const scaled_weight = + (col_scale != 0.0f) ? round(weight_elt / col_scale) : 0.0f; + const int8_t clipped_weight = + int8_t(std::max(-128.f, std::min(127.f, scaled_weight))); + current_quantized_weight_row[jj] = clipped_weight; + } else if (bits_per_weigtht_element == 4) { + // We will pack two int4 elements per iteration of the inner loop. + int8_t packed_int4s = 0; + for (int packed_idx = 0; packed_idx < 2; ++packed_idx) { + int const input_idx = 2 * jj + packed_idx; + if (input_idx < num_cols) { + float const col_scale = per_col_max[input_idx]; + float const weight_elt = float(current_weight_row[input_idx]); + float const scaled_weight = + (col_scale != 0.0f) ? round(weight_elt / col_scale) : 0.0f; + int int_weight = int(scaled_weight); + const int8_t clipped_weight = + std::max(-8, std::min(7, int_weight)); + + // Kill the sign extension bits (hence 0x0F mask) then shift to + // upper bits if packing the second int4 and or the bits into the + // final result. + packed_int4s |= ((clipped_weight & 0x0F) << (4 * packed_idx)); + } + } + current_quantized_weight_row[jj] = packed_int4s; + } else { + PADDLE_ENFORCE(false, "Unsupported quantization type"); + } + } + } + } + + preprocess_weights_for_mixed_gemm(processed_quantized_weight, + unprocessed_quantized_weight, + shape, + quant_type, + force_interleave); } -template void symmetric_quantize( - int8_t*, int8_t*, half*, float const*, std::vector const&, QuantType, bool); +template void symmetric_quantize(int8_t*, + int8_t*, + half*, + float const*, + std::vector const&, + QuantType, + bool); -template void symmetric_quantize( - int8_t*, int8_t*, half*, half const*, std::vector const&, QuantType, bool); +template void symmetric_quantize(int8_t*, + int8_t*, + half*, + half const*, + std::vector const&, + QuantType, + bool); #ifdef ENABLE_BF16 template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>( - int8_t*, int8_t*, __nv_bfloat16*, __nv_bfloat16 const*, std::vector const&, QuantType, bool); + int8_t*, + int8_t*, + __nv_bfloat16*, + __nv_bfloat16 const*, + std::vector const&, + QuantType, + bool); template void symmetric_quantize<__nv_bfloat16, float>( - int8_t*, int8_t*, __nv_bfloat16*, float const*, std::vector const&, QuantType, bool); + int8_t*, + int8_t*, + __nv_bfloat16*, + float const*, + std::vector const&, + QuantType, + bool); #endif template -void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, WeightType const* input_weight_ptr, - std::vector const& shape, QuantType quant_type, bool force_interleave) -{ - symmetric_quantize( - processed_quantized_weight, nullptr, scale_ptr, input_weight_ptr, shape, quant_type, force_interleave); +void symmetric_quantize(int8_t* processed_quantized_weight, + ComputeType* scale_ptr, + WeightType const* input_weight_ptr, + std::vector const& shape, + QuantType quant_type, + bool force_interleave) { + symmetric_quantize(processed_quantized_weight, + nullptr, + scale_ptr, + input_weight_ptr, + shape, + quant_type, + force_interleave); } template void symmetric_quantize( @@ -773,21 +802,42 @@ template void symmetric_quantize( template void symmetric_quantize( int8_t*, half*, float const*, std::vector const&, QuantType, bool); -template void symmetric_quantize(int8_t*, half*, half const*, std::vector const&, QuantType, bool); +template void symmetric_quantize( + int8_t*, half*, half const*, std::vector const&, QuantType, bool); #ifdef ENABLE_BF16 template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>( - int8_t*, __nv_bfloat16*, __nv_bfloat16 const*, std::vector const&, QuantType, bool); + int8_t*, + __nv_bfloat16*, + __nv_bfloat16 const*, + std::vector const&, + QuantType, + bool); template void symmetric_quantize<__nv_bfloat16, half>( - int8_t*, __nv_bfloat16*, half const*, std::vector const&, QuantType, bool); + int8_t*, + __nv_bfloat16*, + half const*, + std::vector const&, + QuantType, + bool); template void symmetric_quantize( - int8_t*, half*, __nv_bfloat16 const*, std::vector const&, QuantType, bool); + int8_t*, + half*, + __nv_bfloat16 const*, + std::vector const&, + QuantType, + bool); template void symmetric_quantize<__nv_bfloat16, float>( - int8_t*, __nv_bfloat16*, float const*, std::vector const&, QuantType, bool); + int8_t*, + __nv_bfloat16*, + float const*, + std::vector const&, + QuantType, + bool); #endif -} // namespace cutlass_kernels -} // namespace kernels +} // namespace cutlass_kernels +} // namespace kernels diff --git a/custom_ops/gpu_ops/cutlass_kernels/cutlass_preprocessors.h b/custom_ops/gpu_ops/cutlass_kernels/cutlass_preprocessors.h index 8d025c1289..292d7d8cbe 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/cutlass_preprocessors.h +++ b/custom_ops/gpu_ops/cutlass_kernels/cutlass_preprocessors.h @@ -20,52 +20,67 @@ #include #include -namespace kernels -{ -namespace cutlass_kernels -{ +namespace kernels { +namespace cutlass_kernels { -enum class QuantType -{ - W8_A16, - W4_A16, - W4_AFP8 -}; +enum class QuantType { W8_A16, W4_A16, W4_AFP8 }; -constexpr int get_weight_quant_bits(QuantType quant_type) -{ - switch (quant_type) - { - case QuantType::W8_A16: return 8; - case QuantType::W4_A16: return 4; - case QuantType::W4_AFP8: return 4; - default: PADDLE_THROW("Invalid quant_type"); return -1; - } +constexpr int get_weight_quant_bits(QuantType quant_type) { + switch (quant_type) { + case QuantType::W8_A16: + return 8; + case QuantType::W4_A16: + return 4; + case QuantType::W4_AFP8: + return 4; + default: + PADDLE_THROW("Invalid quant_type"); + return -1; + } } // Shapes here can be 2 or 3D. 2-D shapes are [num_rows, num_cols] // 3-D shapes are [num_experts, num_rows, num_cols] -void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, int8_t const* quantized_tensor, - std::vector const& shape, QuantType quant_type, const int64_t arch_version); +void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, + int8_t const* quantized_tensor, + std::vector const& shape, + QuantType quant_type, + const int64_t arch_version); -void subbyte_transpose(int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor, - std::vector const& shape, QuantType quant_type); +void subbyte_transpose(int8_t* transposed_quantized_tensor, + int8_t const* quantized_tensor, + std::vector const& shape, + QuantType quant_type); -void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type); +void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, + const size_t num_elts, + QuantType quant_type); -void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, int8_t const* row_major_quantized_weight, - std::vector const& shape, QuantType quant_type, bool force_interleave = false); +void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, + int8_t const* row_major_quantized_weight, + std::vector const& shape, + QuantType quant_type, + bool force_interleave = false); template -void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, WeightType const* input_weight_ptr, - std::vector const& shape, QuantType quant_type, bool force_interleave); +void symmetric_quantize(int8_t* processed_quantized_weight, + ComputeType* scale_ptr, + WeightType const* input_weight_ptr, + std::vector const& shape, + QuantType quant_type, + bool force_interleave); -// This is exposed so that we can write tests that use the processed weights for CUTLASS but the unprocessed weight -// to implement a simple reference implementation. +// This is exposed so that we can write tests that use the processed weights for +// CUTLASS but the unprocessed weight to implement a simple reference +// implementation. template -void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight, - ComputeType* scale_ptr, WeightType const* input_weight_ptr, std::vector const& shape, QuantType quant_type, - bool force_interleave); +void symmetric_quantize(int8_t* processed_quantized_weight, + int8_t* unprocessed_quantized_weight, + ComputeType* scale_ptr, + WeightType const* input_weight_ptr, + std::vector const& shape, + QuantType quant_type, + bool force_interleave); -} // namespace cutlass_kernels -} // namespace kernels +} // namespace cutlass_kernels +} // namespace kernels diff --git a/custom_ops/gpu_ops/cutlass_kernels/cutlass_type_conversion.h b/custom_ops/gpu_ops/cutlass_kernels/cutlass_type_conversion.h index cf344772a6..a10f49cc77 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/cutlass_type_conversion.h +++ b/custom_ops/gpu_ops/cutlass_kernels/cutlass_type_conversion.h @@ -24,45 +24,38 @@ #include "cutlass/float8.h" #include "cutlass/half.h" -namespace kernels -{ -namespace cutlass_kernels -{ +namespace kernels { +namespace cutlass_kernels { /////////////////////////////////////////////////////////////////////////////////////////////////// // Cuda to Cutlass template -struct CudaToCutlassTypeAdapter -{ - using type = T; +struct CudaToCutlassTypeAdapter { + using type = T; }; template <> -struct CudaToCutlassTypeAdapter -{ - using type = cutlass::half_t; +struct CudaToCutlassTypeAdapter { + using type = cutlass::half_t; }; #if defined(ENABLE_BF16) template <> -struct CudaToCutlassTypeAdapter<__nv_bfloat16> -{ - using type = cutlass::bfloat16_t; +struct CudaToCutlassTypeAdapter<__nv_bfloat16> { + using type = cutlass::bfloat16_t; }; #endif #if defined(ENABLE_FP8) template <> -struct CudaToCutlassTypeAdapter<__nv_fp8_e4m3> -{ - using type = cutlass::float_e4m3_t; +struct CudaToCutlassTypeAdapter<__nv_fp8_e4m3> { + using type = cutlass::float_e4m3_t; }; template <> -struct CudaToCutlassTypeAdapter<__nv_fp8_e5m2> -{ - using type = cutlass::float_e5m2_t; +struct CudaToCutlassTypeAdapter<__nv_fp8_e5m2> { + using type = cutlass::float_e5m2_t; }; #endif @@ -70,40 +63,35 @@ struct CudaToCutlassTypeAdapter<__nv_fp8_e5m2> // Cutlass to Cuda template -struct CutlassToCudaTypeAdapter -{ - using type = T; +struct CutlassToCudaTypeAdapter { + using type = T; }; template <> -struct CutlassToCudaTypeAdapter -{ - using type = half; +struct CutlassToCudaTypeAdapter { + using type = half; }; #if defined(ENABLE_BF16) template <> -struct CutlassToCudaTypeAdapter -{ - using type = __nv_bfloat16; +struct CutlassToCudaTypeAdapter { + using type = __nv_bfloat16; }; #endif #if defined(ENABLE_FP8) template <> -struct CutlassToCudaTypeAdapter -{ - using type = __nv_fp8_e4m3; +struct CutlassToCudaTypeAdapter { + using type = __nv_fp8_e4m3; }; template <> -struct CutlassToCudaTypeAdapter -{ - using type = __nv_fp8_e5m2; +struct CutlassToCudaTypeAdapter { + using type = __nv_fp8_e5m2; }; #endif /////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace cutlass_kernels -} // namespace kernels +} // namespace cutlass_kernels +} // namespace kernels diff --git a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/dual_gemm/thread/left_gelu_and_mul.h b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/dual_gemm/thread/left_gelu_and_mul.h index 9dfddf83fc..a4f421a01f 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/dual_gemm/thread/left_gelu_and_mul.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/dual_gemm/thread/left_gelu_and_mul.h @@ -67,87 +67,87 @@ template < ElementOutput_, ///< Data type used to compute linear combination FloatRoundStyle Round = FloatRoundStyle::round_to_nearest> class LeftGELUAndMul { - public: - using ElementOutput = ElementOutput_; - using ElementAccumulator = ElementAccumulator_; - using ElementCompute = ElementCompute_; + public: + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; - static int const kCount = Count; - using FragmentOutput = Array; - using FragmentAccumulator = Array; - using ComputeFragment = Array; + static int const kCount = Count; + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; - static FloatRoundStyle const kRound = Round; + static FloatRoundStyle const kRound = Round; - struct Params { - ElementCompute alpha; - - CUTLASS_HOST_DEVICE - Params() : alpha(ElementCompute(1)) {} - - CUTLASS_HOST_DEVICE - Params(ElementCompute alpha) : alpha(alpha) {} // NOLINT - }; - - private: - // - // Data members - // - - ElementCompute alpha_; - ElementCompute beta_; - - public: - /// Constructs the function object, possibly loading from pointers in host - /// memory - CUTLASS_HOST_DEVICE - LeftGELUAndMul(Params const ¶ms) { alpha_ = params.alpha; } // NOLINT - - /// Returns true if source is needed - CUTLASS_HOST_DEVICE - bool is_source_needed() const { return true; } - - /// Functionally required for serial reduction in the epilogue - CUTLASS_HOST_DEVICE - void set_k_partition(int k_partition, int k_partition_count) { - assert(false); - } - - /// Computes linear scaling: D = alpha * accumulator + beta * source - CUTLASS_HOST_DEVICE - FragmentOutput operator()(FragmentAccumulator const &lhs, - FragmentAccumulator const &rhs) const { - // Convert source to internal compute numeric type - NumericArrayConverter - accumulator_to_compute; - - // Convert to destination numeric type - NumericArrayConverter - compute_to_output; - - ComputeFragment converted_lhs = accumulator_to_compute(lhs); - ComputeFragment converted_rhs = accumulator_to_compute(rhs); - - cutlass::epilogue::thread::GELU_taylor gelu; - cutlass::multiplies mul; - auto gelu_lhs = gelu(converted_lhs); - // return compute_to_output(mul(gelu_lhs, converted_rhs)); - auto tmp = mul(gelu_lhs, converted_rhs); - return compute_to_output(mul(alpha_, tmp)); - } + struct Params { + ElementCompute alpha; CUTLASS_HOST_DEVICE - ElementOutput operator()(ElementAccumulator const &lhs, - ElementAccumulator const &rhs) const { - ElementCompute convert_lhs(lhs); - ElementCompute convert_rhs(rhs); - cutlass::epilogue::thread::GELU_taylor gelu; - cutlass::multiplies mul; - auto gelu_lhs = gelu(convert_lhs); - // return ElementOutput(mul(gelu_lhs, convert_rhs)); - auto tmp = mul(gelu_lhs, convert_rhs); - return compute_to_output(mul(alpha_, tmp)); - } + Params() : alpha(ElementCompute(1)) {} + + CUTLASS_HOST_DEVICE + Params(ElementCompute alpha) : alpha(alpha) {} // NOLINT + }; + + private: + // + // Data members + // + + ElementCompute alpha_; + ElementCompute beta_; + + public: + /// Constructs the function object, possibly loading from pointers in host + /// memory + CUTLASS_HOST_DEVICE + LeftGELUAndMul(Params const ¶ms) { alpha_ = params.alpha; } // NOLINT + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { return true; } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { + assert(false); + } + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()(FragmentAccumulator const &lhs, + FragmentAccumulator const &rhs) const { + // Convert source to internal compute numeric type + NumericArrayConverter + accumulator_to_compute; + + // Convert to destination numeric type + NumericArrayConverter + compute_to_output; + + ComputeFragment converted_lhs = accumulator_to_compute(lhs); + ComputeFragment converted_rhs = accumulator_to_compute(rhs); + + cutlass::epilogue::thread::GELU_taylor gelu; + cutlass::multiplies mul; + auto gelu_lhs = gelu(converted_lhs); + // return compute_to_output(mul(gelu_lhs, converted_rhs)); + auto tmp = mul(gelu_lhs, converted_rhs); + return compute_to_output(mul(alpha_, tmp)); + } + + CUTLASS_HOST_DEVICE + ElementOutput operator()(ElementAccumulator const &lhs, + ElementAccumulator const &rhs) const { + ElementCompute convert_lhs(lhs); + ElementCompute convert_rhs(rhs); + cutlass::epilogue::thread::GELU_taylor gelu; + cutlass::multiplies mul; + auto gelu_lhs = gelu(convert_lhs); + // return ElementOutput(mul(gelu_lhs, convert_rhs)); + auto tmp = mul(gelu_lhs, convert_rhs); + return compute_to_output(mul(alpha_, tmp)); + } }; } // namespace thread diff --git a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/dual_gemm/thread/left_silu_and_mul.h b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/dual_gemm/thread/left_silu_and_mul.h index 7a433bccd6..51da87e898 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/dual_gemm/thread/left_silu_and_mul.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/dual_gemm/thread/left_silu_and_mul.h @@ -67,87 +67,87 @@ template < ElementOutput_, ///< Data type used to compute linear combination FloatRoundStyle Round = FloatRoundStyle::round_to_nearest> class LeftSiLUAndMul { - public: - using ElementOutput = ElementOutput_; - using ElementAccumulator = ElementAccumulator_; - using ElementCompute = ElementCompute_; + public: + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; - static int const kCount = Count; - using FragmentOutput = Array; - using FragmentAccumulator = Array; - using ComputeFragment = Array; + static int const kCount = Count; + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; - static FloatRoundStyle const kRound = Round; + static FloatRoundStyle const kRound = Round; - struct Params { - ElementCompute alpha; - - CUTLASS_HOST_DEVICE - Params() : alpha(ElementCompute(1)) {} - - CUTLASS_HOST_DEVICE - Params(ElementCompute alpha) : alpha(alpha) {} // NOLINT - }; - - private: - // - // Data members - // - - ElementCompute alpha_; - ElementCompute beta_; - - public: - /// Constructs the function object, possibly loading from pointers in host - /// memory - CUTLASS_HOST_DEVICE - LeftSiLUAndMul(Params const ¶ms) { alpha_ = params.alpha; } // NOLINT - - /// Returns true if source is needed - CUTLASS_HOST_DEVICE - bool is_source_needed() const { return true; } - - /// Functionally required for serial reduction in the epilogue - CUTLASS_HOST_DEVICE - void set_k_partition(int k_partition, int k_partition_count) { - assert(false); - } - - /// Computes linear scaling: D = alpha * accumulator + beta * source - CUTLASS_HOST_DEVICE - FragmentOutput operator()(FragmentAccumulator const &lhs, - FragmentAccumulator const &rhs) const { - // Convert source to internal compute numeric type - NumericArrayConverter - accumulator_to_compute; - - // Convert to destination numeric type - NumericArrayConverter - compute_to_output; - - ComputeFragment converted_lhs = accumulator_to_compute(lhs); - ComputeFragment converted_rhs = accumulator_to_compute(rhs); - - cutlass::epilogue::thread::SiLu silu; - cutlass::multiplies mul; - auto silu_lhs = silu(converted_lhs); - // return compute_to_output(mul(silu_lhs, converted_rhs)); - auto tmp = mul(silu_lhs, converted_rhs); - return compute_to_output(mul(alpha_, tmp)); - } + struct Params { + ElementCompute alpha; CUTLASS_HOST_DEVICE - ElementOutput operator()(ElementAccumulator const &lhs, - ElementAccumulator const &rhs) const { - ElementCompute convert_lhs(lhs); - ElementCompute convert_rhs(rhs); - cutlass::epilogue::thread::SiLu silu; - cutlass::multiplies mul; - auto silu_lhs = silu(convert_lhs); - // return ElementOutput(mul(silu_lhs, convert_rhs)); - auto tmp = mul(silu_lhs, convert_rhs); - return ElementOutput(mul(alpha_, tmp)); - } + Params() : alpha(ElementCompute(1)) {} + + CUTLASS_HOST_DEVICE + Params(ElementCompute alpha) : alpha(alpha) {} // NOLINT + }; + + private: + // + // Data members + // + + ElementCompute alpha_; + ElementCompute beta_; + + public: + /// Constructs the function object, possibly loading from pointers in host + /// memory + CUTLASS_HOST_DEVICE + LeftSiLUAndMul(Params const ¶ms) { alpha_ = params.alpha; } // NOLINT + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { return true; } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { + assert(false); + } + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()(FragmentAccumulator const &lhs, + FragmentAccumulator const &rhs) const { + // Convert source to internal compute numeric type + NumericArrayConverter + accumulator_to_compute; + + // Convert to destination numeric type + NumericArrayConverter + compute_to_output; + + ComputeFragment converted_lhs = accumulator_to_compute(lhs); + ComputeFragment converted_rhs = accumulator_to_compute(rhs); + + cutlass::epilogue::thread::SiLu silu; + cutlass::multiplies mul; + auto silu_lhs = silu(converted_lhs); + // return compute_to_output(mul(silu_lhs, converted_rhs)); + auto tmp = mul(silu_lhs, converted_rhs); + return compute_to_output(mul(alpha_, tmp)); + } + + CUTLASS_HOST_DEVICE + ElementOutput operator()(ElementAccumulator const &lhs, + ElementAccumulator const &rhs) const { + ElementCompute convert_lhs(lhs); + ElementCompute convert_rhs(rhs); + cutlass::epilogue::thread::SiLu silu; + cutlass::multiplies mul; + auto silu_lhs = silu(convert_lhs); + // return ElementOutput(mul(silu_lhs, convert_rhs)); + auto tmp = mul(silu_lhs, convert_rhs); + return ElementOutput(mul(alpha_, tmp)); + } }; } // namespace thread diff --git a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/dual_gemm/threadblock/dual_epilogue.h b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/dual_gemm/threadblock/dual_epilogue.h index 7d679c341d..2be6408013 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/dual_gemm/threadblock/dual_epilogue.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/dual_gemm/threadblock/dual_epilogue.h @@ -98,357 +98,357 @@ template ::value)> class DualEpilogue { - public: - using Base = EpilogueBase; + public: + using Base = EpilogueBase; - using Shape = Shape_; - using WarpMmaOperator = WarpMmaOperator_; - static int const kPartitionsK = PartitionsK; - static bool constexpr kStoreD0 = StoreD0; - static bool constexpr kStoreD1 = StoreD1; - using OutputTileIterator = OutputTileIterator_; - using OutputTileIterator2 = OutputTileIterator2_; - using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; - using WarpTileIterator = WarpTileIterator_; - using SharedLoadIterator = SharedLoadIterator_; - using OutputOp0 = OutputOp0_; - using OutputOp1 = OutputOp1_; - using OutputOp2 = OutputOp2_; - using Padding = Padding_; + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + static bool constexpr kStoreD0 = StoreD0; + static bool constexpr kStoreD1 = StoreD1; + using OutputTileIterator = OutputTileIterator_; + using OutputTileIterator2 = OutputTileIterator2_; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using WarpTileIterator = WarpTileIterator_; + using SharedLoadIterator = SharedLoadIterator_; + using OutputOp0 = OutputOp0_; + using OutputOp1 = OutputOp1_; + using OutputOp2 = OutputOp2_; + using Padding = Padding_; - using Layout = layout::RowMajor; - using LongIndex = typename Layout::LongIndex; + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; - // The complete warp-level accumulator tile - using AccumulatorTile = typename Base::AccumulatorTile; + // The complete warp-level accumulator tile + using AccumulatorTile = typename Base::AccumulatorTile; - // Accumulator element - using ElementAccumulator = typename WarpTileIterator::Element; + // Accumulator element + using ElementAccumulator = typename WarpTileIterator::Element; - // Output element - using ElementOutput = typename OutputTileIterator::Element; + // Output element + using ElementOutput = typename OutputTileIterator::Element; - // Output access size - static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + // Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; - // Tensor reference to destination tensor - using TensorRef = typename OutputTileIterator::TensorRef; + // Tensor reference to destination tensor + using TensorRef = typename OutputTileIterator::TensorRef; - // Tensor reference to sync tensor - using SyncTensorRef = - typename cutlass::TensorRef; + // Tensor reference to sync tensor + using SyncTensorRef = + typename cutlass::TensorRef; - // Const tensor reference to source tensor - using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; + // Const tensor reference to source tensor + using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; - // Array type used to output - using OutputAccessType = Array; + // Array type used to output + using OutputAccessType = Array; - // Array type used to output - using OutputAccessType2 = Array; + // Array type used to output + using OutputAccessType2 = Array; - // Array type used by output functor - using AccumulatorAccessType = Array; + // Array type used by output functor + using AccumulatorAccessType = Array; - // Number of warps - using WarpCount = typename Base::WarpCount; + // Number of warps + using WarpCount = typename Base::WarpCount; - struct SharedStorage { - using Element = typename WarpTileIterator::Element; + struct SharedStorage { + using Element = typename WarpTileIterator::Element; - // Tensor reference to shared memory allocation - using TensorRef = typename WarpTileIterator::TensorRef; + // Tensor reference to shared memory allocation + using TensorRef = typename WarpTileIterator::TensorRef; - // Logical shape of the shared memory tile written to by all warps. - using Shape = typename Base::Shape; + // Logical shape of the shared memory tile written to by all warps. + using Shape = typename Base::Shape; - // Shape of the shared memory allocation for the epilogue - using StorageShape = typename Base::SharedStorage::StorageShape; + // Shape of the shared memory allocation for the epilogue + using StorageShape = typename Base::SharedStorage::StorageShape; - // - // Data members - // + // + // Data members + // - AlignedBuffer storage[2]; + AlignedBuffer storage[2]; - // - // Methods - // + // + // Methods + // - // Returns a tensor reference to the shared memory buffer - CUTLASS_DEVICE - TensorRef reference(int i) { - return TensorRef( - storage[i].data(), - Layout::packed({StorageShape::kRow, StorageShape::kColumn})); - } - }; - - static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 - ? Base::kFragmentsPerIteration - : kPartitionsK; - static int constexpr kSmemPointerOffset = - SharedStorage::StorageShape::kCount / kSmemTiles; - - public: - static_assert( - SharedLoadIterator::Fragment::kElements == - OutputTileIterator::Fragment::kElements, - "Mismatch between shared load iterator and output tile iterator."); - - static_assert(OutputTileIterator::kElementsPerAccess, - "OutputTileIterator::kElementsPerAccess must not be zero."); - - static_assert(!(OutputTileIterator::Fragment::kElements % - OutputTileIterator::kElementsPerAccess), - "Divisibility"); - - private: - // Loads fragment from shared memory aligned with output tensor - SharedLoadIterator shared_load_iterator0_; - SharedLoadIterator shared_load_iterator1_; - - // Stores a warp's fragment of accumulators to SMEM - WarpTileIterator warp_tile_iterator0_; - WarpTileIterator warp_tile_iterator1_; - - public: - // Constructor + // Returns a tensor reference to the shared memory buffer CUTLASS_DEVICE - DualEpilogue( - SharedStorage &shared_storage, // Shared storage object // NOLINT - int thread_idx, // ID of a thread within the threadblock - int warp_idx, // ID of warp within threadblock - int lane_idx // Id of thread within warp - ) - : shared_load_iterator0_(shared_storage.reference(0), thread_idx), - shared_load_iterator1_(shared_storage.reference(1), thread_idx), - warp_tile_iterator0_(shared_storage.reference(0), lane_idx), - warp_tile_iterator1_(shared_storage.reference(1), lane_idx) { - int warp_k = warp_idx / (WarpCount::kM * WarpCount::kN); - int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN); - int warp_m = warp_mn % WarpCount::kM; - int warp_n = warp_mn / WarpCount::kM; + TensorRef reference(int i) { + return TensorRef( + storage[i].data(), + Layout::packed({StorageShape::kRow, StorageShape::kColumn})); + } + }; - MatrixCoord warp_offset{warp_k * WarpCount::kM + warp_m, warp_n}; + static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 + ? Base::kFragmentsPerIteration + : kPartitionsK; + static int constexpr kSmemPointerOffset = + SharedStorage::StorageShape::kCount / kSmemTiles; - warp_tile_iterator0_.add_tile_offset(warp_offset); - warp_tile_iterator1_.add_tile_offset(warp_offset); + public: + static_assert( + SharedLoadIterator::Fragment::kElements == + OutputTileIterator::Fragment::kElements, + "Mismatch between shared load iterator and output tile iterator."); + + static_assert(OutputTileIterator::kElementsPerAccess, + "OutputTileIterator::kElementsPerAccess must not be zero."); + + static_assert(!(OutputTileIterator::Fragment::kElements % + OutputTileIterator::kElementsPerAccess), + "Divisibility"); + + private: + // Loads fragment from shared memory aligned with output tensor + SharedLoadIterator shared_load_iterator0_; + SharedLoadIterator shared_load_iterator1_; + + // Stores a warp's fragment of accumulators to SMEM + WarpTileIterator warp_tile_iterator0_; + WarpTileIterator warp_tile_iterator1_; + + public: + // Constructor + CUTLASS_DEVICE + DualEpilogue( + SharedStorage &shared_storage, // Shared storage object // NOLINT + int thread_idx, // ID of a thread within the threadblock + int warp_idx, // ID of warp within threadblock + int lane_idx // Id of thread within warp + ) + : shared_load_iterator0_(shared_storage.reference(0), thread_idx), + shared_load_iterator1_(shared_storage.reference(1), thread_idx), + warp_tile_iterator0_(shared_storage.reference(0), lane_idx), + warp_tile_iterator1_(shared_storage.reference(1), lane_idx) { + int warp_k = warp_idx / (WarpCount::kM * WarpCount::kN); + int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN); + int warp_m = warp_mn % WarpCount::kM; + int warp_n = warp_mn / WarpCount::kM; + + MatrixCoord warp_offset{warp_k * WarpCount::kM + warp_m, warp_n}; + + warp_tile_iterator0_.add_tile_offset(warp_offset); + warp_tile_iterator1_.add_tile_offset(warp_offset); + } + + // Streams the result to global memory + CUTLASS_DEVICE + void operator()(OutputOp0 const &output_op0, + OutputOp1 const &output_op1, + OutputOp2 const &output_op2, + OutputTileIterator dest0, + OutputTileIterator dest1, + OutputTileIterator2 dest2, + AccumulatorTile const &accumulator0, + AccumulatorTile const &accumulator1, + OutputTileIterator source_iterator[2], + bool writeToD2 // true if it's the final split-k + ) { + // TODO: Implement when no source is needed // NOLINT + typename OutputTileIterator::Fragment source_fragment[2]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + source_fragment[i].clear(); } - // Streams the result to global memory - CUTLASS_DEVICE - void operator()(OutputOp0 const &output_op0, - OutputOp1 const &output_op1, - OutputOp2 const &output_op2, - OutputTileIterator dest0, - OutputTileIterator dest1, - OutputTileIterator2 dest2, - AccumulatorTile const &accumulator0, - AccumulatorTile const &accumulator1, - OutputTileIterator source_iterator[2], - bool writeToD2 // true if it's the final split-k - ) { - // TODO: Implement when no source is needed // NOLINT - typename OutputTileIterator::Fragment source_fragment[2]; + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator[2] = {accumulator0, + accumulator1}; + + // + // Iterate over accumulator tile + // + +#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + // + // Load the source + // + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < 2; ++i) { - source_fragment[i].clear(); + source_iterator[i].load(source_fragment[i]); + ++source_iterator[i]; } // - // Iterator over warp-level accumulator fragment + // Convert and store fragment // - AccumulatorFragmentIterator accum_fragment_iterator[2] = {accumulator0, - accumulator1}; + __syncthreads(); + + acc2smem_source_needed>::push(iter, + accum_fragment_iterator[0], + this->warp_tile_iterator0_); + acc2smem_source_needed>::push(iter, + accum_fragment_iterator[1], + this->warp_tile_iterator1_); + + __syncthreads(); // - // Iterate over accumulator tile + // Load fragments from shared memory // - #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) - for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { - // - // Load the source - // + typename SharedLoadIterator::Fragment + aligned_accum_fragment0[kPartitionsK]; + typename SharedLoadIterator::Fragment + aligned_accum_fragment1[kPartitionsK]; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 2; ++i) { - source_iterator[i].load(source_fragment[i]); - ++source_iterator[i]; - } + shared_load_iterator0_.load(aligned_accum_fragment0[0]); + shared_load_iterator1_.load(aligned_accum_fragment1[0]); - // - // Convert and store fragment - // - - __syncthreads(); - - acc2smem_source_needed>::push(iter, - accum_fragment_iterator[0], - this->warp_tile_iterator0_); - acc2smem_source_needed>::push(iter, - accum_fragment_iterator[1], - this->warp_tile_iterator1_); - - __syncthreads(); - - // - // Load fragments from shared memory - // - - typename SharedLoadIterator::Fragment - aligned_accum_fragment0[kPartitionsK]; - typename SharedLoadIterator::Fragment - aligned_accum_fragment1[kPartitionsK]; - - shared_load_iterator0_.load(aligned_accum_fragment0[0]); - shared_load_iterator1_.load(aligned_accum_fragment1[0]); - - // If the number of k-slices is > 1 - perform a reduction amongst the - // k-slices - if (kPartitionsK > 1) { - plus add_fragments; - - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < kPartitionsK; ++i) { - shared_load_iterator0_.add_pointer_offset(kSmemPointerOffset); - shared_load_iterator1_.add_pointer_offset(kSmemPointerOffset); - shared_load_iterator0_.load(aligned_accum_fragment0[i]); - shared_load_iterator1_.load(aligned_accum_fragment1[i]); - aligned_accum_fragment0[0] = add_fragments( - aligned_accum_fragment0[0], aligned_accum_fragment0[i]); - aligned_accum_fragment1[0] = add_fragments( - aligned_accum_fragment1[0], aligned_accum_fragment1[i]); - } - - shared_load_iterator0_.add_pointer_offset((1 - kPartitionsK) * - kSmemPointerOffset); - shared_load_iterator1_.add_pointer_offset((1 - kPartitionsK) * - kSmemPointerOffset); - } - - // - // Compute the output result - // - - typename OutputTileIterator::Fragment output_fragment[2]; - typename OutputTileIterator2::Fragment output_fragment_final; - - apply_output_operator_(output_fragment, - output_fragment_final, - output_op0, - output_op1, - output_op2, - aligned_accum_fragment0[0], - aligned_accum_fragment1[0], - source_fragment); - - // - // Store the final result - // - - if (kStoreD0) { - dest0.store(output_fragment[0]); - ++dest0; - } - if (kStoreD1) { - dest1.store(output_fragment[1]); - ++dest1; - } - if (writeToD2) { - dest2.store(output_fragment_final); - ++dest2; - } - } - } - - private: - static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, - "One of these must be exactly 1."); - - template - struct acc2smem_source_needed; - - template - struct acc2smem_source_needed> { - template - CUTLASS_DEVICE static void helper( - AccumulatorFragmentIterator accum_fragment_iterator, - WarpTileIterator &warp_tile_iterator) { // NOLINT - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Advance; i++) { - ++accum_fragment_iterator; - } - - typename AccumulatorFragmentIterator::Fragment accum_fragment; - accum_fragment_iterator.load(accum_fragment); - warp_tile_iterator.store(accum_fragment); - } - - CUTLASS_DEVICE - static void push(size_t pos, - AccumulatorFragmentIterator const &iterator_begin, - WarpTileIterator &warp_tile_iterator) { // NOLINT - int dummy[] = {(pos == Seq) && - (helper(iterator_begin, warp_tile_iterator), 0)...}; - } - }; - - // Helper to invoke the output functor over each vector of output - CUTLASS_DEVICE - void apply_output_operator_( - typename OutputTileIterator::Fragment (&output_fragment)[2], - typename OutputTileIterator2::Fragment &output_fragment_final, // NOLINT - OutputOp0 const &output_op0, - OutputOp1 const &output_op1, - OutputOp2 const &output_op2, - typename SharedLoadIterator::Fragment const &aligned_accum_fragment0, - typename SharedLoadIterator::Fragment const &aligned_accum_fragment1, - typename OutputTileIterator::Fragment const (&source_fragment)[2]) { - OutputAccessType *output_frag_ptr[2] = { - reinterpret_cast(&output_fragment[0]), - reinterpret_cast(&output_fragment[1])}; - - OutputAccessType2 *output_frag_final_ptr = - reinterpret_cast(&output_fragment_final); - - AccumulatorAccessType const *compute_frag_ptr[2] = { - reinterpret_cast( - &aligned_accum_fragment0), - reinterpret_cast( - &aligned_accum_fragment1)}; - - OutputAccessType const *source_frag_ptr[2] = { - reinterpret_cast(&source_fragment[0]), - reinterpret_cast(&source_fragment[1])}; - - int const kOutputOpIterations = OutputTileIterator::Fragment::kElements / - OutputTileIterator::kElementsPerAccess; + // If the number of k-slices is > 1 - perform a reduction amongst the + // k-slices + if (kPartitionsK > 1) { + plus add_fragments; CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kOutputOpIterations; ++i) { - // Call the output operators - output_frag_ptr[0][i] = - output_op0(compute_frag_ptr[0][i], source_frag_ptr[0][i]); - output_frag_ptr[1][i] = - output_op1(compute_frag_ptr[1][i], source_frag_ptr[1][i]); - output_frag_final_ptr[i] = - output_op2(output_frag_ptr[0][i], output_frag_ptr[1][i]); + for (int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator0_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator1_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator0_.load(aligned_accum_fragment0[i]); + shared_load_iterator1_.load(aligned_accum_fragment1[i]); + aligned_accum_fragment0[0] = add_fragments( + aligned_accum_fragment0[0], aligned_accum_fragment0[i]); + aligned_accum_fragment1[0] = add_fragments( + aligned_accum_fragment1[0], aligned_accum_fragment1[i]); } + + shared_load_iterator0_.add_pointer_offset((1 - kPartitionsK) * + kSmemPointerOffset); + shared_load_iterator1_.add_pointer_offset((1 - kPartitionsK) * + kSmemPointerOffset); + } + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment[2]; + typename OutputTileIterator2::Fragment output_fragment_final; + + apply_output_operator_(output_fragment, + output_fragment_final, + output_op0, + output_op1, + output_op2, + aligned_accum_fragment0[0], + aligned_accum_fragment1[0], + source_fragment); + + // + // Store the final result + // + + if (kStoreD0) { + dest0.store(output_fragment[0]); + ++dest0; + } + if (kStoreD1) { + dest1.store(output_fragment[1]); + ++dest1; + } + if (writeToD2) { + dest2.store(output_fragment_final); + ++dest2; + } } + } + + private: + static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, + "One of these must be exactly 1."); + + template + struct acc2smem_source_needed; + + template + struct acc2smem_source_needed> { + template + CUTLASS_DEVICE static void helper( + AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator &warp_tile_iterator) { // NOLINT + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + accum_fragment_iterator.load(accum_fragment); + warp_tile_iterator.store(accum_fragment); + } + + CUTLASS_DEVICE + static void push(size_t pos, + AccumulatorFragmentIterator const &iterator_begin, + WarpTileIterator &warp_tile_iterator) { // NOLINT + int dummy[] = {(pos == Seq) && + (helper(iterator_begin, warp_tile_iterator), 0)...}; + } + }; + + // Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_( + typename OutputTileIterator::Fragment (&output_fragment)[2], + typename OutputTileIterator2::Fragment &output_fragment_final, // NOLINT + OutputOp0 const &output_op0, + OutputOp1 const &output_op1, + OutputOp2 const &output_op2, + typename SharedLoadIterator::Fragment const &aligned_accum_fragment0, + typename SharedLoadIterator::Fragment const &aligned_accum_fragment1, + typename OutputTileIterator::Fragment const (&source_fragment)[2]) { + OutputAccessType *output_frag_ptr[2] = { + reinterpret_cast(&output_fragment[0]), + reinterpret_cast(&output_fragment[1])}; + + OutputAccessType2 *output_frag_final_ptr = + reinterpret_cast(&output_fragment_final); + + AccumulatorAccessType const *compute_frag_ptr[2] = { + reinterpret_cast( + &aligned_accum_fragment0), + reinterpret_cast( + &aligned_accum_fragment1)}; + + OutputAccessType const *source_frag_ptr[2] = { + reinterpret_cast(&source_fragment[0]), + reinterpret_cast(&source_fragment[1])}; + + int const kOutputOpIterations = OutputTileIterator::Fragment::kElements / + OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + // Call the output operators + output_frag_ptr[0][i] = + output_op0(compute_frag_ptr[0][i], source_frag_ptr[0][i]); + output_frag_ptr[1][i] = + output_op1(compute_frag_ptr[1][i], source_frag_ptr[1][i]); + output_frag_final_ptr[i] = + output_op2(output_frag_ptr[0][i], output_frag_ptr[1][i]); + } + } }; } // namespace threadblock diff --git a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_block_gemm_act_template_3x.h b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_block_gemm_act_template_3x.h index 1a5b838b81..f95f20bfbf 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_block_gemm_act_template_3x.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_block_gemm_act_template_3x.h @@ -34,34 +34,35 @@ using namespace cute; template < - typename InputType = phi::dtype::float8_e4m3fn, - typename OutType = phi::dtype::float16, - bool hasbias = false, - template typename Activation = cutlass::epilogue::thread::Identity, - typename TileShape = Shape<_128, _128, _128>, - typename ClusterShape = Shape<_1, _2, _1>, - typename KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<1>, - typename EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative, - typename TileSchedule = cutlass::gemm::PersistentScheduler, - typename SM = cutlass::arch::Sm90 -> -bool dispatch_fuse_block_gemm_c3x(GemmEpilogueAllParams params){ - using ElementA = typename std::conditional_t, - cutlass::float_e4m3_t, - cutlass::float_e5m2_t>; + typename InputType = phi::dtype::float8_e4m3fn, + typename OutType = phi::dtype::float16, + bool hasbias = false, + template typename Activation = cutlass::epilogue::thread::Identity, + typename TileShape = Shape<_128, _128, _128>, + typename ClusterShape = Shape<_1, _2, _1>, + typename KernelSchedule = cutlass::gemm:: + KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<1>, + typename EpilogueSchedule = + cutlass::epilogue::TmaWarpSpecializedCooperative, + typename TileSchedule = cutlass::gemm::PersistentScheduler, + typename SM = cutlass::arch::Sm90> +bool dispatch_fuse_block_gemm_c3x(GemmEpilogueAllParams params) { + using ElementA = typename std::conditional_t< + std::is_same_v, + cutlass::float_e4m3_t, + cutlass::float_e5m2_t>; using ElementB = ElementA; - using ElementD = typename std::conditional_t, + using ElementD = + typename std::conditional_t, cutlass::bfloat16_t, cutlass::half_t>; - using ElementC = std::conditional_t< - hasbias, - ElementD, - void>; + using ElementC = std::conditional_t; constexpr int ScaleMsPerTile = size<0>(TileShape{}); constexpr int ScaleGranularityM = size<0>(TileShape{}) / ScaleMsPerTile; - static constexpr bool IsStreamK = cute::is_same_v; + static constexpr bool IsStreamK = + cute::is_same_v; using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; @@ -80,37 +81,54 @@ bool dispatch_fuse_block_gemm_c3x(GemmEpilogueAllParams params){ static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; - using FusionOperation = cutlass::epilogue::fusion::LinCombEltAct; + using FusionOperation = + cutlass::epilogue::fusion::LinCombEltAct; - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - SM, cutlass::arch::OpClassTensorOp, - TileShape, - ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, ElementCompute, - ElementC, LayoutC, AlignmentC, - ElementD, LayoutD, AlignmentD, - EpilogueSchedule, - FusionOperation - >::CollectiveOp; + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + SM, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + AlignmentC, + ElementD, + LayoutD, + AlignmentD, + EpilogueSchedule, + FusionOperation>::CollectiveOp; - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - SM, cutlass::arch::OpClassTensorOp, - ElementA, LayoutA, AlignmentA, - ElementB, LayoutB, AlignmentB, - ElementAccumulator, - TileShape, - ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - KernelSchedule - >::CollectiveOp; + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + SM, + cutlass::arch::OpClassTensorOp, + ElementA, + LayoutA, + AlignmentA, + ElementB, + LayoutB, + AlignmentB, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue, - TileSchedule - >; + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal, + CollectiveMainloop, + CollectiveEpilogue, + TileSchedule>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -132,23 +150,27 @@ bool dispatch_fuse_block_gemm_c3x(GemmEpilogueAllParams params){ StrideD stride_D{params.ldd, cute::Int<1>{}, params.M * params.ldd}; auto a_ptr = reinterpret_cast(const_cast(params.A)); - auto a_scale_ptr = reinterpret_cast(const_cast(params.A_scale)); + auto a_scale_ptr = + reinterpret_cast(const_cast(params.A_scale)); auto b_ptr = reinterpret_cast(const_cast(params.B)); - auto b_scale_ptr = reinterpret_cast(const_cast(params.B_scale)); + auto b_scale_ptr = + reinterpret_cast(const_cast(params.B_scale)); auto c_ptr = reinterpret_cast(const_cast(params.bias)); auto d_ptr = reinterpret_cast(params.D); - ProblemShapeType problem_size = ProblemShapeType{params.M, params.N, params.K, params.batch_count}; + ProblemShapeType problem_size = + ProblemShapeType{params.M, params.N, params.K, params.batch_count}; typename Gemm::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - problem_size, - {a_ptr, stride_A, b_ptr, stride_B, - a_scale_ptr, b_scale_ptr}, - {{params.scale}, // epilogue.thread - c_ptr, stride_C, d_ptr, stride_D} - }; - if constexpr (hasbias){ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {a_ptr, stride_A, b_ptr, stride_B, a_scale_ptr, b_scale_ptr}, + {{params.scale}, // epilogue.thread + c_ptr, + stride_C, + d_ptr, + stride_D}}; + if constexpr (hasbias) { arguments.epilogue.thread.beta = 1.0; } @@ -162,12 +184,12 @@ bool dispatch_fuse_block_gemm_c3x(GemmEpilogueAllParams params){ arguments.scheduler.reduction_mode = ReductionMode::Nondeterministic; } - Gemm gemm_op; cutlass::Status status = gemm_op.can_implement(arguments); if (status != cutlass::Status::kSuccess) { - std::cout << "Gemm::can_implement() failed. " << cutlassGetStatusString(status) << std::endl; + std::cout << "Gemm::can_implement() failed. " + << cutlassGetStatusString(status) << std::endl; return false; } size_t workspace_size = Gemm::get_workspace_size(arguments); @@ -176,7 +198,8 @@ bool dispatch_fuse_block_gemm_c3x(GemmEpilogueAllParams params){ status = gemm_op(arguments, workspace->ptr(), params.stream); if (status != cutlass::Status::kSuccess) { - std::cout << "Gemm::run() failed." << cutlassGetStatusString(status) << std::endl; + std::cout << "Gemm::run() failed." << cutlassGetStatusString(status) + << std::endl; return false; } return true; diff --git a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_act_template_3x.h b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_act_template_3x.h index 632cdc296a..ec484ac743 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_act_template_3x.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_act_template_3x.h @@ -28,60 +28,67 @@ #include "cutlass_extensions/gemm/collective/collective_builder_gated.hpp" #include "cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp" -template class Activation = - cutlass::epilogue::thread::SiLu, + template + class Activation = cutlass::epilogue::thread::SiLu, bool SwapAB = true> bool dispatch_dual_gemm_act_sm90(DualGemmEpilogueAllParams params) { using namespace cute; using ElementA = typename std::conditional_t< std::is_same_v, - cutlass::float_e4m3_t, cutlass::float_e5m2_t>; - using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand + cutlass::float_e4m3_t, + cutlass::float_e5m2_t>; + using LayoutA = + cutlass::layout::RowMajor; // Layout type for A matrix operand static constexpr int AlignmentA = 128 / cutlass::sizeof_bits< - ElementA>::value; // Memory access granularity/alignment of A - // matrix in units of elements (up to 16 bytes) + ElementA>::value; // Memory access granularity/alignment of A + // matrix in units of elements (up to 16 bytes) // B matrix configuration - using ElementB = ElementA; // Element type for B matrix operand + using ElementB = ElementA; // Element type for B matrix operand using LayoutB = - cutlass::layout::ColumnMajor; // Layout type for B matrix operand + cutlass::layout::ColumnMajor; // Layout type for B matrix operand static constexpr int AlignmentB = 128 / cutlass::sizeof_bits< - ElementB>::value; // Memory access granularity/alignment of B - // matrix in units of elements (up to 16 bytes) + ElementB>::value; // Memory access granularity/alignment of B + // matrix in units of elements (up to 16 bytes) - using ElementC = ElementA; // Element type for C matrix operands + using ElementC = ElementA; // Element type for C matrix operands - using LayoutC = cute::conditional_t; static constexpr int AlignmentC = 128 / cutlass::sizeof_bits< - ElementC>::value; // Memory access granularity/alignment of C matrices - // in units of elements (up to 16 bytes) + ElementC>::value; // Memory access granularity/alignment of C + // matrices in units of elements (up to 16 bytes) // Output matrix configuration - using ElementOutput = ElementA; // Element type for output matrix operands + using ElementOutput = ElementA; // Element type for output matrix operands // using LayoutOutput = cutlass::layout::RowMajor; // Layout type for output // matrix operands - using LayoutOutput = cute::conditional_t; static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value; // Multiply-accumulate blocking/pipelining details - using ElementAccumulator = float; // Element type for internal accumulation - using ElementCompute = float; // Element type for compute - using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that - // supports the intended feature - using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag - using TileShape = CTAShape; // Threadblock-level tile size + using ElementAccumulator = float; // Element type for internal accumulation + using ElementCompute = float; // Element type for compute + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that + // supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + using TileShape = CTAShape; // Threadblock-level tile size using KernelSchedule = MainloopScheduleType; using EpilogueSchedule = EpilogueScheduleType; using TileScheduler = TileSchedulerType; @@ -94,22 +101,46 @@ bool dispatch_dual_gemm_act_sm90(DualGemmEpilogueAllParams params) { using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, - ElementAccumulator, ElementAccumulator, ElementC, LayoutC, AlignmentC, - ElementOutput, LayoutOutput, AlignmentOutput, EpilogueSchedule, + ArchTag, + OperatorClass, + TileShape, + ClusterShape, + EpilogueTileType, + ElementAccumulator, + ElementAccumulator, + ElementC, + LayoutC, + AlignmentC, + ElementOutput, + LayoutOutput, + AlignmentOutput, + EpilogueSchedule, FusionOperation>::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilderGated< - ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, - LayoutB, AlignmentB, ElementAccumulator, TileShape, ClusterShape, + ArchTag, + OperatorClass, + ElementA, + LayoutA, + AlignmentA, + ElementB, + LayoutB, + AlignmentB, + ElementAccumulator, + TileShape, + ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>, - KernelSchedule, Activation, SwapAB>::CollectiveOp; + KernelSchedule, + Activation, + SwapAB>::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversalGated< - Shape, // Indicates ProblemShape - CollectiveMainloop, CollectiveEpilogue, TileScheduler>; + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -141,7 +172,7 @@ bool dispatch_dual_gemm_act_sm90(DualGemmEpilogueAllParams params) { cutlass::gemm::GemmUniversalMode::kGemm, {arg_m, arg_n, params.K, params.batch_count}, {ptr_A, stride_A, ptr_B0, ptr_B1, stride_B, params.scale0, params.scale1}, - {{}, // epilogue.thread + {{}, // epilogue.thread nullptr, stride_C, reinterpret_cast(params.D), diff --git a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_geglu_template.h b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_geglu_template.h index 0e18c2a389..5d2a9638b3 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_geglu_template.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_geglu_template.h @@ -21,9 +21,15 @@ #include "fp8_gemm_fused/dual_gemm/device/dual_gemm.h" #include "fp8_gemm_fused/dual_gemm/thread/left_gelu_and_mul.h" -template +template bool dispatch_dual_gemm_geglu(DualGemmEpilogueAllParams params) { using ElementInputA = typename std::conditional_t< std::is_same_v, @@ -73,8 +79,8 @@ bool dispatch_dual_gemm_geglu(DualGemmEpilogueAllParams params) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? static constexpr auto ScaleType = - hasbias? cutlass::epilogue::thread::ScaleType::NoBetaScaling - : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; + hasbias ? cutlass::epilogue::thread::ScaleType::NoBetaScaling + : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; using EpilogueOp0 = cutlass::epilogue::thread::LinearCombination< ElementInputC, // <- data type of output matrix @@ -86,7 +92,7 @@ bool dispatch_dual_gemm_geglu(DualGemmEpilogueAllParams params) { ElementAccumulator, // <- data type of accumulator ElementComputeEpilogue, ScaleType>; // <- data type for alpha/beta in linear - // combination function + // combination function using EpilogueOp1 = cutlass::epilogue::thread::LeftGELUAndMul< ElementOutput, @@ -143,11 +149,23 @@ bool dispatch_dual_gemm_geglu(DualGemmEpilogueAllParams params) { params.lda}, {reinterpret_cast(const_cast(params.B0)), params.ldb}, - hasbias? typename cutlass::TensorRef{reinterpret_cast(const_cast(params.bias0)), 0} : nullptr_ref, + hasbias + ? typename cutlass::TensorRef< + typename Gemm::ElementC, + typename Gemm::LayoutC>{reinterpret_cast( + const_cast(params.bias0)), + 0} + : nullptr_ref, nullptr_ref, {reinterpret_cast(const_cast(params.B1)), params.ldb}, - hasbias? typename cutlass::TensorRef{reinterpret_cast(const_cast(params.bias1)), 0} : nullptr_ref, + hasbias + ? typename cutlass::TensorRef< + typename Gemm::ElementC, + typename Gemm::LayoutC>{reinterpret_cast( + const_cast(params.bias1)), + 0} + : nullptr_ref, nullptr_ref, {reinterpret_cast(const_cast(params.D)), params.ldd}, diff --git a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_swiglu_template.h b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_swiglu_template.h index b5de12e7e9..dc78839750 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_swiglu_template.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_swiglu_template.h @@ -21,9 +21,15 @@ #include "fp8_gemm_fused/dual_gemm/device/dual_gemm.h" #include "fp8_gemm_fused/dual_gemm/thread/left_silu_and_mul.h" -template +template bool dispatch_dual_gemm_swiglu(DualGemmEpilogueAllParams params) { using ElementInputA = typename std::conditional_t< std::is_same_v, @@ -73,8 +79,8 @@ bool dispatch_dual_gemm_swiglu(DualGemmEpilogueAllParams params) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? static constexpr auto ScaleType = - hasbias? cutlass::epilogue::thread::ScaleType::NoBetaScaling - : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; + hasbias ? cutlass::epilogue::thread::ScaleType::NoBetaScaling + : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; using EpilogueOp0 = cutlass::epilogue::thread::LinearCombination< ElementInputC, // <- data type of output matrix @@ -86,7 +92,7 @@ bool dispatch_dual_gemm_swiglu(DualGemmEpilogueAllParams params) { ElementAccumulator, // <- data type of accumulator ElementComputeEpilogue, ScaleType>; // <- data type for alpha/beta in linear - // combination function + // combination function using EpilogueOp1 = cutlass::epilogue::thread::LeftSiLUAndMul< ElementOutput, @@ -129,9 +135,9 @@ bool dispatch_dual_gemm_swiglu(DualGemmEpilogueAllParams params) { cutlass::gemm::GemmCoord problem_size = cutlass::gemm::GemmCoord{params.M, params.N, params.K}; - cutlass::gemm::DualGemmMode mode = params.batch_count > 1 ? - cutlass::gemm::DualGemmMode::kBatched : - cutlass::gemm::DualGemmMode::kGemm; + cutlass::gemm::DualGemmMode mode = params.batch_count > 1 + ? cutlass::gemm::DualGemmMode::kBatched + : cutlass::gemm::DualGemmMode::kGemm; typename cutlass::TensorRef nullptr_ref{}; @@ -144,11 +150,23 @@ bool dispatch_dual_gemm_swiglu(DualGemmEpilogueAllParams params) { params.lda}, {reinterpret_cast(const_cast(params.B0)), params.ldb}, - hasbias ? typename cutlass::TensorRef{reinterpret_cast(const_cast(params.bias0)), 0} : nullptr_ref, + hasbias + ? typename cutlass::TensorRef< + typename Gemm::ElementC, + typename Gemm::LayoutC>{reinterpret_cast( + const_cast(params.bias0)), + 0} + : nullptr_ref, nullptr_ref, {reinterpret_cast(const_cast(params.B1)), params.ldb}, - hasbias ? typename cutlass::TensorRef{reinterpret_cast(const_cast(params.bias1)), 0} : nullptr_ref, + hasbias + ? typename cutlass::TensorRef< + typename Gemm::ElementC, + typename Gemm::LayoutC>{reinterpret_cast( + const_cast(params.bias1)), + 0} + : nullptr_ref, nullptr_ref, {reinterpret_cast(const_cast(params.D)), params.ldd}, diff --git a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_act_template_3x.h b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_act_template_3x.h index c470151070..83ded56270 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_act_template_3x.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_act_template_3x.h @@ -26,26 +26,28 @@ #include "cutlass/gemm/kernel/tile_scheduler.hpp" #include "cutlass/util/packed_stride.hpp" -template < - typename InputType, - typename OutType, - bool hasbias, - template typename Activation, - typename TileShape, - typename ClusterShape, - typename KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, - typename EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized, - typename SM = cutlass::arch::Sm90> +template + typename Activation, + typename TileShape, + typename ClusterShape, + typename KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, + typename EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized, + typename SM = cutlass::arch::Sm90> bool dispatch_fuse_gemm_act_sm90(GemmEpilogueAllParams params) { using namespace cute; using ElementA = typename std::conditional_t< std::is_same_v, - cutlass::float_e4m3_t, cutlass::float_e5m2_t>; + cutlass::float_e4m3_t, + cutlass::float_e5m2_t>; using ElementB = ElementA; using ElementD = typename std::conditional_t, - cutlass::bfloat16_t, cutlass::half_t>; + cutlass::bfloat16_t, + cutlass::half_t>; using ElementC = std::conditional_t; using LayoutA = cutlass::layout::RowMajor; @@ -66,29 +68,53 @@ bool dispatch_fuse_gemm_act_sm90(GemmEpilogueAllParams params) { static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; using FusionOperation = - cutlass::epilogue::fusion::LinCombEltAct; + cutlass::epilogue::fusion::LinCombEltAct; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - SM, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, - ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, - AlignmentD, EpilogueSchedule, FusionOperation>::CollectiveOp; + SM, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + AlignmentC, + ElementD, + LayoutD, + AlignmentD, + EpilogueSchedule, + FusionOperation>::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - SM, cutlass::arch::OpClassTensorOp, ElementA, LayoutA, AlignmentA, - ElementB, LayoutB, AlignmentB, ElementAccumulator, TileShape, + SM, + cutlass::arch::OpClassTensorOp, + ElementA, + LayoutA, + AlignmentA, + ElementB, + LayoutB, + AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>, KernelSchedule>::CollectiveOp; - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, CollectiveMainloop, CollectiveEpilogue, - cutlass::gemm::PersistentScheduler>; + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::PersistentScheduler>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -120,7 +146,7 @@ bool dispatch_fuse_gemm_act_sm90(GemmEpilogueAllParams params) { typename Gemm::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGemm, problem_size, {a_ptr, stride_A, b_ptr, stride_B}, - {{params.scale}, // epilogue.thread + {{params.scale}, // epilogue.thread c_ptr, stride_C, d_ptr, diff --git a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_gelu_template.h b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_gelu_template.h index 32b8a132e8..50793eff0a 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_gelu_template.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_gelu_template.h @@ -20,9 +20,14 @@ #include "cutlass/gemm/device/gemm_universal.h" #include "cutlass/gemm/device/gemm_splitk_parallel.h" -template +template bool dispatch_fuse_gemm_gelu(GemmEpilogueAllParams params) { using ElementInputA = typename std::conditional_t< std::is_same_v, @@ -64,9 +69,8 @@ bool dispatch_fuse_gemm_gelu(GemmEpilogueAllParams params) { using ShapeMMAOp = MMAShape; // <- MMA Op tile static constexpr auto ScaleType = - hasbias? cutlass::epilogue::thread::ScaleType::NoBetaScaling - : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; - + hasbias ? cutlass::epilogue::thread::ScaleType::NoBetaScaling + : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; // This code section describes how threadblocks are scheduled on GPU using SwizzleThreadBlock = @@ -82,7 +86,7 @@ bool dispatch_fuse_gemm_gelu(GemmEpilogueAllParams params) { ElementAccumulator, // <- data type of accumulator ElementComputeEpilogue, ScaleType>; // <- data type for alpha/beta in linear - // combination function + // combination function // Number of pipelines you want to use constexpr int NumStages = Stages; @@ -164,10 +168,15 @@ bool dispatch_fuse_gemm_gelu(GemmEpilogueAllParams params) { return true; } - -template +template bool dispatch_fuse_gemm_split_k_gelu(GemmEpilogueAllParams params) { using ElementInputA = typename std::conditional_t< std::is_same_v, @@ -209,9 +218,8 @@ bool dispatch_fuse_gemm_split_k_gelu(GemmEpilogueAllParams params) { using ShapeMMAOp = MMAShape; // <- MMA Op tile static constexpr auto ScaleType = - hasbias? cutlass::epilogue::thread::ScaleType::NoBetaScaling - : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; - + hasbias ? cutlass::epilogue::thread::ScaleType::NoBetaScaling + : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; // This code section describes how threadblocks are scheduled on GPU using SwizzleThreadBlock = @@ -227,7 +235,7 @@ bool dispatch_fuse_gemm_split_k_gelu(GemmEpilogueAllParams params) { ElementAccumulator, // <- data type of accumulator ElementComputeEpilogue, ScaleType>; // <- data type for alpha/beta in linear - // combination function + // combination function // Number of pipelines you want to use constexpr int NumStages = Stages; @@ -309,10 +317,14 @@ bool dispatch_fuse_gemm_split_k_gelu(GemmEpilogueAllParams params) { return true; } - -template +template bool dispatch_fuse_gemm_split_k_gelu(GemmEpilogueAllParams params) { using ElementInputA = typename std::conditional_t< std::is_same_v, @@ -354,8 +366,8 @@ bool dispatch_fuse_gemm_split_k_gelu(GemmEpilogueAllParams params) { using ShapeMMAOp = MMAShape; static constexpr auto ScaleType = - hasbias? cutlass::epilogue::thread::ScaleType::NoBetaScaling - : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; + hasbias ? cutlass::epilogue::thread::ScaleType::NoBetaScaling + : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; using EpilogueOp = cutlass::epilogue::thread::LinearCombinationGELU< ElementOutput, // <- data type of output matrix @@ -367,50 +379,55 @@ bool dispatch_fuse_gemm_split_k_gelu(GemmEpilogueAllParams params) { ElementAccumulator, // <- data type of accumulator ElementComputeEpilogue, ScaleType>; // <- data type for alpha/beta in linear - // combination function + // combination function // Number of pipelines you want to use constexpr int NumStages = Stages; - using ConvertScaledOp = cutlass::epilogue::thread::Convert< - ElementAccumulator, - cutlass::gemm::device::DefaultGemmConfiguration::EpilogueOutputOp::kCount, - ElementAccumulator>; + using ConvertScaledOp = cutlass::epilogue::thread::Convert< + ElementAccumulator, + cutlass::gemm::device::DefaultGemmConfiguration< + cutlass::arch::OpClassSimt, + SmArch, + ElementInputA, + ElementInputB, + ElementAccumulator, + ElementAccumulator>::EpilogueOutputOp::kCount, + ElementAccumulator>; - /// Reduction operator - using ReductionOp = cutlass::reduction::thread::ReduceAdd< - ElementAccumulator, typename EpilogueOp::ElementAccumulator, - EpilogueOp::kCount>; + /// Reduction operator + using ReductionOp = cutlass::reduction::thread::ReduceAdd< + ElementAccumulator, + typename EpilogueOp::ElementAccumulator, + EpilogueOp::kCount>; - /// Threadblock-level swizzling operator - using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle; + /// Threadblock-level swizzling operator + using ThreadblockSwizzle = + cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle; - /// Operation performed by GEMM - using Operator = cutlass::arch::OpMultiplyAddFastAccum; - - using Gemm = cutlass::gemm::device::GemmSplitKParallel; + /// Operation performed by GEMM + using Operator = cutlass::arch::OpMultiplyAddFastAccum; + using Gemm = cutlass::gemm::device::GemmSplitKParallel; cutlass::gemm::GemmCoord problem_size = cutlass::gemm::GemmCoord{params.M, params.N, params.K}; @@ -421,15 +438,18 @@ bool dispatch_fuse_gemm_split_k_gelu(GemmEpilogueAllParams params) { // Split K dimension into 16 partitions int split_k_slices = params.split_k; - // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch - // instantiated CUTLASS kernel - typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication - {reinterpret_cast(const_cast(params.A)),params.lda}, - {reinterpret_cast(const_cast(params.B)),params.ldb}, - {reinterpret_cast(const_cast(params.bias)),0}, - {reinterpret_cast(params.D),params.ldd}, - {alpha, beta}, // <- tuple of alpha and beta - split_k_slices}; // <- k-dimension split factor + // Create a tuple of gemm kernel arguments. This is later passed as arguments + // to launch instantiated CUTLASS kernel + typename Gemm::Arguments arguments{ + problem_size, // <- problem size of matrix multiplication + {reinterpret_cast(const_cast(params.A)), + params.lda}, + {reinterpret_cast(const_cast(params.B)), + params.ldb}, + {reinterpret_cast(const_cast(params.bias)), 0}, + {reinterpret_cast(params.D), params.ldd}, + {alpha, beta}, // <- tuple of alpha and beta + split_k_slices}; // <- k-dimension split factor Gemm gemm_op; diff --git a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_noact_template.h b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_noact_template.h index 31d2a21e2a..f7f2bbf396 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_noact_template.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_noact_template.h @@ -20,9 +20,14 @@ #include "cutlass/gemm/device/gemm_universal.h" #include "cutlass/gemm/device/gemm_splitk_parallel.h" -template +template bool dispatch_fuse_gemm_noact(GemmEpilogueAllParams params) { using ElementInputA = typename std::conditional_t< std::is_same_v, @@ -64,9 +69,8 @@ bool dispatch_fuse_gemm_noact(GemmEpilogueAllParams params) { using ShapeMMAOp = MMAShape; // <- MMA Op tile static constexpr auto ScaleType = - hasbias? cutlass::epilogue::thread::ScaleType::NoBetaScaling - : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; - + hasbias ? cutlass::epilogue::thread::ScaleType::NoBetaScaling + : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; // This code section describes how threadblocks are scheduled on GPU using SwizzleThreadBlock = @@ -82,7 +86,7 @@ bool dispatch_fuse_gemm_noact(GemmEpilogueAllParams params) { ElementAccumulator, // <- data type of accumulator ElementComputeEpilogue, ScaleType>; // <- data type for alpha/beta in linear - // combination function + // combination function // Number of pipelines you want to use constexpr int NumStages = Stages; @@ -164,10 +168,14 @@ bool dispatch_fuse_gemm_noact(GemmEpilogueAllParams params) { return true; } - -template +template bool dispatch_fuse_gemm_split_k_noact(GemmEpilogueAllParams params) { using ElementInputA = typename std::conditional_t< std::is_same_v, @@ -209,8 +217,8 @@ bool dispatch_fuse_gemm_split_k_noact(GemmEpilogueAllParams params) { using ShapeMMAOp = MMAShape; static constexpr auto ScaleType = - hasbias? cutlass::epilogue::thread::ScaleType::NoBetaScaling - : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; + hasbias ? cutlass::epilogue::thread::ScaleType::NoBetaScaling + : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; using EpilogueOp = cutlass::epilogue::thread::LinearCombination< ElementOutput, // <- data type of output matrix @@ -222,50 +230,55 @@ bool dispatch_fuse_gemm_split_k_noact(GemmEpilogueAllParams params) { ElementAccumulator, // <- data type of accumulator ElementComputeEpilogue, ScaleType>; // <- data type for alpha/beta in linear - // combination function + // combination function // Number of pipelines you want to use constexpr int NumStages = Stages; - using ConvertScaledOp = cutlass::epilogue::thread::Convert< - ElementAccumulator, - cutlass::gemm::device::DefaultGemmConfiguration::EpilogueOutputOp::kCount, - ElementAccumulator>; + using ConvertScaledOp = cutlass::epilogue::thread::Convert< + ElementAccumulator, + cutlass::gemm::device::DefaultGemmConfiguration< + cutlass::arch::OpClassSimt, + SmArch, + ElementInputA, + ElementInputB, + ElementAccumulator, + ElementAccumulator>::EpilogueOutputOp::kCount, + ElementAccumulator>; - /// Reduction operator - using ReductionOp = cutlass::reduction::thread::ReduceAdd< - ElementAccumulator, typename EpilogueOp::ElementAccumulator, - EpilogueOp::kCount>; + /// Reduction operator + using ReductionOp = cutlass::reduction::thread::ReduceAdd< + ElementAccumulator, + typename EpilogueOp::ElementAccumulator, + EpilogueOp::kCount>; - /// Threadblock-level swizzling operator - using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle; + /// Threadblock-level swizzling operator + using ThreadblockSwizzle = + cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle; - /// Operation performed by GEMM - using Operator = cutlass::arch::OpMultiplyAddFastAccum; - - using Gemm = cutlass::gemm::device::GemmSplitKParallel; + /// Operation performed by GEMM + using Operator = cutlass::arch::OpMultiplyAddFastAccum; + using Gemm = cutlass::gemm::device::GemmSplitKParallel; cutlass::gemm::GemmCoord problem_size = cutlass::gemm::GemmCoord{params.M, params.N, params.K}; @@ -276,15 +289,18 @@ bool dispatch_fuse_gemm_split_k_noact(GemmEpilogueAllParams params) { // Split K dimension into 16 partitions int split_k_slices = params.split_k; - // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch - // instantiated CUTLASS kernel - typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication - {reinterpret_cast(const_cast(params.A)),params.lda}, - {reinterpret_cast(const_cast(params.B)),params.ldb}, - {reinterpret_cast(const_cast(params.bias)),0}, - {reinterpret_cast(params.D),params.ldd}, - {alpha, beta}, // <- tuple of alpha and beta - split_k_slices}; // <- k-dimension split factor + // Create a tuple of gemm kernel arguments. This is later passed as arguments + // to launch instantiated CUTLASS kernel + typename Gemm::Arguments arguments{ + problem_size, // <- problem size of matrix multiplication + {reinterpret_cast(const_cast(params.A)), + params.lda}, + {reinterpret_cast(const_cast(params.B)), + params.ldb}, + {reinterpret_cast(const_cast(params.bias)), 0}, + {reinterpret_cast(params.D), params.ldd}, + {alpha, beta}, // <- tuple of alpha and beta + split_k_slices}; // <- k-dimension split factor Gemm gemm_op; diff --git a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_relu_template.h b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_relu_template.h index d2aa189eed..f59b84f59f 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_relu_template.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_relu_template.h @@ -20,9 +20,14 @@ #include "cutlass/gemm/device/gemm_universal.h" #include "cutlass/gemm/device/gemm_splitk_parallel.h" -template +template bool dispatch_fuse_gemm_relu(GemmEpilogueAllParams params) { using ElementInputA = typename std::conditional_t< std::is_same_v, @@ -64,9 +69,8 @@ bool dispatch_fuse_gemm_relu(GemmEpilogueAllParams params) { using ShapeMMAOp = MMAShape; // <- MMA Op tile static constexpr auto ScaleType = - hasbias? cutlass::epilogue::thread::ScaleType::NoBetaScaling - : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; - + hasbias ? cutlass::epilogue::thread::ScaleType::NoBetaScaling + : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; // This code section describes how threadblocks are scheduled on GPU using SwizzleThreadBlock = @@ -82,7 +86,7 @@ bool dispatch_fuse_gemm_relu(GemmEpilogueAllParams params) { ElementAccumulator, // <- data type of accumulator ElementComputeEpilogue, ScaleType>; // <- data type for alpha/beta in linear - // combination function + // combination function // Number of pipelines you want to use constexpr int NumStages = Stages; @@ -164,10 +168,14 @@ bool dispatch_fuse_gemm_relu(GemmEpilogueAllParams params) { return true; } - -template +template bool dispatch_fuse_gemm_split_k_relu(GemmEpilogueAllParams params) { using ElementInputA = typename std::conditional_t< std::is_same_v, @@ -209,8 +217,8 @@ bool dispatch_fuse_gemm_split_k_relu(GemmEpilogueAllParams params) { using ShapeMMAOp = MMAShape; static constexpr auto ScaleType = - hasbias? cutlass::epilogue::thread::ScaleType::NoBetaScaling - : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; + hasbias ? cutlass::epilogue::thread::ScaleType::NoBetaScaling + : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; using EpilogueOp = cutlass::epilogue::thread::LinearCombinationRelu< ElementOutput, // <- data type of output matrix @@ -222,50 +230,55 @@ bool dispatch_fuse_gemm_split_k_relu(GemmEpilogueAllParams params) { ElementAccumulator, // <- data type of accumulator ElementComputeEpilogue, ScaleType>; // <- data type for alpha/beta in linear - // combination function + // combination function // Number of pipelines you want to use constexpr int NumStages = Stages; - using ConvertScaledOp = cutlass::epilogue::thread::Convert< - ElementAccumulator, - cutlass::gemm::device::DefaultGemmConfiguration::EpilogueOutputOp::kCount, - ElementAccumulator>; + using ConvertScaledOp = cutlass::epilogue::thread::Convert< + ElementAccumulator, + cutlass::gemm::device::DefaultGemmConfiguration< + cutlass::arch::OpClassSimt, + SmArch, + ElementInputA, + ElementInputB, + ElementAccumulator, + ElementAccumulator>::EpilogueOutputOp::kCount, + ElementAccumulator>; - /// Reduction operator - using ReductionOp = cutlass::reduction::thread::ReduceAdd< - ElementAccumulator, typename EpilogueOp::ElementAccumulator, - EpilogueOp::kCount>; + /// Reduction operator + using ReductionOp = cutlass::reduction::thread::ReduceAdd< + ElementAccumulator, + typename EpilogueOp::ElementAccumulator, + EpilogueOp::kCount>; - /// Threadblock-level swizzling operator - using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle; + /// Threadblock-level swizzling operator + using ThreadblockSwizzle = + cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle; - /// Operation performed by GEMM - using Operator = cutlass::arch::OpMultiplyAddFastAccum; - - using Gemm = cutlass::gemm::device::GemmSplitKParallel; + /// Operation performed by GEMM + using Operator = cutlass::arch::OpMultiplyAddFastAccum; + using Gemm = cutlass::gemm::device::GemmSplitKParallel; cutlass::gemm::GemmCoord problem_size = cutlass::gemm::GemmCoord{params.M, params.N, params.K}; @@ -276,15 +289,18 @@ bool dispatch_fuse_gemm_split_k_relu(GemmEpilogueAllParams params) { // Split K dimension into 16 partitions int split_k_slices = params.split_k; - // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch - // instantiated CUTLASS kernel - typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication - {reinterpret_cast(const_cast(params.A)),params.lda}, - {reinterpret_cast(const_cast(params.B)),params.ldb}, - {reinterpret_cast(const_cast(params.bias)),0}, - {reinterpret_cast(params.D),params.ldd}, - {alpha, beta}, // <- tuple of alpha and beta - split_k_slices}; // <- k-dimension split factor + // Create a tuple of gemm kernel arguments. This is later passed as arguments + // to launch instantiated CUTLASS kernel + typename Gemm::Arguments arguments{ + problem_size, // <- problem size of matrix multiplication + {reinterpret_cast(const_cast(params.A)), + params.lda}, + {reinterpret_cast(const_cast(params.B)), + params.ldb}, + {reinterpret_cast(const_cast(params.bias)), 0}, + {reinterpret_cast(params.D), params.ldd}, + {alpha, beta}, // <- tuple of alpha and beta + split_k_slices}; // <- k-dimension split factor Gemm gemm_op; diff --git a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/per_channel_fp8_fp8_half_gemm.h b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/per_channel_fp8_fp8_half_gemm.h index 3ea4d7c33e..82373f8693 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/per_channel_fp8_fp8_half_gemm.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/per_channel_fp8_fp8_half_gemm.h @@ -17,10 +17,10 @@ #include "fp8_common.h" -#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#ifdef __GNUC__ // Check if the compiler is GCC or Clang #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" -#endif // __GNUC__ +#endif // __GNUC__ // clang-format off #include "cutlass/cutlass.h" @@ -30,91 +30,134 @@ #include "cutlass/epilogue/threadblock/fusion/visitors.hpp" // clang-format on -#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#ifdef __GNUC__ // Check if the compiler is GCC or Clang #pragma GCC diagnostic pop -#endif // __GNUC__ +#endif // __GNUC__ -template -struct DeviceGemmFp8RowwiseSm89 -{ - using ElementInput = typename std::conditional_t< - std::is_same_v, - cutlass::float_e4m3_t, - cutlass::float_e5m2_t>; - using ElementA = ElementInput; - using LayoutA = cutlass::layout::RowMajor; - static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; +template +struct DeviceGemmFp8RowwiseSm89 { + using ElementInput = typename std::conditional_t< + std::is_same_v, + cutlass::float_e4m3_t, + cutlass::float_e5m2_t>; + using ElementA = ElementInput; + using LayoutA = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; - using ElementB = ElementInput; - using LayoutB = cutlass::layout::ColumnMajor; - static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + using ElementB = ElementInput; + using LayoutB = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + using ElementOutput = typename std::conditional_t< + std::is_same_v, + cutlass::bfloat16_t, + cutlass::half_t>; - using ElementOutput = - typename std::conditional_t, - cutlass::bfloat16_t, - cutlass::half_t>; + using ElementC = ElementOutput; + using LayoutC = cutlass::layout::RowMajor; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; - using ElementC = ElementOutput; - using LayoutC = cutlass::layout::RowMajor; - static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + using LayoutOutput = cutlass::layout::RowMajor; + static constexpr int AlignmentOutput = + 128 / cutlass::sizeof_bits::value; - using LayoutOutput = cutlass::layout::RowMajor; - static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value; + using ElementAccumulator = AccumElementType; + using ElementComputeEpilogueScale = float; + using ArchTag = cutlass::arch::Sm89; + using OperatorClass = cutlass::arch::OpClassTensorOp; - using ElementAccumulator = AccumElementType; - using ElementComputeEpilogueScale = float; - using ArchTag = cutlass::arch::Sm89; - using OperatorClass = cutlass::arch::OpClassTensorOp; + // Number of epilogue stages in EVT + static constexpr int EVTEpilogueStages = 1; - // Number of epilogue stages in EVT - static constexpr int EVTEpilogueStages = 1; + using OutputTileThreadMap = + cutlass::epilogue::threadblock::OutputTileThreadLayout; - using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout; + // Definition of EVT + using accSrc = cutlass::epilogue::threadblock::VisitorAccFetch; - // Definition of EVT - using accSrc = cutlass::epilogue::threadblock::VisitorAccFetch; + using ComputeBScale = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, + ElementComputeEpilogueScale, + ElementComputeEpilogueScale, + cutlass::FloatRoundStyle::round_to_nearest>; + using bScaleSrc = cutlass::epilogue::threadblock::VisitorRowBroadcast< + OutputTileThreadMap, + ElementComputeEpilogueScale, + cute::Stride>; + using EpilogueBScale = + cutlass::epilogue::threadblock::Sm80EVT; - using ComputeBScale = cutlass::epilogue::threadblock::VisitorCompute; - using bScaleSrc = cutlass::epilogue::threadblock::VisitorRowBroadcast>; - using EpilogueBScale = cutlass::epilogue::threadblock::Sm80EVT; + using ComputeAScale = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, + ElementComputeEpilogueScale, + ElementComputeEpilogueScale, + cutlass::FloatRoundStyle::round_to_nearest>; + using aScaleSrc = cutlass::epilogue::threadblock::VisitorColBroadcast< + OutputTileThreadMap, + ElementComputeEpilogueScale, + cute::Stride>; + using EpilogueAScale = cutlass::epilogue::threadblock:: + Sm80EVT; - using ComputeAScale = cutlass::epilogue::threadblock::VisitorCompute; - using aScaleSrc = cutlass::epilogue::threadblock::VisitorColBroadcast>; - using EpilogueAScale = cutlass::epilogue::threadblock::Sm80EVT; + using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast< + OutputTileThreadMap, + ElementC, + cute::Stride // StrideMNL + >; - using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast< - OutputTileThreadMap, ElementC, - cute::Stride // StrideMNL - >; + using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::plus, + ElementC, + ElementComputeEpilogueScale, + cutlass::FloatRoundStyle::round_to_nearest>; - using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::plus, ElementC, ElementComputeEpilogueScale, - cutlass::FloatRoundStyle::round_to_nearest - >; + using EVTCompute0 = + cutlass::epilogue::threadblock::Sm80EVT; - using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT< - Compute0, - EpilogueAScale, - Bias>; + using dTar = cutlass::epilogue::threadblock::VisitorAuxStore< + OutputTileThreadMap, + ElementOutput, + cutlass::FloatRoundStyle::round_to_nearest, + cute::Stride>; + using EpilogueStore = + cutlass::epilogue::threadblock::Sm80EVT; - using dTar = cutlass::epilogue::threadblock::VisitorAuxStore>; - using EpilogueStore = cutlass::epilogue::threadblock::Sm80EVT; + using EpilogueOp = EpilogueStore; - using EpilogueOp = EpilogueStore; + using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor< + ElementA, + LayoutA, + cutlass::ComplexTransform::kNone, + AlignmentA, + ElementB, + LayoutB, + cutlass::ComplexTransform::kNone, + AlignmentB, + ElementC, + LayoutC, + AlignmentC, + ElementAccumulator, + ElementComputeEpilogueScale, + OperatorClass, + ArchTag, + CtaShape, + WarpShape, + InstructionShape, + EpilogueOp, + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, + Stages, + cutlass::arch::OpMultiplyAddFastAccum, + EVTEpilogueStages>::GemmKernel; - using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor::GemmKernel; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; }; diff --git a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/visitor_fp8_gemm_fused_template.h b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/visitor_fp8_gemm_fused_template.h index a073c9bff3..111b2cfd15 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/visitor_fp8_gemm_fused_template.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/visitor_fp8_gemm_fused_template.h @@ -16,56 +16,68 @@ #include "per_channel_fp8_fp8_half_gemm.h" // NOLINT template -typename Gemm::Arguments prepar_gemm_args_sm89(void* D, void const* A, void const* B, void const* C_bias, - int m, int n, int k, float const* scale_d0, float const* scale_d1) -{ - using ElementT = typename Gemm::ElementA; - using ElementOutput = typename Gemm::ElementD; - using ElementComputeEpilogue = float; +typename Gemm::Arguments prepar_gemm_args_sm89(void* D, + void const* A, + void const* B, + void const* C_bias, + int m, + int n, + int k, + float const* scale_d0, + float const* scale_d1) { + using ElementT = typename Gemm::ElementA; + using ElementOutput = typename Gemm::ElementD; + using ElementComputeEpilogue = float; - int const lda = k; - int const ldb = k; - int const ldc = n; + int const lda = k; + int const ldb = k; + int const ldc = n; - typename Gemm::Arguments args(cutlass::gemm::GemmUniversalMode::kGemm, // Mode - {m, n, k}, // Problem size - 1, // Split-k factor - {}, // Epilogue args - reinterpret_cast(A), // a pointer - reinterpret_cast(B), // b pointer - nullptr, // c pointer (unused) - nullptr, // d pointer (unused) - m * k, // batch stride a (unused) - n * k, // batch stride b (unused) - m * n, // batch stride c (unused) - m * n, // batch stride d (unused) - lda, // stride a - ldb, // stride b - ldc, // stride c (unused) - ldc); // stride d (unused) + typename Gemm::Arguments args( + cutlass::gemm::GemmUniversalMode::kGemm, // Mode + {m, n, k}, // Problem size + 1, // Split-k factor + {}, // Epilogue args + reinterpret_cast(A), // a pointer + reinterpret_cast(B), // b pointer + nullptr, // c pointer (unused) + nullptr, // d pointer (unused) + m * k, // batch stride a (unused) + n * k, // batch stride b (unused) + m * n, // batch stride c (unused) + m * n, // batch stride d (unused) + lda, // stride a + ldb, // stride b + ldc, // stride c (unused) + ldc); // stride d (unused) - args.epilogue = { - { - { - { - {}, // Accumulator - {reinterpret_cast(scale_d1), ElementComputeEpilogue(0), - {cute::_0{}, cute::_1{}, cute::_0{}}}, - {} // Multiplies - }, - {reinterpret_cast(scale_d0), ElementComputeEpilogue(0), {cute::_0{}, cute::_0{}, cute::_0{}}}, - {} // Multiplies - }, // Accum - {reinterpret_cast(C_bias), ElementOutput(0), {cute::_0{}, cute::_1{}, cute::_0{}}}, // Bias - {} // Compute0 - }, - {reinterpret_cast(D), {n, cute::_1{}, cute::_0{}}} - }; - return args; + args.epilogue = { + { + { + { + {}, // Accumulator + {reinterpret_cast(scale_d1), + ElementComputeEpilogue(0), + {cute::_0{}, cute::_1{}, cute::_0{}}}, + {} // Multiplies + }, + {reinterpret_cast(scale_d0), + ElementComputeEpilogue(0), + {cute::_0{}, cute::_0{}, cute::_0{}}}, + {} // Multiplies + }, // Accum + {reinterpret_cast(C_bias), + ElementOutput(0), + {cute::_0{}, cute::_1{}, cute::_0{}}}, // Bias + {} // Compute0 + }, + {reinterpret_cast(D), {n, cute::_1{}, cute::_0{}}}}; + return args; } template -bool per_channel_fp8_fp8_gemm_scale_bias(GemmEpilogueAllParams params, typename Gemm::Arguments args) { +bool per_channel_fp8_fp8_gemm_scale_bias(GemmEpilogueAllParams params, + typename Gemm::Arguments args) { Gemm per_channel_fp8_gemm; cutlass::Status status = per_channel_fp8_gemm.can_implement(args); @@ -89,14 +101,31 @@ bool per_channel_fp8_fp8_gemm_scale_bias(GemmEpilogueAllParams params, typename return true; } - -template +template bool dispatch_visitor_fuse_gemm(GemmEpilogueAllParams params) { - using AccumElementType = float; - using Gemm = typename DeviceGemmFp8RowwiseSm89::Gemm; - auto args = prepar_gemm_args_sm89(params.D, params.A, params.B, params.bias, params.M, params.N, params.K, params.scalar_scale, params.channel_scale); - per_channel_fp8_fp8_gemm_scale_bias(params, args); + using AccumElementType = float; + using Gemm = typename DeviceGemmFp8RowwiseSm89::Gemm; + auto args = prepar_gemm_args_sm89(params.D, + params.A, + params.B, + params.bias, + params.M, + params.N, + params.K, + params.scalar_scale, + params.channel_scale); + per_channel_fp8_fp8_gemm_scale_bias(params, args); } diff --git a/custom_ops/gpu_ops/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h b/custom_ops/gpu_ops/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h index 6b1ab209e3..44d5ccc220 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h @@ -21,10 +21,8 @@ #include #include -namespace kernels -{ -namespace cutlass_kernels -{ +namespace kernels { +namespace cutlass_kernels { /* This runner only supports: @@ -32,64 +30,105 @@ namespace cutlass_kernels Activations, biases, scales and outputs are all assumed to be row-major. - However, it is assumed that B is in a special format governed by cutlass_extensions/gemm/kernel/mixed_gemm_B_layout. - In this case, B must be preprocessed using the cutlass weight only quant preprocessors. The weight preprocessor - will instantiate the layout and preprocess based on the instantiation, so layout changes should only require - modifications to mix_gemm_B_layout.h. + However, it is assumed that B is in a special format governed by + cutlass_extensions/gemm/kernel/mixed_gemm_B_layout. In this case, B must be + preprocessed using the cutlass weight only quant preprocessors. The weight + preprocessor will instantiate the layout and preprocess based on the + instantiation, so layout changes should only require modifications to + mix_gemm_B_layout.h. */ -class CutlassFpAIntBGemmRunnerInterface -{ -public: - CutlassFpAIntBGemmRunnerInterface() {} +class CutlassFpAIntBGemmRunnerInterface { + public: + CutlassFpAIntBGemmRunnerInterface() {} - virtual ~CutlassFpAIntBGemmRunnerInterface() {} + virtual ~CutlassFpAIntBGemmRunnerInterface() {} - virtual void gemm(void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, - void const* biases, float const alpha, void* C, int m, int n, int k, int const group_size, - cutlass_extensions::CutlassGemmConfig gemmConfig, void* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream - ) = 0; + virtual void gemm(void const* A, + void const* B, + void const* weight_scales, + void const* weight_zero_points, + void const* biases, + float const alpha, + void* C, + int m, + int n, + int k, + int const group_size, + cutlass_extensions::CutlassGemmConfig gemmConfig, + void* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream) = 0; - // Returns desired workspace size in bytes. - virtual size_t getWorkspaceSize(int const m, int const n, int const k) = 0; + // Returns desired workspace size in bytes. + virtual size_t getWorkspaceSize(int const m, int const n, int const k) = 0; - virtual std::vector getConfigs(int k) const = 0; + virtual std::vector getConfigs( + int k) const = 0; -protected: - static constexpr int SPLIT_K_LIMIT = 7; - static constexpr int MIN_M_TILE = 16; - static constexpr int MIN_N_TILE = 64; + protected: + static constexpr int SPLIT_K_LIMIT = 7; + static constexpr int MIN_M_TILE = 16; + static constexpr int MIN_N_TILE = 64; }; -template -class CutlassFpAIntBGemmRunner : public virtual CutlassFpAIntBGemmRunnerInterface -{ -public: - CutlassFpAIntBGemmRunner(); - ~CutlassFpAIntBGemmRunner(); +template +class CutlassFpAIntBGemmRunner + : public virtual CutlassFpAIntBGemmRunnerInterface { + public: + CutlassFpAIntBGemmRunner(); + ~CutlassFpAIntBGemmRunner(); - void gemm(void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, - void const* biases, float const alpha, void* C, int m, int n, int k, int const group_size, - cutlass_extensions::CutlassGemmConfig gemmConfig, void* workspace_ptr, const size_t workspace_bytes, - cudaStream_t stream) override; + void gemm(void const* A, + void const* B, + void const* weight_scales, + void const* weight_zero_points, + void const* biases, + float const alpha, + void* C, + int m, + int n, + int k, + int const group_size, + cutlass_extensions::CutlassGemmConfig gemmConfig, + void* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream) override; - // Returns desired workspace size in bytes. - size_t getWorkspaceSize(int const m, int const n, int const k) override; + // Returns desired workspace size in bytes. + size_t getWorkspaceSize(int const m, int const n, int const k) override; - std::vector getConfigs(int k) const override; + std::vector getConfigs( + int k) const override; -private: - template - void dispatch_to_arch(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, - ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, - int k, int const group_size, cutlass_extensions::CutlassGemmConfig gemm_config, void* workspace_ptr, - const size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr); + private: + template + void dispatch_to_arch(ActivationType const* A, + WeightType const* B, + ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, + BiasType const* biases, + float const alpha, + OutputType* C, + int m, + int n, + int k, + int const group_size, + cutlass_extensions::CutlassGemmConfig gemm_config, + void* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream, + int* occupancy = nullptr); -private: - int sm_; - int multi_processor_count_; + private: + int sm_; + int multi_processor_count_; }; -} // namespace cutlass_kernels -} // namespace kernels +} // namespace cutlass_kernels +} // namespace kernels diff --git a/custom_ops/gpu_ops/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h b/custom_ops/gpu_ops/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h index 2e02658d26..bc62bb34bf 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h @@ -43,311 +43,559 @@ namespace kernels { namespace cutlass_kernels { -template < - typename ActivationType, - typename WeightType, - typename ScaleZeroType, - typename BiasType, - typename OutputType, - typename arch, - cutlass::WeightOnlyQuantOp QuantOp, - typename EpilogueTag, - typename ThreadblockShape, - typename WarpShape, - int Stages> +template void generic_mixed_gemm_kernelLauncher( - ActivationType const* A, - WeightType const* B, - ScaleZeroType const* weight_scales, - ScaleZeroType const* weight_zero_points, - BiasType const* biases, - float const alpha, - OutputType* C, - int m, - int n, - int k, - int const group_size, - cutlass_extensions::CutlassGemmConfig gemm_config, - void* workspace, - size_t workspace_bytes, - cudaStream_t stream, - int* occupancy = nullptr) { - // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if - // necessary. - using CutlassActivationType = typename CudaToCutlassTypeAdapter::type; - using CutlassWeightType = typename CudaToCutlassTypeAdapter::type; - using CutlassScaleZeroType = typename CudaToCutlassTypeAdapter::type; - using CutlassBiasType = typename CudaToCutlassTypeAdapter::type; - using CutlassOutputType = typename CudaToCutlassTypeAdapter::type; + ActivationType const* A, + WeightType const* B, + ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, + BiasType const* biases, + float const alpha, + OutputType* C, + int m, + int n, + int k, + int const group_size, + cutlass_extensions::CutlassGemmConfig gemm_config, + void* workspace, + size_t workspace_bytes, + cudaStream_t stream, + int* occupancy = nullptr) { + // The cutlass type for the input elements. This is needed to convert to + // cutlass::half_t if necessary. + using CutlassActivationType = + typename CudaToCutlassTypeAdapter::type; + using CutlassWeightType = typename CudaToCutlassTypeAdapter::type; + using CutlassScaleZeroType = + typename CudaToCutlassTypeAdapter::type; + using CutlassBiasType = typename CudaToCutlassTypeAdapter::type; + using CutlassOutputType = typename CudaToCutlassTypeAdapter::type; - // We need separate config for each architecture since we will target different tensorcore - // instructions. For float, we do not target TCs. - using MixedGemmArchTraits = cutlass::gemm::kernel:: - MixedGemmArchTraits; - using ElementAccumulator = typename MixedGemmArchTraits::AccType; + // We need separate config for each architecture since we will target + // different tensorcore instructions. For float, we do not target TCs. + using MixedGemmArchTraits = cutlass::gemm::kernel:: + MixedGemmArchTraits; + using ElementAccumulator = typename MixedGemmArchTraits::AccType; - constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; - using EpilogueOp = typename cutlass_extensions:: - Epilogue::Op; + constexpr int ElementsPerAccessC = + 128 / cutlass::sizeof_bits::value; + using EpilogueOp = typename cutlass_extensions::Epilogue::Op; - using Operator = typename MixedGemmArchTraits::Operator; - using TaggedOperator = typename cutlass::arch::TagOperator::TaggedOperator; + using Operator = typename MixedGemmArchTraits::Operator; + using TaggedOperator = + typename cutlass::arch::TagOperator::TaggedOperator; - using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm< - CutlassActivationType, - cutlass::layout::RowMajor, - MixedGemmArchTraits::ElementsPerAccessA, - CutlassWeightType, - typename MixedGemmArchTraits::LayoutB, - MixedGemmArchTraits::ElementsPerAccessB, - CutlassOutputType, - cutlass::layout::RowMajor, - ElementAccumulator, - cutlass::arch::OpClassTensorOp, - arch, - ThreadblockShape, - WarpShape, - typename MixedGemmArchTraits::InstructionShape, - EpilogueOp, - typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - Stages, - true, - TaggedOperator>::GemmKernel; + using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm< + CutlassActivationType, + cutlass::layout::RowMajor, + MixedGemmArchTraits::ElementsPerAccessA, + CutlassWeightType, + typename MixedGemmArchTraits::LayoutB, + MixedGemmArchTraits::ElementsPerAccessB, + CutlassOutputType, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + arch, + ThreadblockShape, + WarpShape, + typename MixedGemmArchTraits::InstructionShape, + EpilogueOp, + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + Stages, + true, + TaggedOperator>::GemmKernel; - using GemmKernel = cutlass::gemm::kernel::GemmFpAIntB< - typename GemmKernel_::Mma, - typename GemmKernel_::Epilogue, - typename GemmKernel_::ThreadblockSwizzle, - arch, // Ensure top level arch is used for dispatch - GemmKernel_::kSplitKSerial>; + using GemmKernel = cutlass::gemm::kernel::GemmFpAIntB< + typename GemmKernel_::Mma, + typename GemmKernel_::Epilogue, + typename GemmKernel_::ThreadblockSwizzle, + arch, // Ensure top level arch is used for dispatch + GemmKernel_::kSplitKSerial>; - if (occupancy != nullptr) { - *occupancy = cutlass_extensions::compute_occupancy_for_kernel(); - return; - } + if (occupancy != nullptr) { + *occupancy = cutlass_extensions::compute_occupancy_for_kernel(); + return; + } - using Gemm = cutlass::gemm::device::GemmUniversalBaseCompat; + using Gemm = cutlass::gemm::device::GemmUniversalBaseCompat; - int const ldb = - cutlass::platform:: - is_same::value - ? n - : k * GemmKernel::kInterleave; + int const ldb = + cutlass::platform::is_same::value + ? n + : k * GemmKernel::kInterleave; - if (weight_scales == nullptr) { - throw std::runtime_error("Weight scales must always be set to a non-null value."); - } + if (weight_scales == nullptr) { + throw std::runtime_error( + "Weight scales must always be set to a non-null value."); + } - if constexpr (cutlass::isFinegrained(QuantOp)) { - if constexpr (cutlass::platform::is_same:: - value) { - if (group_size != 128) { - throw std::runtime_error( - "Only group size 128 supported for fine grained W4A(fp)8 kernels."); - } - } - if (group_size != 64 && group_size != 128) { - throw std::runtime_error( - "Only group size 64 and 128 supported for fine grained kernels."); - } - - if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY) { - if (weight_zero_points != nullptr) { - throw std::runtime_error( - "Weight zero pointer must be a nullptr for scale only fine grained"); - } - } else if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS) { - if (weight_zero_points == nullptr) { - throw std::runtime_error( - "Weight zero pointer must be valid for scale and bias fine grained"); - } - } - } else { - if (group_size != k) { - throw std::runtime_error("Invalid group size for per column scaling kernels."); - } - - if (weight_zero_points != nullptr) { - throw std::runtime_error( - "Weight zero-points must be null when running per column scaling"); - } - } - - int const ld_scale_zero = cutlass::isFinegrained(QuantOp) ? n : 0; - ElementAccumulator output_op_beta = - (biases == nullptr) ? ElementAccumulator(0.f) : ElementAccumulator(1.f); - typename Gemm::Arguments args( - {m, n, k}, - group_size, - {reinterpret_cast(const_cast(A)), k}, - {reinterpret_cast(const_cast(B)), ldb}, - {reinterpret_cast(const_cast(weight_scales)), - ld_scale_zero}, - {reinterpret_cast( - const_cast(weight_zero_points)), - ld_scale_zero}, - {reinterpret_cast(const_cast(biases)), 0}, - {reinterpret_cast(C), n}, - gemm_config.split_k_factor, - {ElementAccumulator(alpha), output_op_beta}); - - // This assertion is enabled because because for the column interleaved layout, K MUST be a - // multiple of threadblockK. The reason for this is that the default pitchlinear iterators are - // used to handle walking over the interleaved matrix. The way masking in handled in these do - // not map to the interleaved layout. We need to write our own predicated iterator in order to - // relax this limitation. - if (GemmKernel::kInterleave > 1 && - ((k % MixedGemmArchTraits::ThreadblockK) || - ((k / gemm_config.split_k_factor) % MixedGemmArchTraits::ThreadblockK))) { + if constexpr (cutlass::isFinegrained(QuantOp)) { + if constexpr (cutlass::platform::is_same::value) { + if (group_size != 128) { throw std::runtime_error( - "Assertion: k[" + std::to_string(k) + "] must be multiple of threadblockK[" + - std::to_string(MixedGemmArchTraits::ThreadblockK) + "]"); + "Only group size 128 supported for fine grained W4A(fp)8 kernels."); + } + } + if (group_size != 64 && group_size != 128) { + throw std::runtime_error( + "Only group size 64 and 128 supported for fine grained kernels."); } - Gemm gemm; - - if (gemm.get_workspace_size(args) > workspace_bytes) { - std::cerr << "Requested split-k but workspace size insufficient. Falling back to " - "non-split-k implementation." - << std::endl; - // If requested split-k factor will require more workspace bytes, revert to standard gemm. - args.batch_count = 1; + if constexpr (QuantOp == + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY) { + if (weight_zero_points != nullptr) { + throw std::runtime_error( + "Weight zero pointer must be a nullptr for scale only fine " + "grained"); + } + } else if constexpr (QuantOp == cutlass::WeightOnlyQuantOp:: + FINEGRAINED_SCALE_AND_ZEROS) { + if (weight_zero_points == nullptr) { + throw std::runtime_error( + "Weight zero pointer must be valid for scale and bias fine " + "grained"); + } + } + } else { + if (group_size != k) { + throw std::runtime_error( + "Invalid group size for per column scaling kernels."); } - auto can_implement = gemm.can_implement(args); - if (can_implement != cutlass::Status::kSuccess) { - std::string err_msg = "fp8_int4 cutlass kernel will fail for params. Error: " + - std::string(cutlassGetStatusString(can_implement)); - throw std::runtime_error("[fp8_int4 Runner] " + err_msg); + if (weight_zero_points != nullptr) { + throw std::runtime_error( + "Weight zero-points must be null when running per column scaling"); } + } - auto init_status = gemm.initialize(args, workspace, stream); - if (init_status != cutlass::Status::kSuccess) { - std::string err_msg = "Failed to initialize cutlass fp8_int4 gemm. Error: " + - std::string(cutlassGetStatusString(init_status)); - throw std::runtime_error("[fp8_int4 Runner] " + err_msg); - } + int const ld_scale_zero = cutlass::isFinegrained(QuantOp) ? n : 0; + ElementAccumulator output_op_beta = + (biases == nullptr) ? ElementAccumulator(0.f) : ElementAccumulator(1.f); + typename Gemm::Arguments args( + {m, n, k}, + group_size, + {reinterpret_cast(const_cast(A)), + k}, + {reinterpret_cast(const_cast(B)), ldb}, + {reinterpret_cast( + const_cast(weight_scales)), + ld_scale_zero}, + {reinterpret_cast( + const_cast(weight_zero_points)), + ld_scale_zero}, + {reinterpret_cast(const_cast(biases)), 0}, + {reinterpret_cast(C), n}, + gemm_config.split_k_factor, + {ElementAccumulator(alpha), output_op_beta}); - auto run_status = gemm.run(stream); - if (run_status != cutlass::Status::kSuccess) { - std::string err_msg = "Failed to run cutlass fp8_int4 gemm. Error: " + - std::string(cutlassGetStatusString(run_status)); - throw std::runtime_error("[fp8_int4 Runner] " + err_msg); - } + // This assertion is enabled because because for the column interleaved + // layout, K MUST be a multiple of threadblockK. The reason for this is that + // the default pitchlinear iterators are used to handle walking over the + // interleaved matrix. The way masking in handled in these do not map to the + // interleaved layout. We need to write our own predicated iterator in order + // to relax this limitation. + if (GemmKernel::kInterleave > 1 && ((k % MixedGemmArchTraits::ThreadblockK) || + ((k / gemm_config.split_k_factor) % + MixedGemmArchTraits::ThreadblockK))) { + throw std::runtime_error("Assertion: k[" + std::to_string(k) + + "] must be multiple of threadblockK[" + + std::to_string(MixedGemmArchTraits::ThreadblockK) + + "]"); + } + + Gemm gemm; + + if (gemm.get_workspace_size(args) > workspace_bytes) { + std::cerr + << "Requested split-k but workspace size insufficient. Falling back to " + "non-split-k implementation." + << std::endl; + // If requested split-k factor will require more workspace bytes, revert to + // standard gemm. + args.batch_count = 1; + } + + auto can_implement = gemm.can_implement(args); + if (can_implement != cutlass::Status::kSuccess) { + std::string err_msg = + "fp8_int4 cutlass kernel will fail for params. Error: " + + std::string(cutlassGetStatusString(can_implement)); + throw std::runtime_error("[fp8_int4 Runner] " + err_msg); + } + + auto init_status = gemm.initialize(args, workspace, stream); + if (init_status != cutlass::Status::kSuccess) { + std::string err_msg = + "Failed to initialize cutlass fp8_int4 gemm. Error: " + + std::string(cutlassGetStatusString(init_status)); + throw std::runtime_error("[fp8_int4 Runner] " + err_msg); + } + + auto run_status = gemm.run(stream); + if (run_status != cutlass::Status::kSuccess) { + std::string err_msg = "Failed to run cutlass fp8_int4 gemm. Error: " + + std::string(cutlassGetStatusString(run_status)); + throw std::runtime_error("[fp8_int4 Runner] " + err_msg); + } } -template < - typename ActivationType, - typename WeightType, - typename ScaleZeroType, - typename BiasType, - typename OutputType, - typename arch, - cutlass::WeightOnlyQuantOp QuantOp, - typename EpilogueTag, - typename ThreadblockShape, - typename WarpShape> -void dispatch_gemm_config( - ActivationType const* A, - WeightType const* B, - ScaleZeroType const* weight_scales, - ScaleZeroType const* weight_zero_points, - BiasType const* biases, - float const alpha, - OutputType* C, - int m, - int n, - int k, - int const group_size, - cutlass_extensions::CutlassGemmConfig gemm_config, - void* workspace, - size_t workspace_bytes, - cudaStream_t stream, - int* occupancy = nullptr) { - switch (gemm_config.stages) { +template +void dispatch_gemm_config(ActivationType const* A, + WeightType const* B, + ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, + BiasType const* biases, + float const alpha, + OutputType* C, + int m, + int n, + int k, + int const group_size, + cutlass_extensions::CutlassGemmConfig gemm_config, + void* workspace, + size_t workspace_bytes, + cudaStream_t stream, + int* occupancy = nullptr) { + switch (gemm_config.stages) { case 2: - throw std::runtime_error( - "[filter_and_run_mixed_gemm] Cutlass fp8_int4 gemm not supported for arch " + - std::to_string(arch::kMinComputeCapability) + " with stages set to 2"); - break; + throw std::runtime_error( + "[filter_and_run_mixed_gemm] Cutlass fp8_int4 gemm not supported for " + "arch " + + std::to_string(arch::kMinComputeCapability) + + " with stages set to 2"); + break; case 3: - generic_mixed_gemm_kernelLauncher< - ActivationType, - WeightType, - ScaleZeroType, - BiasType, - OutputType, - arch, - QuantOp, - EpilogueTag, - ThreadblockShape, - WarpShape, - 3>( - A, - B, - weight_scales, - weight_zero_points, - biases, - alpha, - C, - m, - n, - k, - group_size, - gemm_config, - workspace, - workspace_bytes, - stream, - occupancy); - break; + generic_mixed_gemm_kernelLauncher(A, + B, + weight_scales, + weight_zero_points, + biases, + alpha, + C, + m, + n, + k, + group_size, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + break; case 4: - generic_mixed_gemm_kernelLauncher< - ActivationType, - WeightType, - ScaleZeroType, - BiasType, - OutputType, - arch, - QuantOp, - EpilogueTag, - ThreadblockShape, - WarpShape, - 4>( - A, - B, - weight_scales, - weight_zero_points, - biases, - alpha, - C, - m, - n, - k, - group_size, - gemm_config, - workspace, - workspace_bytes, - stream, - occupancy); - break; + generic_mixed_gemm_kernelLauncher(A, + B, + weight_scales, + weight_zero_points, + biases, + alpha, + C, + m, + n, + k, + group_size, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + break; default: - std::string err_msg = "dispatch_gemm_config does not support stages " + - std::to_string(gemm_config.stages); - throw std::runtime_error("[dispatch_gemm_config] " + err_msg); - break; - } + std::string err_msg = "dispatch_gemm_config does not support stages " + + std::to_string(gemm_config.stages); + throw std::runtime_error("[dispatch_gemm_config] " + err_msg); + break; + } } -template < - typename ActivationType, - typename WeightType, - typename ScaleZeroType, - typename BiasType, - typename OutputType, - typename arch, - cutlass::WeightOnlyQuantOp QuantOp, - typename EpilogueTag> -void dispatch_gemm_to_cutlass( +template +void dispatch_gemm_to_cutlass(ActivationType const* A, + WeightType const* B, + ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, + BiasType const* biases, + float const alpha, + OutputType* C, + int m, + int n, + int k, + int const group_size, + void* workspace, + size_t workspace_bytes, + cutlass_extensions::CutlassGemmConfig gemm_config, + cudaStream_t stream, + int* occupancy = nullptr) { + // Note that SIMT configs are omitted here since they are not supported for + // fp8_int4. We also only instantiate configs here where threadblockShapeM == + // warpShapeM since those usually perform the best for mixed type gemms. + constexpr int tile_shape_k = + 128 * 8 / cutlass::sizeof_bits::value; + switch (gemm_config.tile_config) { + case cutlass_extensions::CutlassTileConfig:: + CtaShape16x128x64_WarpShape16x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<16, 32, tile_shape_k>>( + A, + B, + weight_scales, + weight_zero_points, + biases, + alpha, + C, + m, + n, + k, + group_size, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + break; + case cutlass_extensions::CutlassTileConfig:: + CtaShape16x256x64_WarpShape16x64x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<16, 64, tile_shape_k>>( + A, + B, + weight_scales, + weight_zero_points, + biases, + alpha, + C, + m, + n, + k, + group_size, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + break; + case cutlass_extensions::CutlassTileConfig:: + CtaShape32x128x64_WarpShape32x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<32, 32, tile_shape_k>>( + A, + B, + weight_scales, + weight_zero_points, + biases, + alpha, + C, + m, + n, + k, + group_size, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + break; + case cutlass_extensions::CutlassTileConfig:: + CtaShape64x128x64_WarpShape64x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<64, 32, tile_shape_k>>( + A, + B, + weight_scales, + weight_zero_points, + biases, + alpha, + C, + m, + n, + k, + group_size, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + break; + case cutlass_extensions::CutlassTileConfig:: + CtaShape128x128x64_WarpShape128x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<128, 32, tile_shape_k>>( + A, + B, + weight_scales, + weight_zero_points, + biases, + alpha, + C, + m, + n, + k, + group_size, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + break; + case cutlass_extensions::CutlassTileConfig::Undefined: + throw std::runtime_error( + "[fp8_int4][dispatch_gemm_to_cutlass] gemm config undefined."); + break; + case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: + throw std::runtime_error( + "[fp8_int4][dispatch_gemm_to_cutlass] gemm config should have " + "already been set by " + "heuristic."); + break; + default: + printf("gemm_config.tile_config: %d", int(gemm_config.tile_config)); + throw std::runtime_error( + "[fp8_int4][dispatch_gemm_to_cutlass] Config is invalid for mixed " + "type GEMM."); + break; + } +} + +template +CutlassFpAIntBGemmRunner::CutlassFpAIntBGemmRunner() { + // printf(__PRETTY_FUNCTION__); + int device{-1}; + PADDLE_ENFORCE_GPU_SUCCESS(cudaGetDevice(&device)); + sm_ = common::getSMVersion(); + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceGetAttribute( + &multi_processor_count_, cudaDevAttrMultiProcessorCount, device)); +} + +template +CutlassFpAIntBGemmRunner::~CutlassFpAIntBGemmRunner() { + // printf(__PRETTY_FUNCTION__); +} + +template +template +void CutlassFpAIntBGemmRunner:: + dispatch_to_arch( ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, @@ -359,424 +607,204 @@ void dispatch_gemm_to_cutlass( int n, int k, int const group_size, - void* workspace, - size_t workspace_bytes, cutlass_extensions::CutlassGemmConfig gemm_config, + void* workspace_ptr, + const size_t workspace_bytes, cudaStream_t stream, - int* occupancy = nullptr) { - // Note that SIMT configs are omitted here since they are not supported for fp8_int4. - // We also only instantiate configs here where threadblockShapeM == warpShapeM since those - // usually perform the best for mixed type gemms. - constexpr int tile_shape_k = 128 * 8 / cutlass::sizeof_bits::value; - switch (gemm_config.tile_config) { - case cutlass_extensions::CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: - dispatch_gemm_config< - ActivationType, - WeightType, - ScaleZeroType, - BiasType, - OutputType, - arch, - QuantOp, - EpilogueTag, - cutlass::gemm::GemmShape<16, 128, tile_shape_k>, - cutlass::gemm::GemmShape<16, 32, tile_shape_k>>( - A, - B, - weight_scales, - weight_zero_points, - biases, - alpha, - C, - m, - n, - k, - group_size, - gemm_config, - workspace, - workspace_bytes, - stream, - occupancy); - break; - case cutlass_extensions::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: - dispatch_gemm_config< - ActivationType, - WeightType, - ScaleZeroType, - BiasType, - OutputType, - arch, - QuantOp, - EpilogueTag, - cutlass::gemm::GemmShape<16, 256, tile_shape_k>, - cutlass::gemm::GemmShape<16, 64, tile_shape_k>>( - A, - B, - weight_scales, - weight_zero_points, - biases, - alpha, - C, - m, - n, - k, - group_size, - gemm_config, - workspace, - workspace_bytes, - stream, - occupancy); - break; - case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: - dispatch_gemm_config< - ActivationType, - WeightType, - ScaleZeroType, - BiasType, - OutputType, - arch, - QuantOp, - EpilogueTag, - cutlass::gemm::GemmShape<32, 128, tile_shape_k>, - cutlass::gemm::GemmShape<32, 32, tile_shape_k>>( - A, - B, - weight_scales, - weight_zero_points, - biases, - alpha, - C, - m, - n, - k, - group_size, - gemm_config, - workspace, - workspace_bytes, - stream, - occupancy); - break; - case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: - dispatch_gemm_config< - ActivationType, - WeightType, - ScaleZeroType, - BiasType, - OutputType, - arch, - QuantOp, - EpilogueTag, - cutlass::gemm::GemmShape<64, 128, tile_shape_k>, - cutlass::gemm::GemmShape<64, 32, tile_shape_k>>( - A, - B, - weight_scales, - weight_zero_points, - biases, - alpha, - C, - m, - n, - k, - group_size, - gemm_config, - workspace, - workspace_bytes, - stream, - occupancy); - break; - case cutlass_extensions::CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: - dispatch_gemm_config< - ActivationType, - WeightType, - ScaleZeroType, - BiasType, - OutputType, - arch, - QuantOp, - EpilogueTag, - cutlass::gemm::GemmShape<128, 128, tile_shape_k>, - cutlass::gemm::GemmShape<128, 32, tile_shape_k>>( - A, - B, - weight_scales, - weight_zero_points, - biases, - alpha, - C, - m, - n, - k, - group_size, - gemm_config, - workspace, - workspace_bytes, - stream, - occupancy); - break; - case cutlass_extensions::CutlassTileConfig::Undefined: - throw std::runtime_error("[fp8_int4][dispatch_gemm_to_cutlass] gemm config undefined."); - break; - case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: - throw std::runtime_error( - "[fp8_int4][dispatch_gemm_to_cutlass] gemm config should have already been set by " - "heuristic."); - break; - default: - printf("gemm_config.tile_config: %d", int(gemm_config.tile_config)); - throw std::runtime_error( - "[fp8_int4][dispatch_gemm_to_cutlass] Config is invalid for mixed type GEMM."); - break; - } + int* occupancy) { + dispatch_gemm_to_cutlass(A, + B, + weight_scales, + weight_zero_points, + biases, + alpha, + C, + m, + n, + k, + group_size, + workspace_ptr, + workspace_bytes, + gemm_config, + stream, + occupancy); } -template < - typename ActivationType, - typename WeightType, - cutlass::WeightOnlyQuantOp QuantOp, - typename ScaleZeroType, - typename BiasType, - typename OutputType> -CutlassFpAIntBGemmRunner:: - CutlassFpAIntBGemmRunner() { - // printf(__PRETTY_FUNCTION__); - int device{-1}; - PADDLE_ENFORCE_GPU_SUCCESS(cudaGetDevice(&device)); - sm_ = common::getSMVersion(); - PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceGetAttribute( - &multi_processor_count_, cudaDevAttrMultiProcessorCount, device)); -} - -template < - typename ActivationType, - typename WeightType, - cutlass::WeightOnlyQuantOp QuantOp, - typename ScaleZeroType, - typename BiasType, - typename OutputType> -CutlassFpAIntBGemmRunner:: - ~CutlassFpAIntBGemmRunner() { - // printf(__PRETTY_FUNCTION__); -} - -template < - typename ActivationType, - typename WeightType, - cutlass::WeightOnlyQuantOp QuantOp, - typename ScaleZeroType, - typename BiasType, - typename OutputType> -template +template void CutlassFpAIntBGemmRunner< - ActivationType, - WeightType, - QuantOp, - ScaleZeroType, - BiasType, - OutputType>:: - dispatch_to_arch( - ActivationType const* A, - WeightType const* B, - ScaleZeroType const* weight_scales, - ScaleZeroType const* weight_zero_points, - BiasType const* biases, - float const alpha, - OutputType* C, - int m, - int n, - int k, - int const group_size, - cutlass_extensions::CutlassGemmConfig gemm_config, - void* workspace_ptr, - const size_t workspace_bytes, - cudaStream_t stream, - int* occupancy) { - dispatch_gemm_to_cutlass< - ActivationType, - WeightType, - ScaleZeroType, - BiasType, - OutputType, - cutlass::arch::Sm89, - QuantOp, - EpilogueTag>( - A, - B, - weight_scales, - weight_zero_points, - biases, - alpha, - C, - m, - n, - k, - group_size, - workspace_ptr, - workspace_bytes, - gemm_config, - stream, - occupancy); -} - -template < - typename ActivationType, - typename WeightType, - cutlass::WeightOnlyQuantOp QuantOp, - typename ScaleZeroType, - typename BiasType, - typename OutputType> -void CutlassFpAIntBGemmRunner< - ActivationType, - WeightType, - QuantOp, - ScaleZeroType, - BiasType, - OutputType>:: - gemm(void const* A, - void const* B, - void const* weight_scales, - void const* weight_zero_points, - void const* biases, - float const alpha, - void* C, - int m, - int n, - int k, - int const group_size, - cutlass_extensions::CutlassGemmConfig gemmConfig, - void* workspace_ptr, - const size_t workspace_bytes, - cudaStream_t stream) { - // printf(__PRETTY_FUNCTION__); - if (gemmConfig.tile_config == cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic) { - std::vector configs = getConfigs(k); - std::vector occupancies(configs.size()); - for (size_t i = 0; i < configs.size(); ++i) { - dispatch_to_arch( - (ActivationType const*)A, - (WeightType const*)B, - (ScaleZeroType const*)weight_scales, - (ScaleZeroType const*)weight_zero_points, - (BiasType const*)biases, - alpha, - (OutputType*)C, - m, - n, - k, - group_size, - configs[i], - workspace_ptr, - workspace_bytes, - stream, - &occupancies[i]); - } - auto best_config = estimate_best_config_from_occupancies( - configs, - occupancies, - m, - n, - k, - 1, - SPLIT_K_LIMIT, - workspace_bytes, - multi_processor_count_, - true); - dispatch_to_arch( - (ActivationType const*)A, - (WeightType const*)B, - (ScaleZeroType const*)weight_scales, - (ScaleZeroType const*)weight_zero_points, - (BiasType const*)biases, - alpha, - (OutputType*)C, - m, - n, - k, - group_size, - best_config, - workspace_ptr, - workspace_bytes, - stream, - nullptr); - } else { - dispatch_to_arch( - (ActivationType const*)A, - (WeightType const*)B, - (ScaleZeroType const*)weight_scales, - (ScaleZeroType const*)weight_zero_points, - (BiasType const*)biases, - alpha, - (OutputType*)C, - m, - n, - k, - group_size, - gemmConfig, - workspace_ptr, - workspace_bytes, - stream, - nullptr); + ActivationType, + WeightType, + QuantOp, + ScaleZeroType, + BiasType, + OutputType>::gemm(void const* A, + void const* B, + void const* weight_scales, + void const* weight_zero_points, + void const* biases, + float const alpha, + void* C, + int m, + int n, + int k, + int const group_size, + cutlass_extensions::CutlassGemmConfig gemmConfig, + void* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream) { + // printf(__PRETTY_FUNCTION__); + if (gemmConfig.tile_config == + cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic) { + std::vector configs = getConfigs(k); + std::vector occupancies(configs.size()); + for (size_t i = 0; i < configs.size(); ++i) { + dispatch_to_arch( + (ActivationType const*)A, + (WeightType const*)B, + (ScaleZeroType const*)weight_scales, + (ScaleZeroType const*)weight_zero_points, + (BiasType const*)biases, + alpha, + (OutputType*)C, + m, + n, + k, + group_size, + configs[i], + workspace_ptr, + workspace_bytes, + stream, + &occupancies[i]); } + auto best_config = + estimate_best_config_from_occupancies(configs, + occupancies, + m, + n, + k, + 1, + SPLIT_K_LIMIT, + workspace_bytes, + multi_processor_count_, + true); + dispatch_to_arch( + (ActivationType const*)A, + (WeightType const*)B, + (ScaleZeroType const*)weight_scales, + (ScaleZeroType const*)weight_zero_points, + (BiasType const*)biases, + alpha, + (OutputType*)C, + m, + n, + k, + group_size, + best_config, + workspace_ptr, + workspace_bytes, + stream, + nullptr); + } else { + dispatch_to_arch( + (ActivationType const*)A, + (WeightType const*)B, + (ScaleZeroType const*)weight_scales, + (ScaleZeroType const*)weight_zero_points, + (BiasType const*)biases, + alpha, + (OutputType*)C, + m, + n, + k, + group_size, + gemmConfig, + workspace_ptr, + workspace_bytes, + stream, + nullptr); + } } -template < - typename ActivationType, - typename WeightType, - cutlass::WeightOnlyQuantOp QuantOp, - typename ScaleZeroType, - typename BiasType, - typename OutputType> +template std::vector -CutlassFpAIntBGemmRunner:: - getConfigs(int k) const { - // printf(__PRETTY_FUNCTION__); - cutlass_extensions::CutlassGemmConfig::CandidateConfigTypeParam config_type_param = - cutlass_extensions::CutlassGemmConfig::CandidateConfigTypeParam::HOPPER; - config_type_param = - static_cast( - config_type_param | - cutlass_extensions::CutlassGemmConfig::CandidateConfigTypeParam::WEIGHT_ONLY); - std::vector candidateConfigs = - get_candidate_configs(sm_, SPLIT_K_LIMIT, config_type_param); +CutlassFpAIntBGemmRunner::getConfigs(int k) const { + // printf(__PRETTY_FUNCTION__); + cutlass_extensions::CutlassGemmConfig::CandidateConfigTypeParam + config_type_param = cutlass_extensions::CutlassGemmConfig:: + CandidateConfigTypeParam::HOPPER; + config_type_param = static_cast< + cutlass_extensions::CutlassGemmConfig::CandidateConfigTypeParam>( + config_type_param | cutlass_extensions::CutlassGemmConfig:: + CandidateConfigTypeParam::WEIGHT_ONLY); + std::vector candidateConfigs = + get_candidate_configs(sm_, SPLIT_K_LIMIT, config_type_param); - // filter configs that are not supported on sm89 - std::vector rets; - for (auto config : candidateConfigs) { - // sm89 doesn't support stages 2 - if (config.stages == 2) { - continue; - } - - if (config.stages >= 5) { - continue; - } - if (config.split_k_style != cutlass_extensions::SplitKStyle::NO_SPLIT_K) { - int k_size = (k + config.split_k_factor - 1) / config.split_k_factor; - if (k_size % 128) { - continue; - } - } - rets.push_back(config); + // filter configs that are not supported on sm89 + std::vector rets; + for (auto config : candidateConfigs) { + // sm89 doesn't support stages 2 + if (config.stages == 2) { + continue; } - return rets; + + if (config.stages >= 5) { + continue; + } + if (config.split_k_style != cutlass_extensions::SplitKStyle::NO_SPLIT_K) { + int k_size = (k + config.split_k_factor - 1) / config.split_k_factor; + if (k_size % 128) { + continue; + } + } + rets.push_back(config); + } + return rets; } -template < - typename ActivationType, - typename WeightType, - cutlass::WeightOnlyQuantOp QuantOp, - typename ScaleZeroType, - typename BiasType, - typename OutputType> -size_t -CutlassFpAIntBGemmRunner:: - getWorkspaceSize(int const m, int const n, int const k) { - // printf(__PRETTY_FUNCTION__); - // These are the min tile sizes for each config, which would launch the maximum number of blocks - int const max_grid_m = cutlass::ceil_div(m, MIN_M_TILE); - int const max_grid_n = cutlass::ceil_div(n, MIN_N_TILE); - // We need 4 bytes per block in the worst case. We launch split_k_limit in z dim. - return static_cast(max_grid_m * max_grid_n * SPLIT_K_LIMIT * 4); +template +size_t CutlassFpAIntBGemmRunner::getWorkspaceSize(int const m, + int const n, + int const k) { + // printf(__PRETTY_FUNCTION__); + // These are the min tile sizes for each config, which would launch the + // maximum number of blocks + int const max_grid_m = cutlass::ceil_div(m, MIN_M_TILE); + int const max_grid_n = cutlass::ceil_div(n, MIN_N_TILE); + // We need 4 bytes per block in the worst case. We launch split_k_limit in z + // dim. + return static_cast(max_grid_m * max_grid_n * SPLIT_K_LIMIT * 4); } } // namespace cutlass_kernels diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_bf16.cu b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_bf16.cu index d8496073fa..765321f61f 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_bf16.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_bf16.cu @@ -23,7 +23,8 @@ namespace phi { #ifdef PADDLE_CUDA_BF16 template class MoeGemmRunner< - __nv_bfloat16, cutlass::WintQuantTraits<__nv_bfloat16, cutlass::WintQuantMethod::kNone>>; + __nv_bfloat16, + cutlass::WintQuantTraits<__nv_bfloat16, cutlass::WintQuantMethod::kNone>>; #endif -} // namespace phi +} // namespace phi diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int2.cu b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int2.cu index 92d63948c1..991caf6b4f 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int2.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int2.cu @@ -24,7 +24,8 @@ namespace phi { #ifdef PADDLE_CUDA_BF16 template class MoeGemmRunner< __nv_bfloat16, - cutlass::WintQuantTraits<__nv_bfloat16, cutlass::WintQuantMethod::kWeightOnlyInt2>>; + cutlass::WintQuantTraits<__nv_bfloat16, + cutlass::WintQuantMethod::kWeightOnlyInt2>>; #endif -} // namespace phi +} // namespace phi diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int4.cu b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int4.cu index b82fbc107c..c3512aa45a 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int4.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int4.cu @@ -23,7 +23,8 @@ namespace phi { #ifdef PADDLE_CUDA_BF16 template class MoeGemmRunner< __nv_bfloat16, - cutlass::WintQuantTraits<__nv_bfloat16, cutlass::WintQuantMethod::kWeightOnlyInt4>>; + cutlass::WintQuantTraits<__nv_bfloat16, + cutlass::WintQuantMethod::kWeightOnlyInt4>>; #endif -} // namespace phi +} // namespace phi diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int8.cu b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int8.cu index 97fdd104ba..f7788ca961 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int8.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int8.cu @@ -24,7 +24,8 @@ namespace phi { #ifdef PADDLE_CUDA_BF16 template class MoeGemmRunner< __nv_bfloat16, - cutlass::WintQuantTraits<__nv_bfloat16, cutlass::WintQuantMethod::kWeightOnlyInt8>>; + cutlass::WintQuantTraits<__nv_bfloat16, + cutlass::WintQuantMethod::kWeightOnlyInt8>>; #endif -} // namespace phi +} // namespace phi diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_fp16.cu b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_fp16.cu index a3d34b8e72..40608b1b98 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_fp16.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_fp16.cu @@ -21,7 +21,8 @@ namespace phi { -template class MoeGemmRunner>; +template class MoeGemmRunner< + half, + cutlass::WintQuantTraits>; -} // namespace phi +} // namespace phi diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int2.cu b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int2.cu index 5d84c9cfc1..8d0519beca 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int2.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int2.cu @@ -22,6 +22,7 @@ namespace phi { template class MoeGemmRunner< - half, cutlass::WintQuantTraits>; + half, + cutlass::WintQuantTraits>; -} // namespace phi +} // namespace phi diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int4.cu b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int4.cu index 51707ebbb8..ffbfd11b67 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int4.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int4.cu @@ -22,6 +22,7 @@ namespace phi { template class MoeGemmRunner< - half, cutlass::WintQuantTraits>; + half, + cutlass::WintQuantTraits>; -} // namespace phi +} // namespace phi diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int8.cu b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int8.cu index c796f9bbe5..adf9b91814 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int8.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int8.cu @@ -22,6 +22,7 @@ namespace phi { template class MoeGemmRunner< - half, cutlass::WintQuantTraits>; + half, + cutlass::WintQuantTraits>; -} // namespace phi +} // namespace phi diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/base64_encode.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/base64_encode.h index cefcd666dc..72ab5c4940 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/base64_encode.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/base64_encode.h @@ -16,11 +16,12 @@ #include // Base64 编码表 -const std::string base64_chars = "Tokp9lA/BjimRVKx32edMPFftOzsbNQ8C15Xn+YUEGc4WD0uLIq7hyJ6vZaHSwrg"; +const std::string base64_chars = + "Tokp9lA/BjimRVKx32edMPFftOzsbNQ8C15Xn+YUEGc4WD0uLIq7hyJ6vZaHSwrg"; // 判断字符是否为有效的 Base64 字符 inline bool is_base64(unsigned char c) { - return (isalnum(c) || (c == '+') || (c == '/')); + return (isalnum(c) || (c == '+') || (c == '/')); } // Base64 编码函数 @@ -29,96 +30,104 @@ std::string base64_encode(const std::string &input); // Base64 解码函数 std::string base64_decode(const std::string &encoded_string); - // Base64 编码函数 std::string base64_encode(const std::string &input) { - std::string ret; - int i = 0; - int j = 0; - unsigned char char_array_3[3]; - unsigned char char_array_4[4]; + std::string ret; + int i = 0; + int j = 0; + unsigned char char_array_3[3]; + unsigned char char_array_4[4]; - for (const auto &c : input) { - char_array_3[i++] = c; - if (i == 3) { - char_array_4[0] = (char_array_3[0] & 0xfc) >> 2; - char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); - char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); - char_array_4[3] = char_array_3[2] & 0x3f; + for (const auto &c : input) { + char_array_3[i++] = c; + if (i == 3) { + char_array_4[0] = (char_array_3[0] & 0xfc) >> 2; + char_array_4[1] = + ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); + char_array_4[2] = + ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); + char_array_4[3] = char_array_3[2] & 0x3f; - for (i = 0; i < 4; i++) { - ret += base64_chars[char_array_4[i]]; - } - i = 0; - } + for (i = 0; i < 4; i++) { + ret += base64_chars[char_array_4[i]]; + } + i = 0; + } + } + + if (i) { + for (j = i; j < 3; j++) { + char_array_3[j] = '\0'; } - if (i) { - for (j = i; j < 3; j++) { - char_array_3[j] = '\0'; - } + char_array_4[0] = (char_array_3[0] & 0xfc) >> 2; + char_array_4[1] = + ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); + char_array_4[2] = + ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); + char_array_4[3] = char_array_3[2] & 0x3f; - char_array_4[0] = (char_array_3[0] & 0xfc) >> 2; - char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); - char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); - char_array_4[3] = char_array_3[2] & 0x3f; - - for (j = 0; j < i + 1; j++) { - ret += base64_chars[char_array_4[j]]; - } - - while (i++ < 3) { - ret += '='; - } + for (j = 0; j < i + 1; j++) { + ret += base64_chars[char_array_4[j]]; } - return ret; + while (i++ < 3) { + ret += '='; + } + } + + return ret; } // Base64 解码函数 std::string base64_decode(const std::string &encoded_string) { - int in_len = encoded_string.size(); - int i = 0; - int j = 0; - int in_ = 0; - unsigned char char_array_4[4], char_array_3[3]; - std::string ret; + int in_len = encoded_string.size(); + int i = 0; + int j = 0; + int in_ = 0; + unsigned char char_array_4[4], char_array_3[3]; + std::string ret; - while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { - char_array_4[i++] = encoded_string[in_]; in_++; - if (i == 4) { - for (i = 0; i < 4; i++) { - char_array_4[i] = base64_chars.find(char_array_4[i]); - } + while (in_len-- && (encoded_string[in_] != '=') && + is_base64(encoded_string[in_])) { + char_array_4[i++] = encoded_string[in_]; + in_++; + if (i == 4) { + for (i = 0; i < 4; i++) { + char_array_4[i] = base64_chars.find(char_array_4[i]); + } - char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); - char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + char_array_3[0] = + (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = + ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; - for (i = 0; i < 3; i++) { - ret += char_array_3[i]; - } - i = 0; - } + for (i = 0; i < 3; i++) { + ret += char_array_3[i]; + } + i = 0; + } + } + + if (i) { + for (j = i; j < 4; j++) { + char_array_4[j] = 0; } - if (i) { - for (j = i; j < 4; j++) { - char_array_4[j] = 0; - } - - for (j = 0; j < 4; j++) { - char_array_4[j] = base64_chars.find(char_array_4[j]); - } - - char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); - char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; - - for (j = 0; j < i - 1; j++) { - ret += char_array_3[j]; - } + for (j = 0; j < 4; j++) { + char_array_4[j] = base64_chars.find(char_array_4[j]); } - return ret; + char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = + ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (j = 0; j < i - 1; j++) { + ret += char_array_3[j]; + } + } + + return ret; } diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cuda_utils.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cuda_utils.h index 7f45b8fd61..0927d32738 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cuda_utils.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cuda_utils.h @@ -32,338 +32,324 @@ // workspace for cublas gemm : 32MB #define CUBLAS_WORKSPACE_SIZE 33554432 -typedef struct __align__(4) -{ - half x, y, z, w; +typedef struct __align__(4) { + half x, y, z, w; } half4; /* **************************** type definition ***************************** */ enum CublasDataType { - FLOAT_DATATYPE = 0, - HALF_DATATYPE = 1, - BFLOAT16_DATATYPE = 2, - INT8_DATATYPE = 3, - FP8_DATATYPE = 4 + FLOAT_DATATYPE = 0, + HALF_DATATYPE = 1, + BFLOAT16_DATATYPE = 2, + INT8_DATATYPE = 3, + FP8_DATATYPE = 4 }; -enum FtCudaDataType { - FP32 = 0, - FP16 = 1, - BF16 = 2, - INT8 = 3, - FP8 = 4 -}; +enum FtCudaDataType { FP32 = 0, FP16 = 1, BF16 = 2, INT8 = 3, FP8 = 4 }; -enum class OperationType { - FP32, - FP16, - BF16, - INT8, - FP8 -}; +enum class OperationType { FP32, FP16, BF16, INT8, FP8 }; /* **************************** debug tools ********************************* */ -static const char* _cudaGetErrorEnum(cudaError_t error) -{ - return cudaGetErrorString(error); +static const char* _cudaGetErrorEnum(cudaError_t error) { + return cudaGetErrorString(error); } -static const char* _cudaGetErrorEnum(cublasStatus_t error) -{ - switch (error) { - case CUBLAS_STATUS_SUCCESS: - return "CUBLAS_STATUS_SUCCESS"; +static const char* _cudaGetErrorEnum(cublasStatus_t error) { + switch (error) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; - case CUBLAS_STATUS_NOT_INITIALIZED: - return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; - case CUBLAS_STATUS_ALLOC_FAILED: - return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; - case CUBLAS_STATUS_INVALID_VALUE: - return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; - case CUBLAS_STATUS_ARCH_MISMATCH: - return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; - case CUBLAS_STATUS_MAPPING_ERROR: - return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; - case CUBLAS_STATUS_EXECUTION_FAILED: - return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; - case CUBLAS_STATUS_INTERNAL_ERROR: - return "CUBLAS_STATUS_INTERNAL_ERROR"; + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; - case CUBLAS_STATUS_NOT_SUPPORTED: - return "CUBLAS_STATUS_NOT_SUPPORTED"; + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED"; - case CUBLAS_STATUS_LICENSE_ERROR: - return "CUBLAS_STATUS_LICENSE_ERROR"; - } - return ""; + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR"; + } + return ""; } -template -void check(T result, char const* const func, const char* const file, int const line) -{ - if (result) { - throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") + (_cudaGetErrorEnum(result)) + " " - + file + ":" + std::to_string(line) + " \n"); - } +template +void check(T result, + char const* const func, + const char* const file, + int const line) { + if (result) { + throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") + + (_cudaGetErrorEnum(result)) + " " + file + ":" + + std::to_string(line) + " \n"); + } } #define check_cuda_error(val) check((val), #val, __FILE__, __LINE__) #define check_cuda_error_2(val, file, line) check((val), #val, file, line) -inline void syncAndCheck(const char* const file, int const line) -{ - // When FT_DEBUG_LEVEL=DEBUG, must check error - static char* level_name = std::getenv("FT_DEBUG_LEVEL"); - if (level_name != nullptr) { - static std::string level = std::string(level_name); - if (level == "DEBUG") { - cudaDeviceSynchronize(); - cudaError_t result = cudaGetLastError(); - if (result) { - throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") + (_cudaGetErrorEnum(result)) - + " " + file + ":" + std::to_string(line) + " \n"); - } - std::cout<<"run syncAndCheck at "< -void print_to_file(const T* result, - const int size, - const char* file, - cudaStream_t stream = 0, +template +void print_to_file(const T* result, + const int size, + const char* file, + cudaStream_t stream = 0, std::ios::openmode open_mode = std::ios::out); -template -void print_abs_mean(const T* buf, uint size, cudaStream_t stream, std::string name = ""); +template +void print_abs_mean(const T* buf, + uint size, + cudaStream_t stream, + std::string name = ""); -template +template void print_to_screen(const T* result, const int size); -template +template void printMatrix(T* ptr, int m, int k, int stride, bool is_device_ptr); -void printMatrix(unsigned long long* ptr, int m, int k, int stride, bool is_device_ptr); +void printMatrix( + unsigned long long* ptr, int m, int k, int stride, bool is_device_ptr); void printMatrix(int* ptr, int m, int k, int stride, bool is_device_ptr); void printMatrix(size_t* ptr, int m, int k, int stride, bool is_device_ptr); -template +template void check_max_val(const T* result, const int size); -template +template void check_abs_mean_val(const T* result, const int size); -#define PRINT_FUNC_NAME_() \ - do { \ - std::cout << "[FT][CALL] " << __FUNCTION__ << " " << std::endl; \ - } while (0) +#define PRINT_FUNC_NAME_() \ + do { \ + std::cout << "[FT][CALL] " << __FUNCTION__ << " " << std::endl; \ + } while (0) -[[noreturn]] inline void throwRuntimeError(const char* const file, int const line, std::string const& info = "") -{ - throw std::runtime_error(std::string("[FT][ERROR] ") + info + " Assertion fail: " + file + ":" - + std::to_string(line) + " \n"); +[[noreturn]] inline void throwRuntimeError(const char* const file, + int const line, + std::string const& info = "") { + throw std::runtime_error(std::string("[FT][ERROR] ") + info + + " Assertion fail: " + file + ":" + + std::to_string(line) + " \n"); } -inline void myAssert(bool result, const char* const file, int const line, std::string const& info = "") -{ - if (!result) { - throwRuntimeError(file, line, info); - } +inline void myAssert(bool result, + const char* const file, + int const line, + std::string const& info = "") { + if (!result) { + throwRuntimeError(file, line, info); + } } #define FT_CHECK(val) myAssert(val, __FILE__, __LINE__) -#define FT_CHECK_WITH_INFO(val, info) \ - do { \ - bool is_valid_val = (val); \ - if (!is_valid_val) { \ - paddle::operators::myAssert(is_valid_val, __FILE__, __LINE__, (info)); \ - } \ - } while (0) +#define FT_CHECK_WITH_INFO(val, info) \ + do { \ + bool is_valid_val = (val); \ + if (!is_valid_val) { \ + paddle::operators::myAssert(is_valid_val, __FILE__, __LINE__, (info)); \ + } \ + } while (0) #define FT_THROW(info) throwRuntimeError(__FILE__, __LINE__, info) #ifdef SPARSITY_ENABLED -#define CHECK_CUSPARSE(func) \ - { \ - cusparseStatus_t status = (func); \ - if (status != CUSPARSE_STATUS_SUCCESS) { \ - throw std::runtime_error(std::string("[FT][ERROR] CUSPARSE API failed at line ") \ - + std::to_string(__LINE__) + " in file " + __FILE__ + ": " \ - + cusparseGetErrorString(status) + " " + std::to_string(status)); \ - } \ - } +#define CHECK_CUSPARSE(func) \ + { \ + cusparseStatus_t status = (func); \ + if (status != CUSPARSE_STATUS_SUCCESS) { \ + throw std::runtime_error( \ + std::string("[FT][ERROR] CUSPARSE API failed at line ") + \ + std::to_string(__LINE__) + " in file " + __FILE__ + ": " + \ + cusparseGetErrorString(status) + " " + std::to_string(status)); \ + } \ + } #endif /*************Time Handling**************/ class CudaTimer { -private: - cudaEvent_t event_start_; - cudaEvent_t event_stop_; - cudaStream_t stream_; + private: + cudaEvent_t event_start_; + cudaEvent_t event_stop_; + cudaStream_t stream_; -public: - explicit CudaTimer(cudaStream_t stream = 0) - { - stream_ = stream; - } - void start() - { - check_cuda_error(cudaEventCreate(&event_start_)); - check_cuda_error(cudaEventCreate(&event_stop_)); - check_cuda_error(cudaEventRecord(event_start_, stream_)); - } - float stop() - { - float time; - check_cuda_error(cudaEventRecord(event_stop_, stream_)); - check_cuda_error(cudaEventSynchronize(event_stop_)); - check_cuda_error(cudaEventElapsedTime(&time, event_start_, event_stop_)); - check_cuda_error(cudaEventDestroy(event_start_)); - check_cuda_error(cudaEventDestroy(event_stop_)); - return time; - } - ~CudaTimer() {} + public: + explicit CudaTimer(cudaStream_t stream = 0) { stream_ = stream; } + void start() { + check_cuda_error(cudaEventCreate(&event_start_)); + check_cuda_error(cudaEventCreate(&event_stop_)); + check_cuda_error(cudaEventRecord(event_start_, stream_)); + } + float stop() { + float time; + check_cuda_error(cudaEventRecord(event_stop_, stream_)); + check_cuda_error(cudaEventSynchronize(event_stop_)); + check_cuda_error(cudaEventElapsedTime(&time, event_start_, event_stop_)); + check_cuda_error(cudaEventDestroy(event_start_)); + check_cuda_error(cudaEventDestroy(event_stop_)); + return time; + } + ~CudaTimer() {} }; -static double diffTime(timeval start, timeval end) -{ - return (end.tv_sec - start.tv_sec) * 1000 + (end.tv_usec - start.tv_usec) * 0.001; +static double diffTime(timeval start, timeval end) { + return (end.tv_sec - start.tv_sec) * 1000 + + (end.tv_usec - start.tv_usec) * 0.001; } /* ***************************** common utils ****************************** */ -inline void print_mem_usage(std::string time = "after allocation") -{ - size_t free_bytes, total_bytes; - check_cuda_error(cudaMemGetInfo(&free_bytes, &total_bytes)); - float free = static_cast(free_bytes) / 1024.0 / 1024.0 / 1024.0; - float total = static_cast(total_bytes) / 1024.0 / 1024.0 / 1024.0; - float used = total - free; - printf("%-20s: free: %5.2f GB, total: %5.2f GB, used: %5.2f GB\n", time.c_str(), free, total, used); +inline void print_mem_usage(std::string time = "after allocation") { + size_t free_bytes, total_bytes; + check_cuda_error(cudaMemGetInfo(&free_bytes, &total_bytes)); + float free = static_cast(free_bytes) / 1024.0 / 1024.0 / 1024.0; + float total = static_cast(total_bytes) / 1024.0 / 1024.0 / 1024.0; + float used = total - free; + printf("%-20s: free: %5.2f GB, total: %5.2f GB, used: %5.2f GB\n", + time.c_str(), + free, + total, + used); } -inline int getSMVersion() -{ - int device{-1}; - check_cuda_error(cudaGetDevice(&device)); - int sm_major = 0; - int sm_minor = 0; - check_cuda_error(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device)); - check_cuda_error(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device)); - return sm_major * 10 + sm_minor; +inline int getSMVersion() { + int device{-1}; + check_cuda_error(cudaGetDevice(&device)); + int sm_major = 0; + int sm_minor = 0; + check_cuda_error(cudaDeviceGetAttribute( + &sm_major, cudaDevAttrComputeCapabilityMajor, device)); + check_cuda_error(cudaDeviceGetAttribute( + &sm_minor, cudaDevAttrComputeCapabilityMinor, device)); + return sm_major * 10 + sm_minor; } -inline int getMaxSharedMemoryPerBlock() -{ - int device{-1}; - check_cuda_error(cudaGetDevice(&device)); - int max_shared_memory_size = 0; - check_cuda_error(cudaDeviceGetAttribute(&max_shared_memory_size, cudaDevAttrMaxSharedMemoryPerBlock, device)); - return max_shared_memory_size; +inline int getMaxSharedMemoryPerBlock() { + int device{-1}; + check_cuda_error(cudaGetDevice(&device)); + int max_shared_memory_size = 0; + check_cuda_error(cudaDeviceGetAttribute( + &max_shared_memory_size, cudaDevAttrMaxSharedMemoryPerBlock, device)); + return max_shared_memory_size; } -inline std::string getDeviceName() -{ - int device{-1}; - check_cuda_error(cudaGetDevice(&device)); - cudaDeviceProp props; - check_cuda_error(cudaGetDeviceProperties(&props, device)); - return std::string(props.name); +inline std::string getDeviceName() { + int device{-1}; + check_cuda_error(cudaGetDevice(&device)); + cudaDeviceProp props; + check_cuda_error(cudaGetDeviceProperties(&props, device)); + return std::string(props.name); } -inline int div_up(int a, int n) -{ - return (a + n - 1) / n; -} +inline int div_up(int a, int n) { return (a + n - 1) / n; } cudaError_t getSetDevice(int i_device, int* o_device = NULL); -inline int getDevice() -{ - int current_dev_id = 0; - check_cuda_error(cudaGetDevice(¤t_dev_id)); - return current_dev_id; +inline int getDevice() { + int current_dev_id = 0; + check_cuda_error(cudaGetDevice(¤t_dev_id)); + return current_dev_id; } -inline int getDeviceCount() -{ - int count = 0; - check_cuda_error(cudaGetDeviceCount(&count)); - return count; +inline int getDeviceCount() { + int count = 0; + check_cuda_error(cudaGetDeviceCount(&count)); + return count; } -template -CublasDataType getCublasDataType() -{ - if (std::is_same::value) { - return HALF_DATATYPE; - } - else if (std::is_same::value) { - return FLOAT_DATATYPE; - } - else { - FT_CHECK(false); - return FLOAT_DATATYPE; - } +template +CublasDataType getCublasDataType() { + if (std::is_same::value) { + return HALF_DATATYPE; + } else if (std::is_same::value) { + return FLOAT_DATATYPE; + } else { + FT_CHECK(false); + return FLOAT_DATATYPE; + } } -template -cudaDataType_t getCudaDataType() -{ - if (std::is_same::value) { - return CUDA_R_16F; - } +template +cudaDataType_t getCudaDataType() { + if (std::is_same::value) { + return CUDA_R_16F; + } - else if (std::is_same::value) { - return CUDA_R_32F; - } - else { - FT_CHECK(false); - return CUDA_R_32F; - } + else if (std::is_same::value) { + return CUDA_R_32F; + } else { + FT_CHECK(false); + return CUDA_R_32F; + } } -template +template struct getTypeFromCudaDataType { - using Type = float; + using Type = float; }; -template<> +template <> struct getTypeFromCudaDataType { - using Type = half; + using Type = half; }; - // clang-format off template struct packed_type; template <> struct packed_type { using type = float; }; // we don't need to pack float by default @@ -390,59 +376,75 @@ inline __device__ float2 operator*(float2 a, float2 b) { return make_float2(a.x inline __device__ float2 operator*(float2 a, float b) { return make_float2(a.x * b, a.y * b); } // clang-format on -template -void compareTwoTensor( - const T1* pred, const T2* ref, const int size, const int print_size = 0, const std::string filename = "") -{ - T1* h_pred = new T1[size]; - T2* h_ref = new T2[size]; - check_cuda_error(cudaMemcpy(h_pred, pred, size * sizeof(T1), cudaMemcpyDeviceToHost)); - check_cuda_error(cudaMemcpy(h_ref, ref, size * sizeof(T2), cudaMemcpyDeviceToHost)); +template +void compareTwoTensor(const T1* pred, + const T2* ref, + const int size, + const int print_size = 0, + const std::string filename = "") { + T1* h_pred = new T1[size]; + T2* h_ref = new T2[size]; + check_cuda_error( + cudaMemcpy(h_pred, pred, size * sizeof(T1), cudaMemcpyDeviceToHost)); + check_cuda_error( + cudaMemcpy(h_ref, ref, size * sizeof(T2), cudaMemcpyDeviceToHost)); - FILE* fd = nullptr; - if (filename != "") { - fd = fopen(filename.c_str(), "w"); - fprintf(fd, "| %10s | %10s | %10s | %10s | \n", "pred", "ref", "abs_diff", "rel_diff(%)"); - } + FILE* fd = nullptr; + if (filename != "") { + fd = fopen(filename.c_str(), "w"); + fprintf(fd, + "| %10s | %10s | %10s | %10s | \n", + "pred", + "ref", + "abs_diff", + "rel_diff(%)"); + } - if (print_size > 0) { - std::cout<<" id | pred | ref |abs diff | rel diff (%) |"< 0) { + std::cout << " id | pred | ref |abs diff | rel diff (%) |" + << std::endl; + } + float mean_abs_diff = 0.0f; + float mean_rel_diff = 0.0f; + int count = 0; + for (int i = 0; i < size; i++) { + if (i < print_size) { + std::cout << i << " | " << (float)h_pred[i] << " | " << (float)h_ref[i] + << " | " << (abs((float)h_pred[i] - (float)h_ref[i])) << " | " + << (abs((float)h_pred[i] - (float)h_ref[i]) / + (abs((float)h_ref[i]) + 1e-6f) * 100.f) + << " | " << std::endl; } - float mean_abs_diff = 0.0f; - float mean_rel_diff = 0.0f; - int count = 0; - for (int i = 0; i < size; i++) { - if (i < print_size) { - std::cout< -struct TagOperator -{ - using TaggedOperator = MmaOp; +struct TagOperator { + using TaggedOperator = MmaOp; }; // Specializations below attach more information to the operator template <> -struct TagOperator -{ - using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale; +struct TagOperator { + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale; }; template <> -struct TagOperator -{ - using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_grained_scale; +struct TagOperator { + using TaggedOperator = + OpMultiplyAddDequantizeInterleavedBToA_fine_grained_scale; }; - -// Here we instantiate some structs to "detag" the tagged operator. It splits it back to the original -// operator + the extra information. If no extra info was tagged, the dequant op per column scaling -// as a default. +// Here we instantiate some structs to "detag" the tagged operator. It splits it +// back to the original operator + the extra information. If no extra info was +// tagged, the dequant op per column scaling as a default. template -struct DetagOperator -{ - using Operator = TaggedMmaOp; - static constexpr bool FineGrained = false; +struct DetagOperator { + using Operator = TaggedMmaOp; + static constexpr bool FineGrained = false; }; template <> -struct DetagOperator -{ - using Operator = OpMultiplyAddDequantizeInterleavedBToA; - static constexpr bool FineGrained = false; +struct DetagOperator { + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr bool FineGrained = false; }; template <> -struct DetagOperator -{ - using Operator = OpMultiplyAddDequantizeInterleavedBToA; - static constexpr bool FineGrained = true; +struct DetagOperator< + OpMultiplyAddDequantizeInterleavedBToA_fine_grained_scale> { + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr bool FineGrained = true; }; - -} // namespace arch -} // namespace cutlass +} // namespace arch +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/arch/mma_sm80.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/arch/mma_sm80.h index 4f613a28cd..58d294745a 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/arch/mma_sm80.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/arch/mma_sm80.h @@ -19,18 +19,16 @@ namespace cutlass { namespace arch { template <> -struct Mma< - gemm::GemmShape<16,8,16>, - 32, - int8_t, - layout::RowMajor, - int8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<16,8,16>; +struct Mma, + 32, + int8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + using Shape = gemm::GemmShape<16, 8, 16>; using ElementA = int8_t; using LayoutA = layout::RowMajor; @@ -50,13 +48,10 @@ struct Mma< /// Computes multiply-add CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - + void operator()(FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c) const { #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const &B = reinterpret_cast(b); @@ -65,10 +60,16 @@ struct Mma< int *D = reinterpret_cast(&d); asm volatile( - "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0,%1,%2,%3}, {%4,%5}, {%6}, " + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0,%1,%2,%3}, " + "{%4,%5}, {%6}, " "{%7,%8,%9,%10};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), + : "r"(A[0]), + "r"(A[1]), + "r"(B), + "r"(C[0]), + "r"(C[1]), + "r"(C[2]), "r"(C[3])); #else @@ -82,18 +83,16 @@ struct Mma< }; template <> -struct Mma< - gemm::GemmShape<16,8,32>, - 32, - int8_t, - layout::RowMajor, - int8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<16,8,32>; +struct Mma, + 32, + int8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + using Shape = gemm::GemmShape<16, 8, 32>; using ElementA = int8_t; using LayoutA = layout::RowMajor; @@ -112,13 +111,10 @@ struct Mma< /// Computes multiply-add CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - + void operator()(FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c) const { #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); @@ -128,11 +124,20 @@ struct Mma< int *D = reinterpret_cast(&d); asm volatile( - "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, " "{%8,%9}, {%10,%11,%12,%13};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[0]), + "r"(B[1]), + "r"(C[0]), + "r"(C[1]), + "r"(C[2]), + "r"(C[3])); #else assert(0); @@ -140,6 +145,5 @@ struct Mma< } }; - -} -} +} // namespace arch +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/epilogue/epilogue_quant_helper.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/epilogue/epilogue_quant_helper.h index 38a253d93d..ced7ae1c1a 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/epilogue/epilogue_quant_helper.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/epilogue/epilogue_quant_helper.h @@ -37,10 +37,10 @@ namespace epilogue { // define scaling mode enum class QuantMode { - PerTensorQuant, - PerTokenQuant, - PerChannelQuant, - PerTokenChannelQuant + PerTensorQuant, + PerTokenQuant, + PerChannelQuant, + PerTokenChannelQuant }; } // namespace epilogue diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale_nf4.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale_nf4.h index 457416d75d..8f89790a3d 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale_nf4.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale_nf4.h @@ -52,8 +52,9 @@ namespace cutlass { namespace epilogue { namespace threadblock { template -[[gnu::warning("your type here")]] -bool print_type_1111() { return false; } +[[gnu::warning("your type here")]] bool print_type_1111() { + return false; +} template (&fragment_D_)[frag_idx]; - output = output_converter(result); + output = output_converter(result); /* // Convert to the output, with non zero C added */ // NumericArrayConverter // output_converter; @@ -341,7 +344,6 @@ class EpilogueVisitorPerRowPerColNf4 { // OutputVector& output = // reinterpret_cast(&fragment_D_)[frag_idx]; - // OutputVector& vector_c = // reinterpret_cast(&fragment_C_)[frag_idx]; @@ -353,9 +355,9 @@ class EpilogueVisitorPerRowPerColNf4 { /// Called after accumulators have been exchanged for each accumulator vector CUTLASS_DEVICE - void visit(AccumulatorFragment const &accum, + void visit(AccumulatorFragment const& accum, int reduce_fragment_idx, - OutputTileIterator &destination_iterator) { + OutputTileIterator& destination_iterator) { NumericArrayConverter @@ -368,7 +370,8 @@ class EpilogueVisitorPerRowPerColNf4 { ComputeFragment result = source_converter(accum); // if(threadIdx.x<32){ - // printf("#### %d-%d-%d--%d-%d-%d, reduced accu:%d-%d-%d-%d-%d-%d-%d-%d, dequant: accu:%f-%f-%f-%f-%f-%f-%f-%f \n", + // printf("#### %d-%d-%d--%d-%d-%d, reduced accu:%d-%d-%d-%d-%d-%d-%d-%d, + // dequant: accu:%f-%f-%f-%f-%f-%f-%f-%f \n", // blockIdx.x,blockIdx.y,blockIdx.z, // threadIdx.x,threadIdx.y,threadIdx.z, // accum[0], @@ -401,7 +404,8 @@ class EpilogueVisitorPerRowPerColNf4 { } // just for bug, pass // if(threadIdx.x<32){ - // printf("#### %d-%d-%d--%d-%d-%d, reduced accu:%d-%d-%d-%d-%d-%d-%d-%d, dequant: accu:%f-%f-%f-%f-%f-%f-%f-%f \n", + // printf("#### %d-%d-%d--%d-%d-%d, reduced accu:%d-%d-%d-%d-%d-%d-%d-%d, + // dequant: accu:%f-%f-%f-%f-%f-%f-%f-%f \n", // blockIdx.x,blockIdx.y,blockIdx.z, // threadIdx.x,threadIdx.y,threadIdx.z, // accum[0], @@ -428,7 +432,8 @@ class EpilogueVisitorPerRowPerColNf4 { // auto result_tmp = output_converter(result); // if(threadIdx.x<32){ - // printf("#### %d-%d-%d--%d-%d-%d, reduced accu:%d-%d-%d-%d-%d-%d-%d-%d, dequant: accu:%f-%f-%f-%f-%f-%f-%f-%f \n", + // printf("#### %d-%d-%d--%d-%d-%d, reduced accu:%d-%d-%d-%d-%d-%d-%d-%d, + // dequant: accu:%f-%f-%f-%f-%f-%f-%f-%f \n", // blockIdx.x,blockIdx.y,blockIdx.z, // threadIdx.x,threadIdx.y,threadIdx.z, // accum[0], @@ -452,8 +457,10 @@ class EpilogueVisitorPerRowPerColNf4 { typename OutputTileIterator::Fragment output_fragment; CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii(result[ii]); + for (int ii = 0; ii < output_fragment.size(); ++ii) { + output_fragment[ii] = + static_cast( + result[ii]); } // OutputVector& output = // reinterpret_cast(&output_fragment)[0]; @@ -464,7 +471,8 @@ class EpilogueVisitorPerRowPerColNf4 { // } // if(threadIdx.x<32){ - // printf("#### %d-%d-%d--%d-%d-%d, reduced accu:%d-%d-%d-%d-%d-%d-%d-%d, dequant: accu:%f-%f-%f-%f-%f-%f-%f-%f \n", + // printf("#### %d-%d-%d--%d-%d-%d, reduced accu:%d-%d-%d-%d-%d-%d-%d-%d, + // dequant: accu:%f-%f-%f-%f-%f-%f-%f-%f \n", // blockIdx.x,blockIdx.y,blockIdx.z, // threadIdx.x,threadIdx.y,threadIdx.z, // accum[0], @@ -545,7 +553,9 @@ class EpilogueVisitorPerRowPerColNf4 { ComputeFragment const& scale_col, AlphaScaleElementType const& scale_row) { // if(threadIdx.x<32){ - // printf("#### per_token_channel_scale_accumulator, %d-%d-%d--%d-%d-%d, quanted accu:%f-%f-%f-%f-%f-%f-%f-%f, scale_col:%f-%f-%f-%f-%f-%f-%f-%f \n", + // printf("#### per_token_channel_scale_accumulator, %d-%d-%d--%d-%d-%d, + // quanted accu:%f-%f-%f-%f-%f-%f-%f-%f, scale_col:%f-%f-%f-%f-%f-%f-%f-%f + // \n", // blockIdx.x,blockIdx.y,blockIdx.z, // threadIdx.x,threadIdx.y,threadIdx.z, // static_cast(accum[0]), diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h index 0a3c107f14..b2270332bd 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h @@ -90,28 +90,23 @@ namespace threadblock { namespace detail { -template < - typename ElementOutput, - typename ElementAccumulator, - int ElementsPerAccess, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename ThreadMap -> +template struct Nf4DefaultIteratorsTensorOp { + using WarpTileIterator = + cutlass::epilogue::warp::TileIteratorTensorOp; - using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp< - WarpShape, - InstructionShape, - ElementAccumulator, - layout::RowMajor - >; - - using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< - ThreadMap, - ElementAccumulator - >; + using SharedLoadIterator = + cutlass::epilogue::threadblock::SharedLoadIterator; static int const kFragmentsPerIteration = 1; }; @@ -123,12 +118,12 @@ template struct Nf4DefaultIteratorsTensorOp { + int32_t, + 8, + ThreadblockShape, + WarpShape, + InstructionShape, + ThreadMap> { using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp struct Nf4DefaultIteratorsTensorOp { + int32_t, + 8, + ThreadblockShape, + WarpShape, + InstructionShape, + ThreadMap> { using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp { void load(Fragment& frag) const { load_with_pointer_offset(frag, 0); } }; - - -template < - typename Shape_, - typename WarpMmaTensorOp_, - int PartitionsK, - typename OutputOp_, - int ElementsPerAccess, - bool ScatterD = false, - typename PermuteDLayout = layout::NoPermute -> +template struct DequantEpilogueTensorOp { - using Shape = Shape_; using WarpMmaTensorOp = WarpMmaTensorOp_; static int const kPartitionsK = PartitionsK; @@ -352,76 +342,76 @@ struct DequantEpilogueTensorOp { // Thread map // - using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp< - Shape, - typename WarpMmaTensorOp::Shape, - kPartitionsK, - ElementOutput, - kElementsPerAccess - >::Type; + using OutputTileThreadMap = + typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp< + Shape, + typename WarpMmaTensorOp::Shape, + kPartitionsK, + ElementOutput, + kElementsPerAccess>::Type; - static bool const UseCUDAStore = platform::is_same::value; + static bool const UseCUDAStore = + platform::is_same::value; - using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< - OutputTileThreadMap, - ElementOutput, - ScatterD, - PermuteDLayout, - UseCUDAStore - >; + using OutputTileIterator = + cutlass::epilogue::threadblock::PredicatedTileIterator< + OutputTileThreadMap, + ElementOutput, + ScatterD, + PermuteDLayout, + UseCUDAStore>; - using AccumulatorFragmentIterator = typename platform::conditional::value, - cutlass::epilogue::warp::FragmentIteratorComplexTensorOp< - typename WarpMmaTensorOp::Shape, - typename WarpMmaTensorOp::Policy::Operator::Shape, - typename WarpMmaTensorOp::Policy::Operator::ElementC, - typename WarpMmaTensorOp::Policy::Operator::FragmentC, - LayoutC>, - cutlass::epilogue::warp::FragmentIteratorTensorOp< - typename WarpMmaTensorOp::Shape, - typename WarpMmaTensorOp::Policy::Operator::Shape, - typename WarpMmaTensorOp::Policy::Operator::ElementC, - typename WarpMmaTensorOp::Policy::Operator::FragmentC, - LayoutC> >::type; + using AccumulatorFragmentIterator = typename platform::conditional< + is_complex::value, + cutlass::epilogue::warp::FragmentIteratorComplexTensorOp< + typename WarpMmaTensorOp::Shape, + typename WarpMmaTensorOp::Policy::Operator::Shape, + typename WarpMmaTensorOp::Policy::Operator::ElementC, + typename WarpMmaTensorOp::Policy::Operator::FragmentC, + LayoutC>, + cutlass::epilogue::warp::FragmentIteratorTensorOp< + typename WarpMmaTensorOp::Shape, + typename WarpMmaTensorOp::Policy::Operator::Shape, + typename WarpMmaTensorOp::Policy::Operator::ElementC, + typename WarpMmaTensorOp::Policy::Operator::FragmentC, + LayoutC>>::type; /// Support several implementations depending on structure of epilogue using DefaultIterators = detail::Nf4DefaultIteratorsTensorOp< - ElementOutput, - ElementAccumulator, - kElementsPerAccess, - Shape, - typename WarpMmaTensorOp::Shape, - typename WarpMmaTensorOp::Policy::Operator::Shape, - typename OutputTileThreadMap::CompactedThreadMap - >; + ElementOutput, + ElementAccumulator, + kElementsPerAccess, + Shape, + typename WarpMmaTensorOp::Shape, + typename WarpMmaTensorOp::Policy::Operator::Shape, + typename OutputTileThreadMap::CompactedThreadMap>; using WarpTileIterator = typename DefaultIterators::WarpTileIterator; using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator; /// Hard-coded padding elements added - using Padding = cutlass::MatrixShape<0, 64 / sizeof_bits::value * 4>; + using Padding = + cutlass::MatrixShape<0, 64 / sizeof_bits::value * 4>; - static int const kFragmentsPerIteration = (kPartitionsK == 1 ? DefaultIterators::kFragmentsPerIteration : 1); + static int const kFragmentsPerIteration = + (kPartitionsK == 1 ? DefaultIterators::kFragmentsPerIteration : 1); // // Define the epilogue // - using Epilogue = cutlass::epilogue::threadblock::Epilogue< - Shape, - WarpMmaTensorOp, - kPartitionsK, - OutputTileIterator, - AccumulatorFragmentIterator, - WarpTileIterator, - SharedLoadIterator, - OutputOp, - Padding, - kFragmentsPerIteration - >; + using Epilogue = + cutlass::epilogue::threadblock::Epilogue; }; - - ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace threadblock diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/ft_gemm_configs.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/ft_gemm_configs.h index 702ff05d40..fa77599f12 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/ft_gemm_configs.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/ft_gemm_configs.h @@ -35,39 +35,39 @@ limitations under the License. */ // in the kernel layout details when doing weight only quantization. enum class CutlassTileConfig { // Signals that we should run heuristics do choose a config - Undefined, // 0 + Undefined, // 0 // Signals that we should run heuristics do choose a config - ChooseWithHeuristic, // 1 + ChooseWithHeuristic, // 1 // SiMT config - CtaShape128x128x8_WarpShape64x64x8, // 2 + CtaShape128x128x8_WarpShape64x64x8, // 2 // TensorCore configs CTA_N = 128, CTA_K = 64 // Warp configs for M=16 - CtaShape16x128x64_WarpShape16x32x64, // 3 - CtaShape16x256x64_WarpShape16x64x64, // 4 + CtaShape16x128x64_WarpShape16x32x64, // 3 + CtaShape16x256x64_WarpShape16x64x64, // 4 // Warp configs for M=32 - CtaShape32x128x64_WarpShape32x32x64, // 5 + CtaShape32x128x64_WarpShape32x32x64, // 5 // Warp configs for M=64 - CtaShape64x128x64_WarpShape32x64x64, // 6 - CtaShape64x128x64_WarpShape64x32x64, // 7 + CtaShape64x128x64_WarpShape32x64x64, // 6 + CtaShape64x128x64_WarpShape64x32x64, // 7 // Warp configs for M=128 - CtaShape128x128x64_WarpShape64x32x64, // 8 - CtaShape128x128x64_WarpShape128x32x64, // 9 + CtaShape128x128x64_WarpShape64x32x64, // 8 + CtaShape128x128x64_WarpShape128x32x64, // 9 // configs for large M in encoder - CtaShape128x256x64_WarpShape64x64x64, // 10 - CtaShape256x128x64_WarpShape64x64x64, // 11 + CtaShape128x256x64_WarpShape64x64x64, // 10 + CtaShape256x128x64_WarpShape64x64x64, // 11 }; enum class SplitKStyle { - NO_SPLIT_K, //0 - SPLIT_K_SERIAL, //1 - SPLIT_K_STREAM, //2 + NO_SPLIT_K, // 0 + SPLIT_K_SERIAL, // 1 + SPLIT_K_STREAM, // 2 // SPLIT_K_PARALLEL // Not supported yet }; diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/default_dequant_gemm_nf4.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/default_dequant_gemm_nf4.h index c3ec38968c..12b78b8f03 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/default_dequant_gemm_nf4.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/default_dequant_gemm_nf4.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,25 +18,27 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief - Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with - the appropriate threadblock-scoped epilogue. + Default kernel-level GEMM definitions combine threadblock-scoped matrix + multiply-add with the appropriate threadblock-scoped epilogue. - Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are - accommodated by exchanging A and B operands and assuming transposed layouts. Partial - specializations here choose 'device::GemmTransposed' to implement this functionality. + Note, CUTLASS epilogues universally target row-major outputs. Column-major + outputs are accommodated by exchanging A and B operands and assuming + transposed layouts. Partial specializations here choose + 'device::GemmTransposed' to implement this functionality. */ #pragma once @@ -101,8 +103,7 @@ template < /// Permute result D typename PermuteDLayout = layout::NoPermute, /// - typename Enable = void -> + typename Enable = void> struct DefaultDequantGemm; /////////////////////////////////////////////////// @@ -152,40 +153,77 @@ template < /// Scatter result D by using an index array bool ScatterD, /// Permute result D - typename PermuteDLayout -> -struct DefaultDequantGemm { - - static_assert(platform::is_same::value - || platform::is_same>::value, - "Epilogue in the kernel level must be row major"); + typename PermuteDLayout> +struct DefaultDequantGemm { + static_assert(platform::is_same::value || + platform::is_same>::value, + "Epilogue in the kernel level must be row major"); /// Define the threadblock-scoped matrix multiply-accumulate - using Mma = typename cutlass::gemm::threadblock::DefaultMma< - ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, - ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80, - ThreadblockShape, WarpShape, InstructionShape, Stages, - Operator, false, SharedMemoryClear, GatherA, GatherB>::ThreadblockMma; + using Mma = + typename cutlass::gemm::threadblock::DefaultMma::ThreadblockMma; static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; /// Define the epilogue - using RegularEpilogue = typename cutlass::epilogue::threadblock::DequantEpilogueTensorOp< - ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, - EpilogueOutputOp::kCount, ScatterD, PermuteDLayout>::Epilogue; + using RegularEpilogue = + typename cutlass::epilogue::threadblock::DequantEpilogueTensorOp< + ThreadblockShape, + typename Mma::Operator, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + ScatterD, + PermuteDLayout>::Epilogue; using Epilogue = RegularEpilogue; /// Define the kernel-level GEMM operator. - using GemmKernel = kernel::Gemm; + using GemmKernel = + kernel::Gemm; }; - - template < /// Element type for A matrix operand typename ElementA_, @@ -235,8 +273,7 @@ template < /// Scatter result D by using an index array bool ScatterD = false, /// Permute result D - typename PermuteDLayout = layout::NoPermute -> + typename PermuteDLayout = layout::NoPermute> struct DefaultInt8InterleavedGemm; /// Partial specialization for Ampere @@ -275,54 +312,69 @@ template < /// epilogue bool SplitKSerial, /// Operation performed by GEMM - typename Operator -> + typename Operator> struct DefaultInt8InterleavedGemm { - + LayoutA, + kAlignmentA, + int8_t, + LayoutB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + SplitKSerial, + Operator> { static_assert(platform::is_same::value, - "Epilogue in the kernel level must be row major"); + "Epilogue in the kernel level must be row major"); /// Define the threadblock-scoped matrix multiply-accumulate using Mma = typename cutlass::gemm::threadblock::DefaulInt8Nf4InterleavedMma< - int8_t, LayoutA, kAlignmentA, int8_t, LayoutB, kAlignmentB, - ElementAccumulator, LayoutC, OperatorClass, ArchTag, - ThreadblockShape, WarpShape, InstructionShape, Stages, - Operator, false, SharedMemoryClearOption::kNone>::ThreadblockMma; - - + int8_t, + LayoutA, + kAlignmentA, + int8_t, + LayoutB, + kAlignmentB, + ElementAccumulator, + LayoutC, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + Stages, + Operator, + false, + SharedMemoryClearOption::kNone>::ThreadblockMma; static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; /// Define the epilogue - using RegularEpilogue = typename cutlass::epilogue::threadblock::DequantEpilogueTensorOp< - ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, - EpilogueOutputOp::kCount, false, layout::NoPermute>::Epilogue; + using RegularEpilogue = + typename cutlass::epilogue::threadblock::DequantEpilogueTensorOp< + ThreadblockShape, + typename Mma::Operator, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute>::Epilogue; using Epilogue = RegularEpilogue; /// Define the kernel-level GEMM operator. - using GemmKernel = kernel::Gemm; + using GemmKernel = + kernel::Gemm; }; - /// Partial specialization for Ampere template < /// Layout type for A matrix operand @@ -359,51 +411,67 @@ template < /// epilogue bool SplitKSerial, /// Operation performed by GEMM - typename Operator -> + typename Operator> struct DefaultInt8InterleavedGemm { - + LayoutA, + kAlignmentA, + cutlass::uint4b_t, + LayoutB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + SplitKSerial, + Operator> { static_assert(platform::is_same::value, - "Epilogue in the kernel level must be row major"); + "Epilogue in the kernel level must be row major"); /// Define the threadblock-scoped matrix multiply-accumulate using Mma = typename cutlass::gemm::threadblock::DefaulInt8Nf4InterleavedMma< - int8_t, LayoutA, kAlignmentA, cutlass::uint4b_t, LayoutB, kAlignmentB, - ElementAccumulator, LayoutC, OperatorClass, ArchTag, - ThreadblockShape, WarpShape, InstructionShape, Stages, - Operator, false, SharedMemoryClearOption::kNone>::ThreadblockMma; - - + int8_t, + LayoutA, + kAlignmentA, + cutlass::uint4b_t, + LayoutB, + kAlignmentB, + ElementAccumulator, + LayoutC, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + Stages, + Operator, + false, + SharedMemoryClearOption::kNone>::ThreadblockMma; static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; /// Define the epilogue - using RegularEpilogue = typename cutlass::epilogue::threadblock::DequantEpilogueTensorOp< - ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, - EpilogueOutputOp::kCount, false, layout::NoPermute>::Epilogue; + using RegularEpilogue = + typename cutlass::epilogue::threadblock::DequantEpilogueTensorOp< + ThreadblockShape, + typename Mma::Operator, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute>::Epilogue; using Epilogue = RegularEpilogue; /// Define the kernel-level GEMM operator. - using GemmKernel = kernel::Gemm; + using GemmKernel = + kernel::Gemm; }; //////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/default_intA_nf4B_traits.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/default_intA_nf4B_traits.h index 2a13073f1f..9c69c037f8 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/default_intA_nf4B_traits.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/default_intA_nf4B_traits.h @@ -44,56 +44,63 @@ namespace cutlass { namespace gemm { namespace kernel { -template -struct Int8Nf4GemmArchTraits { -}; +template +struct Int8Nf4GemmArchTraits {}; -template +template struct Int8Nf4GemmArchTraits { - static constexpr int Stages = 2; - using OperatorClass = cutlass::arch::OpClassSimt; - using AccType = float; - using LayoutB = cutlass::layout::RowMajor; + static constexpr int Stages = 2; + using OperatorClass = cutlass::arch::OpClassSimt; + using AccType = float; + using LayoutB = cutlass::layout::RowMajor; - static constexpr int ElementsPerAccessA = 1; - static constexpr int ElementsPerAccessB = 1; - static constexpr int ElementsPerAccessC = 1; - static constexpr int ThreadblockK = 8; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + static constexpr int ElementsPerAccessA = 1; + static constexpr int ElementsPerAccessB = 1; + static constexpr int ElementsPerAccessC = 1; + static constexpr int ThreadblockK = 8; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using Operator = cutlass::arch::OpMultiplyAdd; + using Operator = cutlass::arch::OpMultiplyAdd; }; // ======================= Ampere Traits ============================== -template -struct Int8Nf4GemmArchTraits::value || - cutlass::platform::is_same::value>::type> { -private: - using LayoutDetails = LayoutDetailsB; +template +struct Int8Nf4GemmArchTraits< + IntAType, + IntBType, + OutType, + cutlass::arch::Sm80, + typename cutlass::platform::enable_if< + cutlass::platform::is_same::value || + cutlass::platform::is_same::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; -public: - static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; - using OperatorClass = cutlass::arch::OpClassTensorOp; - using AccType = int32_t; - using LayoutB = typename LayoutDetails::Layout; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = int32_t; + using LayoutB = typename LayoutDetails::Layout; - static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; - static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; - static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; - // static_assert(cutlass::platform::is_same::value, - // "input type must be int8_t"); - // static_assert((ElementsPerAccessA == 16), "====="); - // static_assert((ElementsPerAccessB == 16), "====="); - // static_assert((ElementsPerAccessC == 8), "====="); - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; - // using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; - using Operator = typename LayoutDetails::Operator; + static constexpr int ElementsPerAccessA = + 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = + 128 / cutlass::sizeof_bits::value; + // static_assert(cutlass::platform::is_same::value, + // "input type must be int8_t"); + // static_assert((ElementsPerAccessA == 16), "====="); + // static_assert((ElementsPerAccessB == 16), "====="); + // static_assert((ElementsPerAccessC == 8), "====="); + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + // using InstructionShape = cutlass::gemm::GemmShape<16, 8, + // 16>; + using Operator = typename LayoutDetails::Operator; }; } // namespace kernel diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor_interleaved_nf4.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor_interleaved_nf4.h index 574f9f2056..922e64d6c6 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor_interleaved_nf4.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor_interleaved_nf4.h @@ -290,7 +290,8 @@ struct GemmWithEpilogueVisitorInterleavedNf4 { // ceil_div(args.problem_size.k(), args.batch_count), kAlignK); // if (gemm_k_size) { - // grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + // grid_tiled_shape.k() = ceil_div(args.problem_size.k(), + // gemm_k_size); // } // } @@ -318,7 +319,8 @@ struct GemmWithEpilogueVisitorInterleavedNf4 { /// Determines whether kernel satisfies alignment static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) { - CUTLASS_TRACE_HOST("GemmWithEpilogueVisitorInterleavedNf4::can_implement()"); + CUTLASS_TRACE_HOST( + "GemmWithEpilogueVisitorInterleavedNf4::can_implement()"); static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; @@ -440,16 +442,17 @@ struct GemmWithEpilogueVisitorInterleavedNf4 { params.ptr_B)[threadblock_tile_offset.k()]; } #endif - // if(threadIdx.x==0){ - // printf("##### block: %d-%d-%d, offset_k:%d, threadblock_tile_offset.m-n-k():%d-%d-%d, params.gemm_k_size:%d \n", - // blockIdx.x, blockIdx.y, blockIdx.z, - // offset_k, - // threadblock_tile_offset.m(), - // threadblock_tile_offset.n(), - // threadblock_tile_offset.k(), - // params.gemm_k_size - // ); - // } + // if(threadIdx.x==0){ + // printf("##### block: %d-%d-%d, offset_k:%d, + // threadblock_tile_offset.m-n-k():%d-%d-%d, params.gemm_k_size:%d \n", + // blockIdx.x, blockIdx.y, blockIdx.z, + // offset_k, + // threadblock_tile_offset.m(), + // threadblock_tile_offset.n(), + // threadblock_tile_offset.k(), + // params.gemm_k_size + // ); + // } // Compute initial location in logical coordinates cutlass::MatrixCoord tb_offset_A{ @@ -460,7 +463,8 @@ struct GemmWithEpilogueVisitorInterleavedNf4 { // cutlass::MatrixCoord tb_offset_B{offset_k, threadblock_tile_offset.n() * // Mma::Shape::kN}; // printf("#### kInterleave:%d \n", kInterleave); - // printf("###### offset_k : %d; params.gemm_k_size:%d; threadblock_tile_offset.k():%d \n", + // printf("###### offset_k : %d; params.gemm_k_size:%d; + // threadblock_tile_offset.k():%d \n", // offset_k, // params.gemm_k_size, // threadblock_tile_offset.k() @@ -470,7 +474,8 @@ struct GemmWithEpilogueVisitorInterleavedNf4 { offset_k * kInterleave, threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave}; // if(threadIdx.x==0){ - // printf("##### block: %d-%d-%d, tb_offset_B:%d-%d, kInterleave:%d, Mma::IteratorB::Shape::kRow:%d, Mma::Shape::kK:%d \n", + // printf("##### block: %d-%d-%d, tb_offset_B:%d-%d, kInterleave:%d, + // Mma::IteratorB::Shape::kRow:%d, Mma::Shape::kK:%d \n", // blockIdx.x, blockIdx.y, blockIdx.z, // offset_k * kInterleave, // threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave, @@ -500,13 +505,11 @@ struct GemmWithEpilogueVisitorInterleavedNf4 { thread_idx, tb_offset_B); typename Mma::IteratorNF4LookUpTable iterator_nf4_look_up_table = - Mma::IteratorNF4LookUpTable( - params.params_nf4_look_up_table, - params.ref_nf4_look_up_table.data(), - {0,16}, - threadIdx.x, - {0,0} - ); + Mma::IteratorNF4LookUpTable(params.params_nf4_look_up_table, + params.ref_nf4_look_up_table.data(), + {0, 16}, + threadIdx.x, + {0, 0}); // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. @@ -530,7 +533,12 @@ struct GemmWithEpilogueVisitorInterleavedNf4 { (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; // printf("#### gemm_k_iterations: %d \n", gemm_k_iterations); // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_nf4_look_up_table, accumulators); + mma(gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + iterator_nf4_look_up_table, + accumulators); // if(threadIdx.x==0){ // printf("##### block: %d-%d-%d, offset-m-n-k:%d-%d-%d \n", // blockIdx.x, blockIdx.y, blockIdx.z, @@ -552,7 +560,8 @@ struct GemmWithEpilogueVisitorInterleavedNf4 { threadblock_tile_offset.n() * Mma::Shape::kN); // int block_idx = threadblock_tile_offset.m() + - // threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + // threadblock_tile_offset.n() * + // params.grid_tiled_shape.m(); // // Construct the epilogue visitor diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h index 35469f7ad9..b09368a848 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h @@ -29,9 +29,10 @@ See the License for the specific language governing permissions and limitations under the License. */ /* - This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is - quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices - to be consumed by CUTLASS. + This file exists so that we use the same weight layout for MoE grouped gemm + and regular gemm when the weight is quantized. The preprocessing code reads + this template to know how to organize the quantized weight matrices to be + consumed by CUTLASS. Note that for int4, ThreadBlockK MUST be 64. @@ -53,79 +54,108 @@ namespace cutlass { namespace gemm { namespace kernel { -template -struct LayoutDetailsB { -}; +template +struct LayoutDetailsB {}; -// // Volta specialiations. Volta will dequantize before STS, so we need a different operator -template +// // Volta specialiations. Volta will dequantize before STS, so we need a +// different operator +template struct LayoutDetailsB { - static constexpr int ThreadblockK = 64; - using Layout = layout::RowMajor; - static constexpr int ElementsPerAccess = 8; - using Operator = cutlass::arch::OpMultiplyAdd; + static constexpr int ThreadblockK = 64; + using Layout = layout::RowMajor; + static constexpr int ElementsPerAccess = 8; + using Operator = cutlass::arch::OpMultiplyAdd; }; -// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks. -// TODO - Switch this to column major for weights since gemms should be more performant. -template -struct LayoutDetailsB= 75>::type> { - static constexpr int ThreadblockK = 64; - using Layout = layout::RowMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; +// Specializations for Turing+ when B is FP16. These are currently only used for +// MoE networks. +// TODO - Switch this to column major for weights since gemms should be more +// performant. +template +struct LayoutDetailsB< + half_t, + Arch, + typename platform::enable_if= 75>::type> { + static constexpr int ThreadblockK = 64; + using Layout = layout::RowMajor; + static constexpr int ElementsPerAccess = + 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; }; -template -struct LayoutDetailsB= 75>::type> { - static constexpr int ThreadblockK = 64; - using Layout = layout::RowMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; +template +struct LayoutDetailsB< + bfloat16_t, + Arch, + typename platform::enable_if= 75>::type> { + static constexpr int ThreadblockK = 64; + using Layout = layout::RowMajor; + static constexpr int ElementsPerAccess = + 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; }; -// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA, -// which signals that we want to dequantize after loading from smem. -template -struct LayoutDetailsB= 75>::type> { - static constexpr int ThreadblockK = 64; +// Specializations for Turing+ when B is quantized. These can use the operator +// OpMultiplyAddDequantizeInterleavedBToA, which signals that we want to +// dequantize after loading from smem. +template +struct LayoutDetailsB< + uint8_t, + Arch, + typename platform::enable_if= 75>::type> { + static constexpr int ThreadblockK = 64; -private: - static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; - static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + private: + static constexpr int ElementsPerCacheLine = + 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; -public: - using Layout = layout::ColumnMajorTileInterleave; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; + public: + using Layout = + layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = + 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; }; -template -struct LayoutDetailsB= 75>::type> { - static constexpr int ThreadblockK = 64; +template +struct LayoutDetailsB< + uint4b_t, + Arch, + typename platform::enable_if= 75>::type> { + static constexpr int ThreadblockK = 64; -private: - static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; - static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + private: + static constexpr int ElementsPerCacheLine = + 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; -public: - using Layout = layout::ColumnMajorTileInterleave; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; + public: + using Layout = + layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = + 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; }; // For int8 int8 int32 Gemm. Author(zhengzekang) -template -struct LayoutDetailsB= 75>::type> { - static constexpr int ThreadblockK = 64; +template +struct LayoutDetailsB< + int8_t, + Arch, + typename platform::enable_if= 75>::type> { + static constexpr int ThreadblockK = 64; -private: - static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; - static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + private: + static constexpr int ElementsPerCacheLine = + 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; -public: - using Layout = layout::ColumnMajorTileInterleave; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; + public: + using Layout = + layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = + 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; }; } // namespace kernel diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/default_mma_nf4_int8_interleaved.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/default_mma_nf4_int8_interleaved.h index 540c9a0931..bc8ab30bd1 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/default_mma_nf4_int8_interleaved.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/default_mma_nf4_int8_interleaved.h @@ -74,12 +74,11 @@ template < /// Gather operand A by using an index array bool GatherA = false, /// Gather operand B by using an index array - bool GatherB = false - > + bool GatherB = false> struct DefaulInt8Nf4InterleavedMma; // int8 int8 int32 Gemm specialization. Author(zhengzekang) -template< +template < /// Layout type for A matrix operand typename LayoutA, /// Access granularity of A matrix in units of elements @@ -101,24 +100,6 @@ template< /// Operation performed by GEMM typename Operator> struct DefaulInt8Nf4InterleavedMma { - -private: - - using Mma = Int8Nf4InterleavedMma; + Operator> { + private: + using Mma = Int8Nf4InterleavedMma; -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; -template< +template < /// Layout type for A matrix operand typename LayoutA, /// Access granularity of A matrix in units of elements @@ -174,26 +171,6 @@ template< /// Shared memory clear option SharedMemoryClearOption SharedMemoryClear> struct DefaulInt8Nf4InterleavedMma { - -private: - - using Mma = Int8Nf4InterleavedMma; + false, + SharedMemoryClear> { + private: + using Mma = Int8Nf4InterleavedMma; -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; - - - // int8 int4 int32 -template< +template < /// Layout type for A matrix operand typename LayoutA, /// Access granularity of A matrix in units of elements @@ -250,24 +242,6 @@ template< /// Operation performed by GEMM typename Operator> struct DefaulInt8Nf4InterleavedMma { - -private: - - using Mma = Int8Nf4InterleavedMma; + Operator> { + private: + using Mma = Int8Nf4InterleavedMma; -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; -template< +template < /// Layout type for A matrix operand typename LayoutA, /// Access granularity of A matrix in units of elements @@ -323,26 +313,6 @@ template< /// Shared memory clear option SharedMemoryClearOption SharedMemoryClear> struct DefaulInt8Nf4InterleavedMma { - -private: - - using Mma = Int8Nf4InterleavedMma; + false, + SharedMemoryClear> { + private: + using Mma = Int8Nf4InterleavedMma; -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; } // namespace threadblock diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/default_nf4_int8_interleaved_mma.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/default_nf4_int8_interleaved_mma.h index e3cb9da1bf..5c48abd7cc 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/default_nf4_int8_interleaved_mma.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/default_nf4_int8_interleaved_mma.h @@ -38,57 +38,67 @@ namespace gemm { namespace threadblock { //////////////////////////////////////////////////////////////////////////////// -// // We need to distinguish here, since we want volta support. It is too much effort -// // to write shared memory iterators that are probably needed for volta to function -// // properly. As a result, we allow converters both after the LDG (for volta) and after +// // We need to distinguish here, since we want volta support. It is too much +// effort +// // to write shared memory iterators that are probably needed for volta to +// function +// // properly. As a result, we allow converters both after the LDG (for volta) +// and after // // the LDS for Turing+. -template< +template < /// Iterator for B matrix in global memory typename IteratorB, /// Warp level Mma typename MmaOperator, /// Math operation perform by warp level operator typename MathOperator> -struct SetConvertersInt8Nf4Interleaved { -}; +struct SetConvertersInt8Nf4Interleaved {}; // // Dequantize after LDG, so set transforms accordingly -template< +template < /// Iterator for B matrix in global memory typename IteratorB, /// Mma Policy typename MmaOperator> -struct SetConvertersInt8Nf4Interleaved { - using TransformAfterLDG = - FastInterleavedAndBiasedNumericArrayConverterNf4; +struct SetConvertersInt8Nf4Interleaved { + using TransformAfterLDG = FastInterleavedAndBiasedNumericArrayConverterNf4< + typename MmaOperator::ArchMmaOperator::ElementB, + typename IteratorB::Element, + IteratorB::Fragment::kElements>; - using TransformAfterLDS = NumericArrayConverter; + using TransformAfterLDS = + NumericArrayConverter; }; // Dequantize after LDS, so set transforms accordingly -template< +template < /// Iterator for B matrix in global memory typename IteratorB, /// Mma Policy typename MmaOperator> -struct SetConvertersInt8Nf4Interleaved { - using TransformAfterLDG = - NumericArrayConverter; +struct SetConvertersInt8Nf4Interleaved< + IteratorB, + MmaOperator, + arch::OpMultiplyAddDequantizeInterleavedBToA> { + using TransformAfterLDG = + NumericArrayConverter; - using TransformAfterLDS = - FastInterleavedAndBiasedNumericArrayConverterNf4; + using TransformAfterLDS = FastInterleavedAndBiasedNumericArrayConverterNf4< + typename MmaOperator::ArchMmaOperator::ElementB, + typename TransformAfterLDG::result_type::Element, + MmaOperator::FragmentB::kElements>; }; //////////////////////////////////////////////////////////////////////////////// -template< +template < /// Element type for A matrix operand typename ElementA_, /// Layout type for A matrix operand diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/default_nf4_int8_interleaved_mma_multistage.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/default_nf4_int8_interleaved_mma_multistage.h index b2bc5d2511..aa3a04d9ac 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/default_nf4_int8_interleaved_mma_multistage.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/default_nf4_int8_interleaved_mma_multistage.h @@ -46,7 +46,7 @@ namespace threadblock { //////////////////////////////////////////////////////////////////////////////// -template< +template < /// Type for elementA typename ElementA, /// Layout type for A matrix operand @@ -78,120 +78,134 @@ template< /// SharedMemoryClearOption SharedMemoryClear> struct Int8Nf4InterleavedMma= 80)>::type> { + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementAccumulator, + layout::RowMajor, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + kStages, + Operator, + SharedMemoryClear, + typename platform::enable_if<( + ArchTag::kMinComputeCapability >= 80)>::type> { + static_assert(platform::is_same::value, + "Element A must be in8t"); - static_assert(platform::is_same::value, - "Element A must be in8t"); + static_assert( + platform::is_same::value, + "Mma multistage must dequantize after ldsm"); - static_assert(platform::is_same::value, - "Mma multistage must dequantize after ldsm"); + static_assert(platform::is_same::value || + platform::is_same::value, + "Element B must be int8 or uint4"); + static_assert(WarpShape::kK != 0, ""); + static_assert(ThreadblockShape::kK != 0, ""); - static_assert(platform::is_same::value || platform::is_same::value, - "Element B must be int8 or uint4"); - static_assert(WarpShape::kK!=0,""); - static_assert(ThreadblockShape::kK!=0,""); + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) ? - cutlass::arch::CacheOperation::Global : - cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) ? - cutlass::arch::CacheOperation::Global : - cutlass::arch::CacheOperation::Always; + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma + // multistage pieces are created + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; - // Define the MmaCore components - // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, + LayoutA, + 1, + ThreadMapA, + AccessTypeA>; - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementA, - LayoutA, - 1, - ThreadMapA, - AccessTypeA>; + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, + LayoutB, + 0, + ThreadMapB, + AccessTypeB>; + static int const kAlignmentNF4LookUpTable = 128 / sizeof_bits::value; + using AccessTypeNF4LookUpTable = + cutlass::Array; + using IteratorNF4LookUpTableThreadMap = + transform::PitchLinearStripminedThreadMap, + 16 / kAlignmentNF4LookUpTable, + kAlignmentNF4LookUpTable>; + using IteratorNF4LookUpTable = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape<1, 16>, + int32_t, + cutlass::layout::RowMajor, + 0, + IteratorNF4LookUpTableThreadMap, + AccessTypeNF4LookUpTable>; - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementB, - LayoutB, - 0, - ThreadMapB, - AccessTypeB>; - static int const kAlignmentNF4LookUpTable = 128 / sizeof_bits::value; - using AccessTypeNF4LookUpTable = cutlass::Array; - using IteratorNF4LookUpTableThreadMap = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape<4, 1>, - 16 / kAlignmentNF4LookUpTable, - kAlignmentNF4LookUpTable>; - using IteratorNF4LookUpTable = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape<1, 16>, - int32_t, - cutlass::layout::RowMajor, - 0, - IteratorNF4LookUpTableThreadMap, - AccessTypeNF4LookUpTable>; + using SmemIteratorNF4LookUpTable = IteratorNF4LookUpTable; - using SmemIteratorNF4LookUpTable = IteratorNF4LookUpTable; + using Converter = FastInterleavedAndBiasedNumericArrayConverterNf4< + ElementA, + ElementB, + MmaCore::MmaPolicy::Operator::FragmentB::kElements>; - using Converter = FastInterleavedAndBiasedNumericArrayConverterNf4; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::Int8Nf4InterleavedMmaMultistage; + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = + cutlass::gemm::threadblock::Int8Nf4InterleavedMmaMultistage< + typename MmaCore::Shape, + IteratorA, + typename MmaCore::SmemIteratorA, + MmaCore::kCacheOpA, + IteratorB, + typename MmaCore::SmemIteratorB, + MmaCore::kCacheOpB, + IteratorNF4LookUpTable, + SmemIteratorNF4LookUpTable, + ElementAccumulator, + layout::RowMajor, + typename MmaCore::MmaPolicy, + kStages, + Converter, + SharedMemoryClear>; }; -template< +template < /// Type for element A typename ElementA, /// Layout type for A matrix operand @@ -224,140 +238,156 @@ template< int RowsPerTile, /// int ColumnsInterleaved> -struct Int8Nf4InterleavedMma, - kAlignmentB, - ElementAccumulator, - layout::RowMajor, - OperatorClass, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - kStages, - Operator, - SharedMemoryClear, - typename platform::enable_if<(ArchTag::kMinComputeCapability >= 80)>::type> { +struct Int8Nf4InterleavedMma< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + layout::ColumnMajorTileInterleave, + kAlignmentB, + ElementAccumulator, + layout::RowMajor, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + kStages, + Operator, + SharedMemoryClear, + typename platform::enable_if<(ArchTag::kMinComputeCapability >= + 80)>::type> { + static_assert(platform::is_same::value, "Element A int8_t"); - static_assert(platform::is_same::value , - "Element A int8_t"); + static_assert( + platform::is_same::value, + "Mma multistage must dequantize after ldsm"); - static_assert(platform::is_same::value, - "Mma multistage must dequantize after ldsm"); + // static_assert(platform::is_same::value || + // platform::is_same::value, + // "Element B must be uint8 or uint4"); + static_assert(platform::is_same::value, + "Element B must be uint4"); - // static_assert(platform::is_same::value || platform::is_same::value, - // "Element B must be uint8 or uint4"); - static_assert(platform::is_same::value, - "Element B must be uint4"); + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) ? - cutlass::arch::CacheOperation::Global : - cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; - static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) ? - cutlass::arch::CacheOperation::Global : - cutlass::arch::CacheOperation::Always; + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma + // multistage pieces are created + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; - // Define the MmaCore components - // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, + LayoutA, + 1, + ThreadMapA, + AccessTypeA>; - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementA, - LayoutA, - 1, - ThreadMapA, - AccessTypeA>; + private: + static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); + static_assert(RowsPerTile == MmaCore::Shape::kK, ""); + // static_assert(ColumnsInterleaved==4 || cutlass::platform::is_same::value, "####"); + using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; + using OriginalWarpArrangement = + typename OriginalThreadMap::Detail::WarpThreadArrangement; + static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); -private: - static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); - static_assert(RowsPerTile == MmaCore::Shape::kK, ""); - // static_assert(ColumnsInterleaved==4 || cutlass::platform::is_same::value, "####"); - using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; - using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; - static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); + using GmemIteratorShape = + MatrixShape; + using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, + OriginalThreadMap::kThreads, + layout::PitchLinearShape< + OriginalWarpArrangement::kContiguous * ColumnsInterleaved, + OriginalWarpArrangement::kStrided / ColumnsInterleaved>, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; - using GmemIteratorShape = - MatrixShape; - using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, - OriginalThreadMap::kThreads, - layout::PitchLinearShape, - MmaCore::kAccessSizeInBits / sizeof_bits::value>; + public: + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + GmemIteratorShape, + ElementB, + layout::ColumnMajor, + 0, + GmemThreadMapB, + AccessTypeB>; -public: - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock:: - PredicatedTileAccessIterator; + static int const kAlignmentNF4LookUpTable = 128 / sizeof_bits::value; + using AccessTypeNF4LookUpTable = + cutlass::Array; + using IteratorNF4LookUpTableThreadMap = transform:: + PitchLinearStripminedThreadMap, 1, 4>; + using IteratorNF4LookUpTable = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape<1, 16>, + int32_t, + cutlass::layout::RowMajor, + 0, + IteratorNF4LookUpTableThreadMap, + AccessTypeNF4LookUpTable>; - static int const kAlignmentNF4LookUpTable = 128 / sizeof_bits::value; - using AccessTypeNF4LookUpTable = cutlass::Array; - using IteratorNF4LookUpTableThreadMap = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape<4, 1>, - 1, - 4>; - using IteratorNF4LookUpTable = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape<1, 16>, - int32_t, - cutlass::layout::RowMajor, - 0, - IteratorNF4LookUpTableThreadMap, - AccessTypeNF4LookUpTable>; + using SmemIteratorNF4LookUpTable = IteratorNF4LookUpTable; + // static_assert(MmaCore::MmaPolicy::Operator::FragmentB::kElements==64,"MmaCore::MmaPolicy::Operator::FragmentB::kElements + // == 32"); + using Converter = FastInterleavedAndBiasedNumericArrayConverterNf4< + ElementA, + ElementB, + MmaCore::MmaPolicy::Operator::FragmentB::kElements>; - using SmemIteratorNF4LookUpTable = IteratorNF4LookUpTable; - - // static_assert(MmaCore::MmaPolicy::Operator::FragmentB::kElements==64,"MmaCore::MmaPolicy::Operator::FragmentB::kElements == 32"); - using Converter = FastInterleavedAndBiasedNumericArrayConverterNf4; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::Int8Nf4InterleavedMmaMultistage; + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = + cutlass::gemm::threadblock::Int8Nf4InterleavedMmaMultistage< + typename MmaCore::Shape, + IteratorA, + typename MmaCore::SmemIteratorA, + MmaCore::kCacheOpA, + IteratorB, + typename MmaCore::SmemIteratorB, + MmaCore::kCacheOpB, + IteratorNF4LookUpTable, + SmemIteratorNF4LookUpTable, + ElementAccumulator, + layout::RowMajor, + typename MmaCore::MmaPolicy, + kStages, + Converter, + SharedMemoryClear>; }; } // namespace threadblock diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/int8_mma_base.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/int8_mma_base.h index ac47190447..585d88ec85 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/int8_mma_base.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/int8_mma_base.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. @@ -63,46 +64,48 @@ namespace gemm { namespace threadblock { // //////////////////////////////////////////////////////////////////////////////// -// // SFINAE trick so I can keep the same loop code for Volta and dispatch to the -// // correct warp level mma. On volta, all data is stored to shared memory as FP16. -// template -// CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, +// // SFINAE trick so I can keep the same loop code for Volta and dispatch to +// the +// // correct warp level mma. On volta, all data is stored to shared memory as +// FP16. template CUTLASS_DEVICE +// void run_warp_mma(WarpMma& warp_mma, // typename WarpMma::FragmentC& D, // typename WarpMma::FragmentA const& A, // typename WarpMma::FragmentB const& B, // typename WarpMma::FragmentC const& C, -// const int warp_tileB_k_offset) +// const int warp_tileB_k_offset) // { // warp_mma(D, A, B, C); // } // template -// CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, -// typename WarpMma::FragmentC& D, -// typename WarpMma::TransformedFragmentA const& A, -// typename WarpMma::TransformedFragmentB const& B, -// typename WarpMma::FragmentC const& C, -// const int warp_tileB_k_offset) +// CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, +// typename WarpMma::FragmentC& D, typename +// WarpMma::TransformedFragmentA const& A, +// typename WarpMma::TransformedFragmentB +// const& B, typename WarpMma::FragmentC const& +// C, const int warp_tileB_k_offset) // { // warp_mma(D, A, B, C, warp_tileB_k_offset); // } -// TODO(zhengzekang): Since we annotate the first implement, we currently hack to `run_ampere_warp_mma` used in A100. -template -CUTLASS_DEVICE void run_ampere_warp_mma(WarpMma& warp_mma, - typename WarpMma::FragmentC& D, - typename WarpMma::TransformedFragmentA const& A, - typename WarpMma::TransformedFragmentB const& B, - typename WarpMma::FragmentC const& C, - const int warp_tileB_k_offset) -{ - warp_mma(D, A, B, C, warp_tileB_k_offset); +// TODO(zhengzekang): Since we annotate the first implement, we currently hack +// to `run_ampere_warp_mma` used in A100. +template +CUTLASS_DEVICE void run_ampere_warp_mma( + WarpMma& warp_mma, + typename WarpMma::FragmentC& D, + typename WarpMma::TransformedFragmentA const& A, + typename WarpMma::TransformedFragmentB const& B, + typename WarpMma::FragmentC const& C, + const int warp_tileB_k_offset) { + warp_mma(D, A, B, C, warp_tileB_k_offset); } //////////////////////////////////////////////////////////////////////////////// /// Structure to compute the matrix product targeting CUDA cores and SIMT math /// instructions. -template< +template < /// Size of the Gemm problem - concept: gemm::GemmShape<> typename Shape_, /// Policy describing tuning details (concept: MmaPolicy) @@ -112,135 +115,139 @@ template< /// Used for partial specialization typename Enable = bool> class Int8InterleavedMmaBase { -public: - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; - ///< Policy describing tuning details - using Policy = Policy_; + ///< Policy describing tuning details + using Policy = Policy_; + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; + + /// Number of warp-level GEMM operations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + static_assert(Operator::IteratorB::InstructionShape::kRow >= + Operator::InstructionShape::kK, + ""); + static constexpr int kNumKIterationsPerWarpBLoad = + Operator::IteratorB::InstructionShape::kRow / + Operator::InstructionShape::kK; + static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), ""); + static constexpr int kWarpGemmIterationsForB = + kWarpGemmIterations / kNumKIterationsPerWarpBLoad; + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = + TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = + TensorRef; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: // - // Dependent types + // Type definitions // - /// Warp-level Mma - using Operator = typename Policy::Operator; + /// Shape of the A matrix operand in shared memory + using ShapeA = + MatrixShape; - /// Shape describing the overall GEMM computed from shared memory - /// by each warp. - using WarpGemm = typename Policy::Operator::Shape; + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape; - /// Shape describing the number of warps filling the CTA - using WarpCount = GemmShape; - - /// Number of warp-level GEMM operations - static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); - static_assert(Operator::IteratorB::InstructionShape::kRow>=Operator::InstructionShape::kK,""); - static constexpr int kNumKIterationsPerWarpBLoad = - Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK; - static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), ""); - static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad; - - /// Number of stages - static int const kStages = Stages; - - /// Tensor reference to the A operand - using TensorRefA = TensorRef; - - /// Tensor reference to the B operand - using TensorRefB = TensorRef; - - // - // Nested structs - // - - /// Shared storage object needed by threadblock-scoped GEMM - class SharedStorage { - public: - // - // Type definitions - // - - /// Shape of the A matrix operand in shared memory - using ShapeA = - MatrixShape; - - /// Shape of the B matrix operand in shared memory - using ShapeB = - MatrixShape; - - public: - // - // Data members - // - - /// Buffer for A operand - AlignedBuffer operand_A; - - /// Buffer for B operand - AlignedBuffer operand_B; - - public: - // - // Methods - // - - /// Returns a layout object for the A matrix - CUTLASS_DEVICE - static typename Operator::LayoutA LayoutA() - { - return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); - } - - /// Returns a layout object for the B matrix - CUTLASS_HOST_DEVICE - static typename Operator::LayoutB LayoutB() - { - return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); - } - - /// Returns a TensorRef to the A operand - CUTLASS_HOST_DEVICE - TensorRefA operand_A_ref() - { - return TensorRefA{operand_A.data(), LayoutA()}; - } - - /// Returns a TensorRef to the B operand - CUTLASS_HOST_DEVICE - TensorRefB operand_B_ref() - { - return TensorRefB{operand_B.data(), LayoutB()}; - } - }; - -protected: + public: // // Data members // - /// Iterator to load a warp-scoped tile of A operand from shared memory - typename Operator::IteratorA warp_tile_iterator_A_; + /// Buffer for A operand + AlignedBuffer operand_A; - /// Iterator to load a warp-scoped tile of B operand from shared memory - typename Operator::IteratorB warp_tile_iterator_B_; + /// Buffer for B operand + AlignedBuffer operand_B; -public: - /// Construct from tensor references + public: + // + // Methods + // + + /// Returns a layout object for the A matrix CUTLASS_DEVICE - Int8InterleavedMmaBase( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - SharedStorage& shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx): - warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), - warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) - { + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + }; + + protected: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + Int8InterleavedMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage& shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/int8_mma_multistage.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/int8_mma_multistage.h index 3021605f27..57e19ad92b 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/int8_mma_multistage.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/int8_mma_multistage.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,18 +18,18 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ - /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -72,7 +72,7 @@ namespace threadblock { /// Structure to compute the matrix product targeting CUDA cores and SIMT math /// instructions. -template< +template < /// Size of the Gemm problem - concept: gemm::GemmShape<> typename Shape_, /// Iterates over tiles of A operand in global memory @@ -107,489 +107,532 @@ template< SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, /// Used for partial specialization typename Enable = bool> -class Int8InterleavedMmaMultistage: public Int8InterleavedMmaBase { -public: - ///< Base class - using Base = Int8InterleavedMmaBase; - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - ///< Iterates over tiles of A operand in global memory - using IteratorA = IteratorA_; - ///< Iterates over tiles of B operand in global memory - using IteratorB = IteratorB_; - ///< Data type of accumulator matrix - using ElementC = ElementC_; - ///< Layout of accumulator matrix - using LayoutC = LayoutC_; - ///< Policy describing tuning details - using Policy = Policy_; +class Int8InterleavedMmaMultistage + : public Int8InterleavedMmaBase { + public: + ///< Base class + using Base = Int8InterleavedMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + using TransformBAfterLDS = TransformBAfterLDS_; - using TransformBAfterLDS = TransformBAfterLDS_; + // + // Dependent types + // - // - // Dependent types - // + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; + /// Warp-level Mma + using Operator = typename Policy::Operator; - /// Warp-level Mma - using Operator = typename Policy::Operator; + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; - /// Minimum architecture is Sm80 to support cp.async - using ArchTag = arch::Sm80; + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; + /// Internal structure exposed for introspection. + struct Detail { + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + // static_assert(Base::kWarpGemmIterations==4,"Base::kWarpGemmIterations!=4"); + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; - /// Internal structure exposed for introspection. - struct Detail { + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; - static_assert(Base::kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); - // static_assert(Base::kWarpGemmIterations==4,"Base::kWarpGemmIterations!=4"); - /// Number of cp.async instructions to load one stage of operand A - static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; + /// Number of stages + static int const kStages = Stages; - /// Number of cp.async instructions to load one stage of operand B - static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount; + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; - /// Number of stages - static int const kStages = Stages; + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + }; - /// Number of cp.async instructions to load on group of operand A - static int const kAccessesPerGroupA = - (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; - /// Number of cp.async instructions to load on group of operand B - static int const kAccessesPerGroupB = - (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - }; + static constexpr bool RequiresTileInterleave = + layout::IsColumnMajorTileInterleave< + typename LayoutDetailsForB::Layout>::value; + static_assert(!RequiresTileInterleave || + (RequiresTileInterleave && + (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); -private: - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - using ElementB = typename IteratorB::Element; - using LayoutDetailsForB = kernel::LayoutDetailsB; + private: + // + // Data members + // - static constexpr bool RequiresTileInterleave = - layout::IsColumnMajorTileInterleave::value; - static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), - "Layout K must match threadblockK"); + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; -private: - // - // Data members - // + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - -public: - /// Construct from tensor references - CUTLASS_DEVICE - Int8InterleavedMmaMultistage( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage& shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx): - Base(shared_storage, thread_idx, warp_idx, lane_idx), + public: + /// Construct from tensor references + CUTLASS_DEVICE + Int8InterleavedMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) - { - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA& iterator_A, + IteratorB& iterator_B, + int group_start_A = 0, + int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } } - CUTLASS_DEVICE - void - copy_tiles_and_advance(IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0) - { - iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); - this->smem_iterator_A_.set_iteration_index(group_start_A); + iterator_B.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); - // Async Copy for operand A + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + // if((threadIdx.x==0) && threadIdx.y==0 && threadIdx.z==0 && + // blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ + // printf("### kSrcBytes:%d, + // IteratorB::kAccessesPerVector:%d\n", kSrcBytes, + // IteratorB::kAccessesPerVector); + // } + // // static_assert(kSrcBytes==32, "kSrcBytes==32"); + // if( threadIdx.y==0 && threadIdx.z==0 && + // blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ + // printf("gmem_ptr cp source of thread %d: ",threadIdx.x); + // for(int i=0;i<16;++i){ + // printf("%d, ",iterator_B.get()[i]); + // } + // printf("\n"); + // } CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { - if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { - typename IteratorA::AccessType* dst_ptr = - reinterpret_cast(this->smem_iterator_A_.get()); + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); - int const kSrcBytes = sizeof_bits::value - * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - auto gmem_ptr = iterator_A.get(); + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< initial value of accumulator + FragmentC const& src_accum) { + // + // Prologue + // - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_A.valid()); - } - else { - cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_A.valid()); - } + TransformBAfterLDS lds_converter; - ++iterator_A; - } + // NOTE - switch to ldg.sts + // Issue this first, so cp.async.commit_group will commit this load as well. + // Note: we do not commit here and this load will commit in the same group + // as + // the first load of A. - ++this->smem_iterator_A_; - } + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; } - iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); - this->smem_iterator_B_.set_iteration_index(group_start_B); + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); - // Async Copy for operand B CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { - if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { - typename IteratorB::AccessType* dst_ptr = - reinterpret_cast(this->smem_iterator_B_.get()); + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; - int const kSrcBytes = sizeof_bits::value - * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; - // if((threadIdx.x==0) && threadIdx.y==0 && threadIdx.z==0 && - // blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ - // printf("### kSrcBytes:%d, IteratorB::kAccessesPerVector:%d\n", kSrcBytes, IteratorB::kAccessesPerVector); - // } - // // static_assert(kSrcBytes==32, "kSrcBytes==32"); - // if( threadIdx.y==0 && threadIdx.z==0 && - // blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ - // printf("gmem_ptr cp source of thread %d: ",threadIdx.x); - // for(int i=0;i<16;++i){ - // printf("%d, ",iterator_B.get()[i]); - // } - // printf("\n"); - // } - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - auto gmem_ptr = iterator_B.get(); + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - else { - cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - ++iterator_B; - } - ++this->smem_iterator_B_; - } + ++iterator_B; } + + ++this->smem_iterator_B_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); } - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()( - ///< problem size of GEMM - int gemm_k_iterations, - ///< destination accumulator tile - FragmentC& accum, - ///< iterator over A operand in global memory - IteratorA iterator_A, - ///< iterator over B operand in global memory - IteratorB iterator_B, - ///< initial value of accumulator - FragmentC const& src_accum) - { + // Perform accumulation in the 'd' output operand + accum = src_accum; - // - // Prologue - // + // + // Clear the remaining tiles of SMEM. This is a functional requirement for + // some kernels so that all accumulator elements outside the GEMM footprint + // are zero. + // - TransformBAfterLDS lds_converter; + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + /// Iterator to write threadblock-scoped tile of A operand to shared + /// memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); - // NOTE - switch to ldg.sts - // Issue this first, so cp.async.commit_group will commit this load as well. - // Note: we do not commit here and this load will commit in the same group as - // the first load of A. + typename IteratorA::AccessType zero_A; + zero_A.clear(); - // Issue several complete stages - CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { + last_smem_iterator_A.set_iteration_index(0); - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_A.get()); - iterator_A.set_iteration_index(0); - this->smem_iterator_A_.set_iteration_index(0); + *dst_ptr = zero_A; - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { - typename IteratorA::AccessType* dst_ptr = - reinterpret_cast(this->smem_iterator_A_.get()); + ++last_smem_iterator_A; + } - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - int const kSrcBytes = sizeof_bits::value - * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector - / 8; + /// Iterator to write threadblock-scoped tile of B operand to shared + /// memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; - int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_A.get(), iterator_A.valid()); + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_B.get()); - ++iterator_A; - } + *dst_ptr = zero_B; - ++this->smem_iterator_A_; - } + ++last_smem_iterator_B; + } + } - iterator_B.set_iteration_index(0); - this->smem_iterator_B_.set_iteration_index(0); + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - typename IteratorB::AccessType* dst_ptr = - reinterpret_cast(this->smem_iterator_B_.get()); + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - int const kSrcBytes = sizeof_bits::value - * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector - / 8; + Operator warp_mma; - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_B.get(), iterator_B.valid()); + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); - ++iterator_B; - } + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); - ++this->smem_iterator_B_; - } + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; - // Defines the boundary of a stage of cp.async. - cutlass::arch::cp_async_fence(); - } + // + // Mainloop + // - // Perform accumulation in the 'd' output operand - accum = src_accum; - - // - // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels - // so that all accumulator elements outside the GEMM footprint are zero. - // - - if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); - - typename IteratorA::AccessType zero_A; - zero_A.clear(); - - last_smem_iterator_A.set_iteration_index(0); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { - - typename IteratorA::AccessType* dst_ptr = - reinterpret_cast(last_smem_iterator_A.get()); - - *dst_ptr = zero_A; - - ++last_smem_iterator_A; - } - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); - typename IteratorB::AccessType zero_B; - - zero_B.clear(); - last_smem_iterator_B.set_iteration_index(0); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - - typename IteratorB::AccessType* dst_ptr = - reinterpret_cast(last_smem_iterator_B.get()); - - *dst_ptr = zero_B; - - ++last_smem_iterator_B; - } - } - - // Waits until kStages-2 stages have committed. - cutlass::arch::cp_async_wait(); - __syncthreads(); - - // Pair of fragments used to overlap shared memory loads and math - // instructions - WarpFragmentA warp_frag_A[2]; - WarpFragmentB warp_frag_B[2]; - - Operator warp_mma; - - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - - this->warp_tile_iterator_A_.load(warp_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_frag_B[0]); + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; + // static_assert(Base::kNumKIterationsPerWarpBLoad==1,"Base::kNumKIterationsPerWarpBLoad!=1"); + // static_assert(Base::kWarpGemmIterationsForB==4,"Base::kWarpGemmIterationsForB!=4"); + const int warp_tileB_k_compute_offset = + warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + const int warp_tileB_k_load_offset = + warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + if (warp_tileB_k_compute_offset == + Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load( + warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + // TODO(wangbojun) lds_converter can be remove for int8 B input + typename TransformBAfterLDS::result_type converted_frag_B = + lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); + // TODO(zhengzekang) + // run_warp_mma( + run_ampere_warp_mma(warp_mma, + accum, + warp_frag_A[warp_mma_k % 2], + converted_frag_B, + accum, + warp_tileB_k_compute_offset); + // if(threadIdx.x==0 && threadIdx.y==0 && threadIdx.z==0 && + // blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ + // printf("### run_warp_mma: " + // "%d \n", + // reinterpret_cast(accum)); + // } + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; - int smem_write_stage_idx = Base::kStages - 1; - int smem_read_stage_idx = 0; + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; - // - // Mainloop - // - - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > (-Base::kStages + 1);) { - // - // Loop over GEMM K dimension - // - - // Computes a warp-level GEMM on data held in shared memory - // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { - - // Load warp-level tiles from shared memory, wrapping to k offset if - // this is the last group as the case may be. - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_A_; - // static_assert(Base::kNumKIterationsPerWarpBLoad==1,"Base::kNumKIterationsPerWarpBLoad!=1"); - // static_assert(Base::kWarpGemmIterationsForB==4,"Base::kWarpGemmIterationsForB!=4"); - const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; - const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; - if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) { - this->warp_tile_iterator_B_.set_kgroup_index((warp_tileB_k_load_offset + 1) - % Base::kWarpGemmIterationsForB); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); - ++this->warp_tile_iterator_B_; - } - // TODO(wangbojun) lds_converter can be remove for int8 B input - typename TransformBAfterLDS::result_type converted_frag_B = - lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); - - // TODO(zhengzekang) - // run_warp_mma( - run_ampere_warp_mma( - warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset); - // if(threadIdx.x==0 && threadIdx.y==0 && threadIdx.z==0 && - // blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ - // printf("### run_warp_mma: " - // "%d \n", - // reinterpret_cast(accum)); - // } - // Issue global->shared copies for the this stage - if (warp_mma_k < Base::kWarpGemmIterations - 1) { - int group_start_iteration_A, group_start_iteration_B; - - group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; - group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); - } - - if (warp_mma_k + 2 == Base::kWarpGemmIterations) { - int group_start_iteration_A, group_start_iteration_B; - group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; - group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); - - // Inserts a memory fence between stages of cp.async instructions. - cutlass::arch::cp_async_fence(); - - // Waits until kStages-2 stages have committed. - arch::cp_async_wait(); - __syncthreads(); - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Add negative offsets to return iterators to the 'start' of the - // circular buffer in shared memory - if (smem_write_stage_idx == (Base::kStages - 1)) { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - smem_write_stage_idx = 0; - } - else { - ++smem_write_stage_idx; - } - - if (smem_read_stage_idx == (Base::kStages - 1)) { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); - smem_read_stage_idx = 0; - } - else { - ++smem_read_stage_idx; - } - - --gemm_k_iterations; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - } - } + copy_tiles_and_advance(iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); } - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop - cutlass::arch::cp_async_fence(); - cutlass::arch::cp_async_wait<0>(); - __syncthreads(); + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterationsForB, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); } + } } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM + // mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/int8_mma_pipelined.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/int8_mma_pipelined.h index 2b7312f365..f7c9cac33f 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/int8_mma_pipelined.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/int8_mma_pipelined.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ @@ -74,18 +75,21 @@ namespace threadblock { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -template< +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < /// Size of the Gemm problem - concept: gemm::GemmShape<> typename Shape_, /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) typename IteratorA_, /// Iterates over tiles of A operand in shared memory /// (concept: WriteableTileIterator | RandomAccessTileIterator) typename SmemIteratorA_, /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) typename IteratorB_, /// Iterates over tiles of B operand in shared memory /// (concept: WriteableTileIterator | RandomAccessTileIterator) @@ -102,250 +106,275 @@ template< typename TransformBAfterLDS_, /// Used for partial specialization typename Enable = bool> -class Int8InterleavedMmaPipelined: public Int8InterleavedMmaBase { -public: - ///< Base class - using Base = Int8InterleavedMmaBase; - using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory - using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory - using ElementC = ElementC_; ///< Data type of accumulator matrix - using LayoutC = LayoutC_; ///< Layout of accumulator matrix - using Policy = Policy_; ///< Policy describing tuning details +class Int8InterleavedMmaPipelined + : public Int8InterleavedMmaBase { + public: + ///< Base class + using Base = Int8InterleavedMmaBase; + using Shape = + Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = + IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = + IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; - using TransformBAfterLDG = TransformBAfterLDG_; - using TransformBAfterLDS = TransformBAfterLDS_; + using TransformBAfterLDG = TransformBAfterLDG_; + using TransformBAfterLDS = TransformBAfterLDS_; - // - // Dependent types - // + // + // Dependent types + // - /// Fragment of operand A loaded from global memory - using FragmentA = typename IteratorA::Fragment; + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; - /// Fragment of operand B loaded from global memory - using FragmentB = typename IteratorB::Fragment; + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; - /// Warp-level Mma - using Operator = typename Policy::Operator; + /// Warp-level Mma + using Operator = typename Policy::Operator; - /// Obtain the arch tag from the warp-level operator - using ArchTag = typename Policy::Operator::ArchTag; + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; - // staticaly assert kStages for Int8InterleavedMmaPipelined is two (Double-buffered pipeline) - static_assert((Base::kStages == 2), "Int8InterleavedMmaPipelined requires kStages set to value 2"); + // staticaly assert kStages for Int8InterleavedMmaPipelined is two + // (Double-buffered pipeline) + static_assert((Base::kStages == 2), + "Int8InterleavedMmaPipelined requires kStages set to value 2"); -private: - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - using ElementB = typename IteratorB::Element; - using LayoutDetailsForB = kernel::LayoutDetailsB; + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; - static constexpr bool RequiresTileInterleave = - layout::IsColumnMajorTileInterleave::value; - static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), - "Layout K must match threadblockK"); + static constexpr bool RequiresTileInterleave = + layout::IsColumnMajorTileInterleave< + typename LayoutDetailsForB::Layout>::value; + static_assert(!RequiresTileInterleave || + (RequiresTileInterleave && + (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); -protected: - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; + protected: + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; -public: - /// Construct from tensor references - CUTLASS_DEVICE - Int8InterleavedMmaPipelined(typename Base::SharedStorage& - shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM - int thread_idx, ///< ID within the threadblock - int warp_idx, ///< ID of warp - int lane_idx ///< ID of each thread within a warp - ): - Base(shared_storage, thread_idx, warp_idx, lane_idx), + public: + /// Construct from tensor references + CUTLASS_DEVICE + Int8InterleavedMmaPipelined( + typename Base::SharedStorage& + shared_storage, ///< Shared storage needed for internal use by + ///< threadblock-scoped GEMM + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) - { + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); - } + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + FragmentC const& src_accum) { ///< source accumulator tile - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop - FragmentC& accum, ///< destination accumulator tile - IteratorA iterator_A, ///< iterator over A operand in global memory - IteratorB iterator_B, ///< iterator over B operand in global memory - FragmentC const& src_accum) - { ///< source accumulator tile + // + // Prologue + // + TransformBAfterLDG ldg_converter; + TransformBAfterLDS lds_converter; - // - // Prologue - // - TransformBAfterLDG ldg_converter; - TransformBAfterLDS lds_converter; + using TransformA = NumericArrayConverter; - using TransformA = - NumericArrayConverter; + // These transforms are mainly to handle when we have bfloat activations and + // weights in GMEM and want to issue HMMA on architectures older than + // Ampere. We will convert to FP16 before STS. + TransformA transformA; - // These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want - // to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS. - TransformA transformA; + // Perform accumulation in the 'd' output operand + accum = src_accum; - // Perform accumulation in the 'd' output operand - accum = src_accum; + FragmentA tb_frag_A; + FragmentB tb_frag_B; - FragmentA tb_frag_A; - FragmentB tb_frag_B; + tb_frag_A.clear(); + tb_frag_B.clear(); + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); - tb_frag_A.clear(); - tb_frag_B.clear(); + ++iterator_A; + ++iterator_B; - // The last kblock is loaded in the prolog - iterator_A.load(tb_frag_A); - iterator_B.load(tb_frag_B); + this->smem_iterator_A_.store(transformA(tb_frag_A)); + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); - ++iterator_A; - ++iterator_B; + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; - this->smem_iterator_A_.store(transformA(tb_frag_A)); - this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + __syncthreads(); - ++this->smem_iterator_A_; - ++this->smem_iterator_B_; + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; - __syncthreads(); + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); - // Pair of fragments used to overlap shared memory loads and math instructions - WarpFragmentA warp_frag_A[2]; - WarpFragmentB warp_frag_B[2]; + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; - this->warp_tile_iterator_A_.load(warp_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_frag_B[0]); + Operator warp_mma; - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; + int smem_write_stage_idx = 1; - Operator warp_mma; + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations <= 1); + iterator_B.clear_mask(gemm_k_iterations <= 1); - int smem_write_stage_idx = 1; + // Issue loads during the first warp-level matrix multiply-add *AFTER* + // issuing shared memory loads (which have the tighest latency requirement). - // Avoid reading out of bounds - iterator_A.clear_mask(gemm_k_iterations <= 1); - iterator_B.clear_mask(gemm_k_iterations <= 1); + // + // Mainloop + // - // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing - // shared memory loads (which have the tighest latency requirement). + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // - // - // Mainloop - // + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. - // Note: The main loop does not support Base::kWarpGemmIterations == 2. - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > 0; --gemm_k_iterations) { - // - // Loop over GEMM K dimension - // + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + // Write fragments to shared memory + this->smem_iterator_A_.store(transformA(tb_frag_A)); - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); - // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group - // as the case may be. + __syncthreads(); - if (warp_mma_k == Base::kWarpGemmIterations - 1) { + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; - // Write fragments to shared memory - this->smem_iterator_A_.store(transformA(tb_frag_A)); + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterationsForB, + 0}); + } - this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); - - __syncthreads(); - - ++this->smem_iterator_A_; - ++this->smem_iterator_B_; - - // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory - if (smem_write_stage_idx == 1) { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - } - else { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); - } - - smem_write_stage_idx ^= 1; - } - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_A_; - - const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; - const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; - // We are just about to finish computing on a fragment of B, so initiate the load for the next fragment. - if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) { - this->warp_tile_iterator_B_.set_kgroup_index((warp_tileB_k_load_offset + 1) - % Base::kWarpGemmIterationsForB); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); - ++this->warp_tile_iterator_B_; - } - - if (warp_mma_k == 0) { - - iterator_A.load(tb_frag_A); - iterator_B.load(tb_frag_B); - - ++iterator_A; - ++iterator_B; - - // Avoid reading out of bounds if this was the last loop iteration - iterator_A.clear_mask(gemm_k_iterations <= 2); - iterator_B.clear_mask(gemm_k_iterations <= 2); - } - - typename TransformBAfterLDS::result_type converted_frag_B = - lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); - // run_warp_mma( - run_ampere_warp_mma( - warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset); - } + smem_write_stage_idx ^= 1; } + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + const int warp_tileB_k_compute_offset = + warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + const int warp_tileB_k_load_offset = + warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + // We are just about to finish computing on a fragment of B, so initiate + // the load for the next fragment. + if (warp_tileB_k_compute_offset == + Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load( + warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + if (warp_mma_k == 0) { + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations <= 2); + iterator_B.clear_mask(gemm_k_iterations <= 2); + } + + typename TransformBAfterLDS::result_type converted_frag_B = + lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + // run_warp_mma( + run_ampere_warp_mma(warp_mma, + accum, + warp_frag_A[warp_mma_k % 2], + converted_frag_B, + accum, + warp_tileB_k_compute_offset); + } } + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/nf4_int8_mma_base.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/nf4_int8_mma_base.h index 413e53a9a3..5048bc52e0 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/nf4_int8_mma_base.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/nf4_int8_mma_base.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. @@ -65,7 +66,7 @@ namespace threadblock { //////////////////////////////////////////////////////////////////////////////// /// Structure to compute the matrix product targeting CUDA cores and SIMT math /// instructions. -template< +template < /// Size of the Gemm problem - concept: gemm::GemmShape<> typename Shape_, /// Policy describing tuning details (concept: MmaPolicy) @@ -75,138 +76,142 @@ template< /// Used for partial specialization typename Enable = bool> class Int8Nf4InterleavedMmaBase { -public: - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; - ///< Policy describing tuning details - using Policy = Policy_; + ///< Policy describing tuning details + using Policy = Policy_; + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; + + /// Number of warp-level GEMM operations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + static_assert(Operator::IteratorB::InstructionShape::kRow >= + Operator::InstructionShape::kK, + ""); + static constexpr int kNumKIterationsPerWarpBLoad = + Operator::IteratorB::InstructionShape::kRow / + Operator::InstructionShape::kK; + static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), ""); + static constexpr int kWarpGemmIterationsForB = + kWarpGemmIterations / kNumKIterationsPerWarpBLoad; + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = + TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = + TensorRef; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: // - // Dependent types + // Type definitions // - /// Warp-level Mma - using Operator = typename Policy::Operator; + /// Shape of the A matrix operand in shared memory + using ShapeA = + MatrixShape; - /// Shape describing the overall GEMM computed from shared memory - /// by each warp. - using WarpGemm = typename Policy::Operator::Shape; + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape; - /// Shape describing the number of warps filling the CTA - using WarpCount = GemmShape; - - /// Number of warp-level GEMM operations - static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); - static_assert(Operator::IteratorB::InstructionShape::kRow>=Operator::InstructionShape::kK,""); - static constexpr int kNumKIterationsPerWarpBLoad = - Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK; - static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), ""); - static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad; - - /// Number of stages - static int const kStages = Stages; - - /// Tensor reference to the A operand - using TensorRefA = TensorRef; - - /// Tensor reference to the B operand - using TensorRefB = TensorRef; - - // - // Nested structs - // - - /// Shared storage object needed by threadblock-scoped GEMM - class SharedStorage { - public: - // - // Type definitions - // - - /// Shape of the A matrix operand in shared memory - using ShapeA = - MatrixShape; - - /// Shape of the B matrix operand in shared memory - using ShapeB = - MatrixShape; - - public: - // - // Data members - // - - /// Buffer for A operand - AlignedBuffer operand_A; - - /// Buffer for B operand - AlignedBuffer operand_B; - - /// Buffer to hold scales for threadblock - AlignedBuffer operand_nf4_look_up_table; - - public: - // - // Methods - // - - /// Returns a layout object for the A matrix - CUTLASS_DEVICE - static typename Operator::LayoutA LayoutA() - { - return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); - } - - /// Returns a layout object for the B matrix - CUTLASS_HOST_DEVICE - static typename Operator::LayoutB LayoutB() - { - return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); - } - - /// Returns a TensorRef to the A operand - CUTLASS_HOST_DEVICE - TensorRefA operand_A_ref() - { - return TensorRefA{operand_A.data(), LayoutA()}; - } - - /// Returns a TensorRef to the B operand - CUTLASS_HOST_DEVICE - TensorRefB operand_B_ref() - { - return TensorRefB{operand_B.data(), LayoutB()}; - } - }; - -protected: + public: // // Data members // - /// Iterator to load a warp-scoped tile of A operand from shared memory - typename Operator::IteratorA warp_tile_iterator_A_; + /// Buffer for A operand + AlignedBuffer operand_A; - /// Iterator to load a warp-scoped tile of B operand from shared memory - typename Operator::IteratorB warp_tile_iterator_B_; + /// Buffer for B operand + AlignedBuffer operand_B; -public: - /// Construct from tensor references + /// Buffer to hold scales for threadblock + AlignedBuffer operand_nf4_look_up_table; + + public: + // + // Methods + // + + /// Returns a layout object for the A matrix CUTLASS_DEVICE - Int8Nf4InterleavedMmaBase( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - SharedStorage& shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx): - warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), - warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) - { + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + }; + + protected: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + Int8Nf4InterleavedMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage& shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/nf4_int8_mma_multistage.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/nf4_int8_mma_multistage.h index d124b09cb9..c28e0ed844 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/nf4_int8_mma_multistage.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/nf4_int8_mma_multistage.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,18 +18,18 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ - /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -66,8 +66,9 @@ limitations under the License. */ ///////////////////////////////////////////////////////////////////////////////////////////////// template -[[gnu::warning("your type here")]] -bool print_type() { return false; } +[[gnu::warning("your type here")]] bool print_type() { + return false; +} namespace cutlass { namespace gemm { @@ -77,7 +78,7 @@ namespace threadblock { /// Structure to compute the matrix product targeting CUDA cores and SIMT math /// instructions. -template< +template < /// Size of the Gemm problem - concept: gemm::GemmShape<> typename Shape_, /// Iterates over tiles of A operand in global memory @@ -116,467 +117,601 @@ template< SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, /// Used for partial specialization typename Enable = bool> -class Int8Nf4InterleavedMmaMultistage: public Int8Nf4InterleavedMmaBase { -public: - ///< Base class - using Base = Int8Nf4InterleavedMmaBase; - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - ///< Iterates over tiles of A operand in global memory - using IteratorA = IteratorA_; - ///< Iterates over tiles of B operand in global memory - using IteratorB = IteratorB_; - ///< Iterates over tiles of nf4 look up table in global memory - using IteratorNF4LookUpTable = IteratorNF4LookUpTable_; - using ElementNF4LookUpTable = typename IteratorNF4LookUpTable::Element; - using LayoutNF4LookUpTable = typename IteratorNF4LookUpTable::Layout; +class Int8Nf4InterleavedMmaMultistage + : public Int8Nf4InterleavedMmaBase { + public: + ///< Base class + using Base = Int8Nf4InterleavedMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Iterates over tiles of nf4 look up table in global memory + using IteratorNF4LookUpTable = IteratorNF4LookUpTable_; + using ElementNF4LookUpTable = typename IteratorNF4LookUpTable::Element; + using LayoutNF4LookUpTable = typename IteratorNF4LookUpTable::Layout; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; - ///< Data type of accumulator matrix - using ElementC = ElementC_; - ///< Layout of accumulator matrix - using LayoutC = LayoutC_; - ///< Policy describing tuning details - using Policy = Policy_; + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorNF4LookUpTable = SmemIteratorNF4LookUpTable_; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + using TransformBAfterLDS = TransformBAfterLDS_; - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - using SmemIteratorNF4LookUpTable = SmemIteratorNF4LookUpTable_; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + // + // Dependent types + // - using TransformBAfterLDS = TransformBAfterLDS_; + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; - // - // Dependent types - // + /// Warp-level Mma + using Operator = typename Policy::Operator; - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; - /// Warp-level Mma - using Operator = typename Policy::Operator; + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; - /// Minimum architecture is Sm80 to support cp.async - using ArchTag = arch::Sm80; + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; + /// Internal structure exposed for introspection. + struct Detail { + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + // static_assert(Base::kWarpGemmIterations==4,"Base::kWarpGemmIterations!=4"); + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; - /// Internal structure exposed for introspection. - struct Detail { + /// Number of stages + static int const kStages = Stages; - static_assert(Base::kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); - // static_assert(Base::kWarpGemmIterations==4,"Base::kWarpGemmIterations!=4"); - /// Number of cp.async instructions to load one stage of operand A - static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; - /// Number of cp.async instructions to load one stage of operand B - static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount; + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + }; - /// Number of stages - static int const kStages = Stages; + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; - /// Number of cp.async instructions to load on group of operand A - static int const kAccessesPerGroupA = - (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + static constexpr bool RequiresTileInterleave = + layout::IsColumnMajorTileInterleave< + typename LayoutDetailsForB::Layout>::value; + static_assert(!RequiresTileInterleave || + (RequiresTileInterleave && + (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); - /// Number of cp.async instructions to load on group of operand B - static int const kAccessesPerGroupB = - (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - }; + private: + // + // Data members + // -private: - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - using ElementB = typename IteratorB::Element; - using LayoutDetailsForB = kernel::LayoutDetailsB; + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; - static constexpr bool RequiresTileInterleave = - layout::IsColumnMajorTileInterleave::value; - static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), - "Layout K must match threadblockK"); + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; -private: - // - // Data members - // + SmemIteratorNF4LookUpTable smem_iterator_nf4_look_up_table_; - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - - SmemIteratorNF4LookUpTable smem_iterator_nf4_look_up_table_; - -public: - /// Construct from tensor references - CUTLASS_DEVICE - Int8Nf4InterleavedMmaMultistage( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage& shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx): - Base(shared_storage, thread_idx, warp_idx, lane_idx), + public: + /// Construct from tensor references + CUTLASS_DEVICE + Int8Nf4InterleavedMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), - smem_iterator_nf4_look_up_table_(LayoutNF4LookUpTable(16), - shared_storage.operand_nf4_look_up_table.data(), - {1, 16}, - thread_idx) -{ - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension + smem_iterator_nf4_look_up_table_( + LayoutNF4LookUpTable(16), + shared_storage.operand_nf4_look_up_table.data(), + {1, 16}, + thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + // if((threadIdx.x % 32) == 0){ + // printf("#### %d-%d-%d-%d-%d-%d, gmem_ptr_nf4_look_up_table:%p, + // kSrcBytesNf4:%d \n", + // blockIdx.x, blockIdx.y, blockIdx.z, + // threadIdx.x, threadIdx.y, threadIdx.z, + // gmem_ptr_nf4_look_up_table, + // kSrcBytesNf4); + // } + // cutlass::arch::cp_async_zfill( + // dst_ptr_nf4_look_up_table, gmem_ptr_nf4_look_up_table, + // iterator_nf4_look_up_table.valid()); + } - // if((threadIdx.x % 32) == 0){ - // printf("#### %d-%d-%d-%d-%d-%d, gmem_ptr_nf4_look_up_table:%p, kSrcBytesNf4:%d \n", - // blockIdx.x, blockIdx.y, blockIdx.z, - // threadIdx.x, threadIdx.y, threadIdx.z, - // gmem_ptr_nf4_look_up_table, - // kSrcBytesNf4); - // } - // cutlass::arch::cp_async_zfill( - // dst_ptr_nf4_look_up_table, gmem_ptr_nf4_look_up_table, iterator_nf4_look_up_table.valid()); + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA& iterator_A, + IteratorB& iterator_B, + int group_start_A = 0, + int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } } - CUTLASS_DEVICE - void - copy_tiles_and_advance(IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0) - { - iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); - this->smem_iterator_A_.set_iteration_index(group_start_A); + iterator_B.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; - // Async Copy for operand A CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { - if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { - typename IteratorA::AccessType* dst_ptr = - reinterpret_cast(this->smem_iterator_A_.get()); + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); - int const kSrcBytes = sizeof_bits::value - * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - auto gmem_ptr = iterator_A.get(); + // if(true && (threadIdx.x||threadIdx.y||threadIdx.z)==0){ + // int32_t* print_ptr = + // reinterpret_cast(iterator_B.get()); int32_t* + // print_ptr_smem = reinterpret_cast(dst_ptr+v); if + // (iterator_B.valid()) + // { + // printf("gmem_ptr cp source of thread %d-%d-%d;%d-%d-%d: + // %p:%x-%x-%x-%x=>%x,%x,%x,%x \n", + // blockIdx.x,blockIdx.y,blockIdx.z, + // threadIdx.x,threadIdx.y,threadIdx.z, + // iterator_B.get(), + // static_cast(print_ptr[0]), + // static_cast(print_ptr[1]), + // static_cast(print_ptr[2]), + // static_cast(print_ptr[3]), + // static_cast(print_ptr_smem[0]), + // static_cast(print_ptr_smem[1]), + // static_cast(print_ptr_smem[2]), + // static_cast(print_ptr_smem[3])); + // } + // } + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_A.valid()); - } - else { - cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_A.valid()); - } + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + IteratorNF4LookUpTable iterator_nf4_look_up_table, + ///< initial value of accumulator + FragmentC const& src_accum) { + // printf("gemm_k_iterations:%d\n", gemm_k_iterations); - ++iterator_A; - } + // + // Prologue + // - ++this->smem_iterator_A_; - } + // use share memory to get look_up_table of nf4; + + // __shared__ uint32_t shared_look_up_table[16]; + + // int lane_idx=threadIdx.x%32; + // int warp_idx=threadIdx.x/32; + // if(lane_idx<16){ + // shared_look_up_table[lane_idx]=lane_idx; + // } + + // __shared__ uint32_t shared_look_up_table[32][32]; + // if(warp_idx==0){ + // CUTLASS_PRAGMA_UNROLL + // for(int ii=0;ii<16;++ii){ + // shared_look_up_table[lane_idx][ii]=ii; + // } + // } + + /// load look up table to smem here + // __shared__ int32_t nf4_smem_look_up_table[16]; + + // int32_t* gmem_ptr_nf4_look_up_table = + // reinterpret_cast(iterator_nf4_look_up_table.get()); + // // smem look up table + // int32_t* dst_ptr_nf4_look_up_table = + // reinterpret_cast(nf4_smem_look_up_table); + + // if(lane_idx == 0){ + // int4* dst_ptr_nf4_look_up_table_int4 = + // reinterpret_cast(nf4_smem_look_up_table); + // dst_ptr_nf4_look_up_table_int4[lane_idx] = + // *(reinterpret_cast(gmem_ptr_nf4_look_up_table) + lane_idx); + // } + // __syncthreads(); + // // reg look up table + // cutlass::Array reg_look_up_table; + // CUTLASS_PRAGMA_UNROLL + // for(int ii=0;ii<4;++ii){ + // reg_look_up_table[ii]=*(reinterpret_cast(dst_ptr_nf4_look_up_table) + // + ii); + // } + + TransformBAfterLDS lds_converter; + + // NOTE - switch to ldg.sts + // Issue this first, so cp.async.commit_group will commit this load as well. + // Note: we do not commit here and this load will commit in the same group + // as + // the first load of A. + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; } - iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); - this->smem_iterator_B_.set_iteration_index(group_start_B); + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + // print_type(); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); - // Async Copy for operand B CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { - if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { - typename IteratorB::AccessType* dst_ptr = - reinterpret_cast(this->smem_iterator_B_.get()); + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; - int const kSrcBytes = sizeof_bits::value - * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - auto gmem_ptr = iterator_B.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - else { - cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - - // if(true && (threadIdx.x||threadIdx.y||threadIdx.z)==0){ - // int32_t* print_ptr = reinterpret_cast(iterator_B.get()); - // int32_t* print_ptr_smem = reinterpret_cast(dst_ptr+v); - // if (iterator_B.valid()) - // { - // printf("gmem_ptr cp source of thread %d-%d-%d;%d-%d-%d: %p:%x-%x-%x-%x=>%x,%x,%x,%x \n", - // blockIdx.x,blockIdx.y,blockIdx.z, - // threadIdx.x,threadIdx.y,threadIdx.z, - // iterator_B.get(), - // static_cast(print_ptr[0]), - // static_cast(print_ptr[1]), - // static_cast(print_ptr[2]), - // static_cast(print_ptr[3]), - // static_cast(print_ptr_smem[0]), - // static_cast(print_ptr_smem[1]), - // static_cast(print_ptr_smem[2]), - // static_cast(print_ptr_smem[3])); - // } - // } - ++iterator_B; - } - ++this->smem_iterator_B_; - } + // if(true && (threadIdx.x||threadIdx.y||threadIdx.z)==0){ + // int32_t* print_ptr = + // reinterpret_cast(iterator_B.get()); int32_t* + // print_ptr_smem = reinterpret_cast(dst_ptr+v); if + // (iterator_B.valid()) + // { + // printf("gmem_ptr cp source of thread %d-%d-%d;%d-%d-%d: + // %p:%x-%x-%x-%x=>%x,%x,%x,%x \n", + // blockIdx.x,blockIdx.y,blockIdx.z, + // threadIdx.x,threadIdx.y,threadIdx.z, + // iterator_B.get(), + // static_cast(print_ptr[0]), + // static_cast(print_ptr[1]), + // static_cast(print_ptr[2]), + // static_cast(print_ptr[3]), + // static_cast(print_ptr_smem[0]), + // static_cast(print_ptr_smem[1]), + // static_cast(print_ptr_smem[2]), + // static_cast(print_ptr_smem[3])); + // } + // } + ++iterator_B; } + + ++this->smem_iterator_B_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); } - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()( - ///< problem size of GEMM - int gemm_k_iterations, - ///< destination accumulator tile - FragmentC& accum, - ///< iterator over A operand in global memory - IteratorA iterator_A, - ///< iterator over B operand in global memory - IteratorB iterator_B, - IteratorNF4LookUpTable iterator_nf4_look_up_table, - ///< initial value of accumulator - FragmentC const& src_accum) - { + // Perform accumulation in the 'd' output operand + accum = src_accum; - // printf("gemm_k_iterations:%d\n", gemm_k_iterations); + // + // Clear the remaining tiles of SMEM. This is a functional requirement for + // some kernels so that all accumulator elements outside the GEMM footprint + // are zero. + // - // - // Prologue - // + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + /// Iterator to write threadblock-scoped tile of A operand to shared + /// memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); - // use share memory to get look_up_table of nf4; + typename IteratorA::AccessType zero_A; + zero_A.clear(); - // __shared__ uint32_t shared_look_up_table[16]; + last_smem_iterator_A.set_iteration_index(0); - // int lane_idx=threadIdx.x%32; - // int warp_idx=threadIdx.x/32; - // if(lane_idx<16){ - // shared_look_up_table[lane_idx]=lane_idx; - // } + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_A.get()); - // __shared__ uint32_t shared_look_up_table[32][32]; - // if(warp_idx==0){ - // CUTLASS_PRAGMA_UNROLL - // for(int ii=0;ii<16;++ii){ - // shared_look_up_table[lane_idx][ii]=ii; - // } - // } + *dst_ptr = zero_A; - /// load look up table to smem here - // __shared__ int32_t nf4_smem_look_up_table[16]; + ++last_smem_iterator_A; + } - // int32_t* gmem_ptr_nf4_look_up_table = reinterpret_cast(iterator_nf4_look_up_table.get()); - // // smem look up table - // int32_t* dst_ptr_nf4_look_up_table = reinterpret_cast(nf4_smem_look_up_table); + /// Iterator to write threadblock-scoped tile of B operand to shared + /// memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; - // if(lane_idx == 0){ - // int4* dst_ptr_nf4_look_up_table_int4 = reinterpret_cast(nf4_smem_look_up_table); - // dst_ptr_nf4_look_up_table_int4[lane_idx] = *(reinterpret_cast(gmem_ptr_nf4_look_up_table) + lane_idx); - // } - // __syncthreads(); - // // reg look up table - // cutlass::Array reg_look_up_table; - // CUTLASS_PRAGMA_UNROLL - // for(int ii=0;ii<4;++ii){ - // reg_look_up_table[ii]=*(reinterpret_cast(dst_ptr_nf4_look_up_table) + ii); - // } + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); - TransformBAfterLDS lds_converter; + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_B.get()); - // NOTE - switch to ldg.sts - // Issue this first, so cp.async.commit_group will commit this load as well. - // Note: we do not commit here and this load will commit in the same group as - // the first load of A. + *dst_ptr = zero_B; - // Issue several complete stages - CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { + ++last_smem_iterator_B; + } + } - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); - iterator_A.set_iteration_index(0); - this->smem_iterator_A_.set_iteration_index(0); + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + typename TransformBAfterLDS::result_type converted_frag_B_buffer[2]; + Operator warp_mma; - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { - typename IteratorA::AccessType* dst_ptr = - reinterpret_cast(this->smem_iterator_A_.get()); + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - int const kSrcBytes = sizeof_bits::value - * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector - / 8; + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + converted_frag_B_buffer[0] = lds_converter(warp_frag_B[0]); + this->warp_tile_iterator_A_.load(warp_frag_A[0]); - int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + // if((threadIdx.x||threadIdx.y||threadIdx.z)==0){ + // uint32_t* frag_b_reg_ptr = + // reinterpret_cast(&warp_frag_B[0]); printf("#### + // warp_frag_b_load [0] bid:%d-%d-%d," + // " frag_b_reg_ptr:%x-%x-%x-%x-%x-%x-%x-%x \n", + // blockIdx.x,blockIdx.y,blockIdx.z, + // frag_b_reg_ptr[0], + // frag_b_reg_ptr[1], + // frag_b_reg_ptr[2], + // frag_b_reg_ptr[3], + // frag_b_reg_ptr[4], + // frag_b_reg_ptr[5], + // frag_b_reg_ptr[6], + // frag_b_reg_ptr[7] + // ); + // } + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_A.get(), iterator_A.valid()); + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); - ++iterator_A; - } + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; - ++this->smem_iterator_A_; - } + // + // Mainloop + // - iterator_B.set_iteration_index(0); - // print_type(); - this->smem_iterator_B_.set_iteration_index(0); + __syncthreads(); + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - typename IteratorB::AccessType* dst_ptr = - reinterpret_cast(this->smem_iterator_B_.get()); + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + // static_assert(Base::kNumKIterationsPerWarpBLoad==1,"Base::kNumKIterationsPerWarpBLoad!=1"); + // static_assert(Base::kWarpGemmIterationsForB==4,"Base::kWarpGemmIterationsForB!=4"); + const int warp_tileB_k_compute_offset = + warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + const int warp_tileB_k_load_offset = + warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + if (warp_tileB_k_compute_offset == + Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load( + warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + converted_frag_B_buffer[(warp_tileB_k_load_offset + 1) % 2] = + lds_converter(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - int const kSrcBytes = sizeof_bits::value - * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector - / 8; - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_B.get(), iterator_B.valid()); - - // if(true && (threadIdx.x||threadIdx.y||threadIdx.z)==0){ - // int32_t* print_ptr = reinterpret_cast(iterator_B.get()); - // int32_t* print_ptr_smem = reinterpret_cast(dst_ptr+v); - // if (iterator_B.valid()) - // { - // printf("gmem_ptr cp source of thread %d-%d-%d;%d-%d-%d: %p:%x-%x-%x-%x=>%x,%x,%x,%x \n", - // blockIdx.x,blockIdx.y,blockIdx.z, - // threadIdx.x,threadIdx.y,threadIdx.z, - // iterator_B.get(), - // static_cast(print_ptr[0]), - // static_cast(print_ptr[1]), - // static_cast(print_ptr[2]), - // static_cast(print_ptr[3]), - // static_cast(print_ptr_smem[0]), - // static_cast(print_ptr_smem[1]), - // static_cast(print_ptr_smem[2]), - // static_cast(print_ptr_smem[3])); - // } - // } - ++iterator_B; - } - - ++this->smem_iterator_B_; - } - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Defines the boundary of a stage of cp.async. - cutlass::arch::cp_async_fence(); + ++this->warp_tile_iterator_B_; + // if((threadIdx.x||threadIdx.y||threadIdx.z)==0){ + // uint32_t* frag_b_reg_ptr = + // reinterpret_cast(&warp_frag_B[(warp_tileB_k_load_offset + // + 1) % 2]); printf("#### warp_frag_b load [%d] bid:%d-%d-%d," + // " frag_b_reg_ptr:%x-%x-%x-%x-%x-%x-%x-%x \n", + // ((warp_tileB_k_load_offset + 1) % 2), + // blockIdx.x,blockIdx.y,blockIdx.z, + // frag_b_reg_ptr[0], + // frag_b_reg_ptr[1], + // frag_b_reg_ptr[2], + // frag_b_reg_ptr[3], + // frag_b_reg_ptr[4], + // frag_b_reg_ptr[5], + // frag_b_reg_ptr[6], + // frag_b_reg_ptr[7] + // ); + // } } + // TODO(wangbojun) lds_converter can be remove for int8 B input + // int4 + // typename TransformBAfterLDS::result_type converted_frag_B = + // lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); - // Perform accumulation in the 'd' output operand - accum = src_accum; + // typename TransformBAfterLDS::result_type converted_frag_B = + // lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2], + // reinterpret_cast(nf4_smem_look_up_table)); - // - // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels - // so that all accumulator elements outside the GEMM footprint are zero. - // + // typename TransformBAfterLDS::result_type converted_frag_B = + // lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2], + // reg_look_up_table); - if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); - - typename IteratorA::AccessType zero_A; - zero_A.clear(); - - last_smem_iterator_A.set_iteration_index(0); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { - - typename IteratorA::AccessType* dst_ptr = - reinterpret_cast(last_smem_iterator_A.get()); - - *dst_ptr = zero_A; - - ++last_smem_iterator_A; - } - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); - typename IteratorB::AccessType zero_B; - - zero_B.clear(); - last_smem_iterator_B.set_iteration_index(0); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - - typename IteratorB::AccessType* dst_ptr = - reinterpret_cast(last_smem_iterator_B.get()); - - *dst_ptr = zero_B; - - ++last_smem_iterator_B; - } - } - - // Waits until kStages-2 stages have committed. - cutlass::arch::cp_async_wait(); - __syncthreads(); - - // Pair of fragments used to overlap shared memory loads and math - // instructions - WarpFragmentA warp_frag_A[2]; - WarpFragmentB warp_frag_B[2]; - typename TransformBAfterLDS::result_type converted_frag_B_buffer[2]; - Operator warp_mma; - - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - - this->warp_tile_iterator_B_.load(warp_frag_B[0]); - converted_frag_B_buffer[0] = - lds_converter(warp_frag_B[0]); - this->warp_tile_iterator_A_.load(warp_frag_A[0]); + // typename TransformBAfterLDS::result_type converted_frag_B = + // lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2], + // shared_look_up_table, warp_idx, lane_idx); // if((threadIdx.x||threadIdx.y||threadIdx.z)==0){ - // uint32_t* frag_b_reg_ptr = reinterpret_cast(&warp_frag_B[0]); - // printf("#### warp_frag_b_load [0] bid:%d-%d-%d," - // " frag_b_reg_ptr:%x-%x-%x-%x-%x-%x-%x-%x \n", + // uint32_t* frag_b_reg_ptr = + // reinterpret_cast(&warp_frag_B[(warp_tileB_k_load_offset) + // % 2]); uint32_t* converted_frag_B_reg_ptr = + // reinterpret_cast(&converted_frag_B); printf("#### + // after lds_converter bid:%d-%d-%d" + // " frag_b_reg_ptr[%d]:%x-%x-%x-%x-%x-%x-%x-%x" + // " converted_frag_b_reg_ptr:%x-%x-%x-%x-%x-%x-%x-%x \n", // blockIdx.x,blockIdx.y,blockIdx.z, + // ((warp_tileB_k_load_offset) % 2), // frag_b_reg_ptr[0], // frag_b_reg_ptr[1], // frag_b_reg_ptr[2], @@ -584,308 +719,243 @@ public: // frag_b_reg_ptr[4], // frag_b_reg_ptr[5], // frag_b_reg_ptr[6], - // frag_b_reg_ptr[7] - // ); + // frag_b_reg_ptr[7], + // converted_frag_B_reg_ptr[0], + // converted_frag_B_reg_ptr[1], + // converted_frag_B_reg_ptr[2], + // converted_frag_B_reg_ptr[3], + // converted_frag_B_reg_ptr[4], + // converted_frag_B_reg_ptr[5], + // converted_frag_B_reg_ptr[6], + // converted_frag_B_reg_ptr[7] + // ); // } - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); + // bool ::print_type< ::cutlass::Array< + // ::cutlass::integer_subbyte<(int)4, (bool)0> , (int)64, (bool)0> > + // ()") print_type(); bool ::print_type< + // ::cutlass::Array > ()") from a + // print_type(); + // cutlass::Array + // print_type(); - int smem_write_stage_idx = Base::kStages - 1; - int smem_read_stage_idx = 0; + // print_type(); + // TODO(zhengzekang) + // run_warp_mma( - // - // Mainloop - // + // if(true){ + // uint32_t none_zero = 0; + // // uint32_t* converted_frag_B_reg_ptr = + // reinterpret_cast(&converted_frag_B); uint32_t* + // converted_frag_B_reg_ptr = + // reinterpret_cast(&warp_frag_B[warp_mma_k % 2]); + // uint32_t* frag_a_reg_ptr = + // reinterpret_cast(&warp_frag_A[warp_mma_k % 2]); + // CUTLASS_PRAGMA_UNROLL + // for(int ii=0;ii (-Base::kStages + 1);) { - // - // Loop over GEMM K dimension - // + // CUTLASS_PRAGMA_UNROLL + // for(int none_zero_i = 16; none_zero_i>0;none_zero_i/=2){ + // none_zero|= __shfl_xor_sync(-1,none_zero,none_zero_i); + // } - // Computes a warp-level GEMM on data held in shared memory - // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + // if(none_zero!=0){ + // printf("## before mma ## bidtid:%d-%d-%d-%d-%d-%d, + // warp_mma_k:%d, frag_B_reg_ptr:%x-%x-%x-%x-%x-%x-%x-%x; + // frag_a_reg_ptr:%x-%x-%x-%x-%x-%x-%x-%x" + // " accu: + // %d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d + // \n", blockIdx.x,blockIdx.y,blockIdx.z, warp_mma_k, + // threadIdx.x,threadIdx.y,threadIdx.z, + // converted_frag_B_reg_ptr[0], + // converted_frag_B_reg_ptr[1], + // converted_frag_B_reg_ptr[2], + // converted_frag_B_reg_ptr[3], + // converted_frag_B_reg_ptr[4], + // converted_frag_B_reg_ptr[5], + // converted_frag_B_reg_ptr[6], + // converted_frag_B_reg_ptr[7], + // frag_a_reg_ptr[0], + // frag_a_reg_ptr[1], + // frag_a_reg_ptr[2], + // frag_a_reg_ptr[3], + // frag_a_reg_ptr[4], + // frag_a_reg_ptr[5], + // frag_a_reg_ptr[6], + // frag_a_reg_ptr[7], + // accum[0], + // accum[1], + // accum[2], + // accum[3], + // accum[4], + // accum[5], + // accum[6], + // accum[7], + // accum[8], + // accum[9], + // accum[10], + // accum[11], + // accum[12], + // accum[13], + // accum[14], + // accum[15], + // accum[16], + // accum[17], + // accum[18], + // accum[19], + // accum[20], + // accum[21], + // accum[22], + // accum[23], + // accum[24], + // accum[25], + // accum[26], + // accum[27], + // accum[28], + // accum[29], + // accum[30], + // accum[31] + // ); + // } + // } + run_ampere_warp_mma( + warp_mma, + accum, + warp_frag_A[warp_mma_k % 2], + converted_frag_B_buffer[warp_tileB_k_load_offset % 2], + accum, + warp_tileB_k_compute_offset); + // auto tmp = static_cast(warp_frag_B[warp_tileB_k_load_offset + // % 2]); if(threadIdx.x==0 && threadIdx.y==0 && threadIdx.z==0 && + // blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ + // printf("### run_warp_mma: " + // "%d \n", + // reinterpret_cast(accum)); + // } + // if(true){ + // uint32_t none_zero = 0; + // uint32_t* converted_frag_B_reg_ptr = + // reinterpret_cast(&converted_frag_B); + // // uint32_t* converted_frag_B_reg_ptr = + // reinterpret_cast(&warp_frag_B[warp_mma_k % 2]); + // uint32_t* frag_a_reg_ptr = + // reinterpret_cast(&warp_frag_A[warp_mma_k % 2]); + // CUTLASS_PRAGMA_UNROLL + // for(int ii=0;ii0;none_zero_i/=2){ + // none_zero|= __shfl_xor_sync(-1,none_zero,none_zero_i); + // } - // Load warp-level tiles from shared memory, wrapping to k offset if - // this is the last group as the case may be. - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_A_; - // static_assert(Base::kNumKIterationsPerWarpBLoad==1,"Base::kNumKIterationsPerWarpBLoad!=1"); - // static_assert(Base::kWarpGemmIterationsForB==4,"Base::kWarpGemmIterationsForB!=4"); - const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; - const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; - if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) { - this->warp_tile_iterator_B_.set_kgroup_index((warp_tileB_k_load_offset + 1) - % Base::kWarpGemmIterationsForB); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); - converted_frag_B_buffer[(warp_tileB_k_load_offset + 1) % 2] = - lds_converter(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + // // if(none_zero!=0){ + // if((blockIdx.y||blockIdx.z||threadIdx.x||threadIdx.y||threadIdx.z)==0){ - ++this->warp_tile_iterator_B_; - // if((threadIdx.x||threadIdx.y||threadIdx.z)==0){ - // uint32_t* frag_b_reg_ptr = reinterpret_cast(&warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); - // printf("#### warp_frag_b load [%d] bid:%d-%d-%d," - // " frag_b_reg_ptr:%x-%x-%x-%x-%x-%x-%x-%x \n", - // ((warp_tileB_k_load_offset + 1) % 2), - // blockIdx.x,blockIdx.y,blockIdx.z, - // frag_b_reg_ptr[0], - // frag_b_reg_ptr[1], - // frag_b_reg_ptr[2], - // frag_b_reg_ptr[3], - // frag_b_reg_ptr[4], - // frag_b_reg_ptr[5], - // frag_b_reg_ptr[6], - // frag_b_reg_ptr[7] - // ); - // } - } - // TODO(wangbojun) lds_converter can be remove for int8 B input - // int4 - // typename TransformBAfterLDS::result_type converted_frag_B = - // lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + // printf("## after mma ## bidtid:%d-%d-%d-%d-%d-%d, + // warp_mma_k:%d, gemm_k_iterations:%d, + // Base::kWarpGemmIterations:%d," + // " converted_frag_B_reg_ptr:%x; frag_a_reg_ptr:%x" + // " accu: %d \n", + // blockIdx.x,blockIdx.y,blockIdx.z, + // threadIdx.x,threadIdx.y,threadIdx.z, + // warp_mma_k, + // gemm_k_iterations, + // Base::kWarpGemmIterations, + // converted_frag_B_reg_ptr[0], + // frag_a_reg_ptr[0], + // accum[0] + // ); + // } + // } + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; - // typename TransformBAfterLDS::result_type converted_frag_B = - // lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2], reinterpret_cast(nf4_smem_look_up_table)); + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; - // typename TransformBAfterLDS::result_type converted_frag_B = - // lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2], reg_look_up_table); - - // typename TransformBAfterLDS::result_type converted_frag_B = - // lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2], shared_look_up_table, warp_idx, lane_idx); - - // if((threadIdx.x||threadIdx.y||threadIdx.z)==0){ - // uint32_t* frag_b_reg_ptr = reinterpret_cast(&warp_frag_B[(warp_tileB_k_load_offset) % 2]); - // uint32_t* converted_frag_B_reg_ptr = reinterpret_cast(&converted_frag_B); - // printf("#### after lds_converter bid:%d-%d-%d" - // " frag_b_reg_ptr[%d]:%x-%x-%x-%x-%x-%x-%x-%x" - // " converted_frag_b_reg_ptr:%x-%x-%x-%x-%x-%x-%x-%x \n", - // blockIdx.x,blockIdx.y,blockIdx.z, - // ((warp_tileB_k_load_offset) % 2), - // frag_b_reg_ptr[0], - // frag_b_reg_ptr[1], - // frag_b_reg_ptr[2], - // frag_b_reg_ptr[3], - // frag_b_reg_ptr[4], - // frag_b_reg_ptr[5], - // frag_b_reg_ptr[6], - // frag_b_reg_ptr[7], - // converted_frag_B_reg_ptr[0], - // converted_frag_B_reg_ptr[1], - // converted_frag_B_reg_ptr[2], - // converted_frag_B_reg_ptr[3], - // converted_frag_B_reg_ptr[4], - // converted_frag_B_reg_ptr[5], - // converted_frag_B_reg_ptr[6], - // converted_frag_B_reg_ptr[7] - // ); - // } - - // bool ::print_type< ::cutlass::Array< ::cutlass::integer_subbyte<(int)4, (bool)0> , (int)64, (bool)0> > ()") - // print_type(); - // bool ::print_type< ::cutlass::Array > ()") from a - // print_type(); - // cutlass::Array - // print_type(); - - // print_type(); - // TODO(zhengzekang) - // run_warp_mma( - - - // if(true){ - // uint32_t none_zero = 0; - // // uint32_t* converted_frag_B_reg_ptr = reinterpret_cast(&converted_frag_B); - // uint32_t* converted_frag_B_reg_ptr = reinterpret_cast(&warp_frag_B[warp_mma_k % 2]); - // uint32_t* frag_a_reg_ptr = reinterpret_cast(&warp_frag_A[warp_mma_k % 2]); - // CUTLASS_PRAGMA_UNROLL - // for(int ii=0;ii0;none_zero_i/=2){ - // none_zero|= __shfl_xor_sync(-1,none_zero,none_zero_i); - // } - - // if(none_zero!=0){ - // printf("## before mma ## bidtid:%d-%d-%d-%d-%d-%d, warp_mma_k:%d, frag_B_reg_ptr:%x-%x-%x-%x-%x-%x-%x-%x; frag_a_reg_ptr:%x-%x-%x-%x-%x-%x-%x-%x" - // " accu: %d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d \n", - // blockIdx.x,blockIdx.y,blockIdx.z, - // warp_mma_k, - // threadIdx.x,threadIdx.y,threadIdx.z, - // converted_frag_B_reg_ptr[0], - // converted_frag_B_reg_ptr[1], - // converted_frag_B_reg_ptr[2], - // converted_frag_B_reg_ptr[3], - // converted_frag_B_reg_ptr[4], - // converted_frag_B_reg_ptr[5], - // converted_frag_B_reg_ptr[6], - // converted_frag_B_reg_ptr[7], - // frag_a_reg_ptr[0], - // frag_a_reg_ptr[1], - // frag_a_reg_ptr[2], - // frag_a_reg_ptr[3], - // frag_a_reg_ptr[4], - // frag_a_reg_ptr[5], - // frag_a_reg_ptr[6], - // frag_a_reg_ptr[7], - // accum[0], - // accum[1], - // accum[2], - // accum[3], - // accum[4], - // accum[5], - // accum[6], - // accum[7], - // accum[8], - // accum[9], - // accum[10], - // accum[11], - // accum[12], - // accum[13], - // accum[14], - // accum[15], - // accum[16], - // accum[17], - // accum[18], - // accum[19], - // accum[20], - // accum[21], - // accum[22], - // accum[23], - // accum[24], - // accum[25], - // accum[26], - // accum[27], - // accum[28], - // accum[29], - // accum[30], - // accum[31] - // ); - // } - // } - run_ampere_warp_mma( - warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_buffer[warp_tileB_k_load_offset % 2], accum, warp_tileB_k_compute_offset); - // auto tmp = static_cast(warp_frag_B[warp_tileB_k_load_offset % 2]); - // if(threadIdx.x==0 && threadIdx.y==0 && threadIdx.z==0 && - // blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ - // printf("### run_warp_mma: " - // "%d \n", - // reinterpret_cast(accum)); - // } - // if(true){ - // uint32_t none_zero = 0; - // uint32_t* converted_frag_B_reg_ptr = reinterpret_cast(&converted_frag_B); - // // uint32_t* converted_frag_B_reg_ptr = reinterpret_cast(&warp_frag_B[warp_mma_k % 2]); - // uint32_t* frag_a_reg_ptr = reinterpret_cast(&warp_frag_A[warp_mma_k % 2]); - // CUTLASS_PRAGMA_UNROLL - // for(int ii=0;ii0;none_zero_i/=2){ - // none_zero|= __shfl_xor_sync(-1,none_zero,none_zero_i); - // } - - // // if(none_zero!=0){ - // if((blockIdx.y||blockIdx.z||threadIdx.x||threadIdx.y||threadIdx.z)==0){ - - // printf("## after mma ## bidtid:%d-%d-%d-%d-%d-%d, warp_mma_k:%d, gemm_k_iterations:%d, Base::kWarpGemmIterations:%d," - // " converted_frag_B_reg_ptr:%x; frag_a_reg_ptr:%x" - // " accu: %d \n", - // blockIdx.x,blockIdx.y,blockIdx.z, - // threadIdx.x,threadIdx.y,threadIdx.z, - // warp_mma_k, - // gemm_k_iterations, - // Base::kWarpGemmIterations, - // converted_frag_B_reg_ptr[0], - // frag_a_reg_ptr[0], - // accum[0] - // ); - // } - // } - // Issue global->shared copies for the this stage - if (warp_mma_k < Base::kWarpGemmIterations - 1) { - int group_start_iteration_A, group_start_iteration_B; - - group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; - group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); - } - - if (warp_mma_k + 2 == Base::kWarpGemmIterations) { - int group_start_iteration_A, group_start_iteration_B; - group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; - group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); - - // Inserts a memory fence between stages of cp.async instructions. - cutlass::arch::cp_async_fence(); - - // Waits until kStages-2 stages have committed. - arch::cp_async_wait(); - __syncthreads(); - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Add negative offsets to return iterators to the 'start' of the - // circular buffer in shared memory - if (smem_write_stage_idx == (Base::kStages - 1)) { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - smem_write_stage_idx = 0; - } - else { - ++smem_write_stage_idx; - } - - if (smem_read_stage_idx == (Base::kStages - 1)) { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); - smem_read_stage_idx = 0; - } - else { - ++smem_read_stage_idx; - } - - --gemm_k_iterations; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - } - } + copy_tiles_and_advance(iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); } - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop - cutlass::arch::cp_async_fence(); - cutlass::arch::cp_async_wait<0>(); - __syncthreads(); + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterationsForB, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); } + } } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM + // mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/warp/default_mma_tensor_op.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/warp/default_mma_tensor_op.h index 97fca6da67..1e11b8e10f 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/warp/default_mma_tensor_op.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/warp/default_mma_tensor_op.h @@ -122,10 +122,9 @@ struct DefaultMmaTensorOp; }; - - -// Specialization for int8 int8 int32 gemm, we use instruction shape<16, 8, 32> for better performance. -template< +// Specialization for int8 int8 int32 gemm, we use instruction shape<16, 8, 32> +// for better performance. +template < /// Shape of one matrix production operation (concept: GemmShape) typename WarpShape_, /// Shape of one matrix production operation (concept: GemmShape) @@ -152,48 +151,51 @@ struct DefaultMmaTensorOp { -private: - // Shape for computing the FP16s - using ComputeInstructionShape = InstructionShape_; + private: + // Shape for computing the FP16s + using ComputeInstructionShape = InstructionShape_; - // Chosen so we get K=32. - static constexpr int LoadInstructionK = 32 * sizeof_bits::value / sizeof_bits::value; + // Chosen so we get K=32. + static constexpr int LoadInstructionK = + 32 * sizeof_bits::value / sizeof_bits::value; - // Shape for loading the narrow data type from shared memory - using LoadInstructionShape = GemmShape; + // Shape for loading the narrow data type from shared memory + using LoadInstructionShape = + GemmShape; -public: - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy, - arch::OpMultiplyAdd>, - cutlass::MatrixShape<1, 1>>; + public: + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma< + InstructionShape_, + 32, + int8_t, + cutlass::layout::RowMajor, + int8_t, + cutlass::layout::ColumnMajor, + int32_t, + cutlass::layout::RowMajor, + // cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<32,128>, + arch::OpMultiplyAdd>, + cutlass::MatrixShape<1, 1>>; - // Define the warp-level tensor op - using Type = cutlass::gemm::warp::MmaTensorOpComputeBWithF16; + // Define the warp-level tensor op + using Type = + cutlass::gemm::warp::MmaTensorOpComputeBWithF16; }; - - - -// Specialization for int8 int8 int32 gemm, we use instruction shape<16, 8, 32> for better performance. -template< +// Specialization for int8 int8 int32 gemm, we use instruction shape<16, 8, 32> +// for better performance. +template < /// Shape of one matrix production operation (concept: GemmShape) typename WarpShape_, /// Shape of one matrix production operation (concept: GemmShape) @@ -220,41 +222,46 @@ struct DefaultMmaTensorOp { -private: - // Shape for computing the FP16s - using ComputeInstructionShape = InstructionShape_; + private: + // Shape for computing the FP16s + using ComputeInstructionShape = InstructionShape_; - // Chosen so we get K=64. - static constexpr int LoadInstructionK = 16 * sizeof_bits::value / sizeof_bits::value; + // Chosen so we get K=64. + static constexpr int LoadInstructionK = + 16 * sizeof_bits::value / sizeof_bits::value; - // Shape for loading the narrow data type from shared memory - using LoadInstructionShape = GemmShape; + // Shape for loading the narrow data type from shared memory + using LoadInstructionShape = + GemmShape; -public: - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy, - arch::OpMultiplyAdd>, - cutlass::MatrixShape<1, 1>>; + public: + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma< + InstructionShape_, + 32, + int8_t, + cutlass::layout::RowMajor, + int8_t, + cutlass::layout::ColumnMajor, + int32_t, + cutlass::layout::RowMajor, + // cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<32,128>, + arch::OpMultiplyAdd>, + cutlass::MatrixShape<1, 1>>; - // Define the warp-level tensor op - using Type = cutlass::gemm::warp::MmaTensorOpComputeBWithF16; + // Define the warp-level tensor op + using Type = + cutlass::gemm::warp::MmaTensorOpComputeBWithF16; }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h index 7279541fe9..3476b3816c 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,19 +18,20 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file - \brief Templates implementing warp-level matrix multiply-accumulate operations targeting - Tensor Cores. + \brief Templates implementing warp-level matrix multiply-accumulate + operations targeting Tensor Cores. */ #pragma once @@ -63,8 +64,9 @@ namespace gemm { namespace warp { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -template< +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < /// Size of the Gemm problem - concept: gemm::GemmShape<> typename Shape_, /// Data type of A elements @@ -91,222 +93,229 @@ template< /// Used for partial specialization typename Enable = bool> class MmaTensorOpComputeBWithF16 { -public: - /// Shape of warp-level matrix operation (concept: GemmShape) - using Shape = Shape_; + public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; - /// Data type of multiplicand A - using ElementA = ElementA_; + /// Data type of multiplicand A + using ElementA = ElementA_; - /// Layout of multiplicand A - using LayoutA = LayoutA_; + /// Layout of multiplicand A + using LayoutA = LayoutA_; - /// Data type of multiplicand B - using ElementB = ElementB_; + /// Data type of multiplicand B + using ElementB = ElementB_; - /// Layout of multiplicand B - using LayoutB = LayoutB_; + /// Layout of multiplicand B + using LayoutB = LayoutB_; - /// Data type of accumulator matrix C - using ElementC = ElementC_; + /// Data type of accumulator matrix C + using ElementC = ElementC_; - /// Layout of accumulator matrix C - using LayoutC = LayoutC_; + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; - /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) - using Policy = Policy_; + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; - /// Underlying matrix multiply operator (concept: arch::Mma) - using ArchMmaOperator = typename Policy::Operator; + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; - /// Indicates math operator - using MathOperator = typename ArchMmaOperator::Operator; + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; - /// Architecture tag from underlying instruction - using ArchTag = typename ArchMmaOperator::ArchTag; - static_assert((platform::is_same::value - && platform::is_same::value) - || (platform::is_same::value - && platform::is_same::value - && ArchTag::kMinComputeCapability >= 80), - "MmaTensorOpCvtBToA only supports underlying HMMA"); + /// Architecture tag from underlying instruction + using ArchTag = typename ArchMmaOperator::ArchTag; + static_assert( + (platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value && + ArchTag::kMinComputeCapability >= 80), + "MmaTensorOpCvtBToA only supports underlying HMMA"); - static_assert(platform::is_same::value - || (platform::is_same::value && ArchTag::kMinComputeCapability >= 80), - "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+"); + static_assert(platform::is_same::value || + (platform::is_same::value && + ArchTag::kMinComputeCapability >= 80), + "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+"); - // Indicates class of matrix operator - using OperatorClass = arch::OpClassTensorOp; + // Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; - /// Shape of underlying instruction - using InstructionShape = typename ArchMmaOperator::Shape; + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; - /// Instruction shape to override shared memory iterators with - using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; + /// Instruction shape to override shared memory iterators with + using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; - static_assert(SharedMemoryInstructionShape::kM == InstructionShape::kM, - "M dimension of compute instruction must match load"); - static_assert(SharedMemoryInstructionShape::kN == InstructionShape::kN, - "N dimension of compute instruction must match load"); + static_assert(SharedMemoryInstructionShape::kM == InstructionShape::kM, + "M dimension of compute instruction must match load"); + static_assert(SharedMemoryInstructionShape::kN == InstructionShape::kN, + "N dimension of compute instruction must match load"); - static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK; + static constexpr int kExpansionFactor = + SharedMemoryInstructionShape::kK / InstructionShape::kK; - static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); + static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); - /// Complex transform on A operand - static ComplexTransform const kTransformA = ComplexTransform::kNone; + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; - /// Complex transform on B operand - static ComplexTransform const kTransformB = ComplexTransform::kNone; + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; - /// Number of threads participating in warp-level matrix product - static int const kThreadCount = 32; + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; - /// Number of partitions along K dimension - static int const kPartitionsK = PartitionsK_; + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; -public: - /// Iterates over the A operand in memory - using IteratorA = MmaTensorOpMultiplicandTileIterator, - Operand::kA, - ElementA, - LayoutA, - MatrixShape, - Policy::OpDelta::kRow, - kThreadCount, - kPartitionsK>; + public: + /// Iterates over the A operand in memory + using IteratorA = MmaTensorOpMultiplicandTileIterator< + MatrixShape, + Operand::kA, + ElementA, + LayoutA, + MatrixShape, + Policy::OpDelta::kRow, + kThreadCount, + kPartitionsK>; - /// Storage for A tile - using FragmentA = typename IteratorA::Fragment; + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; - /// Storage for transformed A tile - using TransformedFragmentA = Array; + /// Storage for transformed A tile + using TransformedFragmentA = + Array; - /// Iterates over the B operand in memory - using IteratorB = - MmaTensorOpMultiplicandTileIterator, - Operand::kB, - ElementB, - LayoutB, - MatrixShape, - Policy::OpDelta::kRow, - kThreadCount, - kPartitionsK>; + /// Iterates over the B operand in memory + using IteratorB = MmaTensorOpMultiplicandTileIterator< + MatrixShape, + Operand::kB, + ElementB, + LayoutB, + MatrixShape, + Policy::OpDelta::kRow, + kThreadCount, + kPartitionsK>; - /// Storage for B tile - using FragmentB = typename IteratorB::Fragment; + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; - /// Storage for transformed B tile - using TransformedFragmentB = Array; + /// Storage for transformed B tile + using TransformedFragmentB = + Array; - /// Iterates over the C operand in memory - using IteratorC = MmaTensorOpAccumulatorTileIterator, - ElementC, - LayoutC, - typename ArchMmaOperator::Shape, - typename Policy::OpDelta>; + /// Iterates over the C operand in memory + using IteratorC = + MmaTensorOpAccumulatorTileIterator, + ElementC, + LayoutC, + typename ArchMmaOperator::Shape, + typename Policy::OpDelta>; - /// Storage for C tile - using FragmentC = typename IteratorC::Fragment; + /// Storage for C tile + using FragmentC = typename IteratorC::Fragment; - /// Number of mma operations performed - using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, - (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>; + /// Number of mma operations performed + using MmaIterations = + MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / + ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / + ArchMmaOperator::Shape::kN>; -public: - /// Underlying matrix multiply operator (concept: arch::Mma) - ArchMmaOperator mma; + public: + /// Underlying matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; -public: - // - // Methods - // + public: + // + // Methods + // - /// Ctor - CUTLASS_DEVICE - MmaTensorOpComputeBWithF16() {} + /// Ctor + CUTLASS_DEVICE + MmaTensorOpComputeBWithF16() {} - /// Performs a warp-level matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()(FragmentC& D, - TransformedFragmentA const& A, - TransformedFragmentB const& B, - FragmentC const& C, - const int warp_tileB_k_offset) const - { + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()(FragmentC& D, + TransformedFragmentA const& A, + TransformedFragmentB const& B, + FragmentC const& C, + const int warp_tileB_k_offset) const { + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; - using MmaOperandA = typename ArchMmaOperator::FragmentA; - using MmaOperandB = typename ArchMmaOperator::FragmentB; - using MmaOperandC = typename ArchMmaOperator::FragmentC; + static_assert( + TransformedFragmentB::kElements == + MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, + "Each thread should have a pack of mma registers for each column " + "iteration AND for the expanded K dim of B"); - static_assert( - TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, - "Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of B"); + D = C; - D = C; - - MmaOperandA const* ptr_A = reinterpret_cast(&A); - MmaOperandB const* ptr_B = reinterpret_cast(&B); - MmaOperandC* ptr_D = reinterpret_cast(&D); + MmaOperandA const* ptr_A = reinterpret_cast(&A); + MmaOperandB const* ptr_B = reinterpret_cast(&B); + MmaOperandC* ptr_D = reinterpret_cast(&D); #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) - // Serpentine visitation order maximizing reuse of Rb - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { + // Serpentine visitation order maximizing reuse of Rb + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); - - int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; - if (AccumulatorsInRowMajor) { // matrix B is reordered - mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], - ptr_A[m_serpentine], - ptr_B[n_offsetB], - ptr_D[n + m_serpentine * MmaIterations::kColumn]); - } - else { - mma(ptr_D[m_serpentine + n * MmaIterations::kRow], - ptr_A[m_serpentine], - ptr_B[n_offsetB], - ptr_D[m_serpentine + n * MmaIterations::kRow]); - } - } + int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], + ptr_A[m_serpentine], + ptr_B[n_offsetB], + ptr_D[n + m_serpentine * MmaIterations::kColumn]); + } else { + mma(ptr_D[m_serpentine + n * MmaIterations::kRow], + ptr_A[m_serpentine], + ptr_B[n_offsetB], + ptr_D[m_serpentine + n * MmaIterations::kRow]); } -#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - // Serpentine visitation order maximizing reuse of Ra - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); - - int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine; - if (AccumulatorsInRowMajor) { // matrix B is reordered - mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], - ptr_A[m], - ptr_B[n_serpentine_offsetB], - ptr_D[n_serpentine + m * MmaIterations::kColumn]); - } - else { - mma(ptr_D[m + n_serpentine * MmaIterations::kRow], - ptr_A[m], - ptr_B[n_serpentine_offsetB], - ptr_D[m + n_serpentine * MmaIterations::kRow]); - } - } - } -#else - assert(0); -#endif + } } +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + // Serpentine visitation order maximizing reuse of Ra + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + int n_serpentine_offsetB = + warp_tileB_k_offset + kExpansionFactor * n_serpentine; + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], + ptr_A[m], + ptr_B[n_serpentine_offsetB], + ptr_D[n_serpentine + m * MmaIterations::kColumn]); + } else { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], + ptr_A[m], + ptr_B[n_serpentine_offsetB], + ptr_D[m + n_serpentine * MmaIterations::kRow]); + } + } + } +#else + assert(0); +#endif + } }; // Specialization for int8 int8 int32. Author(zhengzekang) -template< +template < /// Size of the Gemm problem - concept: gemm::GemmShape<> typename Shape_, /// Layout of A matrix (concept: MatrixLayout) @@ -328,233 +337,235 @@ template< bool AccumulatorsInRowMajor, /// Used for partial specialization typename Enable> -class MmaTensorOpComputeBWithF16< - Shape_, - int8_t, - LayoutA_, - int8_t, - LayoutB_, - ElementC_, - LayoutC_, - Policy_, - SharedMemoryInstructionShape_, - PartitionsK_, - AccumulatorsInRowMajor, - Enable> { -public: - /// Shape of warp-level matrix operation (concept: GemmShape) - using Shape = Shape_; +class MmaTensorOpComputeBWithF16 { + public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; - /// Data type of multiplicand A - using ElementA = int8_t; + /// Data type of multiplicand A + using ElementA = int8_t; - /// Layout of multiplicand A - using LayoutA = LayoutA_; + /// Layout of multiplicand A + using LayoutA = LayoutA_; - /// Data type of multiplicand B - using ElementB = int8_t; + /// Data type of multiplicand B + using ElementB = int8_t; - /// Layout of multiplicand B - using LayoutB = LayoutB_; + /// Layout of multiplicand B + using LayoutB = LayoutB_; - /// Data type of accumulator matrix C - using ElementC = ElementC_; + /// Data type of accumulator matrix C + using ElementC = ElementC_; - /// Layout of accumulator matrix C - using LayoutC = LayoutC_; + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; - /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) - using Policy = Policy_; + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; - /// Underlying matrix multiply operator (concept: arch::Mma) - using ArchMmaOperator = typename Policy::Operator; + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; - /// Indicates math operator - using MathOperator = typename ArchMmaOperator::Operator; + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; - /// Architecture tag from underlying instruction - using ArchTag = typename ArchMmaOperator::ArchTag; - static_assert((platform::is_same::value - && platform::is_same::value), - "MmaTensorOpCvtBToA only supports underlying iMMA"); + /// Architecture tag from underlying instruction + using ArchTag = typename ArchMmaOperator::ArchTag; + static_assert( + (platform::is_same::value && + platform::is_same::value), + "MmaTensorOpCvtBToA only supports underlying iMMA"); - // static_assert(platform::is_same::value, - // "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+"); + // static_assert(platform::is_same::value, + // "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on + // Ampere+"); - /// Indicates class of matrix operator - using OperatorClass = arch::OpClassTensorOp; + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; - /// Shape of underlying instruction - using InstructionShape = typename ArchMmaOperator::Shape; + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; - /// Instruction shape to override shared memory iterators with - using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; + /// Instruction shape to override shared memory iterators with + using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; - static_assert(SharedMemoryInstructionShape::kM == InstructionShape::kM, - "M dimension of compute instruction must match load"); - static_assert(SharedMemoryInstructionShape::kN == InstructionShape::kN, - "N dimension of compute instruction must match load"); + static_assert(SharedMemoryInstructionShape::kM == InstructionShape::kM, + "M dimension of compute instruction must match load"); + static_assert(SharedMemoryInstructionShape::kN == InstructionShape::kN, + "N dimension of compute instruction must match load"); - static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK; + static constexpr int kExpansionFactor = + SharedMemoryInstructionShape::kK / InstructionShape::kK; - static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); + static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); - /// Complex transform on A operand - static ComplexTransform const kTransformA = ComplexTransform::kNone; + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; - /// Complex transform on B operand - static ComplexTransform const kTransformB = ComplexTransform::kNone; + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; - /// Number of threads participating in warp-level matrix product - static int const kThreadCount = 32; + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; - /// Number of partitions along K dimension - static int const kPartitionsK = PartitionsK_; + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; -public: - /// Iterates over the A operand in memory - using IteratorA = MmaTensorOpMultiplicandTileIterator, - Operand::kA, - ElementA, - LayoutA, - MatrixShape, - Policy::OpDelta::kRow, - kThreadCount, - kPartitionsK>; + public: + /// Iterates over the A operand in memory + using IteratorA = MmaTensorOpMultiplicandTileIterator< + MatrixShape, + Operand::kA, + ElementA, + LayoutA, + MatrixShape, + Policy::OpDelta::kRow, + kThreadCount, + kPartitionsK>; - /// Storage for A tile - using FragmentA = typename IteratorA::Fragment; + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; - /// Storage for transformed A tile - using TransformedFragmentA = Array; + /// Storage for transformed A tile + using TransformedFragmentA = + Array; - /// Iterates over the B operand in memory - using IteratorB = - MmaTensorOpMultiplicandTileIterator, - Operand::kB, - ElementB, - LayoutB, - MatrixShape, - Policy::OpDelta::kRow, - kThreadCount, - kPartitionsK>; + /// Iterates over the B operand in memory + using IteratorB = MmaTensorOpMultiplicandTileIterator< + MatrixShape, + Operand::kB, + ElementB, + LayoutB, + MatrixShape, + Policy::OpDelta::kRow, + kThreadCount, + kPartitionsK>; - /// Storage for B tile - using FragmentB = typename IteratorB::Fragment; + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; - /// Storage for transformed B tile - using TransformedFragmentB = Array; + /// Storage for transformed B tile + using TransformedFragmentB = + Array; - /// Iterates over the C operand in memory - using IteratorC = MmaTensorOpAccumulatorTileIterator, - ElementC, - LayoutC, - typename ArchMmaOperator::Shape, - typename Policy::OpDelta>; + /// Iterates over the C operand in memory + using IteratorC = + MmaTensorOpAccumulatorTileIterator, + ElementC, + LayoutC, + typename ArchMmaOperator::Shape, + typename Policy::OpDelta>; - /// Storage for C tile - using FragmentC = typename IteratorC::Fragment; + /// Storage for C tile + using FragmentC = typename IteratorC::Fragment; - /// Number of mma operations performed - using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, - (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>; + /// Number of mma operations performed + using MmaIterations = + MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / + ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / + ArchMmaOperator::Shape::kN>; -public: - /// Underlying matrix multiply operator (concept: arch::Mma) - ArchMmaOperator mma; + public: + /// Underlying matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; -public: - // - // Methods - // + public: + // + // Methods + // - /// Ctor - CUTLASS_DEVICE - MmaTensorOpComputeBWithF16() {} + /// Ctor + CUTLASS_DEVICE + MmaTensorOpComputeBWithF16() {} - /// Performs a warp-level matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()(FragmentC& D, - TransformedFragmentA const& A, - TransformedFragmentB const& B, - FragmentC const& C, - const int warp_tileB_k_offset) const - { + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()(FragmentC& D, + TransformedFragmentA const& A, + TransformedFragmentB const& B, + FragmentC const& C, + const int warp_tileB_k_offset) const { + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; - using MmaOperandA = typename ArchMmaOperator::FragmentA; - using MmaOperandB = typename ArchMmaOperator::FragmentB; - using MmaOperandC = typename ArchMmaOperator::FragmentC; + static_assert( + TransformedFragmentB::kElements == + MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, + "Each thread should have a pack of mma registers for each column " + "iteration AND for the expanded K dim of B"); - static_assert( - TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, - "Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of B"); + D = C; - D = C; - - MmaOperandA const* ptr_A = reinterpret_cast(&A); - MmaOperandB const* ptr_B = reinterpret_cast(&B); - MmaOperandC* ptr_D = reinterpret_cast(&D); + MmaOperandA const* ptr_A = reinterpret_cast(&A); + MmaOperandB const* ptr_B = reinterpret_cast(&B); + MmaOperandC* ptr_D = reinterpret_cast(&D); #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) - // Serpentine visitation order maximizing reuse of Rb - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { + // Serpentine visitation order maximizing reuse of Rb + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); - - int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; - if (AccumulatorsInRowMajor) { // matrix B is reordered - mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], - ptr_A[m_serpentine], - ptr_B[n_offsetB], - ptr_D[n + m_serpentine * MmaIterations::kColumn]); - } - else { - mma(ptr_D[m_serpentine + n * MmaIterations::kRow], - ptr_A[m_serpentine], - ptr_B[n_offsetB], - ptr_D[m_serpentine + n * MmaIterations::kRow]); - } - } + int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], + ptr_A[m_serpentine], + ptr_B[n_offsetB], + ptr_D[n + m_serpentine * MmaIterations::kColumn]); + } else { + mma(ptr_D[m_serpentine + n * MmaIterations::kRow], + ptr_A[m_serpentine], + ptr_B[n_offsetB], + ptr_D[m_serpentine + n * MmaIterations::kRow]); } -#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - // Serpentine visitation order maximizing reuse of Ra - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); - - int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine; - if (AccumulatorsInRowMajor) { // matrix B is reordered - mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], - ptr_A[m], - ptr_B[n_serpentine_offsetB], - ptr_D[n_serpentine + m * MmaIterations::kColumn]); - } - else { - mma(ptr_D[m + n_serpentine * MmaIterations::kRow], - ptr_A[m], - ptr_B[n_serpentine_offsetB], - ptr_D[m + n_serpentine * MmaIterations::kRow]); - } - } - } -#else - assert(0); -#endif + } } +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + // Serpentine visitation order maximizing reuse of Ra + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + int n_serpentine_offsetB = + warp_tileB_k_offset + kExpansionFactor * n_serpentine; + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], + ptr_A[m], + ptr_B[n_serpentine_offsetB], + ptr_D[n_serpentine + m * MmaIterations::kColumn]); + } else { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], + ptr_A[m], + ptr_B[n_serpentine_offsetB], + ptr_D[m + n_serpentine * MmaIterations::kRow]); + } + } + } +#else + assert(0); +#endif + } }; - - // Specialization for int8 int8 int32. Author(zhengzekang) -template< +template < /// Size of the Gemm problem - concept: gemm::GemmShape<> typename Shape_, /// Layout of A matrix (concept: MatrixLayout) @@ -576,235 +587,250 @@ template< bool AccumulatorsInRowMajor, /// Used for partial specialization typename Enable> -class MmaTensorOpComputeBWithF16< - Shape_, - int8_t, - LayoutA_, - cutlass::uint4b_t, - LayoutB_, - ElementC_, - LayoutC_, - Policy_, - SharedMemoryInstructionShape_, - PartitionsK_, - AccumulatorsInRowMajor, - Enable> { -public: - /// Shape of warp-level matrix operation (concept: GemmShape) - using Shape = Shape_; +class MmaTensorOpComputeBWithF16 { + public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; - /// Data type of multiplicand A - using ElementA = int8_t; + /// Data type of multiplicand A + using ElementA = int8_t; - /// Layout of multiplicand A - using LayoutA = LayoutA_; + /// Layout of multiplicand A + using LayoutA = LayoutA_; - /// Data type of multiplicand B - using ElementB = cutlass::uint4b_t; + /// Data type of multiplicand B + using ElementB = cutlass::uint4b_t; - /// Layout of multiplicand B - using LayoutB = LayoutB_; + /// Layout of multiplicand B + using LayoutB = LayoutB_; - /// Data type of accumulator matrix C - using ElementC = ElementC_; + /// Data type of accumulator matrix C + using ElementC = ElementC_; - /// Layout of accumulator matrix C - using LayoutC = LayoutC_; + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; - /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) - using Policy = Policy_; + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; - /// Underlying matrix multiply operator (concept: arch::Mma) - using ArchMmaOperator = typename Policy::Operator; + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; - /// Indicates math operator - using MathOperator = typename ArchMmaOperator::Operator; + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; - /// Architecture tag from underlying instruction - using ArchTag = typename ArchMmaOperator::ArchTag; - static_assert((platform::is_same::value - && platform::is_same::value), - "MmaTensorOpCvtBToA only supports underlying iMMA"); + /// Architecture tag from underlying instruction + using ArchTag = typename ArchMmaOperator::ArchTag; + static_assert( + (platform::is_same::value && + platform::is_same::value), + "MmaTensorOpCvtBToA only supports underlying iMMA"); - // static_assert(platform::is_same::value, - // "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+"); + // static_assert(platform::is_same::value, + // "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on + // Ampere+"); - /// Indicates class of matrix operator - using OperatorClass = arch::OpClassTensorOp; + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; - /// Shape of underlying instruction - using InstructionShape = typename ArchMmaOperator::Shape; + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; - /// Instruction shape to override shared memory iterators with - using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; + /// Instruction shape to override shared memory iterators with + using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; - static_assert(SharedMemoryInstructionShape::kM == InstructionShape::kM, - "M dimension of compute instruction must match load"); - static_assert(SharedMemoryInstructionShape::kN == InstructionShape::kN, - "N dimension of compute instruction must match load"); + static_assert(SharedMemoryInstructionShape::kM == InstructionShape::kM, + "M dimension of compute instruction must match load"); + static_assert(SharedMemoryInstructionShape::kN == InstructionShape::kN, + "N dimension of compute instruction must match load"); - static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK; + static constexpr int kExpansionFactor = + SharedMemoryInstructionShape::kK / InstructionShape::kK; - static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); + static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); - /// Complex transform on A operand - static ComplexTransform const kTransformA = ComplexTransform::kNone; + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; - /// Complex transform on B operand - static ComplexTransform const kTransformB = ComplexTransform::kNone; + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; - /// Number of threads participating in warp-level matrix product - static int const kThreadCount = 32; + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; - /// Number of partitions along K dimension - static int const kPartitionsK = PartitionsK_; + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; -public: - /// Iterates over the A operand in memory + public: + /// Iterates over the A operand in memory + // ::cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator< + // ::cutlass::MatrixShape<(int)32, (int)64> , ( ::cutlass::gemm::Operand)0, + // signed char, + // ::cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<(int)8, (int)64> , + // ::cutlass::MatrixShape<(int)16, (int)32> , (int)1, (int)32, (int)1> > ()") + using IteratorA = MmaTensorOpMultiplicandTileIterator< + MatrixShape, + Operand::kA, + ElementA, + LayoutA, + MatrixShape, + Policy::OpDelta::kRow, + kThreadCount, + kPartitionsK>; - // ::cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator< ::cutlass::MatrixShape<(int)32, (int)64> , ( ::cutlass::gemm::Operand)0, signed char, ::cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<(int)8, (int)64> , ::cutlass::MatrixShape<(int)16, (int)32> , (int)1, (int)32, (int)1> > ()") - using IteratorA = MmaTensorOpMultiplicandTileIterator, - Operand::kA, - ElementA, - LayoutA, - MatrixShape, - Policy::OpDelta::kRow, - kThreadCount, - kPartitionsK>; + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; - /// Storage for A tile - using FragmentA = typename IteratorA::Fragment; + /// Storage for transformed A tile + using TransformedFragmentA = + Array; + // bool ::print_type< + // ::cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator< + // ::cutlass::MatrixShape<(int)64, (int)32> , ( ::cutlass::gemm::Operand)1, + // ::cutlass::integer_subbyte<(int)4, (bool)0> , + // ::cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<(int)4, + // (int)64> , ::cutlass::MatrixShape<(int)64, (int)8> , (int)1, (int)32, + // (int)1> > ()") + /// Iterates over the B operand in memory + using IteratorB = MmaTensorOpMultiplicandTileIterator< + MatrixShape, + Operand::kB, + ElementB, + LayoutB, + MatrixShape, + Policy::OpDelta::kRow, + kThreadCount, + kPartitionsK>; + // ::cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator< + // ::cutlass::MatrixShape<(int)64, (int)32> , ( ::cutlass::gemm::Operand)1, + // ::cutlass::integer_subbyte<(int)4, (bool)0> , + // ::cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<(int)4, + // (int)64> , ::cutlass::MatrixShape<(int)64, (int)8> , (int)1, (int)32, + // (int)1> > print_type(); + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; - /// Storage for transformed A tile - using TransformedFragmentA = Array; -// bool ::print_type< ::cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator< ::cutlass::MatrixShape<(int)64, (int)32> , ( ::cutlass::gemm::Operand)1, ::cutlass::integer_subbyte<(int)4, (bool)0> , ::cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<(int)4, (int)64> , ::cutlass::MatrixShape<(int)64, (int)8> , (int)1, (int)32, (int)1> > ()") - /// Iterates over the B operand in memory - using IteratorB = - MmaTensorOpMultiplicandTileIterator, - Operand::kB, - ElementB, - LayoutB, - MatrixShape, - Policy::OpDelta::kRow, - kThreadCount, - kPartitionsK>; -// ::cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator< ::cutlass::MatrixShape<(int)64, (int)32> , ( ::cutlass::gemm::Operand)1, ::cutlass::integer_subbyte<(int)4, (bool)0> , ::cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<(int)4, (int)64> , ::cutlass::MatrixShape<(int)64, (int)8> , (int)1, (int)32, (int)1> > - // print_type(); - /// Storage for B tile - using FragmentB = typename IteratorB::Fragment; + /// Storage for transformed B tile + using TransformedFragmentB = + Array; - /// Storage for transformed B tile - using TransformedFragmentB = Array; + /// Iterates over the C operand in memory + using IteratorC = + MmaTensorOpAccumulatorTileIterator, + ElementC, + LayoutC, + typename ArchMmaOperator::Shape, + typename Policy::OpDelta>; - /// Iterates over the C operand in memory - using IteratorC = MmaTensorOpAccumulatorTileIterator, - ElementC, - LayoutC, - typename ArchMmaOperator::Shape, - typename Policy::OpDelta>; + /// Storage for C tile + using FragmentC = typename IteratorC::Fragment; - /// Storage for C tile - using FragmentC = typename IteratorC::Fragment; + /// Number of mma operations performed + using MmaIterations = + MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / + ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / + ArchMmaOperator::Shape::kN>; - /// Number of mma operations performed - using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, - (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>; + public: + /// Underlying matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; -public: - /// Underlying matrix multiply operator (concept: arch::Mma) - ArchMmaOperator mma; + public: + // + // Methods + // -public: - // - // Methods - // + /// Ctor + CUTLASS_DEVICE + MmaTensorOpComputeBWithF16() {} - /// Ctor - CUTLASS_DEVICE - MmaTensorOpComputeBWithF16() {} + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()(FragmentC& D, + TransformedFragmentA const& A, + TransformedFragmentB const& B, + FragmentC const& C, + const int warp_tileB_k_offset) const { + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; - /// Performs a warp-level matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()(FragmentC& D, - TransformedFragmentA const& A, - TransformedFragmentB const& B, - FragmentC const& C, - const int warp_tileB_k_offset) const - { + static_assert( + TransformedFragmentB::kElements == + MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, + "Each thread should have a pack of mma registers for each column " + "iteration AND for the expanded K dim of B"); - using MmaOperandA = typename ArchMmaOperator::FragmentA; - using MmaOperandB = typename ArchMmaOperator::FragmentB; - using MmaOperandC = typename ArchMmaOperator::FragmentC; + D = C; - static_assert( - TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, - "Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of B"); - - D = C; - - MmaOperandA const* ptr_A = reinterpret_cast(&A); - MmaOperandB const* ptr_B = reinterpret_cast(&B); - MmaOperandC* ptr_D = reinterpret_cast(&D); + MmaOperandA const* ptr_A = reinterpret_cast(&A); + MmaOperandB const* ptr_B = reinterpret_cast(&B); + MmaOperandC* ptr_D = reinterpret_cast(&D); #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) - // Serpentine visitation order maximizing reuse of Rb - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { + // Serpentine visitation order maximizing reuse of Rb + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); - - int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; - if (AccumulatorsInRowMajor) { // matrix B is reordered - mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], - ptr_A[m_serpentine], - ptr_B[n_offsetB], - ptr_D[n + m_serpentine * MmaIterations::kColumn]); - } - else { - mma(ptr_D[m_serpentine + n * MmaIterations::kRow], - ptr_A[m_serpentine], - ptr_B[n_offsetB], - ptr_D[m_serpentine + n * MmaIterations::kRow]); - } - } + int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], + ptr_A[m_serpentine], + ptr_B[n_offsetB], + ptr_D[n + m_serpentine * MmaIterations::kColumn]); + } else { + mma(ptr_D[m_serpentine + n * MmaIterations::kRow], + ptr_A[m_serpentine], + ptr_B[n_offsetB], + ptr_D[m_serpentine + n * MmaIterations::kRow]); } -#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - // Serpentine visitation order maximizing reuse of Ra - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); - - int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine; - if (AccumulatorsInRowMajor) { // matrix B is reordered - mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], - ptr_A[m], - ptr_B[n_serpentine_offsetB], - ptr_D[n_serpentine + m * MmaIterations::kColumn]); - } - else { - mma(ptr_D[m + n_serpentine * MmaIterations::kRow], - ptr_A[m], - ptr_B[n_serpentine_offsetB], - ptr_D[m + n_serpentine * MmaIterations::kRow]); - } - } - } -#else - assert(0); -#endif + } } +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + // Serpentine visitation order maximizing reuse of Ra + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + int n_serpentine_offsetB = + warp_tileB_k_offset + kExpansionFactor * n_serpentine; + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], + ptr_A[m], + ptr_B[n_serpentine_offsetB], + ptr_D[n_serpentine + m * MmaIterations::kColumn]); + } else { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], + ptr_A[m], + ptr_B[n_serpentine_offsetB], + ptr_D[m + n_serpentine * MmaIterations::kRow]); + } + } + } +#else + assert(0); +#endif + } }; - - ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace warp diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h index 73a1122433..73c8fbb457 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h @@ -187,10 +187,9 @@ class MmaTensorOpDequantizer< // Adds a pointer offset in units of elements. CUTLASS_DEVICE - void add_pointer_offset(int64_t const& offset) - { - static_assert(sizeof(ElementScale) > 1, ""); - pointer_ += offset; + void add_pointer_offset(int64_t const& offset) { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_ += offset; } private: @@ -297,10 +296,9 @@ class MmaTensorOpDequantizer< } // Adds a pointer offset in units of elements. CUTLASS_DEVICE - void add_pointer_offset(int64_t const& offset) - { - static_assert(sizeof(ElementScale) > 1, ""); - pointer_ += offset; + void add_pointer_offset(int64_t const& offset) { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_ += offset; } private: diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/interleaved_numeric_conversion.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/interleaved_numeric_conversion.h index 6f09be73e3..d68a683f4a 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/interleaved_numeric_conversion.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/interleaved_numeric_conversion.h @@ -435,180 +435,197 @@ struct FastInterleavedAndBiasedNumericArrayConverter { result_type operator()(source_type const& s) { return convert(s); } }; -template<> +template <> struct FastInterleavedAndBiasedNumericArrayConverter { - using result_type = Array; - using source_type = Array; + using result_type = Array; + using source_type = Array; - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; - uint32_t* h = reinterpret_cast(&result); - uint32_t const i8s = reinterpret_cast(source); + uint32_t* h = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); - // 3 2 1 0 -> 3 1 2 0 - static constexpr uint32_t mask_for_elt = 0x3120; - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "r"(i8s), "n"(mask_for_elt)); + // 3 2 1 0 -> 3 1 2 0 + static constexpr uint32_t mask_for_elt = 0x3120; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[0]) + : "r"(i8s), "r"(i8s), "n"(mask_for_elt)); - // Author zhengzekang - uint8_t* tmp = reinterpret_cast(&result); - #pragma unroll - for(int i = 0; i < 4; i++){ - result[i] = static_cast(static_cast(tmp[i]) - 128); - } - return result; + // Author zhengzekang + uint8_t* tmp = reinterpret_cast(&result); +#pragma unroll + for (int i = 0; i < 4; i++) { + result[i] = static_cast(static_cast(tmp[i]) - 128); } + return result; + } - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } }; - -template<> +template <> struct FastInterleavedAndBiasedNumericArrayConverter { - using result_type = Array; - using source_type = Array; + using result_type = Array; + using source_type = Array; - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; - uint32_t* h = reinterpret_cast(&result); - const uint32_t * i8s = reinterpret_cast(&source); + uint32_t* h = reinterpret_cast(&result); + const uint32_t* i8s = reinterpret_cast(&source); - // 3 2 1 0 -> 3 1 2 0 - static constexpr uint32_t mask_for_elt_1 = 0x7120; - static constexpr uint32_t mask_for_elt_2 = 0x3654; + // 3 2 1 0 -> 3 1 2 0 + static constexpr uint32_t mask_for_elt_1 = 0x7120; + static constexpr uint32_t mask_for_elt_2 = 0x3654; - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s[0]), "r"(i8s[1]), "n"(mask_for_elt_1)); - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s[0]), "r"(i8s[1]), "n"(mask_for_elt_2)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[0]) + : "r"(i8s[0]), "r"(i8s[1]), "n"(mask_for_elt_1)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[1]) + : "r"(i8s[0]), "r"(i8s[1]), "n"(mask_for_elt_2)); - // Author zhengzekang - uint8_t* tmp = reinterpret_cast(&result); - #pragma unroll - for(int i = 0; i < 8; i++){ - result[i] = static_cast(static_cast(tmp[i]) - 128); - } - return result; + // Author zhengzekang + uint8_t* tmp = reinterpret_cast(&result); +#pragma unroll + for (int i = 0; i < 8; i++) { + result[i] = static_cast(static_cast(tmp[i]) - 128); } + return result; + } - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } }; - -template<> +template <> struct FastInterleavedAndBiasedNumericArrayConverter { - using result_type = Array; - using source_type = Array; + using result_type = Array; + using source_type = Array; - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; - uint32_t* h = reinterpret_cast(&result); - const uint32_t * i8s = reinterpret_cast(&source); + uint32_t* h = reinterpret_cast(&result); + const uint32_t* i8s = reinterpret_cast(&source); - // 3 2 1 0 -> 3 1 2 0 - static constexpr uint32_t mask_for_elt_1 = 0x3120; - static constexpr uint32_t mask_for_elt_2 = 0x3120; - static constexpr uint32_t mask_for_elt_3 = 0x3120; - static constexpr uint32_t mask_for_elt_4 = 0x3120; + // 3 2 1 0 -> 3 1 2 0 + static constexpr uint32_t mask_for_elt_1 = 0x3120; + static constexpr uint32_t mask_for_elt_2 = 0x3120; + static constexpr uint32_t mask_for_elt_3 = 0x3120; + static constexpr uint32_t mask_for_elt_4 = 0x3120; - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s[0]), "r"(i8s[2]), "n"(mask_for_elt_1)); - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s[1]), "r"(i8s[3]), "n"(mask_for_elt_2)); - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[2]) : "r"(i8s[2]), "r"(i8s[0]), "n"(mask_for_elt_3)); - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[3]) : "r"(i8s[3]), "r"(i8s[1]), "n"(mask_for_elt_4)); - - - // Author zhengzekang - uint8_t* tmp = reinterpret_cast(&result); - #pragma unroll - for(int i = 0; i < 16; i++){ - result[i] = static_cast(static_cast(tmp[i]) - 128); - } - - return result; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[0]) + : "r"(i8s[0]), "r"(i8s[2]), "n"(mask_for_elt_1)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[1]) + : "r"(i8s[1]), "r"(i8s[3]), "n"(mask_for_elt_2)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[2]) + : "r"(i8s[2]), "r"(i8s[0]), "n"(mask_for_elt_3)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[3]) + : "r"(i8s[3]), "r"(i8s[1]), "n"(mask_for_elt_4)); + // Author zhengzekang + uint8_t* tmp = reinterpret_cast(&result); +#pragma unroll + for (int i = 0; i < 16; i++) { + result[i] = static_cast(static_cast(tmp[i]) - 128); } - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } }; -template +template struct FastInterleavedAndBiasedNumericArrayConverter { - static constexpr int VEC_WIDTH = 4; - static_assert(N == 32,"N must be 32"); - static_assert(!(N % VEC_WIDTH), "N must be multiple of 16."); + static constexpr int VEC_WIDTH = 4; + static_assert(N == 32, "N must be 32"); + static_assert(!(N % VEC_WIDTH), "N must be multiple of 16."); - using result_type = Array; - using source_type = Array; + using result_type = Array; + using source_type = Array; - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - FastInterleavedAndBiasedNumericArrayConverter - convert_vector_; + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; - result_type result; - using vec_result = Array; - using vec_source = Array; + result_type result; + using vec_result = Array; + using vec_source = Array; - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); - CUTLASS_PRAGMA_UNROLL - for (int32_t i = 0; i < N / VEC_WIDTH; ++i) { - result_ptr[i] = convert_vector_((source_ptr)[i]); - } - Array temp_rearrange_array; - auto lane_idx_div4 = threadIdx.x%4; + CUTLASS_PRAGMA_UNROLL + for (int32_t i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_((source_ptr)[i]); + } + Array temp_rearrange_array; + auto lane_idx_div4 = threadIdx.x % 4; - CUTLASS_PRAGMA_UNROLL - for (int32_t i=0;i< N/8;++i){ - uint32_t* temp_rearrange_array_ptr = reinterpret_cast(&temp_rearrange_array); - uint32_t* result_reg_ptr = reinterpret_cast(result_ptr)+i * 2; - temp_rearrange_array_ptr[0] = __shfl_xor_sync(0xFFFFFFFF,reinterpret_cast(result_reg_ptr)[0],3); - temp_rearrange_array_ptr[1] = __shfl_xor_sync(0xFFFFFFFF,reinterpret_cast(result_reg_ptr)[1],3); - if( lane_idx_div4==1 || lane_idx_div4==2 ){ - result_reg_ptr[0]=temp_rearrange_array_ptr[0]; - result_reg_ptr[1]=temp_rearrange_array_ptr[1]; - } - temp_rearrange_array_ptr[0] = __shfl_xor_sync(0xFFFFFFFF,reinterpret_cast(result_reg_ptr)[0],2); - temp_rearrange_array_ptr[1] = __shfl_xor_sync(0xFFFFFFFF,reinterpret_cast(result_reg_ptr)[1],2); - if(lane_idx_div4<2){ - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(result_reg_ptr[0]) : "r"(result_reg_ptr[0]), "r"(temp_rearrange_array_ptr[0]), "n"(0x5410)); - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(result_reg_ptr[1]) : "r"(result_reg_ptr[1]), "r"(temp_rearrange_array_ptr[1]), "n"(0x5410)); - } - else{ - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(result_reg_ptr[0]) : "r"(result_reg_ptr[0]), "r"(temp_rearrange_array_ptr[0]), "n"(0x3276)); - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(result_reg_ptr[1]) : "r"(result_reg_ptr[1]), "r"(temp_rearrange_array_ptr[1]), "n"(0x3276)); - } - } - - return result; + CUTLASS_PRAGMA_UNROLL + for (int32_t i = 0; i < N / 8; ++i) { + uint32_t* temp_rearrange_array_ptr = + reinterpret_cast(&temp_rearrange_array); + uint32_t* result_reg_ptr = + reinterpret_cast(result_ptr) + i * 2; + temp_rearrange_array_ptr[0] = __shfl_xor_sync( + 0xFFFFFFFF, reinterpret_cast(result_reg_ptr)[0], 3); + temp_rearrange_array_ptr[1] = __shfl_xor_sync( + 0xFFFFFFFF, reinterpret_cast(result_reg_ptr)[1], 3); + if (lane_idx_div4 == 1 || lane_idx_div4 == 2) { + result_reg_ptr[0] = temp_rearrange_array_ptr[0]; + result_reg_ptr[1] = temp_rearrange_array_ptr[1]; + } + temp_rearrange_array_ptr[0] = __shfl_xor_sync( + 0xFFFFFFFF, reinterpret_cast(result_reg_ptr)[0], 2); + temp_rearrange_array_ptr[1] = __shfl_xor_sync( + 0xFFFFFFFF, reinterpret_cast(result_reg_ptr)[1], 2); + if (lane_idx_div4 < 2) { + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(result_reg_ptr[0]) + : "r"(result_reg_ptr[0]), + "r"(temp_rearrange_array_ptr[0]), + "n"(0x5410)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(result_reg_ptr[1]) + : "r"(result_reg_ptr[1]), + "r"(temp_rearrange_array_ptr[1]), + "n"(0x5410)); + } else { + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(result_reg_ptr[0]) + : "r"(result_reg_ptr[0]), + "r"(temp_rearrange_array_ptr[0]), + "n"(0x3276)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(result_reg_ptr[1]) + : "r"(result_reg_ptr[1]), + "r"(temp_rearrange_array_ptr[1]), + "n"(0x3276)); + } } - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/interleaved_numeric_conversion_nf4.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/interleaved_numeric_conversion_nf4.h index b21b204c85..c4f03583ce 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/interleaved_numeric_conversion_nf4.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/interleaved_numeric_conversion_nf4.h @@ -104,8 +104,8 @@ struct FastInterleavedAndBiasedNumericArrayConverterNf4 { using scalar_result_type = typename result_type::Element; using scalar_source_type = typename source_type::Element; FastInterleavedAndBiasedNumericArrayConverterNf4 + scalar_source_type, + VEC_WIDTH> convert_vector_; result_type result; @@ -128,7 +128,9 @@ struct FastInterleavedAndBiasedNumericArrayConverterNf4 { }; template <> -struct FastInterleavedAndBiasedNumericArrayConverterNf4 { +struct FastInterleavedAndBiasedNumericArrayConverterNf4 { using result_type = Array; using source_type = Array; @@ -179,7 +181,9 @@ struct FastInterleavedAndBiasedNumericArrayConverterNf4 }; template -struct FastInterleavedAndBiasedNumericArrayConverterNf4 { +struct FastInterleavedAndBiasedNumericArrayConverterNf4 { static constexpr int VEC_WIDTH = 4; static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); @@ -191,8 +195,8 @@ struct FastInterleavedAndBiasedNumericArrayConverterNf4 using scalar_result_type = typename result_type::Element; using scalar_source_type = typename source_type::Element; FastInterleavedAndBiasedNumericArrayConverterNf4 + scalar_source_type, + VEC_WIDTH> convert_vector_; result_type result; @@ -314,8 +318,8 @@ struct FastInterleavedAndBiasedNumericArrayConverterNf4 { using scalar_result_type = typename result_type::Element; using scalar_source_type = typename source_type::Element; FastInterleavedAndBiasedNumericArrayConverterNf4 + scalar_source_type, + VEC_WIDTH> convert_vector_; result_type result; @@ -338,7 +342,9 @@ struct FastInterleavedAndBiasedNumericArrayConverterNf4 { }; template <> -struct FastInterleavedAndBiasedNumericArrayConverterNf4 { +struct FastInterleavedAndBiasedNumericArrayConverterNf4 { using result_type = Array; using source_type = Array; @@ -400,7 +406,9 @@ struct FastInterleavedAndBiasedNumericArrayConverterNf4 }; template -struct FastInterleavedAndBiasedNumericArrayConverterNf4 { +struct FastInterleavedAndBiasedNumericArrayConverterNf4 { static constexpr int VEC_WIDTH = 8; static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); @@ -412,8 +420,8 @@ struct FastInterleavedAndBiasedNumericArrayConverterNf4 using scalar_result_type = typename result_type::Element; using scalar_source_type = typename source_type::Element; FastInterleavedAndBiasedNumericArrayConverterNf4 + scalar_source_type, + VEC_WIDTH> convert_vector_; result_type result; @@ -435,388 +443,405 @@ struct FastInterleavedAndBiasedNumericArrayConverterNf4 result_type operator()(source_type const& s) { return convert(s); } }; -template<> +template <> struct FastInterleavedAndBiasedNumericArrayConverterNf4 { - using result_type = Array; - using source_type = Array; + using result_type = Array; + using source_type = Array; - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; - uint32_t* h = reinterpret_cast(&result); - uint32_t const i8s = reinterpret_cast(source); + uint32_t* h = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); - // 3 2 1 0 -> 3 1 2 0 - static constexpr uint32_t mask_for_elt = 0x3120; - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "r"(i8s), "n"(mask_for_elt)); + // 3 2 1 0 -> 3 1 2 0 + static constexpr uint32_t mask_for_elt = 0x3120; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[0]) + : "r"(i8s), "r"(i8s), "n"(mask_for_elt)); - // Author zhengzekang - uint8_t* tmp = reinterpret_cast(&result); - #pragma unroll - for(int i = 0; i < 4; i++){ - result[i] = static_cast(static_cast(tmp[i]) - 128); - } - return result; + // Author zhengzekang + uint8_t* tmp = reinterpret_cast(&result); +#pragma unroll + for (int i = 0; i < 4; i++) { + result[i] = static_cast(static_cast(tmp[i]) - 128); } + return result; + } - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } }; - -template<> +template <> struct FastInterleavedAndBiasedNumericArrayConverterNf4 { - using result_type = Array; - using source_type = Array; + using result_type = Array; + using source_type = Array; - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; - uint32_t* h = reinterpret_cast(&result); - const uint32_t * i8s = reinterpret_cast(&source); + uint32_t* h = reinterpret_cast(&result); + const uint32_t* i8s = reinterpret_cast(&source); - // 3 2 1 0 -> 3 1 2 0 - static constexpr uint32_t mask_for_elt_1 = 0x7120; - static constexpr uint32_t mask_for_elt_2 = 0x3654; + // 3 2 1 0 -> 3 1 2 0 + static constexpr uint32_t mask_for_elt_1 = 0x7120; + static constexpr uint32_t mask_for_elt_2 = 0x3654; - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s[0]), "r"(i8s[1]), "n"(mask_for_elt_1)); - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s[0]), "r"(i8s[1]), "n"(mask_for_elt_2)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[0]) + : "r"(i8s[0]), "r"(i8s[1]), "n"(mask_for_elt_1)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[1]) + : "r"(i8s[0]), "r"(i8s[1]), "n"(mask_for_elt_2)); - // Author zhengzekang - uint8_t* tmp = reinterpret_cast(&result); - #pragma unroll - for(int i = 0; i < 8; i++){ - result[i] = static_cast(static_cast(tmp[i]) - 128); - } - return result; + // Author zhengzekang + uint8_t* tmp = reinterpret_cast(&result); +#pragma unroll + for (int i = 0; i < 8; i++) { + result[i] = static_cast(static_cast(tmp[i]) - 128); } + return result; + } - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } }; - -template<> +template <> struct FastInterleavedAndBiasedNumericArrayConverterNf4 { - using result_type = Array; - using source_type = Array; + using result_type = Array; + using source_type = Array; - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; - uint32_t* h = reinterpret_cast(&result); - const uint32_t * i8s = reinterpret_cast(&source); + uint32_t* h = reinterpret_cast(&result); + const uint32_t* i8s = reinterpret_cast(&source); - // 3 2 1 0 -> 3 1 2 0 - static constexpr uint32_t mask_for_elt_1 = 0x3120; - static constexpr uint32_t mask_for_elt_2 = 0x3120; - static constexpr uint32_t mask_for_elt_3 = 0x3120; - static constexpr uint32_t mask_for_elt_4 = 0x3120; + // 3 2 1 0 -> 3 1 2 0 + static constexpr uint32_t mask_for_elt_1 = 0x3120; + static constexpr uint32_t mask_for_elt_2 = 0x3120; + static constexpr uint32_t mask_for_elt_3 = 0x3120; + static constexpr uint32_t mask_for_elt_4 = 0x3120; - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s[0]), "r"(i8s[2]), "n"(mask_for_elt_1)); - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s[1]), "r"(i8s[3]), "n"(mask_for_elt_2)); - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[2]) : "r"(i8s[2]), "r"(i8s[0]), "n"(mask_for_elt_3)); - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[3]) : "r"(i8s[3]), "r"(i8s[1]), "n"(mask_for_elt_4)); - - - // Author zhengzekang - uint8_t* tmp = reinterpret_cast(&result); - #pragma unroll - for(int i = 0; i < 16; i++){ - result[i] = static_cast(static_cast(tmp[i]) - 128); - } - - return result; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[0]) + : "r"(i8s[0]), "r"(i8s[2]), "n"(mask_for_elt_1)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[1]) + : "r"(i8s[1]), "r"(i8s[3]), "n"(mask_for_elt_2)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[2]) + : "r"(i8s[2]), "r"(i8s[0]), "n"(mask_for_elt_3)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[3]) + : "r"(i8s[3]), "r"(i8s[1]), "n"(mask_for_elt_4)); + // Author zhengzekang + uint8_t* tmp = reinterpret_cast(&result); +#pragma unroll + for (int i = 0; i < 16; i++) { + result[i] = static_cast(static_cast(tmp[i]) - 128); } - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } }; -template +template struct FastInterleavedAndBiasedNumericArrayConverterNf4 { - static constexpr int VEC_WIDTH = 4; - static_assert(N == 32,"N must be 32"); - static_assert(!(N % VEC_WIDTH), "N must be multiple of 16."); + static constexpr int VEC_WIDTH = 4; + static_assert(N == 32, "N must be 32"); + static_assert(!(N % VEC_WIDTH), "N must be multiple of 16."); - using result_type = Array; - using source_type = Array; + using result_type = Array; + using source_type = Array; - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - return source; - } + CUTLASS_DEVICE + static result_type convert(source_type const& source) { return source; } - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } }; -template -struct FastInterleavedAndBiasedNumericArrayConverterNf4 { - static constexpr int VEC_WIDTH = 8; - // static_assert(N == 64,"N must be 64"); - static_assert(!(N % VEC_WIDTH), "N must be multiple of VEC_WIDTH."); +template +struct FastInterleavedAndBiasedNumericArrayConverterNf4 { + static constexpr int VEC_WIDTH = 8; + // static_assert(N == 64,"N must be 64"); + static_assert(!(N % VEC_WIDTH), "N must be multiple of VEC_WIDTH."); - using result_type = Array; - using source_type = Array; - using vec_source = Array; - using vec_result = Array; - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - //nf4 - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - result_type result; - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / VEC_WIDTH; ++i) { - vec_result index; - uint32_t* result_i_ptr = reinterpret_cast(&(result_ptr[i])); - uint32_t const i4s = reinterpret_cast(source_ptr[i]); - static constexpr uint32_t up_int4_mask = 0xf0f0f0f0; - static constexpr uint32_t immLut_0 = 0x40; - static constexpr uint32_t immLut_1 = 0x80; - asm volatile( - "lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(result_i_ptr[1]) - : "r"(i4s), "r"(i4s), "n"(up_int4_mask), "n"(immLut_1)); - asm volatile( - "lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(result_i_ptr[0]) - : "r"(i4s), "r"(i4s), "n"(up_int4_mask), "n"(immLut_0)); - result_i_ptr[0]=(result_i_ptr[0]<<4); - } - return result; + using result_type = Array; + using source_type = Array; + using vec_source = Array; + using vec_result = Array; + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + // nf4 + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + result_type result; + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + vec_result index; + uint32_t* result_i_ptr = reinterpret_cast(&(result_ptr[i])); + uint32_t const i4s = reinterpret_cast(source_ptr[i]); + static constexpr uint32_t up_int4_mask = 0xf0f0f0f0; + static constexpr uint32_t immLut_0 = 0x40; + static constexpr uint32_t immLut_1 = 0x80; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(result_i_ptr[1]) + : "r"(i4s), "r"(i4s), "n"(up_int4_mask), "n"(immLut_1)); + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(result_i_ptr[0]) + : "r"(i4s), "r"(i4s), "n"(up_int4_mask), "n"(immLut_0)); + result_i_ptr[0] = (result_i_ptr[0] << 4); } + return result; + } - CUTLASS_DEVICE - static result_type convert(source_type const& source, int32_t* shared_look_up_table) - { - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - result_type result; - // static constexpr uint32_t loop_up_table[4]{0x03020100,0x07060504,0x0B0A0908,0x0F0E0D0C}; - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); - // FastInterleavedAndBiasedNumericArrayConverterNf4 - // convert_vector_; - // static constexpr Array loop_up_table{0x00010203,0x04050607,0x08090A0B,0x0C0D0E0F}; - for (int i = 0; i < N / VEC_WIDTH; ++i) { - vec_result index; - uint32_t* h = reinterpret_cast(&index); - // const int8_t* loop_up_table_int8 = reinterpret_cast(&loop_up_table); - uint32_t const i4s = reinterpret_cast(source_ptr[i]); - static constexpr uint32_t down_int4_mask = 0x0f0f0f0f; - static constexpr uint32_t up_int4_mask = 0xf0f0f0f0; - h[0]=i4s&down_int4_mask; - h[1]=i4s&up_int4_mask; - h[1]=h[1]>>4; + CUTLASS_DEVICE + static result_type convert(source_type const& source, + int32_t* shared_look_up_table) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + result_type result; + // static constexpr uint32_t + // loop_up_table[4]{0x03020100,0x07060504,0x0B0A0908,0x0F0E0D0C}; + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + // FastInterleavedAndBiasedNumericArrayConverterNf4 + // convert_vector_; + // static constexpr Array + // loop_up_table{0x00010203,0x04050607,0x08090A0B,0x0C0D0E0F}; + for (int i = 0; i < N / VEC_WIDTH; ++i) { + vec_result index; + uint32_t* h = reinterpret_cast(&index); + // const int8_t* loop_up_table_int8 = reinterpret_cast(&loop_up_table); + uint32_t const i4s = reinterpret_cast(source_ptr[i]); + static constexpr uint32_t down_int4_mask = 0x0f0f0f0f; + static constexpr uint32_t up_int4_mask = 0xf0f0f0f0; + h[0] = i4s & down_int4_mask; + h[1] = i4s & up_int4_mask; + h[1] = h[1] >> 4; - //TODO(wangbojun)!!!! do nf4 lookup table - CUTLASS_PRAGMA_UNROLL - for(int ii=0; ii(index[ii]); - // } + // TODO(wangbojun)!!!! do nf4 lookup table + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < VEC_WIDTH; ++ii) { + result_ptr[i][ii] = shared_look_up_table[index[ii]]; + } + // CUTLASS_PRAGMA_UNROLL + // for(int ii=0; ii(index[ii]); + // } - // CUTLASS_PRAGMA_UNROLL - // for(int ii=0; ii(index[0]), - // static_cast(index[1]), - // static_cast(index[2]), - // static_cast(index[3]), - // static_cast(index[4]), - // static_cast(index[5]), - // static_cast(index[6]), - // static_cast(index[7]), - // static_cast(result_ptr[i][0]), - // static_cast(result_ptr[i][1]), - // static_cast(result_ptr[i][2]), - // static_cast(result_ptr[i][3]), - // static_cast(result_ptr[i][4]), - // static_cast(result_ptr[i][5]), - // static_cast(result_ptr[i][6]), - // static_cast(result_ptr[i][7]) - // ); - } - return result; + // printf("### i:%d index:%d-%d-%d-%d-%d-%d-%d-%d, + // result:%d-%d-%d-%d-%d-%d-%d-%d\n", + // i, + // static_cast(index[0]), + // static_cast(index[1]), + // static_cast(index[2]), + // static_cast(index[3]), + // static_cast(index[4]), + // static_cast(index[5]), + // static_cast(index[6]), + // static_cast(index[7]), + // static_cast(result_ptr[i][0]), + // static_cast(result_ptr[i][1]), + // static_cast(result_ptr[i][2]), + // static_cast(result_ptr[i][3]), + // static_cast(result_ptr[i][4]), + // static_cast(result_ptr[i][5]), + // static_cast(result_ptr[i][6]), + // static_cast(result_ptr[i][7]) + // ); } + return result; + } + // #define NF4_LUT_DEBUG + CUTLASS_DEVICE + static result_type convert( + source_type const& source, + cutlass::Array const& reg_look_up_table) { + // static_assert(VEC_WIDTH==16, "VEC_WIDTH == 16 for int8 int8 int32") + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + result_type result; + // static constexpr uint32_t + // loop_up_table[4]{0x03020100,0x07060504,0x0B0A0908,0x0F0E0D0C}; + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + // FastInterleavedAndBiasedNumericArrayConverterNf4 + // convert_vector_; + // static constexpr Array + // loop_up_table{0x00010203,0x04050607,0x08090A0B,0x0C0D0E0F}; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + uint32_t bitwise_workspace; + uint32_t* result_reg = reinterpret_cast(&(result_ptr[i])); + uint32_t const i4s = reinterpret_cast(source_ptr[i]); + static constexpr uint32_t down_int4_mask = 0x0f0f0f0f; +#ifdef NF4_LUT_DEBUG + if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && + blockIdx.z == 0) { + printf("#### lup 0:%08x, lup 1:%08x, lup2:%08x, lup3:%08x \n", + reg_look_up_table[0], + reg_look_up_table[1], + reg_look_up_table[2], + reg_look_up_table[3]); + printf("#### i4s:%08x \n", i4s); + } +#endif + bitwise_workspace = i4s & down_int4_mask; // 0x0v0v0v0v -// #define NF4_LUT_DEBUG - CUTLASS_DEVICE - static result_type convert(source_type const& source, cutlass::Array const& reg_look_up_table) - { - // static_assert(VEC_WIDTH==16, "VEC_WIDTH == 16 for int8 int8 int32") - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - result_type result; - // static constexpr uint32_t loop_up_table[4]{0x03020100,0x07060504,0x0B0A0908,0x0F0E0D0C}; - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); - // FastInterleavedAndBiasedNumericArrayConverterNf4 - // convert_vector_; - // static constexpr Array loop_up_table{0x00010203,0x04050607,0x08090A0B,0x0C0D0E0F}; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / VEC_WIDTH; ++i) { - uint32_t bitwise_workspace; - uint32_t* result_reg = reinterpret_cast(&(result_ptr[i])); - uint32_t const i4s = reinterpret_cast(source_ptr[i]); - static constexpr uint32_t down_int4_mask = 0x0f0f0f0f; - #ifdef NF4_LUT_DEBUG - if(threadIdx.x==0 && blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ - printf("#### lup 0:%08x, lup 1:%08x, lup2:%08x, lup3:%08x \n", - reg_look_up_table[0], - reg_look_up_table[1], - reg_look_up_table[2], - reg_look_up_table[3]); - printf("#### i4s:%08x \n", i4s); - } - #endif - bitwise_workspace = i4s & down_int4_mask; // 0x0v0v0v0v +#ifdef NF4_LUT_DEBUG + if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && + blockIdx.z == 0) { + printf("#### h0 0x0v0v0v0v:%08x \n", bitwise_workspace); + } +#endif - #ifdef NF4_LUT_DEBUG - if(threadIdx.x==0 && blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ - printf("#### h0 0x0v0v0v0v:%08x \n", bitwise_workspace); - } - #endif + uint32_t ge_7_mask_h_0 = + (bitwise_workspace | bitwise_workspace << 4) & 0x88888888; + ge_7_mask_h_0 |= ge_7_mask_h_0 >> 1; + ge_7_mask_h_0 |= ge_7_mask_h_0 >> 2; - uint32_t ge_7_mask_h_0 = (bitwise_workspace |bitwise_workspace << 4) & 0x88888888; - ge_7_mask_h_0 |= ge_7_mask_h_0 >> 1; - ge_7_mask_h_0 |= ge_7_mask_h_0 >> 2; +#ifdef NF4_LUT_DEBUG + if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && + blockIdx.z == 0) { + printf("#### ge_7_mask_h_0:%08x \n", ge_7_mask_h_0); + } +#endif - #ifdef NF4_LUT_DEBUG - if(threadIdx.x==0 && blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ - printf("#### ge_7_mask_h_0:%08x \n", ge_7_mask_h_0); - } - #endif + // uint32_t look_up_h_0 = h[0]; + bitwise_workspace = + (bitwise_workspace | (bitwise_workspace >> 4)) & 0x00FF00FF; + bitwise_workspace = (bitwise_workspace | (bitwise_workspace >> 8)) & + 0x00007777; // make h[i] into 0x0000vvvv - // uint32_t look_up_h_0 = h[0]; - bitwise_workspace = (bitwise_workspace | (bitwise_workspace >> 4)) & 0x00FF00FF; - bitwise_workspace = (bitwise_workspace | (bitwise_workspace >> 8)) & 0x00007777; // make h[i] into 0x0000vvvv +#ifdef NF4_LUT_DEBUG + if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && + blockIdx.z == 0) { + printf("#### h0 0x0000vvvv:%08x \n", bitwise_workspace); + } +#endif - #ifdef NF4_LUT_DEBUG - if(threadIdx.x==0 && blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ - printf("#### h0 0x0000vvvv:%08x \n", bitwise_workspace); - } - #endif - - uint32_t result_1=0; - asm volatile("prmt.b32 %0,%1,%2,%3;\n" - : "=r"(result_1) - : "r"(reg_look_up_table[0]), "r"(reg_look_up_table[1]), "r"(bitwise_workspace)); - #ifdef NF4_LUT_DEBUG - if(threadIdx.x==0 && blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ - printf("#### result_1:%08x \n", result_1); - } - #endif - result_reg[0] = result_1 & (~ge_7_mask_h_0); - result_1 = ((~result_1)) & (ge_7_mask_h_0); - // result_1 = ((~result_1) | 0x01010101) & (ge_7_mask_h_0); // half mirror - result_reg[0] = result_reg[0] | result_1; - #ifdef NF4_LUT_DEBUG - if(threadIdx.x==0 && blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ - printf("#### result[0]:%08x \n", result_reg[0]); - } - #endif - bitwise_workspace = i4s >> 4; - bitwise_workspace = bitwise_workspace & down_int4_mask; - #ifdef NF4_LUT_DEBUG - if(threadIdx.x==0 && blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ - printf("#### h1 0x0v0v0v0v:%08x \n", bitwise_workspace); - } - #endif - ge_7_mask_h_0 = (bitwise_workspace | bitwise_workspace << 4) & 0x88888888; - ge_7_mask_h_0 |= ge_7_mask_h_0 >> 1; - ge_7_mask_h_0 |= ge_7_mask_h_0 >> 2; - #ifdef NF4_LUT_DEBUG - if(threadIdx.x==0 && blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ - printf("#### ge_7_mask_h_0:%08x \n", ge_7_mask_h_0); - } - #endif - bitwise_workspace = (bitwise_workspace | (bitwise_workspace >> 4)) & 0x00FF00FF; - bitwise_workspace = (bitwise_workspace | (bitwise_workspace >> 8)) & 0x00007777; // make look_up_h_0 into 0x0000vvvv - #ifdef NF4_LUT_DEBUG - if(threadIdx.x==0 && blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ - printf("#### h1 0x0000vvvv:%08x \n", bitwise_workspace); - } - #endif - asm volatile("prmt.b32 %0,%1,%2,%3;\n" - : "=r"(result_1) - : "r"(reg_look_up_table[0]), "r"(reg_look_up_table[1]), "r"(bitwise_workspace)); - #ifdef NF4_LUT_DEBUG - if(threadIdx.x==0 && blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ - printf("#### result_1:%08x \n", result_1); - } - #endif - result_reg[1] = result_1 & (~ge_7_mask_h_0); - result_1 = ((~result_1)) & (ge_7_mask_h_0); - // result_1 = ((~result_1) | 0x01010101) & (ge_7_mask_h_0); // half mirror - result_reg[1] = result_reg[1] | result_1; - #ifdef NF4_LUT_DEBUG - if(threadIdx.x==0 && blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ - printf("#### result[1]:%08x \n", result_reg[1]); - } - #endif - } - return result; + uint32_t result_1 = 0; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(result_1) + : "r"(reg_look_up_table[0]), + "r"(reg_look_up_table[1]), + "r"(bitwise_workspace)); +#ifdef NF4_LUT_DEBUG + if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && + blockIdx.z == 0) { + printf("#### result_1:%08x \n", result_1); + } +#endif + result_reg[0] = result_1 & (~ge_7_mask_h_0); + result_1 = ((~result_1)) & (ge_7_mask_h_0); + // result_1 = ((~result_1) | 0x01010101) & (ge_7_mask_h_0); // half mirror + result_reg[0] = result_reg[0] | result_1; +#ifdef NF4_LUT_DEBUG + if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && + blockIdx.z == 0) { + printf("#### result[0]:%08x \n", result_reg[0]); + } +#endif + bitwise_workspace = i4s >> 4; + bitwise_workspace = bitwise_workspace & down_int4_mask; +#ifdef NF4_LUT_DEBUG + if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && + blockIdx.z == 0) { + printf("#### h1 0x0v0v0v0v:%08x \n", bitwise_workspace); + } +#endif + ge_7_mask_h_0 = (bitwise_workspace | bitwise_workspace << 4) & 0x88888888; + ge_7_mask_h_0 |= ge_7_mask_h_0 >> 1; + ge_7_mask_h_0 |= ge_7_mask_h_0 >> 2; +#ifdef NF4_LUT_DEBUG + if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && + blockIdx.z == 0) { + printf("#### ge_7_mask_h_0:%08x \n", ge_7_mask_h_0); + } +#endif + bitwise_workspace = + (bitwise_workspace | (bitwise_workspace >> 4)) & 0x00FF00FF; + bitwise_workspace = (bitwise_workspace | (bitwise_workspace >> 8)) & + 0x00007777; // make look_up_h_0 into 0x0000vvvv +#ifdef NF4_LUT_DEBUG + if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && + blockIdx.z == 0) { + printf("#### h1 0x0000vvvv:%08x \n", bitwise_workspace); + } +#endif + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(result_1) + : "r"(reg_look_up_table[0]), + "r"(reg_look_up_table[1]), + "r"(bitwise_workspace)); +#ifdef NF4_LUT_DEBUG + if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && + blockIdx.z == 0) { + printf("#### result_1:%08x \n", result_1); + } +#endif + result_reg[1] = result_1 & (~ge_7_mask_h_0); + result_1 = ((~result_1)) & (ge_7_mask_h_0); + // result_1 = ((~result_1) | 0x01010101) & (ge_7_mask_h_0); // half mirror + result_reg[1] = result_reg[1] | result_1; +#ifdef NF4_LUT_DEBUG + if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && + blockIdx.z == 0) { + printf("#### result[1]:%08x \n", result_reg[1]); + } +#endif } + return result; + } - - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } - CUTLASS_DEVICE - result_type operator()(source_type const& s, uint32_t shared_look_up_table[32][32], int32_t warp_idx, int32_t lane_idx) - { - return convert(s,shared_look_up_table,warp_idx, lane_idx); - } - CUTLASS_DEVICE - result_type operator()(source_type const& s, uint32_t shared_look_up_table[16]) - { - return convert(s,shared_look_up_table); - } - CUTLASS_DEVICE - result_type operator()(source_type const& s, int32_t* shared_look_up_table) - { - return convert(s,shared_look_up_table); - } - CUTLASS_DEVICE - result_type operator()(source_type const& s, cutlass::Array const& reg_look_up_table) - { - return convert(s,reg_look_up_table); - } + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } + CUTLASS_DEVICE + result_type operator()(source_type const& s, + uint32_t shared_look_up_table[32][32], + int32_t warp_idx, + int32_t lane_idx) { + return convert(s, shared_look_up_table, warp_idx, lane_idx); + } + CUTLASS_DEVICE + result_type operator()(source_type const& s, + uint32_t shared_look_up_table[16]) { + return convert(s, shared_look_up_table); + } + CUTLASS_DEVICE + result_type operator()(source_type const& s, int32_t* shared_look_up_table) { + return convert(s, shared_look_up_table); + } + CUTLASS_DEVICE + result_type operator()( + source_type const& s, + cutlass::Array const& reg_look_up_table) { + return convert(s, reg_look_up_table); + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/tile_interleaved_layout.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/tile_interleaved_layout.h index 89152360f1..40265273e0 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/tile_interleaved_layout.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/tile_interleaved_layout.h @@ -40,20 +40,20 @@ limitations under the License. */ namespace cutlass { namespace layout { -template +template class ColumnMajorTileInterleave { - static constexpr int kRowsPerTile = RowsPerTile; - static constexpr int kColumnsInterleaved = ColumnsInterleaved; + static constexpr int kRowsPerTile = RowsPerTile; + static constexpr int kColumnsInterleaved = ColumnsInterleaved; }; -template +template struct IsColumnMajorTileInterleave { - static constexpr bool value = false; + static constexpr bool value = false; }; -template +template struct IsColumnMajorTileInterleave> { - static constexpr bool value = true; + static constexpr bool value = true; }; } // namespace layout diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_heuristic_w4a4.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_heuristic_w4a4.h index 096a857de6..5fa4d7c98d 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_heuristic_w4a4.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_heuristic_w4a4.h @@ -33,7 +33,6 @@ limitations under the License. */ #include "glog/logging.h" #include "w4a4_gemm_configs.h" - static TileShape get_cta_shape_for_config_w4a4(CutlassTileConfig tile_config) { switch (tile_config) { case CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: @@ -58,12 +57,12 @@ static TileShape get_cta_shape_for_config_w4a4(CutlassTileConfig tile_config) { } static bool is_valid_split_k_factor_w4a4(const int64_t m, - const int64_t n, - const int64_t k, - const TileShape tile_shape, - const int split_k_factor, - const size_t workspace_bytes, - const bool is_weight_only) { + const int64_t n, + const int64_t k, + const TileShape tile_shape, + const int split_k_factor, + const size_t workspace_bytes, + const bool is_weight_only) { // All tile sizes have a k_tile of 64. static constexpr int k_tile = 64; @@ -125,7 +124,7 @@ static std::vector get_candidate_tiles_w4a4( CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64 // CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, // CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64, - }; + }; std::vector quant_B_configs; switch (sm) { case 90: @@ -148,7 +147,6 @@ static std::vector get_candidate_tiles_w4a4( return simt_configs_only ? simt_configs : allowed_configs; } - static std::vector get_candidate_configs_nf4( int sm, const bool is_weight_only, @@ -171,7 +169,6 @@ static std::vector get_candidate_configs_nf4( return candidate_configs; } - static CutlassGemmConfig estimate_best_config_from_occupancies_w4a4( const std::vector& candidate_configs, const std::vector& occupancies, @@ -189,7 +186,7 @@ static CutlassGemmConfig estimate_best_config_from_occupancies_w4a4( "candidate configs vectors must have equal length."); } - VLOG(1)<<"estimate_best_config_from_occupancies_w4a4"; + VLOG(1) << "estimate_best_config_from_occupancies_w4a4"; CutlassGemmConfig best_config; // Score will be [0, 1]. The objective is to minimize this score. // It represents the fraction of SM resources unused in the last wave. @@ -198,25 +195,25 @@ static CutlassGemmConfig estimate_best_config_from_occupancies_w4a4( int current_m_tile = 0; { - VLOG(1)<<"######## begin of cutlass gemm search"; - if (m >= 256 && - std::find_if( - candidate_configs.begin(), - candidate_configs.end(), - [](const CutlassGemmConfig& gemm_config) { - return gemm_config.tile_config == - CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64; - }) != candidate_configs.end()) { - VLOG(1) << "m >= 256, encoder config"; - best_config = CutlassGemmConfig{ - CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64, - SplitKStyle::SPLIT_K_STREAM, - // SplitKStyle::NO_SPLIT_K, - 1, - 3}; + VLOG(1) << "######## begin of cutlass gemm search"; + if (m >= 256 && + std::find_if( + candidate_configs.begin(), + candidate_configs.end(), + [](const CutlassGemmConfig& gemm_config) { + return gemm_config.tile_config == + CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64; + }) != candidate_configs.end()) { + VLOG(1) << "m >= 256, encoder config"; + best_config = CutlassGemmConfig{ + CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64, + SplitKStyle::SPLIT_K_STREAM, + // SplitKStyle::NO_SPLIT_K, + 1, + 3}; - } else { - VLOG(1) << "m <= 64 , decoder config"; + } else { + VLOG(1) << "m <= 64 , decoder config"; const int max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit; for (int ii = 0; ii < candidate_configs.size(); ++ii) { @@ -240,14 +237,14 @@ static CutlassGemmConfig estimate_best_config_from_occupancies_w4a4( const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; for (int split_k_factor = 1; split_k_factor <= max_split_k; - ++split_k_factor) { + ++split_k_factor) { if (is_valid_split_k_factor_w4a4(m, - n, - k, - tile_shape, - split_k_factor, - workspace_bytes, - is_weight_only)) { + n, + k, + tile_shape, + split_k_factor, + workspace_bytes, + is_weight_only)) { const int ctas_per_wave = occupancy * multi_processor_count; const int ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor; @@ -262,7 +259,7 @@ static CutlassGemmConfig estimate_best_config_from_occupancies_w4a4( const float score_slack = 0.1f; if (current_score < config_score || ((config_waves > num_waves_total) && - (current_score < config_score + score_slack))) { + (current_score < config_score + score_slack))) { config_score = current_score; config_waves = num_waves_total; SplitKStyle split_style = split_k_factor > 1 @@ -273,9 +270,10 @@ static CutlassGemmConfig estimate_best_config_from_occupancies_w4a4( split_k_factor, candidate_config.stages}; current_m_tile = tile_shape.m; - // std::cout<<"#### split-k factor: "<::type>(candidate_config.tile_config)<::type>(candidate_config.tile_config)<::type>(best_config.tile_config); + VLOG(1) << "#### best split-k factor: " << best_config.split_k_factor + << " config: " + << static_cast::type>( + best_config.tile_config); return best_config; } diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a4_gemm_configs.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a4_gemm_configs.h index 790ec3b17b..00c338ba97 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a4_gemm_configs.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a4_gemm_configs.h @@ -33,44 +33,44 @@ limitations under the License. */ // Note: The shapes are in the format MxNxK. The K shape of the runtime config enum class CutlassTileConfig { // Signals that we should run heuristics do choose a config - Undefined, // 0 + Undefined, // 0 // Signals that we should run heuristics do choose a config - ChooseWithHeuristic, // 1 + ChooseWithHeuristic, // 1 // SiMT config - CtaShape128x128x8_WarpShape64x64x8, // 2 + CtaShape128x128x8_WarpShape64x64x8, // 2 // TensorCore configs CTA_N = 128, CTA_K = 64 // Warp configs for M=16 - CtaShape16x128x64_WarpShape16x32x64, // 3 - CtaShape16x256x64_WarpShape16x64x64, // 4 + CtaShape16x128x64_WarpShape16x32x64, // 3 + CtaShape16x256x64_WarpShape16x64x64, // 4 // Warp configs for M=32 - CtaShape32x128x64_WarpShape32x32x64, // 5 + CtaShape32x128x64_WarpShape32x32x64, // 5 // Warp configs for M=64 - CtaShape64x128x64_WarpShape32x64x64, // 6 - CtaShape64x128x64_WarpShape64x32x64, // 7 + CtaShape64x128x64_WarpShape32x64x64, // 6 + CtaShape64x128x64_WarpShape64x32x64, // 7 // Warp configs for M=128 - CtaShape128x128x64_WarpShape64x32x64, // 8 - CtaShape128x128x64_WarpShape128x32x64, // 9 + CtaShape128x128x64_WarpShape64x32x64, // 8 + CtaShape128x128x64_WarpShape128x32x64, // 9 // configs for large M in encoder - CtaShape128x256x64_WarpShape64x64x64, // 10 - CtaShape256x128x64_WarpShape64x64x64, // 11 + CtaShape128x256x64_WarpShape64x64x64, // 10 + CtaShape256x128x64_WarpShape64x64x64, // 11 - CtaShape32x256x64_WarpShape32x64x64, // 12 - CtaShape64x256x64_WarpShape64x64x64, // 13 - CtaShape128x256x64_WarpShape128x64x64, // 14 - CtaShape32x512x64_WarpShape32x128x64, // 15 + CtaShape32x256x64_WarpShape32x64x64, // 12 + CtaShape64x256x64_WarpShape64x64x64, // 13 + CtaShape128x256x64_WarpShape128x64x64, // 14 + CtaShape32x512x64_WarpShape32x128x64, // 15 }; enum class SplitKStyle { - NO_SPLIT_K, //0 - SPLIT_K_SERIAL, //1 - SPLIT_K_STREAM, //2 + NO_SPLIT_K, // 0 + SPLIT_K_SERIAL, // 1 + SPLIT_K_STREAM, // 2 // SPLIT_K_PARALLEL // Not supported yet }; diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_gemm_grouped.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_gemm_grouped.h index 1301cc351c..ae90627fc7 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_gemm_grouped.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_gemm_grouped.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,22 +18,23 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file - \brief The universal GEMM accommodates streamk, batched strided, and batched array variants. + \brief The universal GEMM accommodates streamk, batched strided, and batched + array variants. */ - #pragma once #include @@ -59,11 +60,9 @@ namespace device { ///////////////////////////////////////////////////////////////////////////////////////////////// - template class W4A8MoeGemmUniversalBase { -public: - + public: using GemmKernel = GemmKernel_; using ThreadblockShape = typename GemmKernel::Mma::Shape; @@ -92,8 +91,7 @@ public: /// Argument structure using Arguments = typename GemmKernel::Arguments; -protected: - + protected: // // Device properties (uniform across all instances of the current thread) // @@ -112,8 +110,7 @@ protected: /// Initialize static thread-local members for the thread's current device, /// if necessary. - static Status init_device_props() - { + static Status init_device_props() { CUTLASS_TRACE_HOST("W4A8MoeGemmUniversalBase::init_device_props()"); cudaError_t cudart_result; @@ -122,7 +119,8 @@ protected: int current_ordinal; cudart_result = cudaGetDevice(¤t_ordinal); if (cudart_result != cudaSuccess) { - CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " << cudaGetErrorString(cudart_result)); + CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " + << cudaGetErrorString(cudart_result)); return Status::kErrorInternal; } @@ -133,64 +131,80 @@ protected: } // Update SM count member - cudart_result = cudaDeviceGetAttribute (&device_sms_, cudaDevAttrMultiProcessorCount, current_ordinal); + cudart_result = cudaDeviceGetAttribute( + &device_sms_, cudaDevAttrMultiProcessorCount, current_ordinal); if (cudart_result != cudaSuccess) { - CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " << cudaGetErrorString(cudart_result)); + CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " + << cudaGetErrorString(cudart_result)); return Status::kErrorInternal; } - // Update the kernel function's shared memory configuration for the current device + // Update the kernel function's shared memory configuration for the current + // device smem_size_ = int(sizeof(typename GemmKernel::SharedStorage)); // If requires more than 48KB: configure for extended, dynamic shared memory - if (smem_size_ >= (48 << 10)) - { - cudart_result = cudaFuncSetAttribute( - Kernel2, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size_); + if (smem_size_ >= (48 << 10)) { + cudart_result = + cudaFuncSetAttribute(Kernel2, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size_); if (cudart_result != cudaSuccess) { - CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(cudart_result)); + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " + << cudaGetErrorString(cudart_result)); return Status::kErrorInternal; } - cudart_result = cudaFuncSetAttribute( - Kernel2, - cudaFuncAttributePreferredSharedMemoryCarveout, 100); // 100% shared memory + cudart_result = + cudaFuncSetAttribute(Kernel2, + cudaFuncAttributePreferredSharedMemoryCarveout, + 100); // 100% shared memory if (cudart_result != cudaSuccess) { - CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(cudart_result)); + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " + << cudaGetErrorString(cudart_result)); return Status::kErrorInternal; } } // Update SM occupancy member cudart_result = cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( - &sm_occupancy_, - Kernel2, - GemmKernel::kThreadCount, - smem_size_, - cudaOccupancyDisableCachingOverride); + &sm_occupancy_, + Kernel2, + GemmKernel::kThreadCount, + smem_size_, + cudaOccupancyDisableCachingOverride); if (cudart_result != cudaSuccess) { - CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags() returned error " << cudaGetErrorString(cudart_result)); + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags() returned " + "error " + << cudaGetErrorString(cudart_result)); return Status::kErrorInternal; } // Update device ordinal member on success device_ordinal_ = current_ordinal; - CUTLASS_TRACE_HOST(" " - "device_ordinal: (" << device_ordinal_ << "), " - "device_sms: (" << device_sms_ << "), " - "sm_occupancy: (" << sm_occupancy_ << ") " - "smem_size: (" << smem_size_ << ") " - "GemmKernel::kThreadCount: (" << GemmKernel::kThreadCount << ")"); + CUTLASS_TRACE_HOST( + " " + "device_ordinal: (" + << device_ordinal_ + << "), " + "device_sms: (" + << device_sms_ + << "), " + "sm_occupancy: (" + << sm_occupancy_ + << ") " + "smem_size: (" + << smem_size_ + << ") " + "GemmKernel::kThreadCount: (" + << GemmKernel::kThreadCount << ")"); return Status::kSuccess; } - -protected: - + protected: // // Instance data members // @@ -198,10 +212,8 @@ protected: /// Kernel parameters typename GemmKernel::Params params_; - /// Initialize params member - Status init_params(Arguments const &args) - { + Status init_params(Arguments const &args) { // Initialize static device properties, if necessary Status result = init_device_props(); if (result != Status::kSuccess) { @@ -213,15 +225,13 @@ protected: return Status::kSuccess; } -public: - + public: //--------------------------------------------------------------------------------------------- // Stateless API //--------------------------------------------------------------------------------------------- /// Determines whether the GEMM can execute the given problem. - static Status can_implement(Arguments const &args) - { + static Status can_implement(Arguments const &args) { CUTLASS_TRACE_HOST("W4A8MoeGemmUniversalBase::can_implement()"); // Initialize static kernel and device properties, if necessary. Status result = init_device_props(); @@ -231,18 +241,15 @@ public: dim3 grid = get_grid_shape(args); // printf("--grid:%d, %d, %d\n", grid.x, grid.y, grid.z); if (!(grid.y <= std::numeric_limits::max() && - grid.z <= std::numeric_limits::max())) - { + grid.z <= std::numeric_limits::max())) { return Status::kErrorInvalidProblem; } return GemmKernel::can_implement(args); } - /// Returns the workspace size (in bytes) needed for the problem /// geometry expressed by these arguments - static size_t get_workspace_size(Arguments const &args) - { + static size_t get_workspace_size(Arguments const &args) { CUTLASS_TRACE_HOST("W4A8MoeGemmUniversalBase::get_workspace_size()"); // Initialize parameters from args @@ -258,33 +265,28 @@ public: return workspace_bytes; } - /// Returns the grid extents in thread blocks to launch - static dim3 get_grid_shape(Arguments const &args) - { + static dim3 get_grid_shape(Arguments const &args) { CUTLASS_TRACE_HOST("W4A8MoeGemmUniversalBase::get_grid_shape()"); // Initialize parameters from args W4A8MoeGemmUniversalBase base; if (base.init_params(args) != Status::kSuccess) { - return dim3(0,0,0); + return dim3(0, 0, 0); } // Get dims from parameters dim3 grid_dims = base.params_.get_grid_dims(); - CUTLASS_TRACE_HOST( - " tiled_shape: " << base.params_.get_tiled_shape() << "\n" - << " grid_dims: {" << grid_dims << "}"); + CUTLASS_TRACE_HOST(" tiled_shape: " + << base.params_.get_tiled_shape() << "\n" + << " grid_dims: {" << grid_dims << "}"); return grid_dims; } - - /// Returns the maximum number of active thread blocks per multiprocessor - static int maximum_active_blocks(int smem_capacity = -1) - { + static int maximum_active_blocks(int smem_capacity = -1) { CUTLASS_TRACE_HOST("W4A8MoeGemmUniversalBase::maximum_active_blocks()"); int smem_size = int(sizeof(typename GemmKernel_::SharedStorage)); @@ -300,26 +302,25 @@ public: if (result != cudaSuccess) { // Call cudaGetLastError() to clear the error bit result = cudaGetLastError(); - CUTLASS_TRACE_HOST( - " cudaFuncSetAttribute() returned error " - << cudaGetErrorString(result)); + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " + << cudaGetErrorString(result)); return -1; } } int max_active_blocks = -1; - result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, - Kernel2, - GemmKernel_::kThreadCount, - smem_size); + result = + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, + Kernel2, + GemmKernel_::kThreadCount, + smem_size); if (result != cudaSuccess) { // Call cudaGetLastError() to clear the error bit result = cudaGetLastError(); CUTLASS_TRACE_HOST( - " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " - << cudaGetErrorString(result)); + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " + << cudaGetErrorString(result)); return -1; } @@ -327,19 +328,17 @@ public: return max_active_blocks; } - //--------------------------------------------------------------------------------------------- // Stateful API //--------------------------------------------------------------------------------------------- /// Initializes GEMM state from arguments and workspace memory - Status initialize( - Arguments const &args, - void *workspace, - cudaStream_t stream = nullptr) - { + Status initialize(Arguments const &args, + void *workspace, + cudaStream_t stream = nullptr) { CUTLASS_TRACE_HOST("W4A8MoeGemmUniversalBase::initialize() - workspace " - << workspace << ", stream: " << (stream ? "non-null" : "null")); + << workspace + << ", stream: " << (stream ? "non-null" : "null")); // Initialize parameters from args Status result = init_params(args); @@ -351,20 +350,16 @@ public: return params_.init_workspace(workspace, stream); } - - /// Lightweight update given a subset of arguments. Problem geometry is assumed to - /// remain the same. - Status update(Arguments const &args) - { + /// Lightweight update given a subset of arguments. Problem geometry is + /// assumed to remain the same. + Status update(Arguments const &args) { CUTLASS_TRACE_HOST("W4A8MoeGemmUniversalBase()::update()"); params_.update(args); return Status::kSuccess; } - /// Runs the kernel using initialized state. - Status run(cudaStream_t stream = nullptr) - { + Status run(cudaStream_t stream = nullptr) { CUTLASS_TRACE_HOST("W4A8MoeGemmUniversalBase::run()"); // Configure grid and block dimensions @@ -372,37 +367,37 @@ public: dim3 grid(params_.threadblock_count, 1, 1); // Launch kernel - CUTLASS_TRACE_HOST(" " - "grid: (" << grid << "), " - "block: (" << block << "), " - "SMEM: (" << smem_size_ << ")"); + CUTLASS_TRACE_HOST( + " " + "grid: (" + << grid + << "), " + "block: (" + << block + << "), " + "SMEM: (" + << smem_size_ << ")"); Kernel2<<>>(params_); // Query for errors cudaError_t result = cudaGetLastError(); if (result != cudaSuccess) { - CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + CUTLASS_TRACE_HOST(" grid launch failed with error " + << cudaGetErrorString(result)); return Status::kErrorInternal; } return Status::kSuccess; } + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { return run(stream); } /// Runs the kernel using initialized state. - Status operator()(cudaStream_t stream = nullptr) - { - return run(stream); - } - - - /// Runs the kernel using initialized state. - Status operator()( - Arguments const &args, - void *workspace = nullptr, - cudaStream_t stream = nullptr) - { + Status operator()(Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { @@ -413,7 +408,6 @@ public: } }; - ///////////////////////////////////////////////////////////////////////////////////////////////// /// Static initializers ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -434,12 +428,10 @@ thread_local int W4A8MoeGemmUniversalBase::sm_occupancy_ = -1; template thread_local int W4A8MoeGemmUniversalBase::smem_size_ = -1; - - ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace device -} // namespace gemm -} // namespace cutlass +} // namespace device +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_cutlass_kernel.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_cutlass_kernel.h index dbcc8912f2..18cdeca89c 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_cutlass_kernel.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_cutlass_kernel.h @@ -37,60 +37,65 @@ namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct MoeW4A8Gemm { - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; static bool const kSplitKSerial = SplitKSerial; static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; static bool const kTransposed = false; - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; using TensorRefA = TensorRef; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; using TensorRefB = TensorRef; using LayoutAlphaCol = cutlass::layout::RowMajor; using LayoutAlphaRow = cutlass::layout::ColumnMajor; using TensorRefNf4LookUpTable = TensorRef; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename Mma::LayoutC; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Mma::LayoutC; using TensorRefC = TensorRef; static ComplexTransform const kTransformA = Mma::kTransformA; static ComplexTransform const kTransformB = Mma::kTransformA; // Type definitions about the mainloop. - using Operator = typename Mma::Operator; - using OperatorClass = typename Mma::Operator::OperatorClass; + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename Mma::Operator::Shape; + using WarpShape = typename Mma::Operator::Shape; using InstructionShape = typename Mma::Policy::Operator::InstructionShape; - using ArchTag = typename Mma::ArchTag; + using ArchTag = typename Mma::ArchTag; - static int const kStages = Mma::kStages; + static int const kStages = Mma::kStages; static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; - static int const kAlignmentD = Epilogue::OutputTileIterator::kElementsPerAccess; + static int const kAlignmentD = + Epilogue::OutputTileIterator::kElementsPerAccess; /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; + using WarpCount = typename Mma::WarpCount; static int const kThreadCount = 32 * WarpCount::kCount; - static int const kSplitKAlignment = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); + static int const kSplitKAlignment = const_max( + 128 / sizeof_bits::value, 128 / sizeof_bits::value); - static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; + static constexpr int kInterleave = + Mma::IteratorB::Shape::kRow / Mma::Shape::kK; using ProblemVisitor = GemmMoeProblemVisitor; - /// Argument structure struct Arguments : UniversalArgumentsBase { // @@ -133,8 +137,7 @@ struct MoeW4A8Gemm { Arguments() {} /// constructs an arguments structure - Arguments( - cutlass::gemm::GemmUniversalMode mode_, + Arguments(cutlass::gemm::GemmUniversalMode mode_, GemmCoord problem_size_, int problem_count, int batch_count_, @@ -170,7 +173,7 @@ struct MoeW4A8Gemm { host_problem_sizes(nullptr) {} }; - /// Parameters structure + /// Parameters structure struct Params : UniversalParamsBase::value) { - isAMisaligned = problem_size.k() % kAlignmentA; - } else if (platform::is_same::value) { - isAMisaligned = problem_size.m() % kAlignmentA; - } else if (platform::is_same>::value || - platform::is_same>::value) { - isAMisaligned = problem_size.k() % kAlignmentA; - } - - if (platform::is_same::value) { - isBMisaligned = problem_size.n() % kAlignmentB; - } else if (platform::is_same::value) { - isBMisaligned = problem_size.k() % kAlignmentB; - } else if (platform::is_same>::value || - platform::is_same>::value) { - isBMisaligned = problem_size.k() % kAlignmentB; - } - - if (platform::is_same::value) { - // isCMisaligned = problem_size.n() % kAlignmentC; - } else if (platform::is_same::value) { - // isCMisaligned = problem_size.m() % kAlignmentC; - } else if (platform::is_same>::value || - platform::is_same>::value) { - // isCMisaligned = problem_size.n() % kAlignmentC; - } - - if (isAMisaligned) { - CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); - return Status::kErrorMisalignedOperand; - } - - if (isBMisaligned) { - CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); - return Status::kErrorMisalignedOperand; - } - - if (isCMisaligned) { - CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); - return Status::kErrorMisalignedOperand; - } - - CUTLASS_TRACE_HOST(" returning kSuccess"); - - return Status::kSuccess; + if (platform::is_same::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } else if (platform::is_same::value) { + isAMisaligned = problem_size.m() % kAlignmentA; + } else if (platform::is_same>::value || + platform::is_same>::value) { + isAMisaligned = problem_size.k() % kAlignmentA; } + if (platform::is_same::value) { + isBMisaligned = problem_size.n() % kAlignmentB; + } else if (platform::is_same::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } else if (platform::is_same>::value || + platform::is_same>::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } + + if (platform::is_same::value) { + // isCMisaligned = problem_size.n() % kAlignmentC; + } else if (platform::is_same::value) { + // isCMisaligned = problem_size.m() % kAlignmentC; + } else if (platform::is_same>::value || + platform::is_same>::value) { + // isCMisaligned = problem_size.n() % kAlignmentC; + } + + if (isAMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); + return Status::kErrorMisalignedOperand; + } + + CUTLASS_TRACE_HOST(" returning kSuccess"); + + return Status::kSuccess; + } + static Status can_implement(Arguments const& args) { return can_implement(args.problem_size); } - static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) - { + static size_t get_extra_workspace_size( + Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) { + return 0; + } - return 0; - } - - - CUTLASS_DEVICE - static void invoke( - Params const ¶ms, - SharedStorage &shared_storage) - { - MoeW4A8Gemm op; - op(params, shared_storage); - } + CUTLASS_DEVICE + static void invoke(Params const& params, SharedStorage& shared_storage) { + MoeW4A8Gemm op; + op(params, shared_storage); + } #define SPLIT_K_ENABLED 1 /// Executes one GEMM CUTLASS_DEVICE void operator()(Params const& params, SharedStorage& shared_storage) { - using ElementA = typename Mma::IteratorA::Element; using LayoutA = typename Mma::IteratorA::Layout; using ElementB = typename Mma::IteratorB::Element; @@ -340,12 +334,11 @@ struct MoeW4A8Gemm { static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; - static_assert( - platform::is_same::value && - kInterleave == 1 || - platform::is_same::value && - kInterleave >= 1, - "B must be row major/col major OR col major interleaved."); + static_assert(platform::is_same::value && + kInterleave == 1 || + platform::is_same::value && + kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); // // Problem visitor. @@ -357,190 +350,187 @@ struct MoeW4A8Gemm { int64_t bytes_per_expert_matrix = (gemm_k * gemm_n / 8) * cutlass::sizeof_bits::value; - // Outer 'persistent' loop to iterate over tiles - while (problem_visitor.next_tile()) { + // Outer 'persistent' loop to iterate over tiles + while (problem_visitor.next_tile()) { // // Compute threadblock location ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - GemmCoord problem_size = problem_visitor.problem_size(); - int32_t problem_idx = problem_visitor.problem_index(); - int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); - GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); - cutlass::gemm::GemmCoord threadblock_offset( - int(cta_idx / grid_shape.n()) * Mma::Shape::kM, // NOLINT - int(cta_idx % grid_shape.n()) * Mma::Shape::kN, // NOLINT - 0); + cutlass::gemm::GemmCoord threadblock_offset( + int(cta_idx / grid_shape.n()) * Mma::Shape::kM, // NOLINT + int(cta_idx % grid_shape.n()) * Mma::Shape::kN, // NOLINT + 0); - // Load element pointers. Exchange pointers and strides if working on - // the transpose - const int64_t rows_to_jump = - problem_idx == 0 - ? 0 - : params.problem_visitor.last_row_for_problem[problem_idx - 1]; - ElementA* ptr_A = - reinterpret_cast(params.ptr_A) + rows_to_jump * gemm_k; - typename LayoutA::LongIndex ldm_A = gemm_k; + // Load element pointers. Exchange pointers and strides if working on + // the transpose + const int64_t rows_to_jump = + problem_idx == 0 + ? 0 + : params.problem_visitor.last_row_for_problem[problem_idx - 1]; + ElementA* ptr_A = + reinterpret_cast(params.ptr_A) + rows_to_jump * gemm_k; + typename LayoutA::LongIndex ldm_A = gemm_k; - char* byte_ptr_B = ((char*)params.ptr_B) + // NOLINT - problem_idx * bytes_per_expert_matrix; // NOLINT - ElementB* ptr_B = reinterpret_cast(byte_ptr_B); - typename LayoutB::LongIndex ldm_B = - platform::is_same::value - ? gemm_n - : gemm_k * kInterleave; + char* byte_ptr_B = ((char*)params.ptr_B) + // NOLINT + problem_idx * bytes_per_expert_matrix; // NOLINT + ElementB* ptr_B = reinterpret_cast(byte_ptr_B); + typename LayoutB::LongIndex ldm_B = + platform::is_same::value + ? gemm_n + : gemm_k * kInterleave; int offset_k = 0; int problem_size_k = params.problem_size.k(); - // Compute initial location in logical coordinates + // Compute initial location in logical coordinates cutlass::MatrixCoord tb_offset_A{ threadblock_offset.m(), 0, }; - cutlass::MatrixCoord tb_offset_B{ - 0, - threadblock_offset.n() / kInterleave}; + cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave}; - // Compute position within threadblock - int thread_idx = threadIdx.x; + // Compute position within threadblock + int thread_idx = threadIdx.x; - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A( - params.params_A, - ptr_A, - {params.problem_size.m(), problem_size_k}, - thread_idx, - tb_offset_A); + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + ptr_A, + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A); + typename Mma::IteratorB iterator_B( + params.params_B, + ptr_B, + {problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, + thread_idx, + tb_offset_B); + typename Mma::IteratorNF4LookUpTable iterator_nf4_look_up_table = + Mma::IteratorNF4LookUpTable(params.params_nf4_look_up_table, + params.ref_nf4_look_up_table.data(), + {0, 16}, + threadIdx.x, + {0, 0}); - typename Mma::IteratorB iterator_B( - params.params_B, - ptr_B, - {problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, - thread_idx, - tb_offset_B); - typename Mma::IteratorNF4LookUpTable iterator_nf4_look_up_table = - Mma::IteratorNF4LookUpTable( - params.params_nf4_look_up_table, - params.ref_nf4_look_up_table.data(), - {0,16}, - threadIdx.x, - {0,0} - ); + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; - int lane_idx = threadIdx.x % 32; + // + // Main loop + // - // - // Main loop - // + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + typename Mma::FragmentC accumulators; - typename Mma::FragmentC accumulators; + accumulators.clear(); - accumulators.clear(); + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = + (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + // printf("#### gemm_k_iterations: %d \n", gemm_k_iterations); + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + iterator_nf4_look_up_table, + accumulators); + // if(threadIdx.x==0){ + // printf("##### block: %d-%d-%d, offset-m-n-k:%d-%d-%d \n", + // blockIdx.x, blockIdx.y, blockIdx.z, + // threadblock_tile_offset.m(), + // threadblock_tile_offset.n(), + // threadblock_tile_offset.k() + // ); + // } + // + // Masked tile iterators constructed from members + // + EpilogueOutputOp output_op(params.output_op); - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = - (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; - // printf("#### gemm_k_iterations: %d \n", gemm_k_iterations); - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_nf4_look_up_table, accumulators); - // if(threadIdx.x==0){ - // printf("##### block: %d-%d-%d, offset-m-n-k:%d-%d-%d \n", - // blockIdx.x, blockIdx.y, blockIdx.z, - // threadblock_tile_offset.m(), - // threadblock_tile_offset.n(), - // threadblock_tile_offset.k() - // ); - // } - // - // Masked tile iterators constructed from members - // - EpilogueOutputOp output_op(params.output_op); + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + ElementC* ptr_D = + reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; - ElementC* ptr_D = - reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; + int block_idx = threadblock_tile_offset.m() + + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); - int block_idx = threadblock_tile_offset.m() + - threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + // If performing a reduction via split-K, fetch the initial + // synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); - - - // Construct the semaphore. - Semaphore semaphore(params.semaphore + block_idx, thread_idx); - - // If performing a reduction via split-K, fetch the initial synchronization - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { - - // Fetch the synchronization lock initially but do not block. - semaphore.fetch(); - - // Indicate which position in a serial reduction the output operator is currently updating - output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); - } - - // Tile iterator writing to destination tensor. - typename Epilogue::OutputTileIterator iterator_D(params.params_D, - ptr_D, - params.problem_size.mn(), - thread_idx, - threadblock_offset, - params.scatter_D_indices); - - - Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); - - // Wait on the semaphore - this latency may have been covered by iterator construction - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { - - // For subsequent threadblocks, the source matrix is held in the 'D' tensor. - - semaphore.wait(threadblock_tile_offset.k()); - } - - // Execute the epilogue operator to update the destination tensor. - epilogue(output_op, iterator_D, accumulators, iterator_D); - - // - // Release the semaphore - // - - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { - - int lock = 0; - if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { - - // The final threadblock resets the semaphore for subsequent grids. - lock = 0; - } - else { - // Otherwise, the semaphore is incremented - lock = threadblock_tile_offset.k() + 1; - } - - semaphore.release(lock); - } - - // Next tile - shared_storage.problem_visitor.advance(gridDim.x); + // Indicate which position in a serial reduction the output operator is + // currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), + params.grid_tiled_shape.k()); } + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + ptr_D, + params.problem_size.mn(), + thread_idx, + threadblock_offset, + params.scatter_D_indices); + + Epilogue epilogue( + shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator + // construction + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + // For subsequent threadblocks, the source matrix is held in the 'D' + // tensor. + + semaphore.wait(threadblock_tile_offset.k()); + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_D); + + // + // Release the semaphore + // + + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + + // Next tile + shared_storage.problem_visitor.advance(gridDim.x); + } } }; diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_cutlass_kernel_template.cu b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_cutlass_kernel_template.cu index ede2178d27..cf02b86af1 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_cutlass_kernel_template.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_cutlass_kernel_template.cu @@ -38,10 +38,9 @@ #include "cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor_interleaved_nf4.h" #include "w4a8_moe_gemm_with_epilogue_visitor.h" - template class IntegerType { - public: + public: static constexpr int value = val; }; @@ -76,17 +75,21 @@ void generic_w4a8_moe_gemm_kernelLauncher( int multi_processor_count, cudaStream_t stream, int* occupancy) { - if (gemm_config.split_k_style == SplitKStyle::NO_SPLIT_K){ + if (gemm_config.split_k_style == SplitKStyle::NO_SPLIT_K) { static_assert(cutlass::platform::is_same::value, "input type must be int8_t"); - // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. - // using OutputElementType_ = OutputType; - using OutputElementType_ = typename cutlass::platform::conditional::value, - cutlass::bfloat16_t, OutputType>::type; + // The cutlass type for the input elements. This is needed to convert to + // cutlass::half_t if necessary. using OutputElementType_ = OutputType; + using OutputElementType_ = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, + cutlass::bfloat16_t, + OutputType>::type; - using OutputElementType = typename cutlass::platform::conditional::value, - cutlass::half_t, OutputElementType_>::type; + using OutputElementType = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, + cutlass::half_t, + OutputElementType_>::type; using CutlassIntAType_ = IntAType; using CutlassIntAType = CutlassIntAType_; @@ -94,47 +97,55 @@ void generic_w4a8_moe_gemm_kernelLauncher( using CutlassIntBType_ = IntBType; using CutlassIntBType = CutlassIntBType_; - // We need separate config for each architecture since we will target different tensorcore instructions. For float, - // we do not target TCs. + // We need separate config for each architecture since we will target + // different tensorcore instructions. For float, we do not target TCs. - using MixedGemmArchTraits = cutlass::gemm::kernel:: - Int8Nf4GemmArchTraits; - - using ElementAccumulator = typename MixedGemmArchTraits::AccType; - using ElementCompute = float; + using MixedGemmArchTraits = + cutlass::gemm::kernel::Int8Nf4GemmArchTraits; + using ElementAccumulator = typename MixedGemmArchTraits::AccType; + using ElementCompute = float; // ============== using EpilogueOp = - typename Epilogue::Op; + typename Epilogue::Op; - using ThreadBlockSwizzle = typename cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle; - using GemmKernel_ = typename cutlass::gemm::kernel::DefaultInt8InterleavedGemm< - CutlassIntAType, - cutlass::layout::RowMajor, - MixedGemmArchTraits::ElementsPerAccessA, - CutlassIntBType, - typename MixedGemmArchTraits::LayoutB, - MixedGemmArchTraits::ElementsPerAccessB, - OutputElementType, - cutlass::layout::RowMajor, - ElementAccumulator, - cutlass::arch::OpClassTensorOp, - arch, - ThreadblockShape, - WarpShape, - typename MixedGemmArchTraits::InstructionShape, - EpilogueOp, - ThreadBlockSwizzle, - Stages, - true, - typename MixedGemmArchTraits::Operator>::GemmKernel; - using GemmKernel = cutlass::gemm::kernel::MoeW4A8Gemm; + using ThreadBlockSwizzle = typename cutlass::gemm::threadblock:: + GemmBatchedIdentityThreadblockSwizzle; + using GemmKernel_ = + typename cutlass::gemm::kernel::DefaultInt8InterleavedGemm< + CutlassIntAType, + cutlass::layout::RowMajor, + MixedGemmArchTraits::ElementsPerAccessA, + CutlassIntBType, + typename MixedGemmArchTraits::LayoutB, + MixedGemmArchTraits::ElementsPerAccessB, + OutputElementType, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + arch, + ThreadblockShape, + WarpShape, + typename MixedGemmArchTraits::InstructionShape, + EpilogueOp, + ThreadBlockSwizzle, + Stages, + true, + typename MixedGemmArchTraits::Operator>::GemmKernel; + using GemmKernel = cutlass::gemm::kernel::MoeW4A8Gemm< + typename GemmKernel_::Mma, + typename GemmKernel_::Epilogue, + typename GemmKernel_::ThreadblockSwizzle, + arch, // Ensure top level arch is used for dispatch + GemmKernel_::kSplitKSerial, + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly>; using AlphaColTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< cutlass::epilogue::threadblock::OutputTileOptimalThreadMap< @@ -147,38 +158,44 @@ void generic_w4a8_moe_gemm_kernelLauncher( cutlass::sizeof_bits::value>, OutputElementType>; - using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorPerRowPerColNf4< - ThreadblockShape, - GemmKernel::kThreadCount, - AlphaColTileIterator, - typename GemmKernel::Epilogue::OutputTileIterator, - ElementAccumulator, - ElementCompute, - EpilogueOp>; + using EpilogueVisitor = + typename cutlass::epilogue::threadblock::EpilogueVisitorPerRowPerColNf4< + ThreadblockShape, + GemmKernel::kThreadCount, + AlphaColTileIterator, + typename GemmKernel::Epilogue::OutputTileIterator, + ElementAccumulator, + ElementCompute, + EpilogueOp>; /// Epilogue using Epilogue = typename cutlass::epilogue::threadblock:: - EpilogueWithVisitorFromExistingEpilogue::Epilogue; + EpilogueWithVisitorFromExistingEpilogue< + EpilogueVisitor, + typename GemmKernel::Epilogue>::Epilogue; // GEMM using GemmWithEpilogueVisitorKernel = - cutlass::gemm::kernel::MoeW4A8GemmWithEpilogueVisitorInterleavedNf4; - + cutlass::gemm::kernel::MoeW4A8GemmWithEpilogueVisitorInterleavedNf4< + typename GemmKernel::Mma, + Epilogue, + ThreadBlockSwizzle, + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly>; if (occupancy != nullptr) { - *occupancy = compute_occupancy_for_kernel(); - return; + *occupancy = + compute_occupancy_for_kernel(); + return; } - using Gemm = cutlass::gemm::device::W4A8MoeGemmUniversalBase; + using Gemm = cutlass::gemm::device::W4A8MoeGemmUniversalBase< + GemmWithEpilogueVisitorKernel>; const int ldb = - cutlass::platform::is_same::value ? - n : - k * GemmKernel::kInterleave; + cutlass::platform::is_same::value + ? n + : k * GemmKernel::kInterleave; typename EpilogueOp::Params linear_scaling_params; @@ -192,64 +209,79 @@ void generic_w4a8_moe_gemm_kernelLauncher( const int threadblock_count = multi_processor_count * occupancy_; - typename Gemm::Arguments args{cutlass::gemm::GemmUniversalMode::kBatched, - num_experts, - threadblock_count, - {total_rows, n, k}, - 1, - {reinterpret_cast(const_cast(A)), k}, - {reinterpret_cast(const_cast(B)), ldb}, - quant_mode, - {reinterpret_cast(const_cast(col_scale)), 0}, - {reinterpret_cast(const_cast(row_scale)), 0}, - {const_cast(nf4_look_up_table), 0}, - {reinterpret_cast(C), n}, - {reinterpret_cast(C), n}, - total_rows_before_expert, - total_rows_in_ll_else_minus1, - n, - k, - (int64_t)0, - (int64_t)0, - typename EpilogueVisitor::Arguments(linear_scaling_params, 0, 0, 0)}; + typename Gemm::Arguments args{ + cutlass::gemm::GemmUniversalMode::kBatched, + num_experts, + threadblock_count, + {total_rows, n, k}, + 1, + {reinterpret_cast(const_cast(A)), k}, + {reinterpret_cast(const_cast(B)), ldb}, + quant_mode, + {reinterpret_cast( + const_cast(col_scale)), + 0}, + {reinterpret_cast( + const_cast(row_scale)), + 0}, + {const_cast(nf4_look_up_table), 0}, + {reinterpret_cast(C), n}, + {reinterpret_cast(C), n}, + total_rows_before_expert, + total_rows_in_ll_else_minus1, + n, + k, + (int64_t)0, + (int64_t)0, + typename EpilogueVisitor::Arguments(linear_scaling_params, 0, 0, 0)}; - // This assertion is enabled because because for the column interleaved layout, K MUST be a multiple of - // threadblockK. The reason for this is that the default pitchlinear iterators are used to handle walking over the - // interleaved matrix. The way masking in handled in these do not map to the interleaved layout. We need to write - // our own predicated iterator in order to relax this limitation. - if (GemmKernel::kInterleave > 1 - && ((k % MixedGemmArchTraits::ThreadblockK) - || ((k / gemm_config.split_k_factor) % MixedGemmArchTraits::ThreadblockK))) { - throw std::runtime_error("Temp assertion: k must be multiple of threadblockK"); + // This assertion is enabled because because for the column interleaved + // layout, K MUST be a multiple of threadblockK. The reason for this is that + // the default pitchlinear iterators are used to handle walking over the + // interleaved matrix. The way masking in handled in these do not map to the + // interleaved layout. We need to write our own predicated iterator in order + // to relax this limitation. + if (GemmKernel::kInterleave > 1 && + ((k % MixedGemmArchTraits::ThreadblockK) || + ((k / gemm_config.split_k_factor) % + MixedGemmArchTraits::ThreadblockK))) { + throw std::runtime_error( + "Temp assertion: k must be multiple of threadblockK"); } Gemm gemm; if (gemm.get_workspace_size(args) > workspace_bytes) { - std::cout<< - "Requested split-k but workspace size insufficient. Falling back to non-split-k implementation."< -void dispatch_moe_gemm_to_cutlass( - const IntAType* A, - const IntBType* B, - cutlass::epilogue::QuantMode quant_mode, - const OutputType* col_scale, - const OutputType* row_scale, - const int32_t* nf4_look_up_table, - OutputType* C, - int64_t* total_rows_before_expert, - int64_t total_rows_in_ll_else_minus1, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - CutlassGemmConfig gemm_config, - char* workspace_ptr, - const size_t workspace_bytes, - // int sm_version, - int multi_processor_count, - cudaStream_t stream, - int* occupancy = nullptr) { +void dispatch_moe_gemm_to_cutlass(const IntAType* A, + const IntBType* B, + cutlass::epilogue::QuantMode quant_mode, + const OutputType* col_scale, + const OutputType* row_scale, + const int32_t* nf4_look_up_table, + OutputType* C, + int64_t* total_rows_before_expert, + int64_t total_rows_in_ll_else_minus1, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + CutlassGemmConfig gemm_config, + char* workspace_ptr, + const size_t workspace_bytes, + // int sm_version, + int multi_processor_count, + cudaStream_t stream, + int* occupancy = nullptr) { // VLOG(1)<<__PRETTY_FUNCTION__; auto dispatch_by_tile = [&](auto ThreadblockShapeM, @@ -371,67 +400,67 @@ void dispatch_moe_gemm_to_cutlass( auto WarpShapeM, auto WarpShapeN, auto WarpShapeK) { - dispatch_gemm_config< - OutputType, - IntAType, - IntBType, - arch, - EpilogueTag, - cutlass::gemm::GemmShape, - cutlass::gemm::GemmShape> - (A, - B, - quant_mode, - col_scale, - row_scale, - nf4_look_up_table, - C, - total_rows_before_expert, - total_rows_in_ll_else_minus1, - total_rows, - gemm_n, - gemm_k, - num_experts, - gemm_config, - workspace_ptr, - workspace_bytes, - multi_processor_count, - stream, - occupancy); + dispatch_gemm_config< + OutputType, + IntAType, + IntBType, + arch, + EpilogueTag, + cutlass::gemm::GemmShape, + cutlass::gemm::GemmShape>( + A, + B, + quant_mode, + col_scale, + row_scale, + nf4_look_up_table, + C, + total_rows_before_expert, + total_rows_in_ll_else_minus1, + total_rows, + gemm_n, + gemm_k, + num_experts, + gemm_config, + workspace_ptr, + workspace_bytes, + multi_processor_count, + stream, + occupancy); }; switch (gemm_config.tile_config) { case CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: - dispatch_by_tile(Int<16>(), Int<64>(), Int<64>(), - Int<16>(), Int<32>(), Int<64>()); + dispatch_by_tile( + Int<16>(), Int<64>(), Int<64>(), Int<16>(), Int<32>(), Int<64>()); break; case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: - dispatch_by_tile(Int<32>(), Int<128>(), Int<64>(), - Int<32>(), Int<32>(), Int<64>()); + dispatch_by_tile( + Int<32>(), Int<128>(), Int<64>(), Int<32>(), Int<32>(), Int<64>()); break; case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: - dispatch_by_tile(Int<64>(), Int<128>(), Int<64>(), - Int<64>(), Int<32>(), Int<64>()); + dispatch_by_tile( + Int<64>(), Int<128>(), Int<64>(), Int<64>(), Int<32>(), Int<64>()); break; case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: - dispatch_by_tile(Int<128>(), Int<128>(), Int<64>(), - Int<128>(), Int<32>(), Int<64>()); + dispatch_by_tile( + Int<128>(), Int<128>(), Int<64>(), Int<128>(), Int<32>(), Int<64>()); break; case CutlassTileConfig::CtaShape32x512x64_WarpShape32x128x64: - dispatch_by_tile(Int<32>(), Int<512>(), Int<64>(), - Int<32>(), Int<128>(), Int<64>()); + dispatch_by_tile( + Int<32>(), Int<512>(), Int<64>(), Int<32>(), Int<128>(), Int<64>()); break; case CutlassTileConfig::CtaShape32x256x64_WarpShape32x64x64: - dispatch_by_tile(Int<32>(), Int<256>(), Int<64>(), - Int<32>(), Int<64>(), Int<64>()); + dispatch_by_tile( + Int<32>(), Int<256>(), Int<64>(), Int<32>(), Int<64>(), Int<64>()); break; case CutlassTileConfig::CtaShape64x256x64_WarpShape64x64x64: - dispatch_by_tile(Int<64>(), Int<256>(), Int<64>(), - Int<64>(), Int<64>(), Int<64>()); + dispatch_by_tile( + Int<64>(), Int<256>(), Int<64>(), Int<64>(), Int<64>(), Int<64>()); break; // case CutlassTileConfig::CtaShape128x256x64_WarpShape128x64x64: // dispatch_by_tile(Int<128>(), Int<256>(), Int<64>(), @@ -463,7 +492,6 @@ void dispatch_moe_gemm_to_cutlass( } } - template W4A8MoeGemmRunner::W4A8MoeGemmRunner() { int device{-1}; @@ -472,137 +500,155 @@ W4A8MoeGemmRunner::W4A8MoeGemmRunner() { // sm_ = 80; check_cuda_error(cudaDeviceGetAttribute( &multi_processor_count_, cudaDevAttrMultiProcessorCount, device)); - std::string FLAGS_cutlass_w4a8_moe_best_config=""; + std::string FLAGS_cutlass_w4a8_moe_best_config = ""; if (getenv("FLAGS_cutlass_w4a8_moe_best_config")) { - FLAGS_cutlass_w4a8_moe_best_config = getenv("FLAGS_cutlass_w4a8_moe_best_config"); + FLAGS_cutlass_w4a8_moe_best_config = + getenv("FLAGS_cutlass_w4a8_moe_best_config"); } - if(tuned_configs_from_file.empty() && FLAGS_cutlass_w4a8_moe_best_config!="") { + if (tuned_configs_from_file.empty() && + FLAGS_cutlass_w4a8_moe_best_config != "") { std::string config_file_path = FLAGS_cutlass_w4a8_moe_best_config; - if (config_file_path.find(".config")!=std::string::npos) { + if (config_file_path.find(".config") != std::string::npos) { std::ifstream config_file(FLAGS_cutlass_w4a8_moe_best_config); - if (config_file.is_open()) { - VLOG(1)<<"Get tuned w4a8 moe gemm config from: "< vec_configs; - while(std::getline(ss, item, ',')) { - try { - int value = std::stoi(item); - vec_configs.push_back(value); - } catch (const std::invalid_argument& e) { - std::cerr << "Invalid argument: " << item << " is not an integer." << std::endl; - return; - } catch (const std::out_of_range& e) { - std::cerr << "Out of range: " << item << " is out of the range of representable values." << std::endl; - return; - } + if (config_file.is_open()) { + VLOG(1) << "Get tuned w4a8 moe gemm config from: " << config_file_path; + std::string config_string; + while (std::getline(config_file, config_string)) { + // decode one line of base64 string + config_string = base64_decode(config_string); + VLOG(1) << "decode config_string: " << config_string; + std::stringstream ss(config_string); + std::string item; + std::vector vec_configs; + while (std::getline(ss, item, ',')) { + try { + int value = std::stoi(item); + vec_configs.push_back(value); + } catch (const std::invalid_argument& e) { + std::cerr << "Invalid argument: " << item << " is not an integer." + << std::endl; + return; + } catch (const std::out_of_range& e) { + std::cerr << "Out of range: " << item + << " is out of the range of representable values." + << std::endl; + return; } - W4A8MoeGEMMConfig search_config; - search_config.total_rows = vec_configs[0]; - search_config.n = vec_configs[1]; - search_config.k = vec_configs[2]; - search_config.num_experts = vec_configs[3]; - search_config.tile_config = static_cast(vec_configs[4]); - search_config.split_k_style = static_cast(vec_configs[5]); - search_config.split_k_factor = vec_configs[6]; - search_config.stages = vec_configs[7]; - tuned_configs_from_file.push_back(search_config); - VLOG(1)<<"tuned_configs_from_file: "<(search_config.tile_config)<<"," << static_cast(search_config.split_k_style)<<","<(vec_configs[4]); + search_config.split_k_style = + static_cast(vec_configs[5]); + search_config.split_k_factor = vec_configs[6]; + search_config.stages = vec_configs[7]; + tuned_configs_from_file.push_back(search_config); + VLOG(1) << "tuned_configs_from_file: " << search_config.total_rows + << "," << search_config.n << "," << search_config.k << "," + << search_config.num_experts << "," + << static_cast(search_config.tile_config) << "," + << static_cast(search_config.split_k_style) << "," + << search_config.split_k_factor << "," + << search_config.stages; } + } else { + VLOG(1) << "No tuned w4a8 gemm config."; + } } else { - FILE * fp; + FILE* fp; fp = fopen(config_file_path.c_str(), "r"); - if(fp) { - VLOG(1)<<"Get tuned w4a8 moe gemm config from: "<(tile_config); - search_config.split_k_style = static_cast(split_k_style); - search_config.split_k_factor = split_k_factor; - search_config.stages = stages; - tuned_configs_from_file.push_back(search_config); - VLOG(1)<<"tuned_configs_from_file: "<(tile_config); + search_config.split_k_style = static_cast(split_k_style); + search_config.split_k_factor = split_k_factor; + search_config.stages = stages; + tuned_configs_from_file.push_back(search_config); + VLOG(1) << "tuned_configs_from_file: " << total_rows_tmp << "," + << n_tmp << "," << k_tmp << "," << num_experts_tmp << "," + << tile_config << "," << split_k_style << "," + << split_k_factor << "," << stages; + if (feof(fp)) break; } - } else if(FLAGS_cutlass_w4a8_moe_best_config=="") { - VLOG(1)<<"No tuned w4a8 gemm config."; + } else if (FLAGS_cutlass_w4a8_moe_best_config == "") { + VLOG(1) << "No tuned w4a8 gemm config."; } } } } template -W4A8MoeGemmRunner::~W4A8MoeGemmRunner() { -} - - +W4A8MoeGemmRunner::~W4A8MoeGemmRunner() {} template template -void W4A8MoeGemmRunner::dispatch_to_arch( - const IntAType* A, - const IntBType* B, - cutlass::epilogue::QuantMode quant_mode, - const OutputType* col_scale, - const OutputType* row_scale, - const int32_t* nf4_look_up_table, - OutputType* C, - int64_t* total_rows_before_expert, - int64_t total_rows_in_ll_else_minus1, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - CutlassGemmConfig gemm_config, - char* workspace_ptr, - const size_t workspace_bytes, - cudaStream_t stream, - int* occupancy) { - +void W4A8MoeGemmRunner::dispatch_to_arch< + EpilogueTag>(const IntAType* A, + const IntBType* B, + cutlass::epilogue::QuantMode quant_mode, + const OutputType* col_scale, + const OutputType* row_scale, + const int32_t* nf4_look_up_table, + OutputType* C, + int64_t* total_rows_before_expert, + int64_t total_rows_in_ll_else_minus1, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + CutlassGemmConfig gemm_config, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream, + int* occupancy) { // only sm80 here dispatch_moe_gemm_to_cutlass(A, - B, - quant_mode, - col_scale, - row_scale, - nf4_look_up_table, - C, - total_rows_before_expert, - total_rows_in_ll_else_minus1, - total_rows, - gemm_n, - gemm_k, - num_experts, - gemm_config, - workspace_ptr, - workspace_bytes, - multi_processor_count_, - stream, - occupancy); - - + IntAType, + IntBType, + cutlass::arch::Sm80, + EpilogueTag>(A, + B, + quant_mode, + col_scale, + row_scale, + nf4_look_up_table, + C, + total_rows_before_expert, + total_rows_in_ll_else_minus1, + total_rows, + gemm_n, + gemm_k, + num_experts, + gemm_config, + workspace_ptr, + workspace_bytes, + multi_processor_count_, + stream, + occupancy); } template @@ -625,8 +671,8 @@ void W4A8MoeGemmRunner::run_gemm( int num_experts, cudaStream_t stream, CutlassGemmConfig gemm_config) { - VLOG(1)<<__PRETTY_FUNCTION__; - static constexpr bool is_weight_only = true; //todo(yuanxiaolan) + VLOG(1) << __PRETTY_FUNCTION__; + static constexpr bool is_weight_only = true; // todo(yuanxiaolan) bool is_weight_only_encoder = total_rows >= 512 ? true : false; VLOG(1) << "gemm_config tile_config" @@ -636,29 +682,29 @@ void W4A8MoeGemmRunner::run_gemm( VLOG(1) << "gemm_config split_k_factor " << gemm_config.split_k_factor; VLOG(1) << "gemm_config stages " << gemm_config.stages; - if(gemm_config.tile_config != CutlassTileConfig::Undefined) { + if (gemm_config.tile_config != CutlassTileConfig::Undefined) { dispatch_to_arch(A, - B, - quant_mode, - col_scale, - row_scale, - nf4_look_up_table, - C, - total_rows_before_expert, - total_rows_in_ll_else_minus1, - total_rows, - gemm_n, - gemm_k, - num_experts, - gemm_config, - workspace_ptr, - workspace_bytes, - stream); + B, + quant_mode, + col_scale, + row_scale, + nf4_look_up_table, + C, + total_rows_before_expert, + total_rows_in_ll_else_minus1, + total_rows, + gemm_n, + gemm_k, + num_experts, + gemm_config, + workspace_ptr, + workspace_bytes, + stream); return; } - std::vector candidate_configs = - get_candidate_configs_nf4(80, is_weight_only, is_weight_only_encoder, false); + std::vector candidate_configs = get_candidate_configs_nf4( + 80, is_weight_only, is_weight_only_encoder, false); std::vector occupancies(candidate_configs.size()); for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { @@ -686,20 +732,21 @@ void W4A8MoeGemmRunner::run_gemm( int local_multi_processor_count{0}; check_cuda_error(cudaGetDevice(&local_device)); // sm_ = getSMVersion(); - check_cuda_error(cudaDeviceGetAttribute( - &local_multi_processor_count, cudaDevAttrMultiProcessorCount, local_device)); + check_cuda_error(cudaDeviceGetAttribute(&local_multi_processor_count, + cudaDevAttrMultiProcessorCount, + local_device)); CutlassGemmConfig chosen_config = estimate_best_config_from_occupancies_w4a4(candidate_configs, - occupancies, - total_rows, - gemm_n, - gemm_k, - num_experts, - split_k_limit, - workspace_bytes, - local_multi_processor_count, - is_weight_only); + occupancies, + total_rows, + gemm_n, + gemm_k, + num_experts, + split_k_limit, + workspace_bytes, + local_multi_processor_count, + is_weight_only); VLOG(1) << "chosen_config tile_config " << static_cast(chosen_config.tile_config); @@ -711,7 +758,6 @@ void W4A8MoeGemmRunner::run_gemm( VLOG(1) << "total_rows " << total_rows << "gemm_n " << gemm_n << "gemm_k " << gemm_k; - dispatch_to_arch(A, B, quant_mode, @@ -732,7 +778,8 @@ void W4A8MoeGemmRunner::run_gemm( } // template -// void W4A8MoeGemmRunner::moe_gemm_bias_act( const IntAType* A, +// void W4A8MoeGemmRunner::moe_gemm_bias_act( +// const IntAType* A, // const IntBType* B, // QuantMode quant_mode, // const OutputType* col_scale, @@ -769,65 +816,82 @@ void W4A8MoeGemmRunner::run_gemm( template void W4A8MoeGemmRunner::moe_gemm( - const IntAType* A, - const IntBType* B, - cutlass::epilogue::QuantMode quant_mode, - const OutputType* col_scale, - const OutputType* row_scale, - const int32_t* nf4_look_up_table, - OutputType* C, - int64_t* total_rows_before_expert, - int64_t total_rows_in_ll_else_minus1, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - char* workspace_ptr, - const size_t workspace_bytes, - int num_experts, - cudaStream_t stream, - CutlassGemmConfig gemm_config) { + const IntAType* A, + const IntBType* B, + cutlass::epilogue::QuantMode quant_mode, + const OutputType* col_scale, + const OutputType* row_scale, + const int32_t* nf4_look_up_table, + OutputType* C, + int64_t* total_rows_before_expert, + int64_t total_rows_in_ll_else_minus1, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + char* workspace_ptr, + const size_t workspace_bytes, + int num_experts, + cudaStream_t stream, + CutlassGemmConfig gemm_config) { CutlassGemmConfig gemm_config_from_file_and_param = gemm_config; - if(!tuned_configs_from_file.empty()){ - bool match=false; + if (!tuned_configs_from_file.empty()) { + bool match = false; int best_total_rows, best_n, best_k, best_num_experts; - int max_config_total_rows_in_file=0; + int max_config_total_rows_in_file = 0; W4A8MoeGEMMConfig max_total_rows_config; - for(const auto& tuned_config:tuned_configs_from_file) { - // choose the smallest config_m with config_m >=m - if(tuned_config.total_rows <= total_rows && tuned_config.n==gemm_n && tuned_config.k==gemm_k && tuned_config.num_experts==num_experts) { - best_total_rows=tuned_config.total_rows; - best_n=tuned_config.n; - best_k=tuned_config.k; - best_num_experts=tuned_config.num_experts; - gemm_config_from_file_and_param.tile_config = tuned_config.tile_config; - gemm_config_from_file_and_param.split_k_style = tuned_config.split_k_style; - gemm_config_from_file_and_param.split_k_factor = tuned_config.split_k_factor; - gemm_config_from_file_and_param.stages = tuned_config.stages; - match=true; - } - if(tuned_config.total_rows > max_config_total_rows_in_file && tuned_config.n==gemm_n && tuned_config.k==gemm_k && tuned_config.num_experts==num_experts){ - max_config_total_rows_in_file = tuned_config.total_rows; - max_total_rows_config = tuned_config; - } + for (const auto& tuned_config : tuned_configs_from_file) { + // choose the smallest config_m with config_m >=m + if (tuned_config.total_rows <= total_rows && tuned_config.n == gemm_n && + tuned_config.k == gemm_k && tuned_config.num_experts == num_experts) { + best_total_rows = tuned_config.total_rows; + best_n = tuned_config.n; + best_k = tuned_config.k; + best_num_experts = tuned_config.num_experts; + gemm_config_from_file_and_param.tile_config = tuned_config.tile_config; + gemm_config_from_file_and_param.split_k_style = + tuned_config.split_k_style; + gemm_config_from_file_and_param.split_k_factor = + tuned_config.split_k_factor; + gemm_config_from_file_and_param.stages = tuned_config.stages; + match = true; + } + if (tuned_config.total_rows > max_config_total_rows_in_file && + tuned_config.n == gemm_n && tuned_config.k == gemm_k && + tuned_config.num_experts == num_experts) { + max_config_total_rows_in_file = tuned_config.total_rows; + max_total_rows_config = tuned_config; + } } - if(!match){ - if (max_total_rows_config.n==gemm_n && max_total_rows_config.k==gemm_k && max_total_rows_config.num_experts==num_experts) { + if (!match) { + if (max_total_rows_config.n == gemm_n && + max_total_rows_config.k == gemm_k && + max_total_rows_config.num_experts == num_experts) { best_total_rows = max_config_total_rows_in_file; - gemm_config_from_file_and_param.tile_config = max_total_rows_config.tile_config; - gemm_config_from_file_and_param.split_k_style = max_total_rows_config.split_k_style; - gemm_config_from_file_and_param.split_k_factor = max_total_rows_config.split_k_factor; + gemm_config_from_file_and_param.tile_config = + max_total_rows_config.tile_config; + gemm_config_from_file_and_param.split_k_style = + max_total_rows_config.split_k_style; + gemm_config_from_file_and_param.split_k_factor = + max_total_rows_config.split_k_factor; gemm_config_from_file_and_param.stages = max_total_rows_config.stages; } } - VLOG(1) <<"W4A8 moe gemm " - <<"total_rows: "<(gemm_config_from_file_and_param.tile_config) - <<"split_k_style: "<(gemm_config_from_file_and_param.split_k_style) - <<"split_k_factor: "<(gemm_config_from_file_and_param.split_k_factor) - <<"stages: "<(gemm_config_from_file_and_param.stages); + VLOG(1) << "W4A8 moe gemm " + << "total_rows: " << total_rows << " n: " << gemm_n + << " k: " << gemm_k + << "Using gemm config from config file: config_total_rows: " + << best_total_rows << " config_n: " << best_n + << " config_k: " << best_k << "tile_config: " + << static_cast(gemm_config_from_file_and_param.tile_config) + << "split_k_style: " + << static_cast(gemm_config_from_file_and_param.split_k_style) + << "split_k_factor: " + << static_cast(gemm_config_from_file_and_param.split_k_factor) + << "stages: " + << static_cast(gemm_config_from_file_and_param.stages); } else { - VLOG(1) << "tuned_configs_from_file is empty, use W4A8 gemm config in param"; + VLOG(1) + << "tuned_configs_from_file is empty, use W4A8 gemm config in param"; } run_gemm(A, B, @@ -849,7 +913,10 @@ void W4A8MoeGemmRunner::moe_gemm( } template -std::vector::W4A8MoeGEMMConfig> W4A8MoeGemmRunner::tuned_configs_from_file = {}; +std::vector:: + W4A8MoeGEMMConfig> + W4A8MoeGemmRunner::tuned_configs_from_file = + {}; template int W4A8MoeGemmRunner::getWorkspaceSize( @@ -863,6 +930,5 @@ int W4A8MoeGemmRunner::getWorkspaceSize( return max_grid_m * max_grid_n * split_k_limit * 4; } - template class W4A8MoeGemmRunner; template class W4A8MoeGemmRunner<__nv_bfloat16, int8_t, cutlass::uint4b_t>; diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_kernel.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_kernel.h index 7d06b59d65..6ec9121634 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_kernel.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_kernel.h @@ -30,45 +30,46 @@ class W4A8MoeGemmRunner { ~W4A8MoeGemmRunner(); void moe_gemm_bias_act(const IntAType* A, - const IntBType* B, - cutlass::epilogue::QuantMode quant_mode, - const OutputType* col_scale, - const OutputType* row_scale, - const OutputType* biases, - const int32_t* nf4_look_up_table, - OutputType* C, - int64_t* total_rows_before_expert, - int m, - int n, - int k, - int num_experts, - std::string activation_type, - char* workspace_ptr, - const size_t workspace_bytes, - cudaStream_t stream); + const IntBType* B, + cutlass::epilogue::QuantMode quant_mode, + const OutputType* col_scale, + const OutputType* row_scale, + const OutputType* biases, + const int32_t* nf4_look_up_table, + OutputType* C, + int64_t* total_rows_before_expert, + int m, + int n, + int k, + int num_experts, + std::string activation_type, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream); void moe_gemm(const IntAType* A, - const IntBType* B, - cutlass::epilogue::QuantMode quant_mode, - const OutputType* col_scale, - const OutputType* row_scale, - const int32_t* nf4_look_up_table, - OutputType* C, - int64_t* total_rows_before_expert, - int64_t total_rows_in_ll_else_minus1, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - char* workspace_ptr, - const size_t workspace_bytes, - int num_experts, - cudaStream_t stream, - CutlassGemmConfig gemm_config = CutlassGemmConfig{CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, - SplitKStyle::NO_SPLIT_K, - 1, - 5}); - private: + const IntBType* B, + cutlass::epilogue::QuantMode quant_mode, + const OutputType* col_scale, + const OutputType* row_scale, + const int32_t* nf4_look_up_table, + OutputType* C, + int64_t* total_rows_before_expert, + int64_t total_rows_in_ll_else_minus1, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + char* workspace_ptr, + const size_t workspace_bytes, + int num_experts, + cudaStream_t stream, + CutlassGemmConfig gemm_config = CutlassGemmConfig{ + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + SplitKStyle::NO_SPLIT_K, + 1, + 5}); + private: template void dispatch_to_arch(const IntAType* A, const IntBType* B, @@ -108,8 +109,7 @@ class W4A8MoeGemmRunner { cudaStream_t stream, CutlassGemmConfig gemm_config); - int getWorkspaceSize( - const int m, const int n, const int k); + int getWorkspaceSize(const int m, const int n, const int k); private: static constexpr int split_k_limit = 4; diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_kernel_template.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_kernel_template.h index 80ffb06de4..3ab51307d6 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_kernel_template.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_kernel_template.h @@ -29,7 +29,8 @@ template -void generic_w4a8_moe_gemm_kernelLauncher(const IntAType* A, +void generic_w4a8_moe_gemm_kernelLauncher( + const IntAType* A, const IntBType* B, cutlass::epilogue::QuantMode quant_mode, const OutputType* col_scale, @@ -48,7 +49,6 @@ void generic_w4a8_moe_gemm_kernelLauncher(const IntAType* A, cudaStream_t stream, int* occupancy); - template struct dispatch_stages { + IntAType, + IntBType, + arch, + EpilogueTag, + ThreadblockShape, + WarpShape, + 2> { static void dispatch(const IntAType* A, const IntBType* B, cutlass::epilogue::QuantMode quant_mode, @@ -119,31 +118,31 @@ struct dispatch_stages(A, - B, - quant_mode, - col_scale, - row_scale, - nf4_look_up_table, - C, - total_rows_before_expert, - total_rows, - total_rows_in_ll_else_minus1, - n, - k, - num_experts, - gemm_config, - workspace, - workspace_bytes, - multi_processor_count, - stream, - occupancy); + IntAType, + IntBType, + arch, + EpilogueTag, + ThreadblockShape, + WarpShape, + 2>(A, + B, + quant_mode, + col_scale, + row_scale, + nf4_look_up_table, + C, + total_rows_before_expert, + total_rows, + total_rows_in_ll_else_minus1, + n, + k, + num_experts, + gemm_config, + workspace, + workspace_bytes, + multi_processor_count, + stream, + occupancy); } }; @@ -155,12 +154,12 @@ template struct dispatch_stages 2)>::type> { static void dispatch(const IntAType* A, @@ -183,35 +182,34 @@ struct dispatch_stages(A, - B, - quant_mode, - col_scale, - row_scale, - nf4_look_up_table, - C, - total_rows_before_expert, - total_rows, - total_rows_in_ll_else_minus1, - n, - k, - num_experts, - gemm_config, - workspace, - workspace_bytes, - multi_processor_count, - stream, - occupancy); + IntAType, + IntBType, + cutlass::arch::Sm80, + EpilogueTag, + ThreadblockShape, + WarpShape, + Stages>(A, + B, + quant_mode, + col_scale, + row_scale, + nf4_look_up_table, + C, + total_rows_before_expert, + total_rows, + total_rows_in_ll_else_minus1, + n, + k, + num_experts, + gemm_config, + workspace, + workspace_bytes, + multi_processor_count, + stream, + occupancy); } }; - template -static void PrintMatrix(const T *mat_d, int num, std::string name, +static void PrintMatrix(const T *mat_d, + int num, + std::string name, int numOfCols) { std::vector tmp(num); cudaMemcpy(tmp.data(), mat_d, sizeof(T) * num, cudaMemcpyDeviceToHost); @@ -104,17 +111,17 @@ static void PrintMatrix(const T *mat_d, int num, std::string name, uint as_uint(const float x) { return *(uint *)&x; } uint16_t ConvertFloat2Half(const float x) { - const uint b = as_uint(x) + 0x00001000; // round-to-nearest-even: add last - // bit after truncated mantissa - const uint e = (b & 0x7F800000) >> 23; // exponent - const uint m = b & 0x007FFFFF; // mantissa; in line below: 0x007FF000 = - // 0x00800000-0x00001000 = decimal indicator - // flag - initial rounding + const uint b = as_uint(x) + 0x00001000; // round-to-nearest-even: add last + // bit after truncated mantissa + const uint e = (b & 0x7F800000) >> 23; // exponent + const uint m = b & 0x007FFFFF; // mantissa; in line below: 0x007FF000 = + // 0x00800000-0x00001000 = decimal indicator + // flag - initial rounding return (b & 0x80000000) >> 16 | (e > 112) * ((((e - 112) << 10) & 0x7C00) | m >> 13) | ((e < 113) & (e > 101)) * ((((0x007FF000 + m) >> (125 - e)) + 1) >> 1) | - (e > 143) * 0x7FFF; // sign : normalized : denormalized : saturate + (e > 143) * 0x7FFF; // sign : normalized : denormalized : saturate } inline float fp32_from_bits(uint32_t w) { @@ -274,7 +281,9 @@ float CPUHalfConvert2Float(const uint16_t h) { return fp32_from_bits(result); } -static void PrintHalfMatrix(const int16_t *mat_d, int num, std::string name, +static void PrintHalfMatrix(const int16_t *mat_d, + int num, + std::string name, int numOfCols) { std::vector tmp(num); cudaMemcpy(tmp.data(), mat_d, sizeof(int16_t) * num, cudaMemcpyDeviceToHost); @@ -296,7 +305,9 @@ static void PrintHalfMatrix(const int16_t *mat_d, int num, std::string name, } template -static void PrintMatrixCPU(const T *mat, int num, std::string name, +static void PrintMatrixCPU(const T *mat, + int num, + std::string name, int numOfCols) { std::ofstream outfile; outfile.open(name + ".txt", std::ios::out); @@ -315,7 +326,9 @@ static void PrintMatrixCPU(const T *mat, int num, std::string name, outfile.close(); } -static void PrintMatrixCPU_int4(const int8_t *mat, int num, std::string name, +static void PrintMatrixCPU_int4(const int8_t *mat, + int num, + std::string name, int numOfCols) { std::ofstream outfile; outfile.open(name + ".txt", std::ios::out); @@ -333,7 +346,9 @@ static void PrintMatrixCPU_int4(const int8_t *mat, int num, std::string name, outfile.close(); } template -static void PrintHalfMatrixCPU(const T *mat, int num, std::string name, +static void PrintHalfMatrixCPU(const T *mat, + int num, + std::string name, int numOfCols) { std::ofstream outfile; outfile.open(name + ".txt", std::ios::out); @@ -349,8 +364,8 @@ static void PrintHalfMatrixCPU(const T *mat, int num, std::string name, } template -void naive_matmul(const T *a, const T *b, outputT *c, size_t m, size_t n, - size_t k) { +void naive_matmul( + const T *a, const T *b, outputT *c, size_t m, size_t n, size_t k) { for (int ik = 0; ik < k; ik++) { for (int im = 0; im < m; im++) { for (int in = 0; in < n; in++) { @@ -361,13 +376,17 @@ void naive_matmul(const T *a, const T *b, outputT *c, size_t m, size_t n, } template -void naive_matmul_fused_dequantize_nf4(const T *a, const T *b, +void naive_matmul_fused_dequantize_nf4(const T *a, + const T *b, const ScaleType *col_scale, const ScaleType *row_scale, const int32_t *nf4_look_up_table, - outputT *c, size_t num_experts, + outputT *c, + size_t num_experts, int64_t *total_rows_before_experts, - size_t total_rows, size_t n, size_t k) { + size_t total_rows, + size_t n, + size_t k) { // PrintMatrixCPU( // a, total_rows * k, "naive_matmul_a", k); // PrintMatrixCPU( @@ -442,10 +461,15 @@ void naive_matmul_fused_dequantize_nf4(const T *a, const T *b, } // Author (zhengzekang): we use float to monitor half matmul in CPU. -void CheckHalfDiff(int16_t *device_res, float *host_result, size_t elem_cnt, - float atol, float rtol) { +void CheckHalfDiff(int16_t *device_res, + float *host_result, + size_t elem_cnt, + float atol, + float rtol) { std::vector device_data(elem_cnt); - cudaMemcpy(device_data.data(), device_res, sizeof(int16_t) * elem_cnt, + cudaMemcpy(device_data.data(), + device_res, + sizeof(int16_t) * elem_cnt, cudaMemcpyDeviceToHost); for (size_t i = 0; i < elem_cnt; i++) { @@ -459,7 +483,10 @@ void CheckHalfDiff(int16_t *device_res, float *host_result, size_t elem_cnt, printf( "Here in Idx: %d, CUDA result is: %f, Host result is: %f, absolute " "diff val is: %f \n", - i, device_res_val, host_res_val, absolute_diff); + i, + device_res_val, + host_res_val, + absolute_diff); return; } } @@ -508,11 +535,11 @@ CutlassGemmConfig GetGemmConfig(int token_nums, // gemm_config_tuple:[m,n,k,tile_config,split_k_style,split_k_factor,stages] for (int i = 0; i < len_of_gemm_config_tuple; i += 7) { gemm_config.tile_config = - CutlassTileConfig(gemm_config_tuple[i + 3]); // tile_config + CutlassTileConfig(gemm_config_tuple[i + 3]); // tile_config gemm_config.split_k_style = - SplitKStyle(gemm_config_tuple[i + 4]); // split_k_style - gemm_config.split_k_factor = gemm_config_tuple[i + 5]; // split_k_factor - gemm_config.stages = gemm_config_tuple[i + 6]; // stages + SplitKStyle(gemm_config_tuple[i + 4]); // split_k_style + gemm_config.split_k_factor = gemm_config_tuple[i + 5]; // split_k_factor + gemm_config.stages = gemm_config_tuple[i + 6]; // stages // make sure we have at least one tuned config if (token_nums <= gemm_config_tuple[i + 0]) { break; @@ -522,7 +549,8 @@ CutlassGemmConfig GetGemmConfig(int token_nums, } template -void get_tensor_from_file(const std::string file_path, int64_t numel, +void get_tensor_from_file(const std::string file_path, + int64_t numel, T *tensor_ptr) { std::fstream datafile; datafile.open(file_path, std::ios_base::in | std::ios_base::out); @@ -641,7 +669,7 @@ int main(int argc, char *argv[]) { auto mixed_gemm_runner = W4A8MoeGemmRunner(); // int mixgemm_max_size = std::max(m, k); - int mixgemm_workspace_size_bytes = 1 * 1024 * 1024 * 1024; // 1G workspace + int mixgemm_workspace_size_bytes = 1 * 1024 * 1024 * 1024; // 1G workspace std::cout << "mixgemm_workspace_size_bytes: " << mixgemm_workspace_size_bytes << std::endl; char *mixgemm_workspace_data; @@ -663,8 +691,8 @@ int main(int argc, char *argv[]) { } } else { std::cout << "get a data from: " << a_data_file << std::endl; - get_tensor_from_file(a_data_file, total_rows * k, - a_int.data()); + get_tensor_from_file( + a_data_file, total_rows * k, a_int.data()); } // PrintMatrixCPU(a_int.data(),total_rows*k,"a_int8_cpu",n); } @@ -752,7 +780,8 @@ int main(int argc, char *argv[]) { // PrintMatrixCPU_int4(packed_b_int.data(),num_experts*k*n,"w4a8_packed_b_int4",n); permute_B_rows_for_mixed_gemm_int4<4>( b_int_processed.data() + ie * k * n / 2, - packed_b_int.data() + ie * k * n / 2, std::vector{k, n}, + packed_b_int.data() + ie * k * n / 2, + std::vector{k, n}, (int64_t)80); // PrintMatrixCPU_int4(b_int_processed.data(),num_experts*k*n,"w4a8_permuted_int4",n); @@ -820,8 +849,8 @@ int main(int argc, char *argv[]) { // } } } else { - get_tensor_from_file(row_scale_data_file, total_rows, - row_scale_float.data()); + get_tensor_from_file( + row_scale_data_file, total_rows, row_scale_float.data()); } // PrintMatrixCPU(row_scale_float.data(),total_rows,"row_scale_float_cpu",total_rows); } @@ -839,8 +868,8 @@ int main(int argc, char *argv[]) { // } } } else { - get_tensor_from_file(col_scale_data_file, num_experts * n, - col_scale_float.data()); + get_tensor_from_file( + col_scale_data_file, num_experts * n, col_scale_float.data()); } // PrintMatrixCPU(col_scale_float.data(),num_experts*n,"col_scale_float_cpu",n); } @@ -878,21 +907,35 @@ int main(int argc, char *argv[]) { cudaMalloc(&d_col_scale_half, num_experts * n * sizeof(uint16_t)); cudaMalloc(&d_nf4_look_up_table, 4 * sizeof(uint32_t)); cudaMalloc(&d_total_rows_before_experts, num_experts * sizeof(int64_t)); - cudaMemcpy(d_a_int, a_int.data(), total_rows * k * sizeof(int8_t), + cudaMemcpy(d_a_int, + a_int.data(), + total_rows * k * sizeof(int8_t), + cudaMemcpyHostToDevice); + cudaMemcpy(d_b_int, + b_int_processed_3.data(), + num_experts * k * n / 2 * sizeof(int8_t), cudaMemcpyHostToDevice); - cudaMemcpy(d_b_int, b_int_processed_3.data(), - num_experts * k * n / 2 * sizeof(int8_t), cudaMemcpyHostToDevice); - cudaMemcpy(d_row_scale_half, row_scale_half.data(), - total_rows * sizeof(uint16_t), cudaMemcpyHostToDevice); - cudaMemcpy(d_col_scale_half, col_scale_half.data(), - num_experts * n * sizeof(uint16_t), cudaMemcpyHostToDevice); - cudaMemcpy(d_nf4_look_up_table, nf4_look_up_table_compress.data(), - 4 * sizeof(uint32_t), cudaMemcpyHostToDevice); - cudaMemcpy(d_c_int, c_half.data(), total_rows * n * sizeof(uint16_t), + cudaMemcpy(d_row_scale_half, + row_scale_half.data(), + total_rows * sizeof(uint16_t), + cudaMemcpyHostToDevice); + cudaMemcpy(d_col_scale_half, + col_scale_half.data(), + num_experts * n * sizeof(uint16_t), + cudaMemcpyHostToDevice); + cudaMemcpy(d_nf4_look_up_table, + nf4_look_up_table_compress.data(), + 4 * sizeof(uint32_t), + cudaMemcpyHostToDevice); + cudaMemcpy(d_c_int, + c_half.data(), + total_rows * n * sizeof(uint16_t), + cudaMemcpyHostToDevice); + cudaMemcpy(d_total_rows_before_experts, + total_rows_before_experts.data(), + num_experts * sizeof(int64_t), cudaMemcpyHostToDevice); - cudaMemcpy(d_total_rows_before_experts, total_rows_before_experts.data(), - num_experts * sizeof(int64_t), cudaMemcpyHostToDevice); cudaDeviceSynchronize(); cudaError_t result = cudaGetLastError(); @@ -908,7 +951,9 @@ int main(int argc, char *argv[]) { std::cout << "=== do warm up for " << kWarmTime << " times" << std::endl; auto test_config = CutlassGemmConfig{CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, - SplitKStyle::NO_SPLIT_K, 1, 5}; + SplitKStyle::NO_SPLIT_K, + 1, + 5}; std::cout << "=== do warm up end" << std::endl; for (int i = 0; i < kWarmTime; i++) { printf("warm up %d\n", i); @@ -920,9 +965,16 @@ int main(int argc, char *argv[]) { reinterpret_cast(d_row_scale_half), reinterpret_cast(d_nf4_look_up_table), reinterpret_cast(d_c_int), - reinterpret_cast(d_total_rows_before_experts), -1, - total_rows, n, k, mixgemm_workspace_data, mixgemm_workspace_size_bytes, - num_experts, 0, test_config); + reinterpret_cast(d_total_rows_before_experts), + -1, + total_rows, + n, + k, + mixgemm_workspace_data, + mixgemm_workspace_size_bytes, + num_experts, + 0, + test_config); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { std::cout << "error: " << cudaGetErrorString(err) << std::endl; @@ -1012,7 +1064,6 @@ int main(int argc, char *argv[]) { cudaEventCreate(&end); cudaEventRecord(begin, 0); for (int i = 0; i < kTestTime; ++i) { - mixed_gemm_runner.moe_gemm( reinterpret_cast(d_a_int), reinterpret_cast((void *)d_b_int), @@ -1021,9 +1072,15 @@ int main(int argc, char *argv[]) { reinterpret_cast(d_row_scale_half), reinterpret_cast(d_nf4_look_up_table), reinterpret_cast(d_c_int), - reinterpret_cast(d_total_rows_before_experts), -1, - total_rows, n, k, mixgemm_workspace_data, - mixgemm_workspace_size_bytes, num_experts, 0, + reinterpret_cast(d_total_rows_before_experts), + -1, + total_rows, + n, + k, + mixgemm_workspace_data, + mixgemm_workspace_size_bytes, + num_experts, + 0, test_gemm_config); } cudaEventRecord(end, 0); @@ -1149,37 +1206,58 @@ int main(int argc, char *argv[]) { if (do_check) { std::cout << "=== do accuracy check " << std::endl; cudaMemset(d_c_int, 0, total_rows * n * sizeof(uint16_t)); - PrintHalfMatrix(static_cast(d_c_int), total_rows * n, - "CUDA_c_dequantize_fp16_output_before_gemm", n); + PrintHalfMatrix(static_cast(d_c_int), + total_rows * n, + "CUDA_c_dequantize_fp16_output_before_gemm", + n); mixed_gemm_runner.moe_gemm( reinterpret_cast(d_a_int), reinterpret_cast((void *)d_b_int), cutlass::epilogue::QuantMode::PerChannelQuant, reinterpret_cast(d_col_scale_half), - nullptr, // reinterpret_cast(d_row_scale_half), - nullptr, // reinterpret_cast(d_nf4_look_up_table), + nullptr, // reinterpret_cast(d_row_scale_half), + nullptr, // reinterpret_cast(d_nf4_look_up_table), reinterpret_cast(d_c_int), - reinterpret_cast(d_total_rows_before_experts), -1, - total_rows, n, k, mixgemm_workspace_data, mixgemm_workspace_size_bytes, - num_experts, 0); + reinterpret_cast(d_total_rows_before_experts), + -1, + total_rows, + n, + k, + mixgemm_workspace_data, + mixgemm_workspace_size_bytes, + num_experts, + 0); cudaDeviceSynchronize(); // PrintMatrix(reinterpret_cast(d_nf4_look_up_table),4,"d_nf4_look_up_table",1); printf("##### d_nf4_look_up_table address: %p \n", d_nf4_look_up_table); naive_matmul_fused_dequantize_nf4( - a_int.data(), b_int.data(), col_scale_float.data(), - nullptr, // row_scale_float.data(), - nullptr, // nf4_look_up_table.data(), - c_float.data(), num_experts, total_rows_before_experts.data(), - total_rows, n, k); - PrintMatrixCPU(c_float.data(), total_rows * n, - "CPU_c_fake_fp16_dequantize_output_base", n); - PrintHalfMatrix(static_cast(d_c_int), total_rows * n, - "CUDA_c_dequantize_fp16_output", n); - CheckHalfDiff(static_cast(d_c_int), c_float.data(), - total_rows * n, 1e-4, 1e-2); + a_int.data(), + b_int.data(), + col_scale_float.data(), + nullptr, // row_scale_float.data(), + nullptr, // nf4_look_up_table.data(), + c_float.data(), + num_experts, + total_rows_before_experts.data(), + total_rows, + n, + k); + PrintMatrixCPU(c_float.data(), + total_rows * n, + "CPU_c_fake_fp16_dequantize_output_base", + n); + PrintHalfMatrix(static_cast(d_c_int), + total_rows * n, + "CUDA_c_dequantize_fp16_output", + n); + CheckHalfDiff(static_cast(d_c_int), + c_float.data(), + total_rows * n, + 1e-4, + 1e-2); } // if(kTestTime > 0){ diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_with_epilogue_visitor.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_with_epilogue_visitor.h index 648a21a353..804dcfdf91 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_with_epilogue_visitor.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_with_epilogue_visitor.h @@ -33,10 +33,9 @@ namespace gemm { namespace kernel { template + GroupScheduleMode GroupScheduleMode_> struct MoeW4A8GemmWithEpilogueVisitorInterleavedNf4 { public: using Mma = Mma_; @@ -98,10 +97,10 @@ struct MoeW4A8GemmWithEpilogueVisitorInterleavedNf4 { static bool const kTransposed = false; using ProblemVisitor = GemmMoeProblemVisitor; + kGroupScheduleMode, + kThreadCount, + kThreadCount, + kTransposed>; // // Structures @@ -206,7 +205,6 @@ struct MoeW4A8GemmWithEpilogueVisitorInterleavedNf4 { ElementC, LayoutA, LayoutB> { - using ParamsBase = UniversalParamsBase::value) { - isCMisaligned = problem_size.n() % kAlignmentC; + isCMisaligned = problem_size.n() % kAlignmentC; } else if (platform::is_same::value) { isCMisaligned = problem_size.m() % kAlignmentC; } else if (platform::is_same::value && + kInterleave == 1 || + platform::is_same::value && + kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); - static constexpr int kInterleave = - Mma::IteratorB::Shape::kRow / Mma::Shape::kK; - static_assert( - platform::is_same::value && - kInterleave == 1 || - platform::is_same::value && - kInterleave >= 1, - "B must be row major/col major OR col major interleaved."); + // + // Problem visitor. + // + ProblemVisitor problem_visitor( + params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); + const int64_t gemm_k = params.problem_visitor.gemm_k; + const int64_t gemm_n = params.problem_visitor.gemm_n; + int64_t bytes_per_expert_matrix = + (gemm_k * gemm_n / 8) * cutlass::sizeof_bits::value; - // - // Problem visitor. - // - ProblemVisitor problem_visitor( - params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); - const int64_t gemm_k = params.problem_visitor.gemm_k; - const int64_t gemm_n = params.problem_visitor.gemm_n; - int64_t bytes_per_expert_matrix = - (gemm_k * gemm_n / 8) * cutlass::sizeof_bits::value; - - // Outer 'persistent' loop to iterate over tiles - while (problem_visitor.next_tile()) { + // Outer 'persistent' loop to iterate over tiles + while (problem_visitor.next_tile()) { // // Compute threadblock location ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - GemmCoord problem_size = problem_visitor.problem_size(); - int32_t problem_idx = problem_visitor.problem_index(); - int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); - GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); - cutlass::MatrixCoord threadblock_offset( - int(cta_idx / grid_shape.n()) * Mma::Shape::kM, // NOLINT - int(cta_idx % grid_shape.n()) * Mma::Shape::kN // NOLINT - ); + cutlass::MatrixCoord threadblock_offset( + int(cta_idx / grid_shape.n()) * Mma::Shape::kM, // NOLINT + int(cta_idx % grid_shape.n()) * Mma::Shape::kN // NOLINT + ); - // if (threadIdx.x == 0) { - // printf("%d-%d-%d problem_size: %d, %d problem_idx: %d, cta_idx: %d\n", blockIdx.x,blockIdx.y,blockIdx.z, problem_size.m(), problem_size.n(), problem_idx, cta_idx); - // } + // if (threadIdx.x == 0) { + // printf("%d-%d-%d problem_size: %d, %d problem_idx: %d, cta_idx: + // %d\n", blockIdx.x,blockIdx.y,blockIdx.z, problem_size.m(), + // problem_size.n(), problem_idx, cta_idx); + // } - // Load element pointers. Exchange pointers and strides if working on - // the transpose - int64_t rows_to_jump = 0; - if (params.problem_visitor.total_rows < 0) { - rows_to_jump = problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1]; - } else { - rows_to_jump = problem_idx * (params.problem_visitor.total_rows / params.problem_visitor.problem_count); - } + // Load element pointers. Exchange pointers and strides if working on + // the transpose + int64_t rows_to_jump = 0; + if (params.problem_visitor.total_rows < 0) { + rows_to_jump = + problem_idx == 0 + ? 0 + : params.problem_visitor.last_row_for_problem[problem_idx - 1]; + } else { + rows_to_jump = problem_idx * (params.problem_visitor.total_rows / + params.problem_visitor.problem_count); + } - ElementA* ptr_A = - reinterpret_cast(params.ptr_A) + rows_to_jump * gemm_k; - typename LayoutA::LongIndex ldm_A = gemm_k; - - char* byte_ptr_B = ((char*)params.ptr_B) + // NOLINT - problem_idx * bytes_per_expert_matrix; // NOLINT - ElementB* ptr_B = reinterpret_cast(byte_ptr_B); - typename LayoutB::LongIndex ldm_B = - platform::is_same::value - ? gemm_n - : gemm_k * kInterleave; + ElementA* ptr_A = + reinterpret_cast(params.ptr_A) + rows_to_jump * gemm_k; + typename LayoutA::LongIndex ldm_A = gemm_k; + char* byte_ptr_B = ((char*)params.ptr_B) + // NOLINT + problem_idx * bytes_per_expert_matrix; // NOLINT + ElementB* ptr_B = reinterpret_cast(byte_ptr_B); + typename LayoutB::LongIndex ldm_B = + platform::is_same::value + ? gemm_n + : gemm_k * kInterleave; int offset_k = 0; int problem_size_k = params.problem_size.k(); - - // Maybe need to modify? Author zhengzekang. - #if SPLIT_K_ENABLED +// Maybe need to modify? Author zhengzekang. +#if SPLIT_K_ENABLED // // Fetch pointers based on mode. // if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel) { if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { - problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + problem_size_k = + (threadblock_tile_offset.k() + 1) * params.gemm_k_size; } offset_k = threadblock_tile_offset.k() * params.gemm_k_size; } else if (params.mode == GemmUniversalMode::kBatched) { @@ -475,17 +480,19 @@ struct MoeW4A8GemmWithEpilogueVisitorInterleavedNf4 { ptr_B = static_cast( params.ptr_B)[threadblock_tile_offset.k()]; } - #endif - // if(threadIdx.x==0){ - // printf("##### block: %d-%d-%d, offset_k:%d, threadblock_tile_offset.m-n-k():%d-%d-%d, params.gemm_k_size:%d \n", - // blockIdx.x, blockIdx.y, blockIdx.z, - // offset_k, - // threadblock_tile_offset.m(), - // threadblock_tile_offset.n(), - // threadblock_tile_offset.k(), - // params.gemm_k_size - // ); - // } +#endif + // if(threadIdx.x==0){ + // printf("##### block: %d-%d-%d, offset_k:%d, + // threadblock_tile_offset.m-n-k():%d-%d-%d, params.gemm_k_size:%d + // \n", + // blockIdx.x, blockIdx.y, blockIdx.z, + // offset_k, + // threadblock_tile_offset.m(), + // threadblock_tile_offset.n(), + // threadblock_tile_offset.k(), + // params.gemm_k_size + // ); + // } // Compute initial location in logical coordinates cutlass::MatrixCoord tb_offset_A{ @@ -493,21 +500,18 @@ struct MoeW4A8GemmWithEpilogueVisitorInterleavedNf4 { 0, }; - cutlass::MatrixCoord tb_offset_B{ - 0, - threadblock_offset.column() / kInterleave}; + 0, threadblock_offset.column() / kInterleave}; // Compute position within threadblock int thread_idx = threadIdx.x; // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A( - params.params_A, - ptr_A, - {problem_size.m(), problem_size_k}, - thread_idx, - tb_offset_A); + typename Mma::IteratorA iterator_A(params.params_A, + ptr_A, + {problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A); typename Mma::IteratorB iterator_B( params.params_B, @@ -516,13 +520,11 @@ struct MoeW4A8GemmWithEpilogueVisitorInterleavedNf4 { thread_idx, tb_offset_B); typename Mma::IteratorNF4LookUpTable iterator_nf4_look_up_table = - Mma::IteratorNF4LookUpTable( - params.params_nf4_look_up_table, - params.ref_nf4_look_up_table.data(), - {0,16}, - threadIdx.x, - {0,0} - ); + Mma::IteratorNF4LookUpTable(params.params_nf4_look_up_table, + params.ref_nf4_look_up_table.data(), + {0, 16}, + threadIdx.x, + {0, 0}); // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. @@ -542,9 +544,14 @@ struct MoeW4A8GemmWithEpilogueVisitorInterleavedNf4 { // Compute threadblock-scoped matrix multiply-add int gemm_k_iterations = - (problem_size_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + (problem_size_k + Mma::Shape::kK - 1) / Mma::Shape::kK; // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_nf4_look_up_table, accumulators); + mma(gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + iterator_nf4_look_up_table, + accumulators); // if(threadIdx.x==0){ // printf("##### block: %d-%d-%d, offset-m-n-k:%d-%d-%d \n", // blockIdx.x, blockIdx.y, blockIdx.z, @@ -560,44 +567,54 @@ struct MoeW4A8GemmWithEpilogueVisitorInterleavedNf4 { threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - ElementC* ptr_C = - reinterpret_cast(params.ptr_C) + rows_to_jump * gemm_n; - ElementC* ptr_D = - reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; + ElementC* ptr_C = + reinterpret_cast(params.ptr_C) + rows_to_jump * gemm_n; + ElementC* ptr_D = + reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; - using Element_scale = typename EpilogueVisitor::ScaleTileIterator::Element; - Element_scale* ptr_alpha_row = params.ptr_alpha_row == nullptr ? params.ptr_alpha_row : reinterpret_cast(params.ptr_alpha_row) + rows_to_jump; - Element_scale* ptr_alpha_col = reinterpret_cast(params.ptr_alpha_col) + problem_idx * params.problem_size.n(); + using Element_scale = + typename EpilogueVisitor::ScaleTileIterator::Element; + Element_scale* ptr_alpha_row = + params.ptr_alpha_row == nullptr + ? params.ptr_alpha_row + : reinterpret_cast(params.ptr_alpha_row) + + rows_to_jump; + Element_scale* ptr_alpha_col = + reinterpret_cast(params.ptr_alpha_col) + + problem_idx * params.problem_size.n(); // if (threadIdx.x == 0) - // printf("##### block: %d-%d-%d, ptr_alpha_row:%p,(%f) ptr_alpha_col:%p,(%f)\n", blockIdx.x, blockIdx.y, blockIdx.z, ptr_alpha_row, static_cast(*ptr_alpha_row), ptr_alpha_col, static_cast(*ptr_alpha_col)); + // printf("##### block: %d-%d-%d, ptr_alpha_row:%p,(%f) + // ptr_alpha_col:%p,(%f)\n", blockIdx.x, blockIdx.y, blockIdx.z, + // ptr_alpha_row, static_cast(*ptr_alpha_row), ptr_alpha_col, + // static_cast(*ptr_alpha_col)); // // Construct the epilogue visitor // EpilogueVisitor epilogue_visitor(params.epilogue_visitor, - shared_storage.epilogue.visitor, - problem_size.mn(), - thread_idx, - warp_idx, - lane_idx, - params.params_alpha_col, - params.params_C, - params.params_D, - params.quant_mode, - ptr_alpha_row, - ptr_alpha_col, - ptr_C, - ptr_D, - threadblock_offset, - blockIdx.y * params.problem_size.m()); + shared_storage.epilogue.visitor, + problem_size.mn(), + thread_idx, + warp_idx, + lane_idx, + params.params_alpha_col, + params.params_C, + params.params_D, + params.quant_mode, + ptr_alpha_row, + ptr_alpha_col, + ptr_C, + ptr_D, + threadblock_offset, + blockIdx.y * params.problem_size.m()); if (params.mode == GemmUniversalMode::kGemm) { // Indicate which position in a serial reduction the output operator is // currently updating epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), - params.grid_tiled_shape.k()); + params.grid_tiled_shape.k()); } else if (params.mode == GemmUniversalMode::kBatched || - params.mode == GemmUniversalMode::kArray) { + params.mode == GemmUniversalMode::kArray) { epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); } @@ -608,9 +625,9 @@ struct MoeW4A8GemmWithEpilogueVisitorInterleavedNf4 { // Execute the epilogue operator to update the destination tensor. epilogue(epilogue_visitor, accumulators); - // Next tile - problem_visitor.advance(gridDim.x); - } + // Next tile + problem_visitor.advance(gridDim.x); + } } }; diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/weight_process_utils.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/weight_process_utils.h index 68779ba28c..dd6536c762 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/weight_process_utils.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/weight_process_utils.h @@ -32,482 +32,496 @@ limitations under the License. */ void row_major_to_column_major(int8_t* col_major_tensor, const int8_t* row_major_tensor, - const std::vector& shape){ - size_t m = shape[0]; - size_t n = shape[1]; - for(auto i=0;i& shape) { + size_t m = shape[0]; + size_t n = shape[1]; + for (auto i = 0; i < m * n; i++) { + size_t im = i / n; + size_t in = i % n; + col_major_tensor[in * m + im] = row_major_tensor[im * n + in]; + } } void add_bias_and_interleave_int8s_inplace(int8_t* int8_tensor_ptr, - int64_t num_elts) -{ - int8_t* int8_tensor = reinterpret_cast(int8_tensor_ptr); - for (int ii = 0; ii < num_elts; ++ii) { - int8_tensor[ii] = int8_t(int(int8_tensor[ii]) + 128); - // int8_tensor[ii] = int8_t(int(int8_tensor[ii])); - } + int64_t num_elts) { + int8_t* int8_tensor = reinterpret_cast(int8_tensor_ptr); + for (int ii = 0; ii < num_elts; ++ii) { + int8_tensor[ii] = int8_t(int(int8_tensor[ii]) + 128); + // int8_tensor[ii] = int8_t(int(int8_tensor[ii])); + } - // Step 2 will transform the layout of a 32-bit register in CUDA in order to match the int4 layout. This has no - // performance benefit and is purely so that int4 and int8 have the same layout. - // Pictorially, this does the following: - // bit 32 0 - // [elt_3 elt_2 elt_1 elt_0] (each elt occupies 8 bits) - // - // And it will rearrange the output 32 bit register to be the following: - // bit 32 0 - // [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits) + // Step 2 will transform the layout of a 32-bit register in CUDA in order to + // match the int4 layout. This has no performance benefit and is purely so + // that int4 and int8 have the same layout. Pictorially, this does the + // following: bit 32 0 + // [elt_3 elt_2 elt_1 elt_0] (each elt occupies 8 bits) + // + // And it will rearrange the output 32 bit register to be the following: + // bit 32 0 + // [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits) - for (int64_t base = 0; base < num_elts; base += 4) { - std::swap(int8_tensor[base + 1], int8_tensor[base + 2]); - } + for (int64_t base = 0; base < num_elts; base += 4) { + std::swap(int8_tensor[base + 1], int8_tensor[base + 2]); + } } +void subbyte_transpose_impl_int4(int8_t* transposed_quantized_tensor, + const int8_t* quantized_tensor, + const std::vector& shape) { + const int bits_per_elt = 4; -void subbyte_transpose_impl_int4(int8_t* transposed_quantized_tensor, - const int8_t* quantized_tensor, - const std::vector& shape) -{ - const int bits_per_elt = 4; + const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; - const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; - const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; - const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + const size_t col_bytes = num_cols * bits_per_elt / 8; + const size_t col_bytes_trans = num_rows * bits_per_elt / 8; + const size_t num_bytes = size_t(num_experts) * num_rows * col_bytes; - const size_t col_bytes = num_cols * bits_per_elt / 8; - const size_t col_bytes_trans = num_rows * bits_per_elt / 8; - const size_t num_bytes = size_t(num_experts) * num_rows * col_bytes; + const uint8_t* input_byte_ptr = + reinterpret_cast(quantized_tensor); + uint8_t* output_byte_ptr = + reinterpret_cast(transposed_quantized_tensor); - const uint8_t* input_byte_ptr = reinterpret_cast(quantized_tensor); - uint8_t* output_byte_ptr = reinterpret_cast(transposed_quantized_tensor); + // static_assert(quant_type == QuantType::INT8_WEIGHT_ONLY || quant_type == + // QuantType::PACKED_INT4_WEIGHT_ONLY, ""); + static constexpr int ELTS_PER_BYTE = 2; - // static_assert(quant_type == QuantType::INT8_WEIGHT_ONLY || quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY, ""); - static constexpr int ELTS_PER_BYTE = 2; + static constexpr int M_TILE_L1 = 64; + static constexpr int N_TILE_L1 = M_TILE_L1 / ELTS_PER_BYTE; + uint8_t cache_buf[M_TILE_L1][N_TILE_L1]; - static constexpr int M_TILE_L1 = 64; - static constexpr int N_TILE_L1 = M_TILE_L1 / ELTS_PER_BYTE; - uint8_t cache_buf[M_TILE_L1][N_TILE_L1]; + static constexpr int VECTOR_WIDTH = std::min(32, N_TILE_L1); - static constexpr int VECTOR_WIDTH = std::min(32, N_TILE_L1); + // We assume the dims are a multiple of vector width. Our kernels only handle + // dims which are multiples of 64 for weight-only quantization. As a result, + // this seemed like a reasonable tradeoff because it allows GCC to emit vector + // instructions. - // We assume the dims are a multiple of vector width. Our kernels only handle dims which are multiples - // of 64 for weight-only quantization. As a result, this seemed like a reasonable tradeoff because it - // allows GCC to emit vector instructions. + const int num_m_tiles = (num_rows + M_TILE_L1 - 1) / M_TILE_L1; + const int num_n_tiles = (col_bytes + N_TILE_L1 - 1) / N_TILE_L1; - const int num_m_tiles = (num_rows + M_TILE_L1 - 1) / M_TILE_L1; - const int num_n_tiles = (col_bytes + N_TILE_L1 - 1) / N_TILE_L1; + for (size_t expert = 0; expert < num_experts; ++expert) { + const size_t matrix_offset = expert * num_rows * col_bytes; + for (size_t row_tile_start = 0; row_tile_start < num_rows; + row_tile_start += M_TILE_L1) { + for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes; + col_tile_start_byte += N_TILE_L1) { + const int row_limit = std::min(row_tile_start + M_TILE_L1, num_rows); + const int col_limit = + std::min(col_tile_start_byte + N_TILE_L1, col_bytes); - for (size_t expert = 0; expert < num_experts; ++expert) { - const size_t matrix_offset = expert * num_rows * col_bytes; - for (size_t row_tile_start = 0; row_tile_start < num_rows; row_tile_start += M_TILE_L1) { - for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes; col_tile_start_byte += N_TILE_L1) { + for (int ii = 0; ii < M_TILE_L1; ++ii) { + const int row = row_tile_start + ii; - const int row_limit = std::min(row_tile_start + M_TILE_L1, num_rows); - const int col_limit = std::min(col_tile_start_byte + N_TILE_L1, col_bytes); + for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) { + const int col = col_tile_start_byte + jj; - for (int ii = 0; ii < M_TILE_L1; ++ii) { - const int row = row_tile_start + ii; + const size_t logical_src_offset = + matrix_offset + row * col_bytes + col; - for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) { - const int col = col_tile_start_byte + jj; - - const size_t logical_src_offset = matrix_offset + row * col_bytes + col; - - if (row < row_limit && col < col_limit) { - for (int v = 0; v < VECTOR_WIDTH; ++v) { - cache_buf[ii][jj + v] = input_byte_ptr[logical_src_offset + v]; - } - } - } - } - - - for (int ii = 0; ii < M_TILE_L1; ++ii) { - // Using M_TILE_L1 here is deliberate since we assume that the cache tile - // is square in the number of elements (not necessarily the number of bytes). - for (int jj = ii + 1; jj < M_TILE_L1; ++jj) { - const int ii_byte = ii / ELTS_PER_BYTE; - const int ii_bit_offset = ii % ELTS_PER_BYTE; - - const int jj_byte = jj / ELTS_PER_BYTE; - const int jj_bit_offset = jj % ELTS_PER_BYTE; - - uint8_t src_elt = 0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset)); - uint8_t tgt_elt = 0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset)); - - cache_buf[ii][jj_byte] &= (0xF0 >> (4 * jj_bit_offset)); - cache_buf[jj][ii_byte] &= (0xF0 >> (4 * ii_bit_offset)); - - cache_buf[ii][jj_byte] |= (tgt_elt << (4 * jj_bit_offset)); - cache_buf[jj][ii_byte] |= (src_elt << (4 * ii_bit_offset)); - } - } - - - const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE; - const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE; - - const int row_limit_trans = std::min(row_tile_start_trans + M_TILE_L1, num_cols); - const int col_limit_trans = std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans); - - for (int ii = 0; ii < M_TILE_L1; ++ii) { - const int row = row_tile_start_trans + ii; - for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) { - const int col = col_tile_start_byte_trans + jj; - - const size_t logical_tgt_offset = matrix_offset + row * col_bytes_trans + col; - - if (row < row_limit_trans && col < col_limit_trans) { - for (int v = 0; v < VECTOR_WIDTH; ++v) { - output_byte_ptr[logical_tgt_offset + v] = cache_buf[ii][jj + v]; - } - } - } - } + if (row < row_limit && col < col_limit) { + for (int v = 0; v < VECTOR_WIDTH; ++v) { + cache_buf[ii][jj + v] = input_byte_ptr[logical_src_offset + v]; + } } + } } + + for (int ii = 0; ii < M_TILE_L1; ++ii) { + // Using M_TILE_L1 here is deliberate since we assume that the cache + // tile is square in the number of elements (not necessarily the + // number of bytes). + for (int jj = ii + 1; jj < M_TILE_L1; ++jj) { + const int ii_byte = ii / ELTS_PER_BYTE; + const int ii_bit_offset = ii % ELTS_PER_BYTE; + + const int jj_byte = jj / ELTS_PER_BYTE; + const int jj_bit_offset = jj % ELTS_PER_BYTE; + + uint8_t src_elt = + 0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset)); + uint8_t tgt_elt = + 0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset)); + + cache_buf[ii][jj_byte] &= (0xF0 >> (4 * jj_bit_offset)); + cache_buf[jj][ii_byte] &= (0xF0 >> (4 * ii_bit_offset)); + + cache_buf[ii][jj_byte] |= (tgt_elt << (4 * jj_bit_offset)); + cache_buf[jj][ii_byte] |= (src_elt << (4 * ii_bit_offset)); + } + } + + const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE; + const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE; + + const int row_limit_trans = + std::min(row_tile_start_trans + M_TILE_L1, num_cols); + const int col_limit_trans = + std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans); + + for (int ii = 0; ii < M_TILE_L1; ++ii) { + const int row = row_tile_start_trans + ii; + for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) { + const int col = col_tile_start_byte_trans + jj; + + const size_t logical_tgt_offset = + matrix_offset + row * col_bytes_trans + col; + + if (row < row_limit_trans && col < col_limit_trans) { + for (int v = 0; v < VECTOR_WIDTH; ++v) { + output_byte_ptr[logical_tgt_offset + v] = cache_buf[ii][jj + v]; + } + } + } + } + } } + } } +void add_bias_and_interleave_int4s_inplace(int8_t* packed_int4_tensor, + const size_t num_elts) { + const int num_bytes = num_elts / 2; -void add_bias_and_interleave_int4s_inplace(int8_t* packed_int4_tensor, const size_t num_elts) -{ - const int num_bytes = num_elts / 2; + // Step 1 will be to transform all the int4s to unsigned in order to make the + // dequantize take as little instructions as possible in the CUDA code. + for (size_t ii = 0; ii < num_bytes; ++ii) { + int8_t transformed_packed_int4s = 0; + // We don't need to mask in these ops since everything should be in the + // range 0-15 + int8_t transformed_first_elt = (packed_int4_tensor[ii] & 0x0F); + int8_t transformed_second_elt = (packed_int4_tensor[ii] >> 4); - // Step 1 will be to transform all the int4s to unsigned in order to make the dequantize take as little - // instructions as possible in the CUDA code. - for (size_t ii = 0; ii < num_bytes; ++ii) { - int8_t transformed_packed_int4s = 0; - // We don't need to mask in these ops since everything should be in the range 0-15 - int8_t transformed_first_elt = (packed_int4_tensor[ii] & 0x0F); - int8_t transformed_second_elt = (packed_int4_tensor[ii] >> 4); + transformed_packed_int4s |= transformed_first_elt; + transformed_packed_int4s |= (transformed_second_elt << 4); + packed_int4_tensor[ii] = transformed_packed_int4s; + } - transformed_packed_int4s |= transformed_first_elt; - transformed_packed_int4s |= (transformed_second_elt << 4); - packed_int4_tensor[ii] = transformed_packed_int4s; - } - - // Step 2 will transform the layout of a 32-bit register in CUDA in order to minimize the number of shift & logical - // instructions That are needed to extract the int4s in the GEMM main loop. Pictorially, the loop below will do the - // following: Take as input a 32 bit register with layout: bit 32 0 - // [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt occupies 4 bits) - // - // And it will rearrange the output 32 bit register to be the following: - // bit 32 0 - // [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt occupies 4 bits) - - // FT_CHECK_WITH_INFO(num_bytes % 4 == 0, "Dimensions of int4 tensor must be a multiple of 8 for register relayout"); - const size_t num_registers = num_bytes / 4; - - uint32_t* register_ptr = reinterpret_cast(packed_int4_tensor); - for (size_t ii = 0; ii < num_registers; ++ii) { - const uint32_t current_register = register_ptr[ii]; - uint32_t transformed_register = 0; - - for (int dest_idx = 0; dest_idx < 8; ++dest_idx) { - const int src_idx = dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1; - const int src_shift = 4 * src_idx; - const int dest_shift = 4 * dest_idx; - - const uint32_t src_bits = (current_register >> src_shift) & 0xF; - transformed_register |= (src_bits << dest_shift); - - } - register_ptr[ii] = transformed_register; + // Step 2 will transform the layout of a 32-bit register in CUDA in order to + // minimize the number of shift & logical instructions That are needed to + // extract the int4s in the GEMM main loop. Pictorially, the loop below will + // do the following: Take as input a 32 bit register with layout: bit 32 0 + // [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt + // occupies 4 bits) + // + // And it will rearrange the output 32 bit register to be the following: + // bit 32 0 + // [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt + // occupies 4 bits) + + // FT_CHECK_WITH_INFO(num_bytes % 4 == 0, "Dimensions of int4 tensor must be a + // multiple of 8 for register relayout"); + const size_t num_registers = num_bytes / 4; + + uint32_t* register_ptr = reinterpret_cast(packed_int4_tensor); + for (size_t ii = 0; ii < num_registers; ++ii) { + const uint32_t current_register = register_ptr[ii]; + uint32_t transformed_register = 0; + + for (int dest_idx = 0; dest_idx < 8; ++dest_idx) { + const int src_idx = dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1; + const int src_shift = 4 * src_idx; + const int dest_shift = 4 * dest_idx; + + const uint32_t src_bits = (current_register >> src_shift) & 0xF; + transformed_register |= (src_bits << dest_shift); } + register_ptr[ii] = transformed_register; + } } -void permute_B_rows_for_mixed_and_int8_gemm(int8_t* permuted_quantized_tensor, - const int8_t* quantized_tensor, +void permute_B_rows_for_mixed_and_int8_gemm(int8_t* permuted_quantized_tensor, + const int8_t* quantized_tensor, const std::vector& shape, - const int64_t arch_version) -{ + const int64_t arch_version) { + // We only want to run this step for weight only quant. + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; - // We only want to run this step for weight only quant. - const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; - const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + const int BITS_PER_ELT = 8; + const int K = 16 / BITS_PER_ELT; + const int ELTS_PER_BYTE = 8 / BITS_PER_ELT; + const int ELTS_PER_REG = 32 / BITS_PER_ELT; - const int BITS_PER_ELT = 8; - const int K = 16 / BITS_PER_ELT; - const int ELTS_PER_BYTE = 8 / BITS_PER_ELT; - const int ELTS_PER_REG = 32 / BITS_PER_ELT; + const uint32_t* input_byte_ptr = + reinterpret_cast(quantized_tensor); + uint32_t* output_byte_ptr = + reinterpret_cast(permuted_quantized_tensor); - const uint32_t* input_byte_ptr = reinterpret_cast(quantized_tensor); - uint32_t* output_byte_ptr = reinterpret_cast(permuted_quantized_tensor); + int MMA_SHAPE_N = 8; + int B_ROWS_PER_MMA = 8 * K; + const int elts_in_int32 = 32 / BITS_PER_ELT; - int MMA_SHAPE_N = 8; - int B_ROWS_PER_MMA = 8 * K; - const int elts_in_int32 = 32 / BITS_PER_ELT; + const int num_vec_cols = num_cols / elts_in_int32; - const int num_vec_cols = num_cols / elts_in_int32; + // The code is written as below so it works for both int8 and packed int4. + for (int base_row = 0; base_row < num_rows; base_row += B_ROWS_PER_MMA) { + for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) { + for (int write_col = 0; write_col < num_vec_cols; ++write_col) { + const int write_row = base_row + tile_row; + const int tile_read_row = 4 * (((tile_row % ELTS_PER_REG) / 2)) + + tile_row % 2 + 2 * (tile_row / ELTS_PER_REG); - // The code is written as below so it works for both int8 and packed int4. - for (int base_row = 0; base_row < num_rows; base_row += B_ROWS_PER_MMA) { - for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) { + const int read_row = base_row + tile_read_row; + const int read_col = write_col; - for (int write_col = 0; write_col < num_vec_cols; ++write_col) { - const int write_row = base_row + tile_row; - const int tile_read_row = - 4 * (((tile_row % ELTS_PER_REG) / 2)) + tile_row % 2 + 2 * (tile_row / ELTS_PER_REG); + const int64_t read_offset = int64_t(read_row) * num_vec_cols + read_col; + const int64_t write_offset = + int64_t(write_row) * num_vec_cols + write_col; - const int read_row = base_row + tile_read_row; - const int read_col = write_col; - - const int64_t read_offset = int64_t(read_row) * num_vec_cols + read_col; - const int64_t write_offset = int64_t(write_row) * num_vec_cols + write_col; - - output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; - } - } + output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; + } } + } } -// Permutes the rows of B for Turing and Ampere. Throws an error for other architectures. -// The data is permuted such that: -// For int8, each group of 16 rows is permuted using the map below: +// Permutes the rows of B for Turing and Ampere. Throws an error for other +// architectures. The data is permuted such that: For int8, each group of 16 +// rows is permuted using the map below: // 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15 // 0 1 2 3 4 5 6 7 -template -void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, - const int8_t* quantized_tensor, +template +void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, + const int8_t* quantized_tensor, const std::vector& shape, - const int64_t arch_version) -{ + const int64_t arch_version) { + // We only want to run this step for weight only quant. + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; - // We only want to run this step for weight only quant. - const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; - const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + const int BITS_PER_ELT = bits; + const int K = 16 / BITS_PER_ELT; + const int ELTS_PER_BYTE = 8 / BITS_PER_ELT; + const int ELTS_PER_REG = 32 / BITS_PER_ELT; - const int BITS_PER_ELT = bits; - const int K = 16 / BITS_PER_ELT; - const int ELTS_PER_BYTE = 8 / BITS_PER_ELT; - const int ELTS_PER_REG = 32 / BITS_PER_ELT; + const uint32_t* input_byte_ptr = + reinterpret_cast(quantized_tensor); + uint32_t* output_byte_ptr = + reinterpret_cast(permuted_quantized_tensor); - const uint32_t* input_byte_ptr = reinterpret_cast(quantized_tensor); - uint32_t* output_byte_ptr = reinterpret_cast(permuted_quantized_tensor); + int MMA_SHAPE_N = 8; + int B_ROWS_PER_MMA = 8 * K; + const int elts_in_int32 = 32 / BITS_PER_ELT; - int MMA_SHAPE_N = 8; - int B_ROWS_PER_MMA = 8 * K; - const int elts_in_int32 = 32 / BITS_PER_ELT; + const int num_vec_cols = num_cols / elts_in_int32; - const int num_vec_cols = num_cols / elts_in_int32; - - // The code is written as below so it works for both int8 and packed int4. - for (int base_row = 0; base_row < num_rows; base_row += B_ROWS_PER_MMA) { - for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) { - - for (int write_col = 0; write_col < num_vec_cols; ++write_col) { - const int write_row = base_row + tile_row; - const int tile_read_row = - 8 * (((tile_row % ELTS_PER_REG) / 2)) + tile_row % 2 + 2 * (tile_row / ELTS_PER_REG); - if(base_row == 0 && write_col == 0){ - std::cout<<"tile_read_row:"< -void permute_B_rows_for_mixed_gemm_int4(int8_t* permuted_quantized_tensor, - const int8_t* quantized_tensor, - const std::vector& shape, - const int64_t arch_version) -{ +template +void permute_B_rows_for_mixed_gemm_int4(int8_t* permuted_quantized_tensor, + const int8_t* quantized_tensor, + const std::vector& shape, + const int64_t arch_version) { + // We only want to run this step for weight only quant. + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; - // We only want to run this step for weight only quant. - const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; - const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + const int BITS_PER_ELT = bits; // 4 + const int K = 16 / BITS_PER_ELT; // 4 + const int ELTS_PER_BYTE = 8 / BITS_PER_ELT; // 2 + const int ELTS_PER_REG = 32 / BITS_PER_ELT; // 8 - const int BITS_PER_ELT = bits; //4 - const int K = 16 / BITS_PER_ELT; // 4 - const int ELTS_PER_BYTE = 8 / BITS_PER_ELT; // 2 - const int ELTS_PER_REG = 32 / BITS_PER_ELT; // 8 + const uint32_t* input_byte_ptr = + reinterpret_cast(quantized_tensor); + uint32_t* output_byte_ptr = + reinterpret_cast(permuted_quantized_tensor); - const uint32_t* input_byte_ptr = reinterpret_cast(quantized_tensor); - uint32_t* output_byte_ptr = reinterpret_cast(permuted_quantized_tensor); + int MMA_SHAPE_N = 8; + int B_ROWS_PER_MMA = 8 * K; // 32 + const int elts_in_int32 = 32 / BITS_PER_ELT; - int MMA_SHAPE_N = 8; - int B_ROWS_PER_MMA = 8 * K; // 32 - const int elts_in_int32 = 32 / BITS_PER_ELT; + const int num_vec_cols = num_cols / elts_in_int32; + const std::vector tile_col_map{0, 2, 16, 18, 1, 3, 17, 19, 4, 6, 20, + 22, 5, 7, 21, 23, 8, 10, 24, 26, 9, 11, + 25, 27, 12, 14, 28, 30, 13, 15, 29, 31}; - const int num_vec_cols = num_cols / elts_in_int32; - const std::vector tile_col_map{ - 0,2,16,18, - 1,3,17,19, - 4,6,20,22, - 5,7,21,23, - 8,10,24,26, - 9,11,25,27, - 12,14,28,30, - 13,15,29,31}; + // const std::vector tile_col_map{ + // 0 0,2,16,18, + // 4 1,3,17,19, + // 8 4,6,20,22, + // 12 5,7,21,23, + // 16 8,10,24,26, + // 20 9,11,25,27, + // 24 12,14,28,30, + // 28 13,15,29,31}; + // std::vector tile_col_map(32); + // for(int i=0;i<32;i++){ + // tile_col_map[i]=i; + // } + // // tile_col_map[1]=4; + // tile_col_map[0]=0; + // tile_col_map[4]=1; + // tile_col_map[1]=2; + // tile_col_map[5]=3; + // tile_col_map[8]=4; + // tile_col_map[12]=5; + // tile_col_map[9]=6; + // tile_col_map[13]=7; + // tile_col_map[16]=8; + // tile_col_map[20]=9; + // tile_col_map[17]=10; + // tile_col_map[21]=11; + // tile_col_map[24]=12; + // tile_col_map[28]=13; + // tile_col_map[25]=14; + // tile_col_map[29]=15; - // const std::vector tile_col_map{ - // 0 0,2,16,18, - // 4 1,3,17,19, - // 8 4,6,20,22, - // 12 5,7,21,23, - // 16 8,10,24,26, - // 20 9,11,25,27, - // 24 12,14,28,30, - // 28 13,15,29,31}; - // std::vector tile_col_map(32); - // for(int i=0;i<32;i++){ - // tile_col_map[i]=i; - // } - // // tile_col_map[1]=4; - // tile_col_map[0]=0; - // tile_col_map[4]=1; - // tile_col_map[1]=2; - // tile_col_map[5]=3; - // tile_col_map[8]=4; - // tile_col_map[12]=5; - // tile_col_map[9]=6; - // tile_col_map[13]=7; - // tile_col_map[16]=8; - // tile_col_map[20]=9; - // tile_col_map[17]=10; - // tile_col_map[21]=11; - // tile_col_map[24]=12; - // tile_col_map[28]=13; - // tile_col_map[25]=14; - // tile_col_map[29]=15; + // tile_col_map[4]=1; + // tile_col_map[4]=1; + // tile_col_map[4]=2; - // tile_col_map[4]=1; - // tile_col_map[4]=1; - // tile_col_map[4]=2; - - // The code is written as below so it works for both int8 and packed int4. - for (int base_row = 0; base_row < num_rows; base_row += B_ROWS_PER_MMA) { - for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) { - - for (int write_col = 0; write_col < num_vec_cols; ++write_col) { - const int write_row = base_row + tile_row; - // const int tile_read_row = - // 8 * (((tile_row % ELTS_PER_REG) / 2)) + tile_row % 2 + 2 * (tile_row / ELTS_PER_REG); - // const int tile_read_row = std::distance(tile_col_map.begin(), std::find(tile_col_map.begin(),tile_col_map.end(), tile_row)); - const int tile_read_row = tile_col_map[tile_row]; - if(base_row == 0 && write_col == 0){ - std::cout<<" write_row:"<& shape) { + // We only want to run this step for weight only quant. + std::cout << "### in interleave_column_major_tensor" << std::endl; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; -void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, - const int8_t* quantized_tensor, - const std::vector& shape) -{ + const size_t BITS_PER_ELT = 8; + const size_t elts_in_int32 = 32 / BITS_PER_ELT; - // We only want to run this step for weight only quant. - std::cout<<"### in interleave_column_major_tensor"<(quantized_tensor); + uint32_t* output_byte_ptr = + reinterpret_cast(interleaved_quantized_tensor); - const size_t rows_per_tile = 64; - std::cout<<"running interleave_column_major_tensor"<(quantized_tensor); - uint32_t* output_byte_ptr = reinterpret_cast(interleaved_quantized_tensor); - - - const size_t num_vec_rows = num_rows / elts_in_int32; - const size_t vec_rows_per_tile = rows_per_tile / elts_in_int32; - const size_t interleave = 2; - std::cout<<"num_vec_rows:"<& shape) { + // We only want to run this step for weight only quant. + std::cout << "### in interleave_column_major_tensor" << std::endl; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; -void interleave_column_major_tensor_int4(int8_t* interleaved_quantized_tensor, - const int8_t* quantized_tensor, - const std::vector& shape) -{ + const size_t BITS_PER_ELT = 4; + const size_t elts_in_int32 = 32 / BITS_PER_ELT; - // We only want to run this step for weight only quant. - std::cout<<"### in interleave_column_major_tensor"<(quantized_tensor); + uint32_t* output_byte_ptr = + reinterpret_cast(interleaved_quantized_tensor); - const size_t rows_per_tile = 64; - std::cout<<"running interleave_column_major_tensor"<(quantized_tensor); - uint32_t* output_byte_ptr = reinterpret_cast(interleaved_quantized_tensor); - - - const size_t num_vec_rows = num_rows / elts_in_int32; - const size_t vec_rows_per_tile = rows_per_tile / elts_in_int32; - const size_t interleave = 4; - std::cout<<"num_vec_rows:"< -get_problem_shape(paddle::Tensor const &a, paddle::Tensor const &b) { +static inline cute::Shape get_problem_shape( + paddle::Tensor const &a, paddle::Tensor const &b) { int32_t m = a.dims()[0], n = b.dims()[0], k = a.dims()[1]; return {m, n, k, 1}; } template void cutlass_gemm_caller( - phi::Place device, cute::Shape prob_shape, + phi::Place device, + cute::Shape prob_shape, typename GemmKernel::MainloopArguments mainloop_args, typename GemmKernel::EpilogueArguments epilogue_args, typename GemmKernel::TileSchedulerArguments scheduler = {}) { @@ -57,7 +59,8 @@ void cutlass_gemm_caller( } template -void cutlass_gemm_caller(paddle::Tensor &out, paddle::Tensor const &a, +void cutlass_gemm_caller(paddle::Tensor &out, + paddle::Tensor const &a, paddle::Tensor const &b, EpilogueArgs &&...epilogue_params) { using ElementAB = typename Gemm::ElementAB; @@ -86,17 +89,20 @@ void cutlass_gemm_caller(paddle::Tensor &out, paddle::Tensor const &a, auto a_ptr = static_cast(const_cast(a.data())); auto b_ptr = static_cast(const_cast(b.data())); - typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, - b_stride}; + typename GemmKernel::MainloopArguments mainloop_args{ + a_ptr, a_stride, b_ptr, b_stride}; auto c_ptr = static_cast(const_cast(out.data())); typename GemmKernel::EpilogueArguments epilogue_args{ Gemm::Epilogue::prepare_args( std::forward(epilogue_params)...), - c_ptr, c_stride, c_ptr, d_stride}; + c_ptr, + c_stride, + c_ptr, + d_stride}; - cutlass_gemm_caller(a.place(), prob_shape, mainloop_args, - epilogue_args); + cutlass_gemm_caller( + a.place(), prob_shape, mainloop_args, epilogue_args); } -} // namespace fastdeploy::c3x +} // namespace fastdeploy::c3x diff --git a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm.cuh b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm.cuh index 26278a79fd..4d8edbd621 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm.cuh +++ b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm.cuh @@ -1,4 +1,5 @@ -// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh +// adapted from: +// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh #pragma once @@ -31,16 +32,19 @@ using namespace cute; namespace fastdeploy { -template typename Epilogue_, - typename TileShape, typename ClusterShape, typename KernelSchedule, +template + typename Epilogue_, + typename TileShape, + typename ClusterShape, + typename KernelSchedule, typename EpilogueSchedule> struct cutlass_3x_gemm { using ElementAB = ElementAB_; using ElementD = ElementD_; - using ElementAcc = - typename std::conditional, int32_t, - float>::type; + using ElementAcc = typename std:: + conditional, int32_t, float>::type; using Epilogue = Epilogue_; @@ -57,10 +61,21 @@ struct cutlass_3x_gemm { using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, - ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, - ElementAcc, float, ElementC, StrideC, AlignmentCD, ElementD, StrideD, - AlignmentCD, EpilogueSchedule, EVTCompute>::CollectiveOp; + cutlass::arch::Sm90, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, + float, + ElementC, + StrideC, + AlignmentCD, + ElementD, + StrideD, + AlignmentCD, + EpilogueSchedule, + EVTCompute>::CollectiveOp; static constexpr size_t CEStorageSize = sizeof(typename CollectiveEpilogue::SharedStorage); @@ -78,16 +93,22 @@ struct cutlass_3x_gemm { KernelSchedule>::CollectiveOp; // clang-format on - using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, - cutlass::gemm::PersistentScheduler>>; + using KernelType = enable_sm90_or_later< + cutlass::gemm::kernel::GemmUniversal, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::PersistentScheduler>>; struct GemmKernel : public KernelType {}; }; -template typename Epilogue_, - typename TileShape, typename ClusterShape, typename KernelSchedule, +template + typename Epilogue_, + typename TileShape, + typename ClusterShape, + typename KernelSchedule, typename EpilogueSchedule> struct cutlass_3x_gemm_sm100 { using ElementAB = ElementAB_; @@ -108,9 +129,8 @@ struct cutlass_3x_gemm_sm100 { using LayoutD = cutlass::layout::RowMajor; static constexpr int AlignmentD = AlignmentC; - using ElementAcc = - typename std::conditional, int32_t, - float>::type; + using ElementAcc = typename std:: + conditional, int32_t, float>::type; using Epilogue = Epilogue_; // MMA type @@ -127,23 +147,44 @@ struct cutlass_3x_gemm_sm100 { using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape, - ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, - ElementD, LayoutD, AlignmentD, EpilogueSchedule, + cutlass::arch::Sm100, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + AlignmentC, + ElementD, + LayoutD, + AlignmentD, + EpilogueSchedule, EVTCompute>::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementAB, - LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB, - ElementAccumulator, TileShape, ClusterShape, + cutlass::arch::Sm100, + cutlass::arch::OpClassTensorOp, + ElementAB, + LayoutA, + AlignmentA, + ElementAB, + LayoutB, + AlignmentB, + ElementAccumulator, + TileShape, + ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>, KernelSchedule>::CollectiveOp; - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, CollectiveMainloop, CollectiveEpilogue, void>; + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal, + CollectiveMainloop, + CollectiveEpilogue, + void>; }; -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu index f5d4d6aa28..704f1021f8 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu @@ -1,4 +1,5 @@ -// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu +// adapted from: +// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu // clang-format will break include orders // clang-format off @@ -10,18 +11,22 @@ namespace fastdeploy { void cutlass_scaled_mm_azp_sm90_int8( - paddle::Tensor &out, paddle::Tensor const &a, paddle::Tensor const &b, - paddle::Tensor const &a_scales, paddle::Tensor const &b_scales, - paddle::Tensor const &azp_adj, paddle::optional const &azp, + paddle::Tensor &out, + paddle::Tensor const &a, + paddle::Tensor const &b, + paddle::Tensor const &a_scales, + paddle::Tensor const &b_scales, + paddle::Tensor const &azp_adj, + paddle::optional const &azp, paddle::optional const &bias) { if (azp) { return cutlass_scaled_mm_sm90_int8_epilogue< - c3x::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales, azp_adj, - *azp, bias); + c3x::ScaledEpilogueBiasAzpToken>( + out, a, b, a_scales, b_scales, azp_adj, *azp, bias); } else { return cutlass_scaled_mm_sm90_int8_epilogue( out, a, b, a_scales, b_scales, azp_adj, bias); } } -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_helper.hpp b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_helper.hpp index 9a601f75ad..2bfa58231f 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_helper.hpp +++ b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_helper.hpp @@ -1,34 +1,38 @@ -// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp +// adapted from: +// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp #include "helper.h" template -void dispatch_scaled_mm(paddle::Tensor &c, paddle::Tensor const &a, - paddle::Tensor const &b, paddle::Tensor const &a_scales, +void dispatch_scaled_mm(paddle::Tensor &c, + paddle::Tensor const &a, + paddle::Tensor const &b, + paddle::Tensor const &a_scales, paddle::Tensor const &b_scales, paddle::optional const &bias, - Fp8Func fp8_func, Int8Func int8_func) { - PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32); - PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32); + Fp8Func fp8_func, + Int8Func int8_func) { + PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32); + PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32); - int M = a.dims()[0], N = b.dims()[0], K = a.dims()[1]; + int M = a.dims()[0], N = b.dims()[0], K = a.dims()[1]; - if ((a_scales.numel() == 1 || a_scales.numel() == a.dims()[0]) && - (b_scales.numel() == 1 || b_scales.numel() == b.dims()[0])) { - // Standard per-tensor/per-token/per-channel scaling - PD_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); - if (a.dtype() == phi::DataType::FLOAT8_E4M3FN) { - fp8_func(c, a, b, a_scales, b_scales, bias); - } else { - PD_CHECK(a.dtype() == paddle::DataType::INT8); - if constexpr (!std::is_same_v) { - int8_func(c, a, b, a_scales, b_scales, bias); - } else { - PD_CHECK(false, "Int8 not supported for this architecture"); - } - } + if ((a_scales.numel() == 1 || a_scales.numel() == a.dims()[0]) && + (b_scales.numel() == 1 || b_scales.numel() == b.dims()[0])) { + // Standard per-tensor/per-token/per-channel scaling + PD_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + if (a.dtype() == phi::DataType::FLOAT8_E4M3FN) { + fp8_func(c, a, b, a_scales, b_scales, bias); } else { - PADDLE_THROW(phi::errors::Unimplemented( - "No kernel for this combination of input dtypes is implemented.")); + PD_CHECK(a.dtype() == paddle::DataType::INT8); + if constexpr (!std::is_same_v) { + int8_func(c, a, b, a_scales, b_scales, bias); + } else { + PD_CHECK(false, "Int8 not supported for this architecture"); + } } + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "No kernel for this combination of input dtypes is implemented.")); + } } diff --git a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_kernels.hpp b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_kernels.hpp index 75472ea805..4e3e67ac5a 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_kernels.hpp +++ b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_kernels.hpp @@ -1,4 +1,5 @@ -// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp +// adapted from: +// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp #pragma once @@ -6,30 +7,35 @@ namespace fastdeploy { -void cutlass_scaled_mm_sm90_fp8(paddle::Tensor &out, paddle::Tensor const &a, +void cutlass_scaled_mm_sm90_fp8(paddle::Tensor &out, + paddle::Tensor const &a, paddle::Tensor const &b, paddle::Tensor const &a_scales, paddle::Tensor const &b_scales, paddle::optional const &bias); -void cutlass_scaled_mm_sm90_int8(paddle::Tensor &out, paddle::Tensor const &a, +void cutlass_scaled_mm_sm90_int8(paddle::Tensor &out, + paddle::Tensor const &a, paddle::Tensor const &b, paddle::Tensor const &a_scales, paddle::Tensor const &b_scales, paddle::optional const &bias); -void cutlass_scaled_mm_azp_sm90_int8(paddle::Tensor& out, paddle::Tensor const& a, - paddle::Tensor const& b, - paddle::Tensor const& a_scales, - paddle::Tensor const& b_scales, - paddle::Tensor const& azp_adj, - paddle::optional const& azp, - paddle::optional const& bias); +void cutlass_scaled_mm_azp_sm90_int8( + paddle::Tensor &out, + paddle::Tensor const &a, + paddle::Tensor const &b, + paddle::Tensor const &a_scales, + paddle::Tensor const &b_scales, + paddle::Tensor const &azp_adj, + paddle::optional const &azp, + paddle::optional const &bias); -void cutlass_scaled_mm_sm100_fp8(paddle::Tensor &out, paddle::Tensor const &a, +void cutlass_scaled_mm_sm100_fp8(paddle::Tensor &out, + paddle::Tensor const &a, paddle::Tensor const &b, paddle::Tensor const &a_scales, paddle::Tensor const &b_scales, paddle::optional const &bias); -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8.cu b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8.cu index 801e90fd73..1b197b8028 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8.cu @@ -1,4 +1,5 @@ -// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu +// adapted from: +// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu // clang-format will break include orders // clang-format off @@ -9,7 +10,8 @@ namespace fastdeploy { -void cutlass_scaled_mm_sm90_fp8(paddle::Tensor &out, paddle::Tensor const &a, +void cutlass_scaled_mm_sm90_fp8(paddle::Tensor &out, + paddle::Tensor const &a, paddle::Tensor const &b, paddle::Tensor const &a_scales, paddle::Tensor const &b_scales, @@ -17,7 +19,8 @@ void cutlass_scaled_mm_sm90_fp8(paddle::Tensor &out, paddle::Tensor const &a, PD_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); if (bias) { PD_CHECK(bias->dtype() == out.dtype(), - "currently bias dtype must match output dtype ", out.dtype()); + "currently bias dtype must match output dtype ", + out.dtype()); return cutlass_scaled_mm_sm90_fp8_epilogue( out, a, b, a_scales, b_scales, *bias); } else { @@ -25,4 +28,4 @@ void cutlass_scaled_mm_sm90_fp8(paddle::Tensor &out, paddle::Tensor const &a, out, a, b, a_scales, b_scales); } } -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh index ac86aeba85..cd0eda3ad9 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh +++ b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh @@ -1,4 +1,5 @@ -// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh +// adapted from: +// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh #pragma once @@ -17,8 +18,10 @@ namespace fastdeploy { using c3x::cutlass_gemm_caller; -template typename Epilogue> +template + typename Epilogue> struct sm90_fp8_config_default { // M in (128, inf) static_assert(std::is_same()); @@ -27,13 +30,19 @@ struct sm90_fp8_config_default { using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; + using Cutlass3xGemm = cutlass_3x_gemm; }; -template typename Epilogue> +template + typename Epilogue> struct sm90_fp8_config_M128 { // M in (64, 128] static_assert(std::is_same()); @@ -42,13 +51,19 @@ struct sm90_fp8_config_M128 { using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_64, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; + using Cutlass3xGemm = cutlass_3x_gemm; }; -template typename Epilogue> +template + typename Epilogue> struct sm90_fp8_config_M64 { // M in [1, 64] static_assert(std::is_same()); @@ -58,13 +73,19 @@ struct sm90_fp8_config_M64 { using TileShape = Shape<_64, _64, _128>; using ClusterShape = Shape<_1, _8, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; + using Cutlass3xGemm = cutlass_3x_gemm; }; -template typename Epilogue, +template + typename Epilogue, typename... EpilogueArgs> inline void cutlass_gemm_sm90_fp8_dispatch(paddle::Tensor &out, paddle::Tensor const &a, @@ -75,8 +96,8 @@ inline void cutlass_gemm_sm90_fp8_dispatch(paddle::Tensor &out, PD_CHECK(b.dtype() == phi::DataType::FLOAT8_E4M3FN); using Cutlass3xGemmDefault = - typename sm90_fp8_config_default::Cutlass3xGemm; + typename sm90_fp8_config_default:: + Cutlass3xGemm; using Cutlass3xGemmM64 = typename sm90_fp8_config_M64::Cutlass3xGemm; using Cutlass3xGemmM128 = @@ -84,7 +105,7 @@ inline void cutlass_gemm_sm90_fp8_dispatch(paddle::Tensor &out, uint32_t const m = a.dims()[0]; uint32_t const mp2 = - std::max(static_cast(64), next_pow_2(m)); // next power of 2 + std::max(static_cast(64), next_pow_2(m)); // next power of 2 if (mp2 <= 64) { // m in [1, 64] @@ -112,14 +133,16 @@ void cutlass_scaled_mm_sm90_fp8_epilogue(paddle::Tensor &out, if (out.dtype() == paddle::DataType::BFLOAT16) { return cutlass_gemm_sm90_fp8_dispatch( + cutlass::bfloat16_t, + Epilogue>( out, a, b, std::forward(epilogue_args)...); } else { PD_CHECK(out.dtype() == paddle::DataType::FLOAT16); return cutlass_gemm_sm90_fp8_dispatch( + cutlass::half_t, + Epilogue>( out, a, b, std::forward(epilogue_args)...); } } -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8.cu b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8.cu index 633f76fd88..5256b27c14 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8.cu @@ -1,4 +1,5 @@ -// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu +// adapted from: +// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu // clang-format will break include orders // clang-format off @@ -9,7 +10,8 @@ namespace fastdeploy { -void cutlass_scaled_mm_sm90_int8(paddle::Tensor &out, paddle::Tensor const &a, +void cutlass_scaled_mm_sm90_int8(paddle::Tensor &out, + paddle::Tensor const &a, paddle::Tensor const &b, paddle::Tensor const &a_scales, paddle::Tensor const &b_scales, @@ -17,7 +19,8 @@ void cutlass_scaled_mm_sm90_int8(paddle::Tensor &out, paddle::Tensor const &a, PD_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); if (bias) { PD_CHECK(bias->dtype() == out.dtype(), - "currently bias dtype must match output dtype ", out.dtype()); + "currently bias dtype must match output dtype ", + out.dtype()); return cutlass_scaled_mm_sm90_int8_epilogue( out, a, b, a_scales, b_scales, *bias); } else { @@ -26,4 +29,4 @@ void cutlass_scaled_mm_sm90_int8(paddle::Tensor &out, paddle::Tensor const &a, } } -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh index df63de0fa6..1b14e1b749 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh +++ b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh @@ -1,4 +1,5 @@ -// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh +// adapted from: +// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh #pragma once @@ -17,8 +18,10 @@ namespace fastdeploy { using c3x::cutlass_gemm_caller; -template typename Epilogue> +template + typename Epilogue> struct sm90_int8_config_default { // For M > 128 and any N static_assert(std::is_same()); @@ -27,13 +30,19 @@ struct sm90_int8_config_default { using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; + using Cutlass3xGemm = cutlass_3x_gemm; }; -template typename Epilogue> +template + typename Epilogue> struct sm90_int8_config_M128 { // For M in (64, 128] and any N static_assert(std::is_same()); @@ -42,13 +51,19 @@ struct sm90_int8_config_M128 { using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_64, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; + using Cutlass3xGemm = cutlass_3x_gemm; }; -template typename Epilogue> +template + typename Epilogue> struct sm90_int8_config_M64 { // For M in (32, 64] and any N static_assert(std::is_same()); @@ -56,13 +71,19 @@ struct sm90_int8_config_M64 { using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_64, _64, _256>; using ClusterShape = Shape<_1, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; + using Cutlass3xGemm = cutlass_3x_gemm; }; -template typename Epilogue> +template + typename Epilogue> struct sm90_int8_config_M32_NBig { // For M in [1, 32] and N >= 8192 static_assert(std::is_same()); @@ -70,13 +91,19 @@ struct sm90_int8_config_M32_NBig { using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_64, _128, _256>; using ClusterShape = Shape<_1, _4, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; + using Cutlass3xGemm = cutlass_3x_gemm; }; -template typename Epilogue> +template + typename Epilogue> struct sm90_int8_config_M32_NSmall { // For M in [1, 32] and N < 8192 static_assert(std::is_same()); @@ -84,13 +111,19 @@ struct sm90_int8_config_M32_NSmall { using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_64, _64, _256>; using ClusterShape = Shape<_1, _8, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; + using Cutlass3xGemm = cutlass_3x_gemm; }; -template typename Epilogue, +template + typename Epilogue, typename... EpilogueArgs> inline void cutlass_gemm_sm90_int8_dispatch(paddle::Tensor &out, paddle::Tensor const &a, @@ -101,25 +134,25 @@ inline void cutlass_gemm_sm90_int8_dispatch(paddle::Tensor &out, PD_CHECK(b.dtype() == paddle::DataType::INT8); using Cutlass3xGemmDefault = - typename sm90_int8_config_default::Cutlass3xGemm; + typename sm90_int8_config_default:: + Cutlass3xGemm; using Cutlass3xGemmM128 = typename sm90_int8_config_M128::Cutlass3xGemm; using Cutlass3xGemmM64 = typename sm90_int8_config_M64::Cutlass3xGemm; using Cutlass3xGemmM32NBig = - typename sm90_int8_config_M32_NBig::Cutlass3xGemm; + typename sm90_int8_config_M32_NBig:: + Cutlass3xGemm; using Cutlass3xGemmM32NSmall = - typename sm90_int8_config_M32_NSmall::Cutlass3xGemm; + typename sm90_int8_config_M32_NSmall:: + Cutlass3xGemm; uint32_t const n = out.dims()[1]; bool const is_small_n = n < 8192; uint32_t const m = a.dims()[0]; uint32_t const mp2 = - std::max(static_cast(32), next_pow_2(m)); // next power of 2 + std::max(static_cast(32), next_pow_2(m)); // next power of 2 if (mp2 <= 32) { // m in [1, 32] @@ -155,7 +188,8 @@ void cutlass_scaled_mm_sm90_int8_epilogue(paddle::Tensor &out, PD_CHECK(b.dtype() == paddle::DataType::INT8); if (out.dtype() == paddle::DataType::BFLOAT16) { - return cutlass_gemm_sm90_int8_dispatch( out, a, b, std::forward(epilogue_args)...); } else { @@ -165,4 +199,4 @@ void cutlass_scaled_mm_sm90_int8_epilogue(paddle::Tensor &out, } } -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/custom_ops/gpu_ops/cutlass_kernels/w8a8/scaled_mm_c2x.cu b/custom_ops/gpu_ops/cutlass_kernels/w8a8/scaled_mm_c2x.cu index 55015ea3e9..e3ce7e1fcb 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w8a8/scaled_mm_c2x.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/w8a8/scaled_mm_c2x.cu @@ -1,4 +1,5 @@ -// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu +// adapted from: +// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu #include "helper.h" #include @@ -20,7 +21,8 @@ using namespace fastdeploy; template