mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 17:11:21 +08:00
[LLM] First commit the llm deployment code
This commit is contained in:
@@ -0,0 +1,190 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/half.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
|
||||
#include "cutlass_extensions/arch/mma.h"
|
||||
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
template <typename TypeA, typename TypeB, typename arch, typename Enable = void>
|
||||
struct MixedGemmArchTraits
|
||||
{
|
||||
static_assert(dependent_false<arch>, "Unrecognised parameterization");
|
||||
};
|
||||
|
||||
template <typename Arch>
|
||||
struct MixedGemmArchTraits<float, float, Arch>
|
||||
{
|
||||
static constexpr int Stages = 2;
|
||||
using OperatorClass = cutlass::arch::OpClassSimt;
|
||||
using AccType = float;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
|
||||
static constexpr int ElementsPerAccessA = 1;
|
||||
static constexpr int ElementsPerAccessB = 1;
|
||||
static constexpr int ElementsPerAccessC = 1;
|
||||
static constexpr int ThreadblockK = 8;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
|
||||
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
// ========================= Volta Traits ===========================
|
||||
// Volta will always dequantize after the global memory load.
|
||||
// This will instantiate any HMMA tensorcore kernels for Volta.
|
||||
// Note that volta does not have native bfloat support so weights and activations will be casted to fp16
|
||||
// and compute will happen in fp16 then will be converted for bf16 output.
|
||||
template <typename TypeA, typename TypeB>
|
||||
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm70,
|
||||
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|
||||
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type>
|
||||
{
|
||||
private:
|
||||
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm70>;
|
||||
|
||||
public:
|
||||
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
|
||||
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using AccType = float;
|
||||
using LayoutB = typename LayoutDetails::Layout;
|
||||
|
||||
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
|
||||
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
|
||||
|
||||
using Operator = typename LayoutDetails::Operator;
|
||||
};
|
||||
|
||||
// ======================= Turing Traits ==============================
|
||||
// Note that turing does not have native bfloat support so weights and activations will be casted to fp16
|
||||
// and compute will happen in fp16 then will be converted for bf16 output.
|
||||
template <typename TypeA, typename TypeB>
|
||||
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm75,
|
||||
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|
||||
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type>
|
||||
{
|
||||
private:
|
||||
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm75>;
|
||||
|
||||
public:
|
||||
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
|
||||
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using AccType = float;
|
||||
using LayoutB = typename LayoutDetails::Layout;
|
||||
|
||||
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
|
||||
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using Operator = typename LayoutDetails::Operator;
|
||||
};
|
||||
|
||||
// ======================= Ampere Traits ==============================
|
||||
template <typename TypeA, typename TypeB>
|
||||
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm80,
|
||||
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|
||||
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type>
|
||||
{
|
||||
private:
|
||||
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm80>;
|
||||
|
||||
public:
|
||||
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
|
||||
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using AccType = float;
|
||||
using LayoutB = typename LayoutDetails::Layout;
|
||||
|
||||
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
|
||||
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using Operator = typename LayoutDetails::Operator;
|
||||
};
|
||||
|
||||
// ======================= Ada Traits ==============================
|
||||
template <typename TypeA, typename TypeB>
|
||||
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm89,
|
||||
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|
||||
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type>
|
||||
{
|
||||
private:
|
||||
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm89>;
|
||||
|
||||
public:
|
||||
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
|
||||
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using AccType = float;
|
||||
using LayoutB = typename LayoutDetails::Layout;
|
||||
|
||||
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
|
||||
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits<TypeA>::value>;
|
||||
|
||||
using Operator = typename LayoutDetails::Operator;
|
||||
};
|
||||
|
||||
// FP8 A/B = fp8, C/D = fp32
|
||||
template <typename TypeA, typename TypeB>
|
||||
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm89,
|
||||
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::float_e4m3_t>::value
|
||||
|| cutlass::platform::is_same<TypeA, cutlass::float_e5m2_t>::value>::type>
|
||||
{
|
||||
private:
|
||||
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm89>;
|
||||
|
||||
public:
|
||||
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
|
||||
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using AccType = float;
|
||||
// be careful, TypeC should align with HopperGroupedGemmInput::OutputTypeAdaptor_t<TypeA>
|
||||
using TypeC = __nv_bfloat16;
|
||||
using LayoutB = typename LayoutDetails::Layout;
|
||||
|
||||
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
|
||||
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeC>::value;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits<TypeA>::value>;
|
||||
|
||||
using Operator = typename LayoutDetails::Operator;
|
||||
};
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,57 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
template <typename arch>
|
||||
struct Int8GemmArchTraits
|
||||
{
|
||||
using OperatorClass = cutlass::arch::OpClassSimt;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
|
||||
};
|
||||
|
||||
// ======================= Turing Traits ==============================
|
||||
template <>
|
||||
struct Int8GemmArchTraits<cutlass::arch::Sm75>
|
||||
{
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
};
|
||||
|
||||
// ======================= Ampere Traits ==============================
|
||||
template <>
|
||||
struct Int8GemmArchTraits<cutlass::arch::Sm80>
|
||||
{
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
};
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,568 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail
|
||||
{
|
||||
template <typename>
|
||||
inline constexpr bool dependent_false_v = false;
|
||||
}
|
||||
|
||||
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
typename KernelArch, ///! The Architecture this kernel is compiled for. Used since SIMT kernels lose top-level
|
||||
/// arch.
|
||||
bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled.
|
||||
>
|
||||
struct GemmFpAIntB
|
||||
{
|
||||
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueOutputOp = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
static bool const kSplitKSerial = SplitKSerial;
|
||||
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename Mma::LayoutC;
|
||||
using ElementScale = ElementC;
|
||||
|
||||
static ComplexTransform const kTransformA = Mma::kTransformA;
|
||||
static ComplexTransform const kTransformB = Mma::kTransformA;
|
||||
|
||||
// Type definitions about the mainloop.
|
||||
using Operator = typename Mma::Operator;
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename Mma::Operator::Shape;
|
||||
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
|
||||
|
||||
/// Parameters structure
|
||||
struct Arguments
|
||||
{
|
||||
GemmUniversalMode mode = GemmUniversalMode::kGemm;
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
int group_size;
|
||||
typename Mma::IteratorA::TensorRef ref_A;
|
||||
typename Mma::IteratorB::TensorRef ref_B;
|
||||
typename Mma::IteratorScale::TensorRef ref_scale;
|
||||
typename Mma::IteratorScale::TensorRef ref_zero;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D;
|
||||
|
||||
// Control serial split-k
|
||||
int batch_count;
|
||||
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
// For gather+scatter operations
|
||||
int const* gather_A_indices;
|
||||
int const* gather_B_indices;
|
||||
int const* scatter_D_indices;
|
||||
|
||||
// Included so we can use Gemm Universal
|
||||
int batch_stride_D = 0;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments() {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(cutlass::gemm::GemmCoord const& problem_size, int const group_size,
|
||||
typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B,
|
||||
typename Mma::IteratorScale::TensorRef ref_scale, typename Mma::IteratorScale::TensorRef ref_zero,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D, int serial_split_k_factor,
|
||||
typename EpilogueOutputOp::Params output_op = typename EpilogueOutputOp::Params(),
|
||||
int const* gather_A_indices = nullptr, int const* gather_B_indices = nullptr,
|
||||
int const* scatter_D_indices = nullptr)
|
||||
: problem_size(problem_size)
|
||||
, group_size(group_size)
|
||||
, ref_A(ref_A)
|
||||
, ref_B(ref_B)
|
||||
, ref_scale(ref_scale)
|
||||
, ref_zero(ref_zero)
|
||||
, ref_C(ref_C)
|
||||
, ref_D(ref_D)
|
||||
, batch_count(serial_split_k_factor)
|
||||
, output_op(output_op)
|
||||
, gather_A_indices(gather_A_indices)
|
||||
, gather_B_indices(gather_B_indices)
|
||||
, scatter_D_indices(scatter_D_indices)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
/// Parameters structure
|
||||
struct Params
|
||||
{
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
int group_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
typename Mma::IteratorA::Params params_A;
|
||||
typename Mma::IteratorA::TensorRef ref_A;
|
||||
typename Mma::IteratorB::Params params_B;
|
||||
typename Mma::IteratorB::TensorRef ref_B;
|
||||
typename Mma::IteratorScale::Params params_scale;
|
||||
typename Mma::IteratorScale::TensorRef ref_scale;
|
||||
typename Mma::IteratorScale::TensorRef ref_zero;
|
||||
typename Epilogue::OutputTileIterator::Params params_C;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C;
|
||||
typename Epilogue::OutputTileIterator::Params params_D;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D;
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
int* semaphore;
|
||||
int gemm_k_size;
|
||||
// For gather+scatter operations
|
||||
int const* gather_A_indices;
|
||||
int const* gather_B_indices;
|
||||
int const* scatter_D_indices;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params()
|
||||
: swizzle_log_tile(0)
|
||||
, semaphore(0)
|
||||
, gemm_k_size(0)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape, int const gemm_k_size,
|
||||
void* workspace = nullptr)
|
||||
: problem_size(args.problem_size)
|
||||
, group_size(args.group_size)
|
||||
, grid_tiled_shape(grid_tiled_shape)
|
||||
, swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape))
|
||||
, params_A(args.ref_A.layout())
|
||||
, ref_A(args.ref_A)
|
||||
, params_B(args.ref_B.layout())
|
||||
, ref_B(args.ref_B)
|
||||
, params_scale(args.ref_scale.layout())
|
||||
, ref_scale(args.ref_scale)
|
||||
, ref_zero(args.ref_zero)
|
||||
, params_C(args.ref_C.layout())
|
||||
, ref_C(args.ref_C)
|
||||
, params_D(args.ref_D.layout())
|
||||
, ref_D(args.ref_D)
|
||||
, output_op(args.output_op)
|
||||
, semaphore(static_cast<int*>(workspace))
|
||||
, gemm_k_size(gemm_k_size)
|
||||
, gather_A_indices(args.gather_A_indices)
|
||||
, gather_B_indices(args.gather_B_indices)
|
||||
, scatter_D_indices(args.scatter_D_indices)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage
|
||||
{
|
||||
typename Mma::SharedStorage main_loop;
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmFpAIntB() {}
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(Arguments const& args)
|
||||
{
|
||||
static int const kAlignmentA
|
||||
= (platform::is_same<typename Mma::IteratorA::Layout, layout::ColumnMajorInterleaved<32>>::value) ? 32
|
||||
: (platform::is_same<typename Mma::IteratorA::Layout, layout::ColumnMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB
|
||||
= (platform::is_same<typename Mma::IteratorB::Layout, layout::RowMajorInterleaved<32>>::value) ? 32
|
||||
: (platform::is_same<typename Mma::IteratorB::Layout, layout::RowMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Mma::IteratorB::AccessType::kElements;
|
||||
|
||||
static int const kAlignmentScale = Mma::IteratorScale::AccessType::kElements;
|
||||
|
||||
static int const kAlignmentC = (platform::is_same<typename Epilogue::OutputTileIterator::Layout,
|
||||
layout::ColumnMajorInterleaved<32>>::value)
|
||||
? 32
|
||||
: (platform::is_same<typename Epilogue::OutputTileIterator::Layout,
|
||||
layout::ColumnMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
if (!TensorRef_aligned(args.ref_A, kAlignmentA))
|
||||
{
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(args.ref_B, kAlignmentB))
|
||||
{
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(args.ref_scale, kAlignmentScale))
|
||||
{
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(args.ref_zero, kAlignmentScale))
|
||||
{
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(args.ref_C, kAlignmentC))
|
||||
{
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(args.ref_D, kAlignmentC))
|
||||
{
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!args.ref_scale.good())
|
||||
{
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
if constexpr (hasZero(Mma::QuantOp))
|
||||
{
|
||||
if (!args.ref_zero.good())
|
||||
{
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (args.ref_zero.good())
|
||||
{
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (isFinegrained(Mma::QuantOp))
|
||||
{
|
||||
if (args.group_size != 64 && args.group_size != 128)
|
||||
{
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape)
|
||||
{
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Initializes the fine grained scale+bias iterator. Needed since the fine grained iterator
|
||||
// has a different constructor signature than a regular cutlass iterator
|
||||
template <typename IteratorScale, WeightOnlyQuantOp op, std::enable_if_t<isFinegrained(op), bool> = true>
|
||||
CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params,
|
||||
typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero,
|
||||
typename IteratorScale::TensorCoord extent, int thread_id,
|
||||
typename IteratorScale::TensorCoord const& threadblock_offset, int group_size)
|
||||
{
|
||||
|
||||
return IteratorScale(params, pointer_scale, pointer_zero, extent, thread_id, threadblock_offset, group_size);
|
||||
}
|
||||
|
||||
template <typename IteratorScale, WeightOnlyQuantOp op, std::enable_if_t<!isFinegrained(op), bool> = true>
|
||||
CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params,
|
||||
typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero,
|
||||
typename IteratorScale::TensorCoord extent, int thread_id,
|
||||
typename IteratorScale::TensorCoord const& threadblock_offset, int group_size)
|
||||
{
|
||||
|
||||
return IteratorScale(params, pointer_scale, extent, thread_id, threadblock_offset);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void run_kernel_(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
static_assert(platform::is_same<LayoutB, layout::RowMajor>::value && kInterleave == 1
|
||||
|| platform::is_same<LayoutB, layout::ColumnMajor>::value && kInterleave >= 1,
|
||||
"B must be row major/col major OR col major interleaved.");
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m()
|
||||
|| params.grid_tiled_shape.n() <= threadblock_tile_offset.n())
|
||||
{
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
threadblock_tile_offset.k() * params.gemm_k_size,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave,
|
||||
threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave};
|
||||
|
||||
typename MatrixCoord::Index fg_row_offset = threadblock_tile_offset.k() * params.gemm_k_size / 64;
|
||||
typename MatrixCoord::Index scale_row_offset = isFinegrained(Mma::QuantOp) ? fg_row_offset : 0;
|
||||
cutlass::MatrixCoord tb_offset_scale{scale_row_offset, threadblock_tile_offset.n() * Mma::Shape::kN};
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size);
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(params.params_A, params.ref_A.data(),
|
||||
{params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A, params.gather_A_indices);
|
||||
|
||||
typename Mma::IteratorB iterator_B(params.params_B, params.ref_B.data(),
|
||||
{problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, thread_idx, tb_offset_B,
|
||||
params.gather_B_indices);
|
||||
|
||||
typename MatrixCoord::Index scale_row_extent = isFinegrained(Mma::QuantOp) ? problem_size_k / 64 : 1;
|
||||
typename Mma::IteratorScale iterator_scale = initialize_scale<typename Mma::IteratorScale, Mma::QuantOp>(
|
||||
params.params_scale, params.ref_scale.data(), params.ref_zero.data(),
|
||||
{scale_row_extent, params.problem_size.n()}, thread_idx, tb_offset_scale, params.group_size);
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
if (!kSplitKSerial || gemm_k_iterations > 0)
|
||||
{
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators);
|
||||
}
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
EpilogueOutputOp output_op(params.output_op);
|
||||
|
||||
//
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN);
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
// Construct the semaphore.
|
||||
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
||||
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
|
||||
{
|
||||
|
||||
// Fetch the synchronization lock initially but do not block.
|
||||
semaphore.fetch();
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
}
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C(params.params_C, params.ref_C.data(), params.problem_size.mn(),
|
||||
thread_idx, threadblock_offset, params.scatter_D_indices);
|
||||
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D(params.params_D, params.ref_D.data(), params.problem_size.mn(),
|
||||
thread_idx, threadblock_offset, params.scatter_D_indices);
|
||||
|
||||
Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Wait on the semaphore - this latency may have been covered by iterator construction
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
|
||||
{
|
||||
|
||||
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
||||
if (threadblock_tile_offset.k())
|
||||
{
|
||||
iterator_C = iterator_D;
|
||||
}
|
||||
|
||||
semaphore.wait(threadblock_tile_offset.k());
|
||||
}
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(output_op, iterator_D, accumulators, iterator_C);
|
||||
|
||||
//
|
||||
// Release the semaphore
|
||||
//
|
||||
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
|
||||
{
|
||||
|
||||
int lock = 0;
|
||||
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1)
|
||||
{
|
||||
|
||||
// The final threadblock resets the semaphore for subsequent grids.
|
||||
lock = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Otherwise, the semaphore is incremented
|
||||
lock = threadblock_tile_offset.k() + 1;
|
||||
}
|
||||
|
||||
semaphore.release(lock);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename CompilationArch>
|
||||
CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
if constexpr (platform::is_same<KernelArch, CompilationArch>::value)
|
||||
{
|
||||
run_kernel_(params, shared_storage);
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
|
||||
to the ArchTag of the cutlass kernel operator.
|
||||
*/
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#if (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750)
|
||||
run_kernel<arch::Sm70>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
|
||||
run_kernel<arch::Sm75>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890)
|
||||
run_kernel<arch::Sm80>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ == 890)
|
||||
run_kernel<arch::Sm89>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 900)
|
||||
CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels.
|
||||
#else
|
||||
static_assert(
|
||||
false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels.");
|
||||
#endif
|
||||
#else
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,93 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Scheduler for grouped GEMM
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/kernel/moe_problem_visitor.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/// Visitor class to abstract away the algorithm for iterating over tiles
|
||||
template <typename ThreadblockShape,
|
||||
GroupScheduleMode GroupScheduleMode_,
|
||||
int PrefetchTileCount,
|
||||
int ThreadCount,
|
||||
bool Transposed = false>
|
||||
struct GemmMoeProblemVisitor
|
||||
: public MoeProblemVisitor<
|
||||
detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>,
|
||||
ThreadblockShape,
|
||||
GroupScheduleMode_,
|
||||
PrefetchTileCount,
|
||||
ThreadCount> {
|
||||
static bool const kTransposed = Transposed;
|
||||
|
||||
using ProblemSizeHelper =
|
||||
detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>;
|
||||
using Base = MoeProblemVisitor<ProblemSizeHelper,
|
||||
ThreadblockShape,
|
||||
GroupScheduleMode_,
|
||||
PrefetchTileCount,
|
||||
ThreadCount>;
|
||||
using Params = typename Base::Params;
|
||||
using SharedStorage = typename Base::SharedStorage;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_DEVICE
|
||||
GemmMoeProblemVisitor(Params const& params_,
|
||||
SharedStorage& shared_storage_, // NOLINT
|
||||
int32_t block_idx)
|
||||
: Base(params_, shared_storage_, block_idx) {}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,587 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief GEMM kernel to support the epilogue visitor model
|
||||
for customized softmax partial reduction epilogue fusion.
|
||||
|
||||
This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once
|
||||
its usage has been stabilized. For now, it is included in this example to demonstrate
|
||||
some basic output fusion options.
|
||||
|
||||
original file: 3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
#include "cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h"
|
||||
|
||||
namespace tk = common;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
|
||||
>
|
||||
struct GemmWithEpilogueVisitor
|
||||
{
|
||||
public:
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueVisitor = typename Epilogue::Visitor;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using TensorRefA = TensorRef<ElementA, LayoutA>;
|
||||
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using TensorRefB = TensorRef<ElementB, LayoutB>;
|
||||
|
||||
using ElementCompute = typename EpilogueVisitor::ElementCompute;
|
||||
using LayoutAlphaCol = cutlass::layout::RowMajor;
|
||||
using LayoutAlphaRow = cutlass::layout::ColumnMajor;
|
||||
using TensorRefAlphaCol = TensorRef<ElementCompute, LayoutAlphaCol>;
|
||||
using TensorRefAlphaRow = TensorRef<ElementCompute, LayoutAlphaRow>;
|
||||
|
||||
using ElementC = typename EpilogueVisitor::ElementOutput;
|
||||
using LayoutC = typename Epilogue::Layout;
|
||||
using TensorRefC = TensorRef<ElementC, LayoutC>;
|
||||
|
||||
static ComplexTransform const kTransformA = Mma::kTransformA;
|
||||
static ComplexTransform const kTransformB = Mma::kTransformB;
|
||||
using Operator = typename Mma::Operator;
|
||||
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename Mma::Operator::Shape;
|
||||
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
using EpilogueOutputOp =
|
||||
typename Epilogue::Visitor::ElementwiseFunctor; // Define type so GemmUniversalBase doesn't complain
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
/// Split-K preserves splits that are 128b aligned
|
||||
static int const kSplitKAlignment
|
||||
= const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value);
|
||||
|
||||
//
|
||||
// Structures
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments
|
||||
{
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmUniversalMode mode;
|
||||
GemmCoord problem_size;
|
||||
int batch_count;
|
||||
|
||||
TensorRefA ref_A;
|
||||
TensorRefB ref_B;
|
||||
tk::QuantMode quant_option;
|
||||
TensorRefAlphaCol ref_alpha_col;
|
||||
TensorRefAlphaRow ref_alpha_row;
|
||||
TensorRefC ref_C;
|
||||
TensorRefC ref_D;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Arguments()
|
||||
: mode(GemmUniversalMode::kGemm)
|
||||
, batch_count(1)
|
||||
{
|
||||
}
|
||||
|
||||
/// constructs an arguments structure
|
||||
Arguments(GemmUniversalMode mode_, GemmCoord problem_size_, int batch_count_, TensorRefA ref_A_,
|
||||
TensorRefB ref_B_, tk::QuantMode quant_option_, TensorRefAlphaCol ref_alpha_col_,
|
||||
TensorRefAlphaRow ref_alpha_row_, TensorRefC ref_C_, TensorRefC ref_D_, int64_t batch_stride_A_,
|
||||
int64_t batch_stride_B_, typename EpilogueVisitor::Arguments epilogue_visitor_)
|
||||
: mode(mode_)
|
||||
, problem_size(problem_size_)
|
||||
, batch_count(batch_count_)
|
||||
, ref_A(ref_A_)
|
||||
, ref_B(ref_B_)
|
||||
, quant_option(quant_option_)
|
||||
, ref_alpha_col(ref_alpha_col_)
|
||||
, ref_alpha_row(ref_alpha_row_)
|
||||
, ref_C(ref_C_)
|
||||
, ref_D(ref_D_)
|
||||
, batch_stride_A(batch_stride_A_)
|
||||
, batch_stride_B(batch_stride_B_)
|
||||
, batch_stride_D(0)
|
||||
, epilogue_visitor(epilogue_visitor_)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params
|
||||
{
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
|
||||
typename Mma::IteratorA::Params params_A;
|
||||
typename Mma::IteratorB::Params params_B;
|
||||
typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_col;
|
||||
typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_row;
|
||||
typename EpilogueVisitor::OutputTileIterator::Params params_C;
|
||||
typename EpilogueVisitor::OutputTileIterator::Params params_D;
|
||||
|
||||
GemmUniversalMode mode;
|
||||
int batch_count;
|
||||
int gemm_k_size;
|
||||
|
||||
void* ptr_A;
|
||||
void* ptr_B;
|
||||
tk::QuantMode quant_option;
|
||||
typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_col;
|
||||
typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_row;
|
||||
ElementC* ptr_C;
|
||||
ElementC* ptr_D;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
|
||||
typename EpilogueVisitor::Params epilogue_visitor;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params()
|
||||
: swizzle_log_tile(0)
|
||||
, params_A(0)
|
||||
, params_B(0)
|
||||
, params_alpha_col(0)
|
||||
, params_C(0)
|
||||
, params_D(0)
|
||||
, batch_count(0)
|
||||
, gemm_k_size(0)
|
||||
, mode(cutlass::gemm::GemmUniversalMode::kGemm)
|
||||
, ptr_A(nullptr)
|
||||
, ptr_B(nullptr)
|
||||
, ptr_alpha_col(nullptr)
|
||||
, ptr_alpha_row(nullptr)
|
||||
, ptr_C(nullptr)
|
||||
, ptr_D(nullptr)
|
||||
, batch_stride_A(0)
|
||||
, batch_stride_B(0)
|
||||
{
|
||||
}
|
||||
|
||||
Params(
|
||||
Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape_, int gemm_k_size_, int* workspace_)
|
||||
: problem_size(args.problem_size)
|
||||
, swizzle_log_tile(0)
|
||||
, params_A(args.ref_A.layout())
|
||||
, params_B(args.ref_B.layout())
|
||||
, params_alpha_col(args.ref_alpha_col.layout())
|
||||
, params_alpha_row(args.ref_alpha_col.layout())
|
||||
, params_C(args.ref_C.layout())
|
||||
, params_D(args.ref_D.layout())
|
||||
, mode(args.mode)
|
||||
, batch_count(args.batch_count)
|
||||
, gemm_k_size(args.problem_size.k())
|
||||
, ptr_A(args.ref_A.data())
|
||||
, ptr_B(args.ref_B.data())
|
||||
, quant_option(args.quant_option)
|
||||
, ptr_alpha_col(args.ref_alpha_col.data())
|
||||
, ptr_alpha_row(args.ref_alpha_row.data())
|
||||
, ptr_C(args.ref_C.data())
|
||||
, ptr_D(args.ref_D.data())
|
||||
, batch_stride_A(args.batch_stride_A)
|
||||
, batch_stride_B(args.batch_stride_B)
|
||||
, epilogue_visitor(args.epilogue_visitor)
|
||||
{
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.problem_size,
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count);
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel)
|
||||
{
|
||||
|
||||
int const kAlignK
|
||||
= const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
|
||||
|
||||
gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);
|
||||
|
||||
if (gemm_k_size)
|
||||
{
|
||||
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
|
||||
}
|
||||
}
|
||||
|
||||
swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape);
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage
|
||||
{
|
||||
|
||||
typename Mma::SharedStorage main_loop;
|
||||
|
||||
struct
|
||||
{
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
typename EpilogueVisitor::SharedStorage visitor;
|
||||
} epilogue;
|
||||
};
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
GemmWithEpilogueVisitor() {}
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(cutlass::gemm::GemmCoord const& problem_size)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()");
|
||||
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = EpilogueVisitor::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
bool isAMisaligned = false;
|
||||
bool isBMisaligned = false;
|
||||
bool isCMisaligned = false;
|
||||
|
||||
if (platform::is_same<LayoutA, layout::RowMajor>::value)
|
||||
{
|
||||
isAMisaligned = problem_size.k() % kAlignmentA;
|
||||
}
|
||||
else if (platform::is_same<LayoutA, layout::ColumnMajor>::value)
|
||||
{
|
||||
isAMisaligned = problem_size.m() % kAlignmentA;
|
||||
}
|
||||
else if (platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value)
|
||||
{
|
||||
isAMisaligned = problem_size.k() % kAlignmentA;
|
||||
}
|
||||
|
||||
if (platform::is_same<LayoutB, layout::RowMajor>::value)
|
||||
{
|
||||
isBMisaligned = problem_size.n() % kAlignmentB;
|
||||
}
|
||||
else if (platform::is_same<LayoutB, layout::ColumnMajor>::value)
|
||||
{
|
||||
isBMisaligned = problem_size.k() % kAlignmentB;
|
||||
}
|
||||
else if (platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value)
|
||||
{
|
||||
isBMisaligned = problem_size.k() % kAlignmentB;
|
||||
}
|
||||
|
||||
if (platform::is_same<LayoutC, layout::RowMajor>::value)
|
||||
{
|
||||
isCMisaligned = problem_size.n() % kAlignmentC;
|
||||
}
|
||||
else if (platform::is_same<LayoutC, layout::ColumnMajor>::value)
|
||||
{
|
||||
isCMisaligned = problem_size.m() % kAlignmentC;
|
||||
}
|
||||
else if (platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value)
|
||||
{
|
||||
isCMisaligned = problem_size.n() % kAlignmentC;
|
||||
}
|
||||
|
||||
if (isAMisaligned)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (isBMisaligned)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (isCMisaligned)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" returning kSuccess");
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static Status can_implement(Arguments const& args)
|
||||
{
|
||||
return can_implement(args.problem_size);
|
||||
}
|
||||
|
||||
static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape)
|
||||
{
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#define SPLIT_K_ENABLED 1
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void run_kernel_(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m()
|
||||
|| params.grid_tiled_shape.n() <= threadblock_tile_offset.n())
|
||||
{
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
int offset_k = 0;
|
||||
int problem_size_k = params.problem_size.k();
|
||||
|
||||
ElementA* ptr_A = static_cast<ElementA*>(params.ptr_A);
|
||||
ElementB* ptr_B = static_cast<ElementB*>(params.ptr_B);
|
||||
|
||||
#if SPLIT_K_ENABLED
|
||||
//
|
||||
// Fetch pointers based on mode.
|
||||
//
|
||||
if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel)
|
||||
{
|
||||
|
||||
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k())
|
||||
{
|
||||
|
||||
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
|
||||
}
|
||||
|
||||
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kBatched)
|
||||
{
|
||||
ptr_A += threadblock_tile_offset.k() * params.batch_stride_A;
|
||||
ptr_B += threadblock_tile_offset.k() * params.batch_stride_B;
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kArray)
|
||||
{
|
||||
ptr_A = static_cast<ElementA* const*>(params.ptr_A)[threadblock_tile_offset.k()];
|
||||
ptr_B = static_cast<ElementB* const*>(params.ptr_B)[threadblock_tile_offset.k()];
|
||||
}
|
||||
#endif
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
offset_k,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{offset_k, threadblock_tile_offset.n() * Mma::Shape::kN};
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B);
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
|
||||
|
||||
//
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN);
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
//
|
||||
// Construct the epilogue visitor
|
||||
//
|
||||
|
||||
EpilogueVisitor epilogue_visitor(params.epilogue_visitor, shared_storage.epilogue.visitor,
|
||||
params.problem_size.mn(), thread_idx, warp_idx, lane_idx, params.params_alpha_col, params.params_C,
|
||||
params.params_D, params.quant_option, params.ptr_alpha_row, params.ptr_alpha_col, params.ptr_C,
|
||||
params.ptr_D, threadblock_offset, blockIdx.y * params.problem_size.m());
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm)
|
||||
{
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray)
|
||||
{
|
||||
epilogue_visitor.set_batch_index(threadblock_tile_offset.k());
|
||||
}
|
||||
|
||||
// Construct the epilogue
|
||||
Epilogue epilogue(shared_storage.epilogue.epilogue, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(epilogue_visitor, accumulators);
|
||||
}
|
||||
|
||||
template <typename CompilationArch>
|
||||
CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
if constexpr (platform::is_same<ArchTag, CompilationArch>::value)
|
||||
{
|
||||
run_kernel_(params, shared_storage);
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
|
||||
to the ArchTag of the cutlass kernel operator.
|
||||
*/
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#if (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 720)
|
||||
run_kernel<arch::Sm70>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 720) && (__CUDA_ARCH__ < 750)
|
||||
run_kernel<arch::Sm72>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
|
||||
run_kernel<arch::Sm75>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
|
||||
run_kernel<arch::Sm80>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 900)
|
||||
// TODO - replace with CUTLASS_NOT_IMPLEMENTED() and upgrade to 3.x kernels.
|
||||
run_kernel<arch::Sm80>(params, shared_storage);
|
||||
#else
|
||||
static_assert(
|
||||
false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels.");
|
||||
#endif
|
||||
#else
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,153 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
/*
|
||||
This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is
|
||||
quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices
|
||||
to be consumed by CUTLASS.
|
||||
|
||||
Note that for int4, ThreadBlockK MUST be 64.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/platform/platform.h"
|
||||
|
||||
#include "cutlass_extensions/arch/mma.h"
|
||||
#include "cutlass_extensions/tile_interleaved_layout.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
template <typename TypeA, typename TypeB, typename Arch, typename Enable = void>
|
||||
struct LayoutDetailsB
|
||||
{
|
||||
};
|
||||
|
||||
// Volta specialiations. Volta will dequantize before STS, so we need a different operator
|
||||
template <typename TypeA, typename TypeB>
|
||||
struct LayoutDetailsB<TypeA, TypeB, arch::Sm70>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 8;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks.
|
||||
// TODO - Switch this to column major for weights since gemms should be more performant.
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, half_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, bfloat16_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<bfloat16_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
template <typename TypeA>
|
||||
struct LayoutDetailsB<TypeA, cutlass::float_e4m3_t, arch::Sm89>
|
||||
{
|
||||
static constexpr int ThreadblockK = 64;
|
||||
|
||||
private:
|
||||
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint8_t>::value;
|
||||
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
|
||||
|
||||
public:
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<cutlass::float_e4m3_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
// for fast accumulation
|
||||
// using Operator = cutlass::arch::OpMultiplyAddFastAccum;
|
||||
};
|
||||
|
||||
// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA,
|
||||
// which signals that we want to dequantize after loading from smem.
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB < TypeA,
|
||||
uint8_t, Arch,
|
||||
typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability<90>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
|
||||
private:
|
||||
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint8_t>::value;
|
||||
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
|
||||
|
||||
public:
|
||||
using Layout = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint8_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
|
||||
};
|
||||
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB < TypeA,
|
||||
uint4b_t, Arch,
|
||||
typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability<90>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
|
||||
private:
|
||||
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint4b_t>::value;
|
||||
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
|
||||
|
||||
public:
|
||||
using Layout = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint4b_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
|
||||
};
|
||||
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, uint8_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 90>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, uint4b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 90>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,357 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Base scheduler for grouped problems, using MoE
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/kernel/grouped_problem_visitor.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Visitor class to abstract away the algorithm for iterating over tiles
|
||||
template <typename ProblemSizeHelper, typename ThreadblockShape_>
|
||||
struct BaseMoeProblemVisitor {
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
|
||||
struct ProblemInfo {
|
||||
static int32_t const kNoPrefetchEntry = -1;
|
||||
int32_t problem_idx;
|
||||
int32_t problem_start;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ProblemInfo()
|
||||
: problem_idx(kNoPrefetchEntry), problem_start(kNoPrefetchEntry) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ProblemInfo(int32_t problem_idx_, int32_t problem_start_)
|
||||
: problem_idx(problem_idx_), problem_start(problem_start_) {}
|
||||
};
|
||||
|
||||
struct Params {
|
||||
int64_t const* last_row_for_problem;
|
||||
int64_t total_rows;
|
||||
int64_t gemm_n;
|
||||
int64_t gemm_k;
|
||||
int32_t problem_count;
|
||||
void const* workspace;
|
||||
int32_t tile_count;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params()
|
||||
: last_row_for_problem(nullptr),
|
||||
total_rows(-1),
|
||||
gemm_n(0),
|
||||
gemm_k(0),
|
||||
problem_count(0),
|
||||
workspace(nullptr),
|
||||
tile_count(0) {}
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(int64_t const* last_row_for_problem,
|
||||
int64_t total_rows,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int32_t problem_count,
|
||||
void const* workspace = nullptr,
|
||||
int32_t tile_count = 0)
|
||||
: last_row_for_problem(last_row_for_problem),
|
||||
total_rows(total_rows),
|
||||
gemm_n(gemm_n),
|
||||
gemm_k(gemm_k),
|
||||
problem_count(problem_count),
|
||||
workspace(workspace),
|
||||
tile_count(tile_count) {}
|
||||
};
|
||||
|
||||
Params const& params;
|
||||
int32_t tile_idx;
|
||||
int32_t problem_tile_start;
|
||||
int32_t problem_idx;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_DEVICE
|
||||
BaseMoeProblemVisitor(Params const& params_, int32_t block_idx)
|
||||
: params(params_),
|
||||
tile_idx(block_idx),
|
||||
problem_tile_start(0),
|
||||
problem_idx(0) {}
|
||||
|
||||
/// Get the grid shape
|
||||
CUTLASS_HOST_DEVICE
|
||||
static cutlass::gemm::GemmCoord grid_shape(
|
||||
const cutlass::gemm::GemmCoord& problem) {
|
||||
return cutlass::gemm::GemmCoord(
|
||||
((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM),
|
||||
((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN),
|
||||
1);
|
||||
}
|
||||
|
||||
/// Gets the global tile index
|
||||
CUTLASS_HOST_DEVICE
|
||||
int32_t tile_index() const { return tile_idx; }
|
||||
|
||||
/// Gets the index of the problem
|
||||
CUTLASS_HOST_DEVICE
|
||||
int32_t problem_index() const { return problem_idx; }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
int32_t threadblock_idx() const { return tile_idx - problem_tile_start; }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void advance(int32_t grid_size) { tile_idx += grid_size; }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static void possibly_transpose_problem(
|
||||
cutlass::gemm::GemmCoord& problem) { // NOLINT
|
||||
ProblemSizeHelper::possibly_transpose_problem(problem);
|
||||
}
|
||||
|
||||
/// Returns the problem size for the current problem
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::gemm::GemmCoord problem_size() const {
|
||||
return problem_size(problem_idx);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::gemm::GemmCoord problem_size(int idx) const {
|
||||
|
||||
int64_t gemm_m = 0;
|
||||
|
||||
if (params.total_rows < 0) {
|
||||
const int64_t prev_problem_row = idx == 0 ? 0 : params.last_row_for_problem[idx - 1];
|
||||
const int64_t current_problem_row = params.last_row_for_problem[idx];
|
||||
gemm_m = current_problem_row - prev_problem_row;
|
||||
} else {
|
||||
gemm_m = params.last_row_for_problem[idx];
|
||||
}
|
||||
|
||||
GemmCoord problem(GemmCoord::Index(gemm_m),
|
||||
GemmCoord::Index(params.gemm_n),
|
||||
GemmCoord::Index(params.gemm_k));
|
||||
ProblemSizeHelper::possibly_transpose_problem(problem);
|
||||
return problem;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static int32_t tile_count(const cutlass::gemm::GemmCoord& grid) {
|
||||
return ProblemSizeHelper::tile_count(grid);
|
||||
}
|
||||
|
||||
static int32_t group_tile_count(
|
||||
const cutlass::gemm::GemmCoord* host_problem_sizes_ptr,
|
||||
int32_t problem_count) {
|
||||
int32_t total_tiles = 0;
|
||||
for (int32_t i = 0; i < problem_count; ++i) {
|
||||
auto problem = host_problem_sizes_ptr[i];
|
||||
possibly_transpose_problem(problem);
|
||||
auto grid = grid_shape(problem);
|
||||
total_tiles += tile_count(grid);
|
||||
}
|
||||
|
||||
return total_tiles;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename ProblemSizeHelper,
|
||||
typename ThreadblockShape,
|
||||
GroupScheduleMode GroupScheduleMode_,
|
||||
int PrefetchTileCount,
|
||||
int ThreadCount>
|
||||
struct MoeProblemVisitor;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// ProblemVisitor that performs all scheduling on device
|
||||
//
|
||||
template <typename ProblemSizeHelper,
|
||||
typename ThreadblockShape,
|
||||
int PrefetchTileCount,
|
||||
int ThreadCount>
|
||||
struct MoeProblemVisitor<ProblemSizeHelper,
|
||||
ThreadblockShape,
|
||||
GroupScheduleMode::kDeviceOnly,
|
||||
PrefetchTileCount,
|
||||
ThreadCount>
|
||||
: public BaseMoeProblemVisitor<ProblemSizeHelper, ThreadblockShape> {
|
||||
using Base = BaseMoeProblemVisitor<ProblemSizeHelper, ThreadblockShape>;
|
||||
using Params = typename Base::Params;
|
||||
static int const kThreadCount = ThreadCount;
|
||||
static bool const kRequiresPrecomputation = false;
|
||||
static int const kThreadsPerWarp = 32;
|
||||
|
||||
struct SharedStorage {};
|
||||
|
||||
// Final tile of the problem loaded by this thread. Each thread will hold
|
||||
// a separate value.
|
||||
int32_t problem_ending_tile;
|
||||
|
||||
SharedStorage& shared_storage;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_DEVICE
|
||||
MoeProblemVisitor(Params const& params_,
|
||||
SharedStorage& shared_storage_, // NOLINT
|
||||
int32_t block_idx)
|
||||
: Base(params_, block_idx),
|
||||
problem_ending_tile(0),
|
||||
shared_storage(shared_storage_) {
|
||||
this->problem_idx = -1 * kThreadsPerWarp;
|
||||
this->problem_tile_start = 0;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool next_tile() {
|
||||
// Check whether the tile to compute is within the range of the current
|
||||
// problem.
|
||||
int32_t problem_tile_end = __shfl_sync(
|
||||
0xffffffff, problem_ending_tile, this->problem_idx % kThreadsPerWarp);
|
||||
if (this->tile_idx < problem_tile_end) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check whether the tile to compute is within the current group of problems
|
||||
// fetched by the warp. The last tile for this group is the final tile of
|
||||
// the problem held by the final thread in the warp.
|
||||
int32_t group_tile_end =
|
||||
__shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1);
|
||||
|
||||
// Keep the starting problem for this group in `problem_idx`. This is done
|
||||
// to reduce register pressure. The starting problem for this group is
|
||||
// simply the first problem in the group most recently fetched by the warp.
|
||||
int32_t& group_problem_start = this->problem_idx;
|
||||
group_problem_start =
|
||||
(this->problem_idx / kThreadsPerWarp) * kThreadsPerWarp;
|
||||
|
||||
// Keep the starting tile for this group in `problem_tile_start`. This is
|
||||
// done to reduce register pressure.
|
||||
int32_t& group_tile_start = this->problem_tile_start;
|
||||
|
||||
// Each thread in the warp processes a separate problem to advance until
|
||||
// reaching a problem whose starting tile is less less than tile_idx.
|
||||
while (group_tile_end <= this->tile_idx) {
|
||||
group_problem_start += kThreadsPerWarp;
|
||||
if (group_problem_start > this->params.problem_count) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Since `group_tile_start` is a reference to `this->problem_tile_start`,
|
||||
// this also sets `this->problem_tile_start`. The fact that
|
||||
// `this->problem_tile_start` is also set here is used later in
|
||||
// `next_tile`.
|
||||
group_tile_start = group_tile_end;
|
||||
|
||||
int lane_idx = threadIdx.x % kThreadsPerWarp;
|
||||
int32_t lane_problem = group_problem_start + lane_idx;
|
||||
|
||||
// Compute the number of tiles in the problem assigned to each thread.
|
||||
problem_ending_tile = 0;
|
||||
if (lane_problem < this->params.problem_count) {
|
||||
cutlass::gemm::GemmCoord problem = this->problem_size(lane_problem);
|
||||
cutlass::gemm::GemmCoord grid = this->grid_shape(problem);
|
||||
problem_ending_tile = this->tile_count(grid);
|
||||
}
|
||||
|
||||
// Compute a warp-wide inclusive prefix sum to compute the ending tile
|
||||
// index of each thread's problem.
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 1; i < kThreadsPerWarp; i <<= 1) {
|
||||
int32_t val = __shfl_up_sync(0xffffffff, problem_ending_tile, i);
|
||||
if (lane_idx >= i) {
|
||||
problem_ending_tile += val;
|
||||
}
|
||||
}
|
||||
|
||||
// The total tile count for this group is now in the final position of the
|
||||
// prefix sum
|
||||
int32_t tiles_in_group =
|
||||
__shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1);
|
||||
|
||||
problem_ending_tile += group_tile_start;
|
||||
group_tile_end += tiles_in_group;
|
||||
}
|
||||
|
||||
// The next problem to process is the first one that does not have ending
|
||||
// tile position that is greater than or equal to tile index.
|
||||
int32_t problem_idx_in_group = __popc(
|
||||
__ballot_sync(0xffffffff, problem_ending_tile <= this->tile_idx));
|
||||
|
||||
this->problem_idx = group_problem_start + problem_idx_in_group;
|
||||
|
||||
// The starting tile for this problem is the ending tile of the previous
|
||||
// problem. In cases where `problem_idx_in_group` is the first problem in
|
||||
// the group, we do not need to reset `problem_tile_start`, because it is
|
||||
// set to the previous group's ending tile in the while loop above.
|
||||
if (problem_idx_in_group > 0) {
|
||||
this->problem_tile_start = __shfl_sync(
|
||||
0xffffffff, problem_ending_tile, problem_idx_in_group - 1);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static size_t get_workspace_size(
|
||||
const cutlass::gemm::GemmCoord* host_problem_sizes_ptr,
|
||||
int32_t problem_count,
|
||||
int32_t block_count) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void host_precompute(
|
||||
const cutlass::gemm::GemmCoord* host_problem_sizes_ptr,
|
||||
int32_t problem_count,
|
||||
int32_t block_count,
|
||||
void* host_workspace_ptr) {}
|
||||
};
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,494 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief based on cutlass/include/cutlass/gemm/kernel/gemm_grouped.h
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
|
||||
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
GroupScheduleMode GroupScheduleMode_, ///! Type of scheduling to perform
|
||||
bool Transposed = false>
|
||||
struct SplitkGemmGrouped
|
||||
{
|
||||
public:
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueOutputOp = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_;
|
||||
static bool const kTransposed = Transposed;
|
||||
|
||||
// Optional transpose
|
||||
using MapArguments = kernel::detail::MapArguments<typename Mma::IteratorA::Element, typename Mma::IteratorA::Layout,
|
||||
Mma::kTransformA, Mma::IteratorA::AccessType::kElements, typename Mma::IteratorB::Element,
|
||||
typename Mma::IteratorB::Layout, Mma::kTransformB, Mma::IteratorB::AccessType::kElements, typename Mma::LayoutC,
|
||||
kTransposed>;
|
||||
|
||||
// Public-facing type definitions related to operand element type, layout, and complex conjugate
|
||||
// operation. Must interact with the 'kTransposed' notion.
|
||||
using ElementA = typename MapArguments::ElementA;
|
||||
using LayoutA = typename MapArguments::LayoutA;
|
||||
using ElementB = typename MapArguments::ElementB;
|
||||
using LayoutB = typename MapArguments::LayoutB;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename MapArguments::LayoutC;
|
||||
|
||||
using ElementFinalOutput = typename MapArguments::ElementA;
|
||||
|
||||
static ComplexTransform const kTransformA = MapArguments::kTransformA;
|
||||
static ComplexTransform const kTransformB = MapArguments::kTransformB;
|
||||
|
||||
// Type definitions about the mainloop.
|
||||
using Operator = typename Mma::Operator;
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename Mma::Operator::Shape;
|
||||
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static int const kAlignmentA = MapArguments::kAlignmentA;
|
||||
static int const kAlignmentB = MapArguments::kAlignmentB;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
using ProblemVisitor
|
||||
= GemmGroupedProblemVisitor<ThreadblockShape, kGroupScheduleMode, kThreadCount, kThreadCount, kTransposed>;
|
||||
|
||||
//
|
||||
// Structures
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments
|
||||
{
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmCoord* problem_sizes;
|
||||
int problem_count;
|
||||
int threadblock_count;
|
||||
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
ElementA** ptr_A;
|
||||
ElementB** ptr_B;
|
||||
ElementFinalOutput** ptr_C;
|
||||
ElementFinalOutput** ptr_D;
|
||||
|
||||
typename LayoutA::Stride::LongIndex* lda;
|
||||
typename LayoutB::Stride::LongIndex* ldb;
|
||||
typename LayoutC::Stride::LongIndex* ldc;
|
||||
typename LayoutC::Stride::LongIndex* ldd;
|
||||
|
||||
// Only used by device-level operator
|
||||
GemmCoord* host_problem_sizes;
|
||||
|
||||
// splitK
|
||||
int split_k_slices;
|
||||
int64_t* splitk_buffer_offsets;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments()
|
||||
: problem_count(0)
|
||||
, threadblock_count(0)
|
||||
, ptr_A(nullptr)
|
||||
, ptr_B(nullptr)
|
||||
, ptr_C(nullptr)
|
||||
, ptr_D(nullptr)
|
||||
, lda(nullptr)
|
||||
, ldb(nullptr)
|
||||
, ldc(nullptr)
|
||||
, ldd(nullptr)
|
||||
, host_problem_sizes(nullptr)
|
||||
, split_k_slices(1)
|
||||
, splitk_buffer_offsets(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(GemmCoord* problem_sizes, int problem_count, int threadblock_count,
|
||||
typename EpilogueOutputOp::Params output_op, ElementA** ptr_A, ElementB** ptr_B, ElementFinalOutput** ptr_C,
|
||||
ElementFinalOutput** ptr_D, typename LayoutA::Stride::LongIndex* lda,
|
||||
typename LayoutB::Stride::LongIndex* ldb, typename LayoutC::Stride::LongIndex* ldc,
|
||||
typename LayoutC::Stride::LongIndex* ldd, GemmCoord* host_problem_sizes, int split_k_slices,
|
||||
int64_t* splitk_buffer_offsets)
|
||||
: problem_sizes(problem_sizes)
|
||||
, problem_count(problem_count)
|
||||
, threadblock_count(threadblock_count)
|
||||
, output_op(output_op)
|
||||
, ptr_A(ptr_A)
|
||||
, ptr_B(ptr_B)
|
||||
, ptr_C(ptr_C)
|
||||
, ptr_D(ptr_D)
|
||||
, lda(lda)
|
||||
, ldb(ldb)
|
||||
, ldc(ldc)
|
||||
, ldd(ldd)
|
||||
, host_problem_sizes(host_problem_sizes)
|
||||
, split_k_slices(split_k_slices)
|
||||
, splitk_buffer_offsets(splitk_buffer_offsets)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params
|
||||
{
|
||||
|
||||
typename ProblemVisitor::Params problem_visitor;
|
||||
int threadblock_count;
|
||||
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
ElementA** ptr_A;
|
||||
ElementB** ptr_B;
|
||||
ElementFinalOutput** ptr_C;
|
||||
ElementFinalOutput** ptr_D;
|
||||
ElementC* ptr_C_split;
|
||||
ElementC* ptr_D_split;
|
||||
|
||||
typename LayoutA::Stride::LongIndex* lda;
|
||||
typename LayoutB::Stride::LongIndex* ldb;
|
||||
typename LayoutC::Stride::LongIndex* ldc;
|
||||
typename LayoutC::Stride::LongIndex* ldd;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
// splitk
|
||||
GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
int gemm_k_size;
|
||||
GemmCoord* host_problem_sizes;
|
||||
int split_k_slices;
|
||||
int64_t* splitk_buffer_offsets;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params()
|
||||
: ptr_A(nullptr)
|
||||
, ptr_B(nullptr)
|
||||
, ptr_C(nullptr)
|
||||
, ptr_D(nullptr)
|
||||
, ptr_C_split(nullptr)
|
||||
, ptr_D_split(nullptr)
|
||||
, lda(nullptr)
|
||||
, ldb(nullptr)
|
||||
, ldc(nullptr)
|
||||
, ldd(nullptr)
|
||||
, swizzle_log_tile(0)
|
||||
, gemm_k_size(0)
|
||||
, host_problem_sizes(nullptr)
|
||||
, split_k_slices(1)
|
||||
, splitk_buffer_offsets(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0)
|
||||
: problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count)
|
||||
, host_problem_sizes(args.host_problem_sizes)
|
||||
, threadblock_count(args.threadblock_count)
|
||||
, output_op(args.output_op)
|
||||
, ptr_A(args.ptr_A)
|
||||
, ptr_B(args.ptr_B)
|
||||
, ptr_C(args.ptr_C)
|
||||
, ptr_D(args.ptr_D)
|
||||
, ptr_C_split((ElementC*) workspace)
|
||||
, ptr_D_split((ElementC*) workspace)
|
||||
, lda(args.lda)
|
||||
, ldb(args.ldb)
|
||||
, ldc(args.ldc)
|
||||
, ldd(args.ldd)
|
||||
, split_k_slices(args.split_k_slices)
|
||||
, splitk_buffer_offsets(args.splitk_buffer_offsets)
|
||||
{
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.host_problem_sizes[0],
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.split_k_slices);
|
||||
swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape);
|
||||
|
||||
// only support same k
|
||||
int full_gemm_k_iterations = args.host_problem_sizes[0].k() / Mma::Shape::kK;
|
||||
int gemm_k_iterations = full_gemm_k_iterations / grid_tiled_shape.k();
|
||||
|
||||
gemm_k_size = gemm_k_iterations * Mma::Shape::kK;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0)
|
||||
{
|
||||
|
||||
problem_visitor =
|
||||
typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, workspace, tile_count);
|
||||
threadblock_count = args.threadblock_count;
|
||||
output_op = args.output_op;
|
||||
ptr_A = args.ptr_A;
|
||||
ptr_B = args.ptr_B;
|
||||
ptr_C = args.ptr_C;
|
||||
ptr_D = args.ptr_D;
|
||||
ptr_C_split = workspace;
|
||||
ptr_D_split = workspace;
|
||||
|
||||
lda = args.lda;
|
||||
ldb = args.ldb;
|
||||
ldc = args.ldc;
|
||||
ldd = args.ldd;
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
struct SharedStorage
|
||||
{
|
||||
union
|
||||
{
|
||||
typename Mma::SharedStorage main_loop;
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
} kernel;
|
||||
|
||||
// ProblemVisitor shared storage can't be overlapped with others
|
||||
typename ProblemVisitor::SharedStorage problem_visitor;
|
||||
};
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
SplitkGemmGrouped() {}
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(cutlass::gemm::GemmCoord const& problem_size)
|
||||
{
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static Status can_implement(Arguments const& args)
|
||||
{
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
|
||||
//
|
||||
// These types shadow the type-level definitions and support the ability to implement
|
||||
// a 'transposed' GEMM that computes the transposed problems.
|
||||
//
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
|
||||
|
||||
//
|
||||
// Problem visitor.
|
||||
//
|
||||
ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x);
|
||||
|
||||
// Outer 'persistent' loop to iterate over tiles
|
||||
while (problem_visitor.next_tile())
|
||||
{
|
||||
|
||||
GemmCoord problem_size = problem_visitor.problem_size();
|
||||
int32_t problem_idx = problem_visitor.problem_index();
|
||||
int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx());
|
||||
|
||||
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
|
||||
|
||||
// Load element pointers. Exchange pointers and strides if working on the transpose
|
||||
ElementA* ptr_A
|
||||
= reinterpret_cast<ElementA*>((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx]));
|
||||
typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]);
|
||||
|
||||
ElementB* ptr_B
|
||||
= reinterpret_cast<ElementB*>((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx]));
|
||||
typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]);
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_offset(int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM,
|
||||
int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN, 0);
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_offset.m(),
|
||||
threadblock_tile_offset.k() * params.gemm_k_size,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size, threadblock_offset.n()};
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
int problem_size_k;
|
||||
if (threadblock_tile_offset.k() + 1 == params.grid_tiled_shape.k())
|
||||
{
|
||||
problem_size_k = problem_size.k();
|
||||
}
|
||||
else
|
||||
{
|
||||
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
|
||||
}
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size_k}, thread_idx, tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
LayoutB(ldm_B), ptr_B, {problem_size_k, problem_size.n()}, thread_idx, tb_offset_B);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Matrix multiply phase
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Wait for all threads to finish their epilogue phases from the previous tile.
|
||||
__syncthreads();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
EpilogueOutputOp output_op(params.output_op);
|
||||
|
||||
ElementC* ptr_C = params.ptr_C_split;
|
||||
ElementC* ptr_D = params.ptr_D_split;
|
||||
|
||||
LayoutC layout_C(params.ldc[problem_idx]);
|
||||
LayoutC layout_D(params.ldd[problem_idx]);
|
||||
|
||||
typename Epilogue::OutputTileIterator::Params params_C(layout_C);
|
||||
typename Epilogue::OutputTileIterator::Params params_D(layout_D);
|
||||
|
||||
// assume identity swizzle
|
||||
MatrixCoord threadblock_offset_C(threadblock_offset.m(), threadblock_offset.n());
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C(
|
||||
params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset_C);
|
||||
|
||||
iterator_C.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k()
|
||||
+ gridDim.z * params.splitk_buffer_offsets[problem_idx]);
|
||||
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D(
|
||||
params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset_C);
|
||||
iterator_D.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k()
|
||||
+ gridDim.z * params.splitk_buffer_offsets[problem_idx]);
|
||||
|
||||
Epilogue epilogue(shared_storage.kernel.epilogue, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(output_op, iterator_D, accumulators, iterator_C);
|
||||
|
||||
// Next tile
|
||||
problem_visitor.advance(gridDim.x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
Reference in New Issue
Block a user