mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-27 18:51:50 +08:00
@@ -1,6 +1,6 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION &
|
||||
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -19,18 +19,15 @@
|
||||
#include "cutlass_extensions/arch/mma.h"
|
||||
#include "cutlass_extensions/interleaved_numeric_conversion.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// We need to distinguish here, since we want volta support. It is too much effort
|
||||
// to write shared memory iterators that are probably needed for volta to function
|
||||
// properly. As a result, we allow converters both after the LDG (for volta) and after
|
||||
// the LDS for Turing+.
|
||||
// We need to distinguish here, since we want volta support. It is too much
|
||||
// effort to write shared memory iterators that are probably needed for volta to
|
||||
// function properly. As a result, we allow converters both after the LDG (for
|
||||
// volta) and after the LDS for Turing+.
|
||||
template <
|
||||
/// Iterator for B matrix in global memory
|
||||
typename IteratorB,
|
||||
@@ -38,9 +35,7 @@ template <
|
||||
typename MmaOperator,
|
||||
/// Math operation perform by warp level operator
|
||||
typename MathOperator>
|
||||
struct SetConverters
|
||||
{
|
||||
};
|
||||
struct SetConverters {};
|
||||
|
||||
// Dequantize after LDG, so set transforms accordingly
|
||||
template <
|
||||
@@ -48,14 +43,16 @@ template <
|
||||
typename IteratorB,
|
||||
/// Mma Policy
|
||||
typename MmaOperator>
|
||||
struct SetConverters<IteratorB, MmaOperator, arch::OpMultiplyAdd>
|
||||
{
|
||||
using TransformAfterLDG
|
||||
= FastInterleavedAndBiasedNumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
|
||||
typename IteratorB::Element, IteratorB::Fragment::kElements>;
|
||||
struct SetConverters<IteratorB, MmaOperator, arch::OpMultiplyAdd> {
|
||||
using TransformAfterLDG = FastInterleavedAndBiasedNumericArrayConverter<
|
||||
typename MmaOperator::ArchMmaOperator::ElementB,
|
||||
typename IteratorB::Element,
|
||||
IteratorB::Fragment::kElements>;
|
||||
|
||||
using TransformAfterLDS = NumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
|
||||
typename MmaOperator::ArchMmaOperator::ElementB, MmaOperator::FragmentB::kElements>;
|
||||
using TransformAfterLDS =
|
||||
NumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
|
||||
typename MmaOperator::ArchMmaOperator::ElementB,
|
||||
MmaOperator::FragmentB::kElements>;
|
||||
};
|
||||
|
||||
// Dequantize after LDS, so set transforms accordingly
|
||||
@@ -65,14 +62,18 @@ template <
|
||||
typename IteratorB,
|
||||
/// Mma Policy
|
||||
typename MmaOperator>
|
||||
struct SetConverters<IteratorB, MmaOperator, arch::OpMultiplyAddDequantizeInterleavedBToA>
|
||||
{
|
||||
using TransformAfterLDG = NumericArrayConverter<typename IteratorB::Element, typename IteratorB::Element,
|
||||
IteratorB::Fragment::kElements>;
|
||||
struct SetConverters<IteratorB,
|
||||
MmaOperator,
|
||||
arch::OpMultiplyAddDequantizeInterleavedBToA> {
|
||||
using TransformAfterLDG =
|
||||
NumericArrayConverter<typename IteratorB::Element,
|
||||
typename IteratorB::Element,
|
||||
IteratorB::Fragment::kElements>;
|
||||
|
||||
using TransformAfterLDS
|
||||
= FastInterleavedAndBiasedNumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
|
||||
typename TransformAfterLDG::result_type::Element, MmaOperator::FragmentB::kElements>;
|
||||
using TransformAfterLDS = FastInterleavedAndBiasedNumericArrayConverter<
|
||||
typename MmaOperator::ArchMmaOperator::ElementB,
|
||||
typename TransformAfterLDG::result_type::Element,
|
||||
MmaOperator::FragmentB::kElements>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -120,6 +121,6 @@ template <
|
||||
typename Enable = void>
|
||||
struct DqMma;
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
+314
-152
@@ -1,6 +1,6 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION &
|
||||
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -27,49 +27,77 @@
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
|
||||
#include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment,
|
||||
typename Enable = void>
|
||||
template <typename MmaShape,
|
||||
typename Element,
|
||||
typename Layout,
|
||||
WeightOnlyQuantOp QuantOp,
|
||||
int Alignment,
|
||||
typename Enable = void>
|
||||
struct DefaultScaleIteratorsMultistage;
|
||||
|
||||
// Fine grained iterators
|
||||
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
|
||||
struct DefaultScaleIteratorsMultistage<MmaShape, Element, Layout, QuantOp, Alignment,
|
||||
std::enable_if_t<isFinegrained(QuantOp)>>
|
||||
{
|
||||
using IteratorScale
|
||||
= cutlass::transform::threadblock::FineGrainedScaleZeroIterator<cutlass::MatrixShape<1, MmaShape::kN>, Element,
|
||||
Layout, 0, Alignment>;
|
||||
template <typename MmaShape,
|
||||
typename Element,
|
||||
typename Layout,
|
||||
WeightOnlyQuantOp QuantOp,
|
||||
int Alignment>
|
||||
struct DefaultScaleIteratorsMultistage<
|
||||
MmaShape,
|
||||
Element,
|
||||
Layout,
|
||||
QuantOp,
|
||||
Alignment,
|
||||
std::enable_if_t<isFinegrained(QuantOp)>> {
|
||||
using IteratorScale =
|
||||
cutlass::transform::threadblock::FineGrainedScaleZeroIterator<
|
||||
cutlass::MatrixShape<1, MmaShape::kN>,
|
||||
Element,
|
||||
Layout,
|
||||
0,
|
||||
Alignment>;
|
||||
|
||||
using SmemIteratorScale = IteratorScale;
|
||||
using SmemIteratorScale = IteratorScale;
|
||||
};
|
||||
|
||||
// Per column iterators
|
||||
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
|
||||
struct DefaultScaleIteratorsMultistage<MmaShape, Element, Layout, QuantOp, Alignment,
|
||||
std::enable_if_t<!isFinegrained(QuantOp)>>
|
||||
{
|
||||
// ThreadMap for scale iterator
|
||||
static_assert((MmaShape::kN % Alignment) == 0, "");
|
||||
template <typename MmaShape,
|
||||
typename Element,
|
||||
typename Layout,
|
||||
WeightOnlyQuantOp QuantOp,
|
||||
int Alignment>
|
||||
struct DefaultScaleIteratorsMultistage<
|
||||
MmaShape,
|
||||
Element,
|
||||
Layout,
|
||||
QuantOp,
|
||||
Alignment,
|
||||
std::enable_if_t<!isFinegrained(QuantOp)>> {
|
||||
// ThreadMap for scale iterator
|
||||
static_assert((MmaShape::kN % Alignment) == 0, "");
|
||||
|
||||
private:
|
||||
using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaShape::kN, 1>,
|
||||
MmaShape::kN / Alignment, Alignment>;
|
||||
private:
|
||||
using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap<
|
||||
layout::PitchLinearShape<MmaShape::kN, 1>,
|
||||
MmaShape::kN / Alignment,
|
||||
Alignment>;
|
||||
|
||||
public:
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaShape::kN>,
|
||||
Element, Layout, 0, IteratorScaleThreadMap, Alignment>;
|
||||
public:
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<1, MmaShape::kN>,
|
||||
Element,
|
||||
Layout,
|
||||
0,
|
||||
IteratorScaleThreadMap,
|
||||
Alignment>;
|
||||
|
||||
using SmemIteratorScale = IteratorScale;
|
||||
using SmemIteratorScale = IteratorScale;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -111,69 +139,133 @@ template <
|
||||
typename Operator_,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementScale, LayoutScale, kAlignmentScale,
|
||||
ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape,
|
||||
kStages, Operator_, SharedMemoryClear,
|
||||
typename platform::enable_if<(
|
||||
ArchTag::kMinComputeCapability >= 80 && !layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type>
|
||||
{
|
||||
struct DqMma<ElementA,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementScale,
|
||||
LayoutScale,
|
||||
kAlignmentScale,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
kStages,
|
||||
Operator_,
|
||||
SharedMemoryClear,
|
||||
typename platform::enable_if<(
|
||||
ArchTag::kMinComputeCapability >= 80 &&
|
||||
!layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type> {
|
||||
static_assert(platform::is_same<ElementA, half_t>::value ||
|
||||
platform::is_same<ElementA, bfloat16_t>::value ||
|
||||
platform::is_same<ElementA, float_e4m3_t>::value,
|
||||
"Element A must be fp16, fp8 or bf16");
|
||||
|
||||
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value
|
||||
|| platform::is_same<ElementA, float_e4m3_t>::value,
|
||||
"Element A must be fp16, fp8 or bf16");
|
||||
using OperatorInfo = arch::DetagOperator<Operator_>;
|
||||
using Operator = typename OperatorInfo::Operator;
|
||||
static_assert(
|
||||
platform::is_same<Operator,
|
||||
arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
|
||||
"Mma multistage must dequantize after ldsm");
|
||||
|
||||
using OperatorInfo = arch::DetagOperator<Operator_>;
|
||||
using Operator = typename OperatorInfo::Operator;
|
||||
static_assert(platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
|
||||
"Mma multistage must dequantize after ldsm");
|
||||
static_assert(platform::is_same<ElementB, uint8_t>::value ||
|
||||
platform::is_same<ElementB, uint4b_t>::value,
|
||||
"Element B must be uint8 or uint4");
|
||||
|
||||
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
|
||||
"Element B must be uint8 or uint4");
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
// Define the MmaCore components
|
||||
// Mma core does not depend on stages, so pass in at least 3 here to mma
|
||||
// multistage pieces are created
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
OperatorClass,
|
||||
std::max(kStages, 3),
|
||||
Operator,
|
||||
false,
|
||||
CacheOpA,
|
||||
CacheOpB>;
|
||||
|
||||
// Define the MmaCore components
|
||||
// Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
|
||||
ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, std::max(kStages, 3),
|
||||
Operator, false, CacheOpA, CacheOpB>;
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
1,
|
||||
ThreadMapA,
|
||||
AccessTypeA>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, ElementA, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA>;
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
0,
|
||||
ThreadMapB,
|
||||
AccessTypeB>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, ElementB, LayoutB, 0, ThreadMapB,
|
||||
AccessTypeB>;
|
||||
using ScaleIterators =
|
||||
DefaultScaleIteratorsMultistage<typename MmaCore::Shape,
|
||||
ElementScale,
|
||||
LayoutScale,
|
||||
OperatorInfo::QuantOp,
|
||||
kAlignmentScale>;
|
||||
|
||||
using ScaleIterators = DefaultScaleIteratorsMultistage<typename MmaCore::Shape, ElementScale, LayoutScale,
|
||||
OperatorInfo::QuantOp, kAlignmentScale>;
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale = typename ScaleIterators::IteratorScale;
|
||||
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale = typename ScaleIterators::IteratorScale;
|
||||
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
|
||||
|
||||
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
|
||||
using Converter = FastInterleavedAndBiasedNumericArrayConverter<
|
||||
ElementScale,
|
||||
ElementB,
|
||||
MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
|
||||
|
||||
using Converter = FastInterleavedAndBiasedNumericArrayConverter<ElementScale, ElementB,
|
||||
MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, IteratorScale, SmemIteratorScale, ElementAccumulator, layout::RowMajor,
|
||||
typename MmaCore::MmaPolicy, kStages, Converter, OperatorInfo::QuantOp, SharedMemoryClear>;
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage<
|
||||
typename MmaCore::Shape,
|
||||
IteratorA,
|
||||
typename MmaCore::SmemIteratorA,
|
||||
MmaCore::kCacheOpA,
|
||||
IteratorB,
|
||||
typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB,
|
||||
IteratorScale,
|
||||
SmemIteratorScale,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
typename MmaCore::MmaPolicy,
|
||||
kStages,
|
||||
Converter,
|
||||
OperatorInfo::QuantOp,
|
||||
SharedMemoryClear>;
|
||||
};
|
||||
|
||||
// Specialization to handle column major interleave B
|
||||
@@ -214,89 +306,159 @@ template <
|
||||
typename Operator_,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementScale, LayoutScale, kAlignmentScale,
|
||||
ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape,
|
||||
kStages, Operator_, SharedMemoryClear,
|
||||
typename platform::enable_if<(
|
||||
ArchTag::kMinComputeCapability >= 80 && layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type>
|
||||
{
|
||||
struct DqMma<ElementA,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementScale,
|
||||
LayoutScale,
|
||||
kAlignmentScale,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
kStages,
|
||||
Operator_,
|
||||
SharedMemoryClear,
|
||||
typename platform::enable_if<(
|
||||
ArchTag::kMinComputeCapability >= 80 &&
|
||||
layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type> {
|
||||
static_assert(platform::is_same<ElementA, half_t>::value ||
|
||||
platform::is_same<ElementA, bfloat16_t>::value ||
|
||||
platform::is_same<ElementA, float_e4m3_t>::value,
|
||||
"Element A must be fp16, fp8 or bf16");
|
||||
|
||||
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value
|
||||
|| platform::is_same<ElementA, float_e4m3_t>::value,
|
||||
"Element A must be fp16, fp8 or bf16");
|
||||
using OperatorInfo = arch::DetagOperator<Operator_>;
|
||||
using Operator = typename OperatorInfo::Operator;
|
||||
static_assert(
|
||||
platform::is_same<Operator,
|
||||
arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
|
||||
"Mma multistage must dequantize after ldsm");
|
||||
|
||||
using OperatorInfo = arch::DetagOperator<Operator_>;
|
||||
using Operator = typename OperatorInfo::Operator;
|
||||
static_assert(platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
|
||||
"Mma multistage must dequantize after ldsm");
|
||||
static_assert(platform::is_same<ElementB, uint8_t>::value ||
|
||||
platform::is_same<ElementB, uint4b_t>::value,
|
||||
"Element B must be uint8 or uint4");
|
||||
|
||||
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
|
||||
"Element B must be uint8 or uint4");
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
// Define the MmaCore components
|
||||
// Mma core does not depend on stages, so pass in at least 3 here to mma
|
||||
// multistage pieces are created
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
layout::ColumnMajor,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
OperatorClass,
|
||||
std::max(kStages, 3),
|
||||
Operator,
|
||||
false,
|
||||
CacheOpA,
|
||||
CacheOpB>;
|
||||
|
||||
// Define the MmaCore components
|
||||
// Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
|
||||
ElementA, LayoutA, ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, OperatorClass,
|
||||
std::max(kStages, 3), Operator, false, CacheOpA, CacheOpB>;
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
1,
|
||||
ThreadMapA,
|
||||
AccessTypeA>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, ElementA, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA>;
|
||||
private:
|
||||
static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved;
|
||||
static constexpr int RowsPerTile = LayoutB::kRowsPerTile;
|
||||
static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), "");
|
||||
static_assert(RowsPerTile == MmaCore::Shape::kK, "");
|
||||
|
||||
private:
|
||||
static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved;
|
||||
static constexpr int RowsPerTile = LayoutB::kRowsPerTile;
|
||||
static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), "");
|
||||
static_assert(RowsPerTile == MmaCore::Shape::kK, "");
|
||||
using OriginalThreadMap = typename MmaCore::IteratorThreadMapB;
|
||||
using OriginalWarpArrangement =
|
||||
typename OriginalThreadMap::Detail::WarpThreadArrangement;
|
||||
static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), "");
|
||||
|
||||
using OriginalThreadMap = typename MmaCore::IteratorThreadMapB;
|
||||
using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement;
|
||||
static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), "");
|
||||
using GmemIteratorShape =
|
||||
MatrixShape<MmaCore::Shape::kK * ColumnsInterleaved,
|
||||
MmaCore::Shape::kN / ColumnsInterleaved>;
|
||||
using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap<
|
||||
layout::PitchLinearShape<GmemIteratorShape::kRow,
|
||||
GmemIteratorShape::kColumn>,
|
||||
OriginalThreadMap::kThreads,
|
||||
layout::PitchLinearShape<
|
||||
OriginalWarpArrangement::kContiguous * ColumnsInterleaved,
|
||||
OriginalWarpArrangement::kStrided / ColumnsInterleaved>,
|
||||
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
|
||||
|
||||
using GmemIteratorShape
|
||||
= MatrixShape<MmaCore::Shape::kK * ColumnsInterleaved, MmaCore::Shape::kN / ColumnsInterleaved>;
|
||||
using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap<
|
||||
layout::PitchLinearShape<GmemIteratorShape::kRow, GmemIteratorShape::kColumn>, OriginalThreadMap::kThreads,
|
||||
layout::PitchLinearShape<OriginalWarpArrangement::kContiguous * ColumnsInterleaved,
|
||||
OriginalWarpArrangement::kStrided / ColumnsInterleaved>,
|
||||
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
|
||||
public:
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
GmemIteratorShape,
|
||||
ElementB,
|
||||
layout::ColumnMajor,
|
||||
0,
|
||||
GmemThreadMapB,
|
||||
AccessTypeB>;
|
||||
|
||||
public:
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<GmemIteratorShape, ElementB,
|
||||
layout::ColumnMajor, 0, GmemThreadMapB, AccessTypeB>;
|
||||
using ScaleIterators =
|
||||
DefaultScaleIteratorsMultistage<typename MmaCore::Shape,
|
||||
ElementScale,
|
||||
LayoutScale,
|
||||
OperatorInfo::QuantOp,
|
||||
kAlignmentScale>;
|
||||
|
||||
using ScaleIterators = DefaultScaleIteratorsMultistage<typename MmaCore::Shape, ElementScale, LayoutScale,
|
||||
OperatorInfo::QuantOp, kAlignmentScale>;
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale = typename ScaleIterators::IteratorScale;
|
||||
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale = typename ScaleIterators::IteratorScale;
|
||||
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
|
||||
|
||||
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
|
||||
using Converter = FastInterleavedAndBiasedNumericArrayConverter<
|
||||
ElementScale,
|
||||
ElementB,
|
||||
MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
|
||||
|
||||
using Converter = FastInterleavedAndBiasedNumericArrayConverter<ElementScale, ElementB,
|
||||
MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, IteratorScale, SmemIteratorScale, ElementAccumulator, layout::RowMajor,
|
||||
typename MmaCore::MmaPolicy, kStages, Converter, OperatorInfo::QuantOp, SharedMemoryClear>;
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage<
|
||||
typename MmaCore::Shape,
|
||||
IteratorA,
|
||||
typename MmaCore::SmemIteratorA,
|
||||
MmaCore::kCacheOpA,
|
||||
IteratorB,
|
||||
typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB,
|
||||
IteratorScale,
|
||||
SmemIteratorScale,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
typename MmaCore::MmaPolicy,
|
||||
kStages,
|
||||
Converter,
|
||||
OperatorInfo::QuantOp,
|
||||
SharedMemoryClear>;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
+290
-142
@@ -1,6 +1,6 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION &
|
||||
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -27,58 +27,95 @@
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
|
||||
#include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment,
|
||||
typename Enable = void>
|
||||
template <typename MmaShape,
|
||||
typename Element,
|
||||
typename Layout,
|
||||
WeightOnlyQuantOp QuantOp,
|
||||
int Alignment,
|
||||
typename Enable = void>
|
||||
struct DefaultScaleIteratorsPipelined;
|
||||
|
||||
// Fine grained iterators
|
||||
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
|
||||
struct DefaultScaleIteratorsPipelined<MmaShape, Element, Layout, QuantOp, Alignment,
|
||||
std::enable_if_t<isFinegrained(QuantOp)>>
|
||||
{
|
||||
private:
|
||||
using SmemScaleType = half_t;
|
||||
template <typename MmaShape,
|
||||
typename Element,
|
||||
typename Layout,
|
||||
WeightOnlyQuantOp QuantOp,
|
||||
int Alignment>
|
||||
struct DefaultScaleIteratorsPipelined<
|
||||
MmaShape,
|
||||
Element,
|
||||
Layout,
|
||||
QuantOp,
|
||||
Alignment,
|
||||
std::enable_if_t<isFinegrained(QuantOp)>> {
|
||||
private:
|
||||
using SmemScaleType = half_t;
|
||||
|
||||
public:
|
||||
using IteratorScale
|
||||
= cutlass::transform::threadblock::FineGrainedScaleZeroIterator<cutlass::MatrixShape<1, MmaShape::kN>, Element,
|
||||
Layout, 0, Alignment>;
|
||||
public:
|
||||
using IteratorScale =
|
||||
cutlass::transform::threadblock::FineGrainedScaleZeroIterator<
|
||||
cutlass::MatrixShape<1, MmaShape::kN>,
|
||||
Element,
|
||||
Layout,
|
||||
0,
|
||||
Alignment>;
|
||||
|
||||
using SmemIteratorScale
|
||||
= cutlass::transform::threadblock::FineGrainedScaleZeroIterator<cutlass::MatrixShape<1, MmaShape::kN>,
|
||||
SmemScaleType, Layout, 0, Alignment>;
|
||||
using SmemIteratorScale =
|
||||
cutlass::transform::threadblock::FineGrainedScaleZeroIterator<
|
||||
cutlass::MatrixShape<1, MmaShape::kN>,
|
||||
SmemScaleType,
|
||||
Layout,
|
||||
0,
|
||||
Alignment>;
|
||||
};
|
||||
|
||||
// Per column iterators
|
||||
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
|
||||
struct DefaultScaleIteratorsPipelined<MmaShape, Element, Layout, QuantOp, Alignment,
|
||||
std::enable_if_t<!isFinegrained(QuantOp)>>
|
||||
{
|
||||
static_assert((MmaShape::kN % Alignment) == 0, "");
|
||||
template <typename MmaShape,
|
||||
typename Element,
|
||||
typename Layout,
|
||||
WeightOnlyQuantOp QuantOp,
|
||||
int Alignment>
|
||||
struct DefaultScaleIteratorsPipelined<
|
||||
MmaShape,
|
||||
Element,
|
||||
Layout,
|
||||
QuantOp,
|
||||
Alignment,
|
||||
std::enable_if_t<!isFinegrained(QuantOp)>> {
|
||||
static_assert((MmaShape::kN % Alignment) == 0, "");
|
||||
|
||||
private:
|
||||
// ThreadMap for scale iterator
|
||||
using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaShape::kN, 1>,
|
||||
MmaShape::kN / Alignment, Alignment>;
|
||||
using SmemScaleType = half_t;
|
||||
private:
|
||||
// ThreadMap for scale iterator
|
||||
using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap<
|
||||
layout::PitchLinearShape<MmaShape::kN, 1>,
|
||||
MmaShape::kN / Alignment,
|
||||
Alignment>;
|
||||
using SmemScaleType = half_t;
|
||||
|
||||
public:
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaShape::kN>,
|
||||
Element, Layout, 0, IteratorScaleThreadMap, Alignment>;
|
||||
public:
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<1, MmaShape::kN>,
|
||||
Element,
|
||||
Layout,
|
||||
0,
|
||||
IteratorScaleThreadMap,
|
||||
Alignment>;
|
||||
|
||||
using SmemIteratorScale
|
||||
= cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaShape::kN>, SmemScaleType,
|
||||
Layout, 0, IteratorScaleThreadMap, Alignment>;
|
||||
using SmemIteratorScale =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<1, MmaShape::kN>,
|
||||
SmemScaleType,
|
||||
Layout,
|
||||
0,
|
||||
IteratorScaleThreadMap,
|
||||
Alignment>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -116,57 +153,110 @@ template <
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_>
|
||||
struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementScale, LayoutScale, kAlignmentScale,
|
||||
ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2,
|
||||
Operator_, SharedMemoryClearOption::kNone,
|
||||
typename platform::enable_if<(
|
||||
ArchTag::kMinComputeCapability < 80 && !layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type>
|
||||
{
|
||||
struct DqMma<ElementA,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementScale,
|
||||
LayoutScale,
|
||||
kAlignmentScale,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
2,
|
||||
Operator_,
|
||||
SharedMemoryClearOption::kNone,
|
||||
typename platform::enable_if<(
|
||||
ArchTag::kMinComputeCapability < 80 &&
|
||||
!layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type> {
|
||||
static_assert(platform::is_same<ElementA, half_t>::value ||
|
||||
platform::is_same<ElementA, bfloat16_t>::value,
|
||||
"Element A must be fp16 or bf16");
|
||||
|
||||
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
|
||||
"Element A must be fp16 or bf16");
|
||||
static_assert(platform::is_same<ElementB, uint8_t>::value ||
|
||||
platform::is_same<ElementB, uint4b_t>::value,
|
||||
"Element B must be uint8 or uint4");
|
||||
|
||||
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
|
||||
"Element B must be uint8 or uint4");
|
||||
using OperatorInfo = arch::DetagOperator<Operator_>;
|
||||
using Operator = typename OperatorInfo::Operator;
|
||||
static_assert(OperatorInfo::QuantOp ==
|
||||
WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY,
|
||||
"");
|
||||
|
||||
using OperatorInfo = arch::DetagOperator<Operator_>;
|
||||
using Operator = typename OperatorInfo::Operator;
|
||||
static_assert(OperatorInfo::QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, "");
|
||||
static constexpr bool DqAfterLDG =
|
||||
platform::is_same<arch::OpMultiplyAdd, Operator>::value;
|
||||
using MmaCoreElementA = half_t;
|
||||
using MmaCoreElementB = typename platform::
|
||||
conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
|
||||
|
||||
static constexpr bool DqAfterLDG = platform::is_same<arch::OpMultiplyAdd, Operator>::value;
|
||||
using MmaCoreElementA = half_t;
|
||||
using MmaCoreElementB = typename platform::conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
|
||||
// Define the MmaCore components
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
MmaCoreElementA,
|
||||
LayoutA,
|
||||
MmaCoreElementB,
|
||||
LayoutB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
OperatorClass,
|
||||
2,
|
||||
Operator>;
|
||||
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
|
||||
MmaCoreElementA, LayoutA, MmaCoreElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, 2,
|
||||
Operator>;
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
1,
|
||||
typename MmaCore::IteratorThreadMapA,
|
||||
kAlignmentA>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>, ElementA, LayoutA, 1,
|
||||
typename MmaCore::IteratorThreadMapA, kAlignmentA>;
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
0,
|
||||
typename MmaCore::IteratorThreadMapB,
|
||||
kAlignmentB>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>, ElementB, LayoutB, 0,
|
||||
typename MmaCore::IteratorThreadMapB, kAlignmentB>;
|
||||
using ScaleIterators = DefaultScaleIteratorsPipelined<typename MmaCore::Shape,
|
||||
ElementScale,
|
||||
LayoutScale,
|
||||
OperatorInfo::QuantOp,
|
||||
kAlignmentScale>;
|
||||
|
||||
using ScaleIterators = DefaultScaleIteratorsPipelined<typename MmaCore::Shape, ElementScale, LayoutScale,
|
||||
OperatorInfo::QuantOp, kAlignmentScale>;
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale = typename ScaleIterators::IteratorScale;
|
||||
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale = typename ScaleIterators::IteratorScale;
|
||||
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
|
||||
|
||||
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
|
||||
using Converters =
|
||||
SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
|
||||
|
||||
using Converters = SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, IteratorB, typename MmaCore::SmemIteratorB, IteratorScale, SmemIteratorScale,
|
||||
ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, typename Converters::TransformAfterLDG,
|
||||
typename Converters::TransformAfterLDS, OperatorInfo::QuantOp>;
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<
|
||||
typename MmaCore::Shape,
|
||||
IteratorA,
|
||||
typename MmaCore::SmemIteratorA,
|
||||
IteratorB,
|
||||
typename MmaCore::SmemIteratorB,
|
||||
IteratorScale,
|
||||
SmemIteratorScale,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
typename MmaCore::MmaPolicy,
|
||||
typename Converters::TransformAfterLDG,
|
||||
typename Converters::TransformAfterLDS,
|
||||
OperatorInfo::QuantOp>;
|
||||
};
|
||||
|
||||
// Specialization to handle column major interleave B
|
||||
@@ -203,82 +293,140 @@ template <
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_>
|
||||
struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementScale, LayoutScale, kAlignmentScale,
|
||||
ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2,
|
||||
Operator_, SharedMemoryClearOption::kNone,
|
||||
typename platform::enable_if<(
|
||||
ArchTag::kMinComputeCapability < 80 && layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type>
|
||||
{
|
||||
struct DqMma<ElementA,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementScale,
|
||||
LayoutScale,
|
||||
kAlignmentScale,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
2,
|
||||
Operator_,
|
||||
SharedMemoryClearOption::kNone,
|
||||
typename platform::enable_if<(
|
||||
ArchTag::kMinComputeCapability < 80 &&
|
||||
layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type> {
|
||||
static_assert(platform::is_same<ElementA, half_t>::value ||
|
||||
platform::is_same<ElementA, bfloat16_t>::value,
|
||||
"Element A must be fp16 or bf16");
|
||||
|
||||
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
|
||||
"Element A must be fp16 or bf16");
|
||||
static_assert(platform::is_same<ElementB, uint8_t>::value ||
|
||||
platform::is_same<ElementB, uint4b_t>::value,
|
||||
"Element B must be uint8 or uint4");
|
||||
|
||||
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
|
||||
"Element B must be uint8 or uint4");
|
||||
using OperatorInfo = arch::DetagOperator<Operator_>;
|
||||
using Operator = typename OperatorInfo::Operator;
|
||||
|
||||
using OperatorInfo = arch::DetagOperator<Operator_>;
|
||||
using Operator = typename OperatorInfo::Operator;
|
||||
static constexpr bool DqAfterLDG =
|
||||
platform::is_same<arch::OpMultiplyAdd, Operator>::value;
|
||||
using MmaCoreElementA = half_t;
|
||||
using MmaCoreElementB = typename platform::
|
||||
conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
|
||||
|
||||
static constexpr bool DqAfterLDG = platform::is_same<arch::OpMultiplyAdd, Operator>::value;
|
||||
using MmaCoreElementA = half_t;
|
||||
using MmaCoreElementB = typename platform::conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
|
||||
// Define the MmaCore components
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
MmaCoreElementA,
|
||||
LayoutA,
|
||||
MmaCoreElementB,
|
||||
layout::ColumnMajor,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
OperatorClass,
|
||||
2,
|
||||
Operator>;
|
||||
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
|
||||
MmaCoreElementA, LayoutA, MmaCoreElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor,
|
||||
OperatorClass, 2, Operator>;
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
1,
|
||||
typename MmaCore::IteratorThreadMapA,
|
||||
kAlignmentA>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>, ElementA, LayoutA, 1,
|
||||
typename MmaCore::IteratorThreadMapA, kAlignmentA>;
|
||||
private:
|
||||
static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved;
|
||||
static constexpr int RowsPerTile = LayoutB::kRowsPerTile;
|
||||
static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), "");
|
||||
static_assert(RowsPerTile == MmaCore::Shape::kK, "");
|
||||
|
||||
private:
|
||||
static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved;
|
||||
static constexpr int RowsPerTile = LayoutB::kRowsPerTile;
|
||||
static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), "");
|
||||
static_assert(RowsPerTile == MmaCore::Shape::kK, "");
|
||||
using OriginalThreadMap = typename MmaCore::IteratorThreadMapB;
|
||||
using OriginalWarpArrangement =
|
||||
typename OriginalThreadMap::Detail::WarpThreadArrangement;
|
||||
static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), "");
|
||||
|
||||
using OriginalThreadMap = typename MmaCore::IteratorThreadMapB;
|
||||
using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement;
|
||||
static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), "");
|
||||
using GmemIteratorShape =
|
||||
MatrixShape<MmaCore::Shape::kK * ColumnsInterleaved,
|
||||
MmaCore::Shape::kN / ColumnsInterleaved>;
|
||||
using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap<
|
||||
layout::PitchLinearShape<GmemIteratorShape::kRow,
|
||||
GmemIteratorShape::kColumn>,
|
||||
OriginalThreadMap::kThreads,
|
||||
layout::PitchLinearShape<
|
||||
OriginalWarpArrangement::kContiguous * ColumnsInterleaved,
|
||||
OriginalWarpArrangement::kStrided / ColumnsInterleaved>,
|
||||
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
|
||||
|
||||
using GmemIteratorShape
|
||||
= MatrixShape<MmaCore::Shape::kK * ColumnsInterleaved, MmaCore::Shape::kN / ColumnsInterleaved>;
|
||||
using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap<
|
||||
layout::PitchLinearShape<GmemIteratorShape::kRow, GmemIteratorShape::kColumn>, OriginalThreadMap::kThreads,
|
||||
layout::PitchLinearShape<OriginalWarpArrangement::kContiguous * ColumnsInterleaved,
|
||||
OriginalWarpArrangement::kStrided / ColumnsInterleaved>,
|
||||
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
|
||||
public:
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
GmemIteratorShape,
|
||||
ElementB,
|
||||
layout::ColumnMajor,
|
||||
0,
|
||||
GmemThreadMapB,
|
||||
kAlignmentB>;
|
||||
|
||||
public:
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<GmemIteratorShape, ElementB,
|
||||
layout::ColumnMajor, 0, GmemThreadMapB, kAlignmentB>;
|
||||
// ThreadMap for scale iterator
|
||||
static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
|
||||
using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap<
|
||||
layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
|
||||
MmaCore::Shape::kN / kAlignmentScale,
|
||||
kAlignmentScale>;
|
||||
|
||||
// ThreadMap for scale iterator
|
||||
static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
|
||||
using IteratorScaleThreadMap
|
||||
= transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
|
||||
MmaCore::Shape::kN / kAlignmentScale, kAlignmentScale>;
|
||||
using ScaleIterators = DefaultScaleIteratorsPipelined<typename MmaCore::Shape,
|
||||
ElementScale,
|
||||
LayoutScale,
|
||||
OperatorInfo::QuantOp,
|
||||
kAlignmentScale>;
|
||||
|
||||
using ScaleIterators = DefaultScaleIteratorsPipelined<typename MmaCore::Shape, ElementScale, LayoutScale,
|
||||
OperatorInfo::QuantOp, kAlignmentScale>;
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale = typename ScaleIterators::IteratorScale;
|
||||
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale = typename ScaleIterators::IteratorScale;
|
||||
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
|
||||
|
||||
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
|
||||
using Converters =
|
||||
SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
|
||||
|
||||
using Converters = SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, IteratorB, typename MmaCore::SmemIteratorB, IteratorScale, SmemIteratorScale,
|
||||
ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, typename Converters::TransformAfterLDG,
|
||||
typename Converters::TransformAfterLDS, OperatorInfo::QuantOp>;
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<
|
||||
typename MmaCore::Shape,
|
||||
IteratorA,
|
||||
typename MmaCore::SmemIteratorA,
|
||||
IteratorB,
|
||||
typename MmaCore::SmemIteratorB,
|
||||
IteratorScale,
|
||||
SmemIteratorScale,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
typename MmaCore::MmaPolicy,
|
||||
typename Converters::TransformAfterLDG,
|
||||
typename Converters::TransformAfterLDS,
|
||||
OperatorInfo::QuantOp>;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION &
|
||||
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -27,7 +27,8 @@ namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma pipelined (stage=2)
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16
|
||||
/// activation & int8 weight, mma pipelined (stage=2)
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
@@ -49,34 +50,61 @@ template <
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
|
||||
{
|
||||
struct DefaultMma<cutlass::half_t,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
uint8_t,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
2,
|
||||
Operator> {
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
||||
using Mma = DqMma<half_t,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
uint8_t,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
half_t,
|
||||
layout::RowMajor,
|
||||
kAlignmentScale,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
2,
|
||||
Operator>;
|
||||
|
||||
using Mma = DqMma<half_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, half_t, layout::RowMajor,
|
||||
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, 2, Operator>;
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma pipelined (stage=2)
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16
|
||||
/// activation & int4 weight, mma pipelined (stage=2)
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
@@ -98,35 +126,61 @@ template <
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
|
||||
{
|
||||
struct DefaultMma<cutlass::half_t,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
uint4b_t,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
2,
|
||||
Operator> {
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
||||
using Mma = DqMma<half_t,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
uint4b_t,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
half_t,
|
||||
layout::RowMajor,
|
||||
kAlignmentScale,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
2,
|
||||
Operator>;
|
||||
|
||||
using Mma = DqMma<half_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, half_t, layout::RowMajor,
|
||||
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, 2, Operator>;
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma multistage
|
||||
/// (stage>=3)
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16
|
||||
/// activation & int8 weight, mma multistage (stage>=3)
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
@@ -152,36 +206,64 @@ template <
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
struct DefaultMma<cutlass::half_t,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
uint8_t,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
kStages,
|
||||
Operator,
|
||||
false,
|
||||
SharedMemoryClear> {
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
||||
using Mma = DqMma<half_t,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
uint8_t,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
half_t,
|
||||
layout::RowMajor,
|
||||
kAlignmentScale,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
kStages,
|
||||
Operator,
|
||||
SharedMemoryClear>;
|
||||
|
||||
using Mma = DqMma<half_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, half_t, layout::RowMajor,
|
||||
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma multistage
|
||||
/// (stage>=3)
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16
|
||||
/// activation & int4 weight, mma multistage (stage>=3)
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
@@ -207,37 +289,65 @@ template <
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
struct DefaultMma<cutlass::half_t,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
uint4b_t,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
kStages,
|
||||
Operator,
|
||||
false,
|
||||
SharedMemoryClear> {
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
||||
using Mma = DqMma<half_t,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
uint4b_t,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
half_t,
|
||||
layout::RowMajor,
|
||||
kAlignmentScale,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
kStages,
|
||||
Operator,
|
||||
SharedMemoryClear>;
|
||||
|
||||
using Mma = DqMma<half_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, half_t, layout::RowMajor,
|
||||
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp8 activation & int4 weight, mma multistage
|
||||
/// (stage>=3)
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp8 activation
|
||||
/// & int4 weight, mma multistage (stage>=3)
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
@@ -263,36 +373,65 @@ template <
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultMma<cutlass::float_e4m3_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
struct DefaultMma<cutlass::float_e4m3_t,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
uint4b_t,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
kStages,
|
||||
Operator,
|
||||
false,
|
||||
SharedMemoryClear> {
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
||||
using Mma = DqMma<cutlass::float_e4m3_t,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
uint4b_t,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
half_t,
|
||||
layout::RowMajor,
|
||||
kAlignmentScale,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
kStages,
|
||||
Operator,
|
||||
SharedMemoryClear>;
|
||||
|
||||
using Mma = DqMma<cutlass::float_e4m3_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, half_t,
|
||||
layout::RowMajor, kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag,
|
||||
ThreadblockShape, WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on
|
||||
// large tile when not enough shared mem is present to do 3+ stage
|
||||
// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps
|
||||
// avoid reg spills on large tile when not enough shared mem is present to do 3+
|
||||
// stage
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
@@ -318,39 +457,86 @@ template <
|
||||
bool GatherA,
|
||||
/// Gather operand B by using an index array
|
||||
bool GatherB>
|
||||
struct DefaultMma<half_t, LayoutA, kAlignmentA, half_t, LayoutB, kAlignmentB, ElementAccumulator, layout::RowMajor,
|
||||
arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, 2, Operator, false,
|
||||
SharedMemoryClear, GatherA, GatherB>
|
||||
{
|
||||
struct DefaultMma<half_t,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
half_t,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
arch::Sm80,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
2,
|
||||
Operator,
|
||||
false,
|
||||
SharedMemoryClear,
|
||||
GatherA,
|
||||
GatherB> {
|
||||
// Define the MmaCore components
|
||||
// 3 is used on purpose here to trigger components for mma multistage
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
half_t,
|
||||
LayoutA,
|
||||
half_t,
|
||||
LayoutB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
3,
|
||||
Operator>;
|
||||
|
||||
// Define the MmaCore components
|
||||
// 3 is used on purpose here to trigger components for mma multistage
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
|
||||
half_t, LayoutA, half_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator>;
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
half_t,
|
||||
LayoutA,
|
||||
1,
|
||||
ThreadMapA,
|
||||
AccessTypeA,
|
||||
GatherA>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, half_t, LayoutA, 1, ThreadMapA, AccessTypeA,
|
||||
GatherA>;
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
half_t,
|
||||
LayoutB,
|
||||
0,
|
||||
ThreadMapB,
|
||||
AccessTypeB,
|
||||
GatherB>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, half_t, LayoutB, 0, ThreadMapB, AccessTypeB,
|
||||
GatherB>;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma =
|
||||
cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape,
|
||||
IteratorA,
|
||||
typename MmaCore::SmemIteratorA,
|
||||
MmaCore::kCacheOpA,
|
||||
IteratorB,
|
||||
typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
typename MmaCore::MmaPolicy,
|
||||
2>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fbf16 activation & int2 weight, mma multistage
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fbf16
|
||||
/// activation & int2 weight, mma multistage
|
||||
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
@@ -373,26 +559,50 @@ template <
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
|
||||
{
|
||||
private:
|
||||
using Mma = DefaultWint2xMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, 2, Operator>;
|
||||
struct DefaultMma<cutlass::half_t,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
uint2b_t,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
2,
|
||||
Operator> {
|
||||
private:
|
||||
using Mma = DefaultWint2xMma<half_t,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
uint2b_t,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
2,
|
||||
Operator>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
template <
|
||||
@@ -420,29 +630,55 @@ template <
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
private:
|
||||
using Mma = DefaultWint2xMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
|
||||
struct DefaultMma<half_t,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
uint2b_t,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
kStages,
|
||||
Operator,
|
||||
false,
|
||||
SharedMemoryClear> {
|
||||
private:
|
||||
using Mma = DefaultWint2xMma<half_t,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
uint2b_t,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
kStages,
|
||||
Operator,
|
||||
SharedMemoryClear>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION &
|
||||
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -27,7 +27,8 @@ namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & bf16 weight
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), bf16
|
||||
/// activation & bf16 weight
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
@@ -55,40 +56,85 @@ template <
|
||||
bool GatherA,
|
||||
/// Gather operand B by using an index array
|
||||
bool GatherB>
|
||||
struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, bfloat16_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator, false,
|
||||
SharedMemoryClear, GatherA, GatherB>
|
||||
{
|
||||
struct DefaultMma<bfloat16_t,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
bfloat16_t,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
2,
|
||||
Operator,
|
||||
false,
|
||||
SharedMemoryClear,
|
||||
GatherA,
|
||||
GatherB> {
|
||||
private:
|
||||
// Conversions only needed pre-ampere. This will trigger mma pipeline, so we
|
||||
// convert before STS.
|
||||
static constexpr bool arch_has_bf16_mma =
|
||||
ArchTag::kMinComputeCapability >= 80;
|
||||
using MmaElementA = typename platform::
|
||||
conditional<arch_has_bf16_mma, bfloat16_t, half_t>::type;
|
||||
using MmaElementB = typename platform::
|
||||
conditional<arch_has_bf16_mma, bfloat16_t, half_t>::type;
|
||||
|
||||
private:
|
||||
// Conversions only needed pre-ampere. This will trigger mma pipeline, so we convert before STS.
|
||||
static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80;
|
||||
using MmaElementA = typename platform::conditional<arch_has_bf16_mma, bfloat16_t, half_t>::type;
|
||||
using MmaElementB = typename platform::conditional<arch_has_bf16_mma, bfloat16_t, half_t>::type;
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
MmaElementA,
|
||||
LayoutA,
|
||||
MmaElementB,
|
||||
LayoutB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
2,
|
||||
Operator>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, MmaElementA,
|
||||
LayoutA, MmaElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 2, Operator>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,
|
||||
bfloat16_t,
|
||||
LayoutA,
|
||||
1,
|
||||
typename MmaCore::IteratorThreadMapA,
|
||||
kAlignmentA,
|
||||
GatherA>;
|
||||
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>, bfloat16_t, LayoutA, 1,
|
||||
typename MmaCore::IteratorThreadMapA, kAlignmentA, GatherA>;
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>,
|
||||
bfloat16_t,
|
||||
LayoutB,
|
||||
0,
|
||||
typename MmaCore::IteratorThreadMapB,
|
||||
kAlignmentB,
|
||||
GatherB>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>, bfloat16_t, LayoutB, 0,
|
||||
typename MmaCore::IteratorThreadMapB, kAlignmentB, GatherB>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator,
|
||||
layout::RowMajor, typename MmaCore::MmaPolicy>;
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma =
|
||||
cutlass::gemm::threadblock::MmaPipelined<typename MmaCore::Shape,
|
||||
IteratorA,
|
||||
typename MmaCore::SmemIteratorA,
|
||||
IteratorB,
|
||||
typename MmaCore::SmemIteratorB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
typename MmaCore::MmaPolicy>;
|
||||
};
|
||||
|
||||
// bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on
|
||||
// large tile when not enough shared mem is present to do 3+ stage
|
||||
// bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps
|
||||
// avoid reg spills on large tile when not enough shared mem is present to do 3+
|
||||
// stage
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
@@ -114,40 +160,86 @@ template <
|
||||
bool GatherA,
|
||||
/// Gather operand B by using an index array
|
||||
bool GatherB>
|
||||
struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, bfloat16_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, 2, Operator,
|
||||
false, SharedMemoryClear, GatherA, GatherB>
|
||||
{
|
||||
struct DefaultMma<bfloat16_t,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
bfloat16_t,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
arch::Sm80,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
2,
|
||||
Operator,
|
||||
false,
|
||||
SharedMemoryClear,
|
||||
GatherA,
|
||||
GatherB> {
|
||||
// Define the MmaCore components
|
||||
// 3 is used on purpose here to trigger components for mma multistage
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
bfloat16_t,
|
||||
LayoutA,
|
||||
bfloat16_t,
|
||||
LayoutB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
3,
|
||||
Operator>;
|
||||
|
||||
// Define the MmaCore components
|
||||
// 3 is used on purpose here to trigger components for mma multistage
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, bfloat16_t,
|
||||
LayoutA, bfloat16_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator>;
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
bfloat16_t,
|
||||
LayoutA,
|
||||
1,
|
||||
ThreadMapA,
|
||||
AccessTypeA,
|
||||
GatherA>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, bfloat16_t, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA, GatherA>;
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
bfloat16_t,
|
||||
LayoutB,
|
||||
0,
|
||||
ThreadMapB,
|
||||
AccessTypeB,
|
||||
GatherB>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, bfloat16_t, LayoutB, 0, ThreadMapB,
|
||||
AccessTypeB, GatherB>;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma =
|
||||
cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape,
|
||||
IteratorA,
|
||||
typename MmaCore::SmemIteratorA,
|
||||
MmaCore::kCacheOpA,
|
||||
IteratorB,
|
||||
typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
typename MmaCore::MmaPolicy,
|
||||
2>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), bf16
|
||||
/// activation & int8 weight
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
@@ -169,34 +261,61 @@ template <
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
|
||||
{
|
||||
struct DefaultMma<cutlass::bfloat16_t,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
uint8_t,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
2,
|
||||
Operator> {
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
|
||||
using Mma = DqMma<bfloat16_t,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
uint8_t,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
bfloat16_t,
|
||||
layout::RowMajor,
|
||||
kAlignmentScale,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
2,
|
||||
Operator>;
|
||||
|
||||
using Mma = DqMma<bfloat16_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, bfloat16_t, layout::RowMajor,
|
||||
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, 2, Operator>;
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), bf16
|
||||
/// activation & int4 weight
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
@@ -218,34 +337,61 @@ template <
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
|
||||
{
|
||||
struct DefaultMma<cutlass::bfloat16_t,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
uint4b_t,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
2,
|
||||
Operator> {
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
|
||||
using Mma = DqMma<bfloat16_t,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
uint4b_t,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
bfloat16_t,
|
||||
layout::RowMajor,
|
||||
kAlignmentScale,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
2,
|
||||
Operator>;
|
||||
|
||||
using Mma = DqMma<bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, bfloat16_t, layout::RowMajor,
|
||||
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, 2, Operator>;
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), bf16
|
||||
/// activation & int8 weight
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
@@ -271,35 +417,64 @@ template <
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
struct DefaultMma<cutlass::bfloat16_t,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
uint8_t,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
kStages,
|
||||
Operator,
|
||||
false,
|
||||
SharedMemoryClear> {
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
|
||||
using Mma = DqMma<bfloat16_t,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
uint8_t,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
bfloat16_t,
|
||||
layout::RowMajor,
|
||||
kAlignmentScale,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
kStages,
|
||||
Operator,
|
||||
SharedMemoryClear>;
|
||||
|
||||
using Mma = DqMma<bfloat16_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, bfloat16_t, layout::RowMajor,
|
||||
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16
|
||||
/// activation & int4 weight
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
@@ -325,35 +500,64 @@ template <
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
struct DefaultMma<bfloat16_t,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
uint4b_t,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
kStages,
|
||||
Operator,
|
||||
false,
|
||||
SharedMemoryClear> {
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
|
||||
using Mma = DqMma<bfloat16_t,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
uint4b_t,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
bfloat16_t,
|
||||
layout::RowMajor,
|
||||
kAlignmentScale,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
kStages,
|
||||
Operator,
|
||||
SharedMemoryClear>;
|
||||
|
||||
using Mma = DqMma<bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, bfloat16_t, layout::RowMajor,
|
||||
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fbf16 activation & int2 weight, mma multistage
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fbf16
|
||||
/// activation & int2 weight, mma multistage
|
||||
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
@@ -376,26 +580,50 @@ template <
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
|
||||
{
|
||||
private:
|
||||
using Mma = DefaultWint2xMma<bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, 2, Operator>;
|
||||
struct DefaultMma<cutlass::bfloat16_t,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
uint2b_t,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
2,
|
||||
Operator> {
|
||||
private:
|
||||
using Mma = DefaultWint2xMma<bfloat16_t,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
uint2b_t,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
2,
|
||||
Operator>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
template <
|
||||
@@ -423,29 +651,55 @@ template <
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
private:
|
||||
using Mma = DefaultWint2xMma<bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
|
||||
struct DefaultMma<bfloat16_t,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
uint2b_t,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
kStages,
|
||||
Operator,
|
||||
false,
|
||||
SharedMemoryClear> {
|
||||
private:
|
||||
using Mma = DefaultWint2xMma<bfloat16_t,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
uint2b_t,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
kStages,
|
||||
Operator,
|
||||
SharedMemoryClear>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
@@ -18,14 +18,15 @@
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@@ -66,10 +67,21 @@ template <
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Cache operation of operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB>
|
||||
struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
|
||||
layout::RowMajor, uint2b_t, layout::ColumnMajor,
|
||||
ElementC_, LayoutC_, arch::OpClassTensorOp, Stages,
|
||||
Operator_, false, CacheOpA, CacheOpB> {
|
||||
struct DefaultMmaCore<Shape_,
|
||||
WarpShape_,
|
||||
InstructionShape_,
|
||||
ElementA_,
|
||||
layout::RowMajor,
|
||||
uint2b_t,
|
||||
layout::ColumnMajor,
|
||||
ElementC_,
|
||||
LayoutC_,
|
||||
arch::OpClassTensorOp,
|
||||
Stages,
|
||||
Operator_,
|
||||
false,
|
||||
CacheOpA,
|
||||
CacheOpB> {
|
||||
using Shape = Shape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = InstructionShape_;
|
||||
@@ -104,7 +116,8 @@ struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
|
||||
|
||||
/// Size of a threadblock-scoped access of B
|
||||
static constexpr int kMaxThreadsForB =
|
||||
(Shape::kK * Shape::kN * sizeof_bits<ElementB>::value) / kAccessSizeInBits;
|
||||
(Shape::kK * Shape::kN * sizeof_bits<ElementB>::value) /
|
||||
kAccessSizeInBits;
|
||||
static constexpr int kThreadsForB =
|
||||
kMaxThreadsForB > kThreads ? kThreads : kMaxThreadsForB;
|
||||
|
||||
@@ -129,11 +142,13 @@ struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
|
||||
//
|
||||
|
||||
using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise<
|
||||
sizeof_bits<ElementA>::value, Shape::kK>;
|
||||
sizeof_bits<ElementA>::value,
|
||||
Shape::kK>;
|
||||
|
||||
// Shared memory layout
|
||||
using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
||||
sizeof_bits<ElementB>::value, Shape::kK>;
|
||||
sizeof_bits<ElementB>::value,
|
||||
Shape::kK>;
|
||||
|
||||
//
|
||||
// Iterators to write to shared memory
|
||||
@@ -141,26 +156,34 @@ struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
|
||||
|
||||
/// ThreadMap of iterator A
|
||||
using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap<
|
||||
layout::PitchLinearShape<Shape::kK, Shape::kM>, kThreads,
|
||||
layout::PitchLinearShape<Shape::kK, Shape::kM>,
|
||||
kThreads,
|
||||
layout::PitchLinearShape<kWarpThreadArrangementContiguousA,
|
||||
kWarpThreadArrangementStridedA>,
|
||||
kAccessSizeInBits / sizeof_bits<ElementA>::value>;
|
||||
|
||||
/// Shared memory iterator to A operand
|
||||
using SmemIteratorA = transform::threadblock::RegularTileAccessIterator<
|
||||
MatrixShape<Shape::kM, Shape::kK>, ElementA, SmemLayoutA, 0,
|
||||
MatrixShape<Shape::kM, Shape::kK>,
|
||||
ElementA,
|
||||
SmemLayoutA,
|
||||
0,
|
||||
IteratorThreadMapA>;
|
||||
|
||||
/// ThreadMap of iterator B
|
||||
using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap<
|
||||
layout::PitchLinearShape<Shape::kK, Shape::kN>, kThreadsForB,
|
||||
layout::PitchLinearShape<Shape::kK, Shape::kN>,
|
||||
kThreadsForB,
|
||||
layout::PitchLinearShape<kWarpThreadArrangementContiguousB,
|
||||
kWarpThreadArrangementStridedB>,
|
||||
kAccessSizeInBits / sizeof_bits<ElementB>::value>;
|
||||
|
||||
/// Shared memory iterator to B operand
|
||||
using SmemIteratorB = transform::threadblock::RegularTileAccessIterator<
|
||||
MatrixShape<Shape::kK, Shape::kN>, ElementB, SmemLayoutB, 1,
|
||||
MatrixShape<Shape::kK, Shape::kN>,
|
||||
ElementB,
|
||||
SmemLayoutB,
|
||||
1,
|
||||
IteratorThreadMapB>;
|
||||
|
||||
//
|
||||
@@ -168,13 +191,23 @@ struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
|
||||
//
|
||||
|
||||
// Define the warp-level tensor op
|
||||
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||
WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB,
|
||||
ElementC, LayoutC, Operator, WarpCount::kK>::Type;
|
||||
using MmaTensorOp =
|
||||
typename cutlass::gemm::warp::DefaultMmaTensorOp<WarpShape,
|
||||
InstructionShape,
|
||||
ElementA,
|
||||
SmemLayoutA,
|
||||
ElementB,
|
||||
SmemLayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
Operator,
|
||||
WarpCount::kK>::Type;
|
||||
|
||||
/// Policy used to define MmaPipelined
|
||||
using MmaPolicy = MmaPolicy<MmaTensorOp, MatrixShape<0, 0>,
|
||||
MatrixShape<0, 0>, WarpCount::kK>;
|
||||
using MmaPolicy = MmaPolicy<MmaTensorOp,
|
||||
MatrixShape<0, 0>,
|
||||
MatrixShape<0, 0>,
|
||||
WarpCount::kK>;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION &
|
||||
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -31,47 +31,61 @@ namespace threadblock {
|
||||
|
||||
template <typename ThreadblockShape, typename ElementT, int GroupSize>
|
||||
struct DefaultQuantParamsIterators {
|
||||
private:
|
||||
static constexpr int kAlignment = 128 / sizeof_bits<ElementT>::value;
|
||||
static_assert((ThreadblockShape::kN % kAlignment) == 0, "");
|
||||
private:
|
||||
static constexpr int kAlignment = 128 / sizeof_bits<ElementT>::value;
|
||||
static_assert((ThreadblockShape::kN % kAlignment) == 0, "");
|
||||
|
||||
static constexpr int kRows =
|
||||
(GroupSize == -1) ? 1 : (ThreadblockShape::kK + GroupSize - 1) / GroupSize;
|
||||
static constexpr int kColumns = ThreadblockShape::kN;
|
||||
static constexpr int kRows =
|
||||
(GroupSize == -1) ? 1
|
||||
: (ThreadblockShape::kK + GroupSize - 1) / GroupSize;
|
||||
static constexpr int kColumns = ThreadblockShape::kN;
|
||||
|
||||
using IteratorThreadMap = transform::PitchLinearStripminedThreadMap<
|
||||
layout::PitchLinearShape<kColumns, kRows>,
|
||||
kColumns / kAlignment, kAlignment>;
|
||||
using IteratorThreadMap = transform::PitchLinearStripminedThreadMap<
|
||||
layout::PitchLinearShape<kColumns, kRows>,
|
||||
kColumns / kAlignment,
|
||||
kAlignment>;
|
||||
|
||||
public:
|
||||
using Iterator = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
MatrixShape<kRows, kColumns>, ElementT, layout::RowMajor, 0,
|
||||
IteratorThreadMap, kAlignment>;
|
||||
using SmemIterator = Iterator;
|
||||
public:
|
||||
using Iterator = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
MatrixShape<kRows, kColumns>,
|
||||
ElementT,
|
||||
layout::RowMajor,
|
||||
0,
|
||||
IteratorThreadMap,
|
||||
kAlignment>;
|
||||
using SmemIterator = Iterator;
|
||||
};
|
||||
|
||||
template <typename ThreadblockShape, int GroupSize>
|
||||
struct DefaultQuantParamsIterators<ThreadblockShape, uint4b_t, GroupSize> {
|
||||
private:
|
||||
static constexpr int kAlignment = 32 / sizeof_bits<uint4b_t>::value;
|
||||
static_assert((ThreadblockShape::kN % kAlignment) == 0, "");
|
||||
private:
|
||||
static constexpr int kAlignment = 32 / sizeof_bits<uint4b_t>::value;
|
||||
static_assert((ThreadblockShape::kN % kAlignment) == 0, "");
|
||||
|
||||
static constexpr int kRows =
|
||||
(GroupSize == -1) ? 1 : (ThreadblockShape::kK + 2 * GroupSize - 1) / (2 * GroupSize);
|
||||
static constexpr int kColumns =
|
||||
(GroupSize == -1) ? ThreadblockShape::kN : ThreadblockShape::kN * 2;
|
||||
static constexpr int kRows =
|
||||
(GroupSize == -1)
|
||||
? 1
|
||||
: (ThreadblockShape::kK + 2 * GroupSize - 1) / (2 * GroupSize);
|
||||
static constexpr int kColumns =
|
||||
(GroupSize == -1) ? ThreadblockShape::kN : ThreadblockShape::kN * 2;
|
||||
|
||||
using IteratorThreadMap = transform::PitchLinearStripminedThreadMap<
|
||||
layout::PitchLinearShape<kColumns, kRows>,
|
||||
kColumns / kAlignment, kAlignment>;
|
||||
using IteratorThreadMap = transform::PitchLinearStripminedThreadMap<
|
||||
layout::PitchLinearShape<kColumns, kRows>,
|
||||
kColumns / kAlignment,
|
||||
kAlignment>;
|
||||
|
||||
public:
|
||||
using AccessType = cutlass::Array<uint4b_t, kAlignment>;
|
||||
using Iterator = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
MatrixShape<kRows, kColumns>, uint4b_t, layout::RowMajor,
|
||||
0, IteratorThreadMap, AccessType>;
|
||||
public:
|
||||
using AccessType = cutlass::Array<uint4b_t, kAlignment>;
|
||||
using Iterator =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
MatrixShape<kRows, kColumns>,
|
||||
uint4b_t,
|
||||
layout::RowMajor,
|
||||
0,
|
||||
IteratorThreadMap,
|
||||
AccessType>;
|
||||
|
||||
using SmemIterator = Iterator;
|
||||
using SmemIterator = Iterator;
|
||||
};
|
||||
|
||||
template <
|
||||
@@ -142,105 +156,174 @@ template <
|
||||
typename Operator,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultWint2xMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape,
|
||||
kStages, Operator, SharedMemoryClear>
|
||||
{
|
||||
public:
|
||||
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
|
||||
"Element A must be fp16 or bf16");
|
||||
struct DefaultWint2xMma<ElementA,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
kStages,
|
||||
Operator,
|
||||
SharedMemoryClear> {
|
||||
public:
|
||||
static_assert(platform::is_same<ElementA, half_t>::value ||
|
||||
platform::is_same<ElementA, bfloat16_t>::value,
|
||||
"Element A must be fp16 or bf16");
|
||||
|
||||
static_assert(platform::is_same<ElementB, uint2b_t>::value,
|
||||
"Element B must be uint2b_t");
|
||||
static_assert(platform::is_same<ElementB, uint2b_t>::value,
|
||||
"Element B must be uint2b_t");
|
||||
|
||||
static_assert(platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
|
||||
"Mma multistage must dequantize after ldsm");
|
||||
static_assert(
|
||||
platform::is_same<Operator,
|
||||
arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
|
||||
"Mma multistage must dequantize after ldsm");
|
||||
|
||||
using ElementSuperScale = ElementA;
|
||||
using ElementLocalScale = uint4b_t;
|
||||
using ElementCodeScaleZp = float;
|
||||
using ElementSuperScale = ElementA;
|
||||
using ElementLocalScale = uint4b_t;
|
||||
using ElementCodeScaleZp = float;
|
||||
|
||||
static constexpr int kGroupSize = 64;
|
||||
static constexpr int kGroupSize = 64;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the MmaCore components
|
||||
// Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
|
||||
ElementA, LayoutA, ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, OperatorClass,
|
||||
std::max(kStages, 3), Operator, false, CacheOpA, CacheOpB>;
|
||||
// Define the MmaCore components
|
||||
// Mma core does not depend on stages, so pass in at least 3 here to mma
|
||||
// multistage pieces are created
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
layout::ColumnMajor,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
OperatorClass,
|
||||
std::max(kStages, 3),
|
||||
Operator,
|
||||
false,
|
||||
CacheOpA,
|
||||
CacheOpB>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, ElementA, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA>;
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
1,
|
||||
ThreadMapA,
|
||||
AccessTypeA>;
|
||||
|
||||
private:
|
||||
static constexpr int kColumnsInterleaved = LayoutB::kColumnsInterleaved;
|
||||
static constexpr int kRowsPerTile = LayoutB::kRowsPerTile;
|
||||
static_assert(!(MmaCore::Shape::kN % kColumnsInterleaved), "ThreadblockShape must be disivle by kColumnsInterleaved");
|
||||
static_assert(kRowsPerTile == MmaCore::Shape::kK, "");
|
||||
private:
|
||||
static constexpr int kColumnsInterleaved = LayoutB::kColumnsInterleaved;
|
||||
static constexpr int kRowsPerTile = LayoutB::kRowsPerTile;
|
||||
static_assert(!(MmaCore::Shape::kN % kColumnsInterleaved),
|
||||
"ThreadblockShape must be disivle by kColumnsInterleaved");
|
||||
static_assert(kRowsPerTile == MmaCore::Shape::kK, "");
|
||||
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using WarpArrangement = typename ThreadMapB::Detail::WarpThreadArrangement;
|
||||
static_assert(!(WarpArrangement::kStrided % kColumnsInterleaved), "");
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using WarpArrangement = typename ThreadMapB::Detail::WarpThreadArrangement;
|
||||
static_assert(!(WarpArrangement::kStrided % kColumnsInterleaved), "");
|
||||
|
||||
using IteratorShapeB = MatrixShape<
|
||||
MmaCore::Shape::kK * kColumnsInterleaved, MmaCore::Shape::kN / kColumnsInterleaved>;
|
||||
using InterleavedThreadMapB = transform::PitchLinearWarpRakedThreadMap<
|
||||
layout::PitchLinearShape<IteratorShapeB::kRow, IteratorShapeB::kColumn>,
|
||||
ThreadMapB::kThreads,
|
||||
layout::PitchLinearShape<WarpArrangement::kContiguous * kColumnsInterleaved,
|
||||
WarpArrangement::kStrided / kColumnsInterleaved>,
|
||||
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
|
||||
using IteratorShapeB = MatrixShape<MmaCore::Shape::kK * kColumnsInterleaved,
|
||||
MmaCore::Shape::kN / kColumnsInterleaved>;
|
||||
using InterleavedThreadMapB = transform::PitchLinearWarpRakedThreadMap<
|
||||
layout::PitchLinearShape<IteratorShapeB::kRow, IteratorShapeB::kColumn>,
|
||||
ThreadMapB::kThreads,
|
||||
layout::PitchLinearShape<WarpArrangement::kContiguous *
|
||||
kColumnsInterleaved,
|
||||
WarpArrangement::kStrided / kColumnsInterleaved>,
|
||||
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
|
||||
|
||||
public:
|
||||
// Define iterators over tiles from the B operand
|
||||
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
IteratorShapeB, ElementB, layout::ColumnMajor, 0, InterleavedThreadMapB,
|
||||
AccessTypeB>;
|
||||
public:
|
||||
// Define iterators over tiles from the B operand
|
||||
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
IteratorShapeB,
|
||||
ElementB,
|
||||
layout::ColumnMajor,
|
||||
0,
|
||||
InterleavedThreadMapB,
|
||||
AccessTypeB>;
|
||||
|
||||
private:
|
||||
// Define iterators over tiles from extra quant params for B operand
|
||||
using IteratorSuperScale = typename DefaultQuantParamsIterators<
|
||||
ThreadblockShape, ElementSuperScale, -1>::Iterator;
|
||||
using SmemIteratorSuperScale = typename DefaultQuantParamsIterators<
|
||||
ThreadblockShape, ElementSuperScale, -1>::SmemIterator;
|
||||
private:
|
||||
// Define iterators over tiles from extra quant params for B operand
|
||||
using IteratorSuperScale =
|
||||
typename DefaultQuantParamsIterators<ThreadblockShape,
|
||||
ElementSuperScale,
|
||||
-1>::Iterator;
|
||||
using SmemIteratorSuperScale =
|
||||
typename DefaultQuantParamsIterators<ThreadblockShape,
|
||||
ElementSuperScale,
|
||||
-1>::SmemIterator;
|
||||
|
||||
using IteratorLocalScale = typename DefaultQuantParamsIterators<
|
||||
ThreadblockShape, ElementLocalScale, kGroupSize>::Iterator;
|
||||
using SmemIteratorLocalScale = typename DefaultQuantParamsIterators<
|
||||
ThreadblockShape, ElementLocalScale, kGroupSize>::SmemIterator;
|
||||
using IteratorLocalScale =
|
||||
typename DefaultQuantParamsIterators<ThreadblockShape,
|
||||
ElementLocalScale,
|
||||
kGroupSize>::Iterator;
|
||||
using SmemIteratorLocalScale =
|
||||
typename DefaultQuantParamsIterators<ThreadblockShape,
|
||||
ElementLocalScale,
|
||||
kGroupSize>::SmemIterator;
|
||||
|
||||
using IteratorCodeScaleZp = typename DefaultQuantParamsIterators<
|
||||
ThreadblockShape, ElementCodeScaleZp, -1>::Iterator;
|
||||
using SmemIteratorCodeScaleZp = typename DefaultQuantParamsIterators<
|
||||
ThreadblockShape, ElementCodeScaleZp, -1>::Iterator;
|
||||
using IteratorCodeScaleZp =
|
||||
typename DefaultQuantParamsIterators<ThreadblockShape,
|
||||
ElementCodeScaleZp,
|
||||
-1>::Iterator;
|
||||
using SmemIteratorCodeScaleZp =
|
||||
typename DefaultQuantParamsIterators<ThreadblockShape,
|
||||
ElementCodeScaleZp,
|
||||
-1>::Iterator;
|
||||
|
||||
public:
|
||||
using QuantParamsAccessor = Wint2ParamsAccessor<
|
||||
ElementA, ThreadblockShape, IteratorSuperScale, SmemIteratorSuperScale,
|
||||
IteratorLocalScale, SmemIteratorLocalScale,
|
||||
IteratorCodeScaleZp, SmemIteratorCodeScaleZp, kStages, kGroupSize>;
|
||||
public:
|
||||
using QuantParamsAccessor = Wint2ParamsAccessor<ElementA,
|
||||
ThreadblockShape,
|
||||
IteratorSuperScale,
|
||||
SmemIteratorSuperScale,
|
||||
IteratorLocalScale,
|
||||
SmemIteratorLocalScale,
|
||||
IteratorCodeScaleZp,
|
||||
SmemIteratorCodeScaleZp,
|
||||
kStages,
|
||||
kGroupSize>;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<
|
||||
typename MmaCore::Shape,
|
||||
IteratorA, typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA,
|
||||
IteratorB, typename MmaCore::SmemIteratorB, MmaCore::kCacheOpB,
|
||||
ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy,
|
||||
kStages, QuantParamsAccessor, SharedMemoryClear>;
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<
|
||||
typename MmaCore::Shape,
|
||||
IteratorA,
|
||||
typename MmaCore::SmemIteratorA,
|
||||
MmaCore::kCacheOpA,
|
||||
IteratorB,
|
||||
typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
typename MmaCore::MmaPolicy,
|
||||
kStages,
|
||||
QuantParamsAccessor,
|
||||
SharedMemoryClear>;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
@@ -18,14 +18,15 @@
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
@@ -47,30 +48,33 @@
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// SFINAE trick so I can keep the same loop code for Volta and dispatch to the
|
||||
// correct warp level mma. On volta, all data is stored to shared memory as FP16.
|
||||
// correct warp level mma. On volta, all data is stored to shared memory as
|
||||
// FP16.
|
||||
template <typename WarpMma, int kExpansionFactor = 1>
|
||||
CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D,
|
||||
typename WarpMma::FragmentA const& A, typename WarpMma::FragmentB const& B, typename WarpMma::FragmentC const& C,
|
||||
int const warp_tileB_k_offset)
|
||||
{
|
||||
warp_mma(D, A, B, C);
|
||||
CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma,
|
||||
typename WarpMma::FragmentC& D,
|
||||
typename WarpMma::FragmentA const& A,
|
||||
typename WarpMma::FragmentB const& B,
|
||||
typename WarpMma::FragmentC const& C,
|
||||
int const warp_tileB_k_offset) {
|
||||
warp_mma(D, A, B, C);
|
||||
}
|
||||
|
||||
template <typename WarpMma, int kExpansionFactor = WarpMma::kExpansionFactor>
|
||||
CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D,
|
||||
typename WarpMma::TransformedFragmentA const& A, typename WarpMma::TransformedFragmentB const& B,
|
||||
typename WarpMma::FragmentC const& C, int const warp_tileB_k_offset)
|
||||
{
|
||||
warp_mma(D, A, B, C, warp_tileB_k_offset);
|
||||
CUTLASS_DEVICE void run_warp_mma(
|
||||
WarpMma& warp_mma,
|
||||
typename WarpMma::FragmentC& D,
|
||||
typename WarpMma::TransformedFragmentA const& A,
|
||||
typename WarpMma::TransformedFragmentB const& B,
|
||||
typename WarpMma::FragmentC const& C,
|
||||
int const warp_tileB_k_offset) {
|
||||
warp_mma(D, A, B, C, warp_tileB_k_offset);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -90,168 +94,169 @@ template <
|
||||
WeightOnlyQuantOp DequantOp,
|
||||
/// Used for partial specialization,
|
||||
typename Enable = bool>
|
||||
class DqMmaBase
|
||||
{
|
||||
public:
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
class DqMmaBase {
|
||||
public:
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
|
||||
///< Type of the scale to be loaded
|
||||
using ElementScale = ElementScale_;
|
||||
///< Type of the scale to be loaded
|
||||
using ElementScale = ElementScale_;
|
||||
|
||||
static_assert(DequantOp != WeightOnlyQuantOp::UNDEFINED, "");
|
||||
static_assert(DequantOp != WeightOnlyQuantOp::UNDEFINED, "");
|
||||
|
||||
// Finegrained scales get streamed in via cp.async
|
||||
static constexpr int ScalebiasStages = isFinegrained(DequantOp) ? Stages : 1;
|
||||
// We always have scales.
|
||||
static constexpr int ScaleElementsPerStage = Shape::kN;
|
||||
// We sometimes have a bias
|
||||
static constexpr int BiasElementsPerStage = hasZero(DequantOp) ? Shape::kN : 0;
|
||||
// Finegrained scales get streamed in via cp.async
|
||||
static constexpr int ScalebiasStages = isFinegrained(DequantOp) ? Stages : 1;
|
||||
// We always have scales.
|
||||
static constexpr int ScaleElementsPerStage = Shape::kN;
|
||||
// We sometimes have a bias
|
||||
static constexpr int BiasElementsPerStage =
|
||||
hasZero(DequantOp) ? Shape::kN : 0;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Shape describing the overall GEMM computed from shared memory
|
||||
/// by each warp.
|
||||
using WarpGemm = typename Policy::Operator::Shape;
|
||||
|
||||
/// Shape describing the number of warps filling the CTA
|
||||
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM,
|
||||
Shape::kN / WarpGemm::kN,
|
||||
Shape::kK / WarpGemm::kK>;
|
||||
|
||||
/// Number of warp-level GEMM operations
|
||||
static int const kWarpGemmIterations =
|
||||
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
|
||||
|
||||
static constexpr int kNumKIterationsPerWarpBLoad =
|
||||
Operator::IteratorB::InstructionShape::kRow /
|
||||
Operator::InstructionShape::kK;
|
||||
|
||||
static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), "");
|
||||
static constexpr int kWarpGemmIterationsForB =
|
||||
kWarpGemmIterations / kNumKIterationsPerWarpBLoad;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Tensor reference to the A operand
|
||||
using TensorRefA =
|
||||
TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
|
||||
|
||||
/// Tensor reference to the B operand
|
||||
using TensorRefB =
|
||||
TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
|
||||
|
||||
//
|
||||
// Nested structs
|
||||
//
|
||||
|
||||
/// Shared storage object needed by threadblock-scoped GEMM
|
||||
class SharedStorage {
|
||||
public:
|
||||
//
|
||||
// Dependent types
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
/// Shape of the A matrix operand in shared memory
|
||||
using ShapeA =
|
||||
MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow,
|
||||
Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
|
||||
|
||||
/// Shape describing the overall GEMM computed from shared memory
|
||||
/// by each warp.
|
||||
using WarpGemm = typename Policy::Operator::Shape;
|
||||
/// Shape of the B matrix operand in shared memory
|
||||
using ShapeB = MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
|
||||
Shape::kN + Policy::SmemPaddingB::kColumn>;
|
||||
|
||||
/// Shape describing the number of warps filling the CTA
|
||||
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>;
|
||||
/// Shape of the shared memory buffer for the scales for the B matrix.
|
||||
using ShapeScale = MatrixShape<ScalebiasStages, ScaleElementsPerStage>;
|
||||
/// Shape of the shared memory buffer for the biases of the B matrix.
|
||||
using ShapeZero = MatrixShape<ScalebiasStages, BiasElementsPerStage>;
|
||||
|
||||
/// Number of warp-level GEMM operations
|
||||
static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK);
|
||||
|
||||
static constexpr int kNumKIterationsPerWarpBLoad
|
||||
= Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK;
|
||||
|
||||
static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), "");
|
||||
static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Tensor reference to the A operand
|
||||
using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
|
||||
|
||||
/// Tensor reference to the B operand
|
||||
using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
|
||||
|
||||
//
|
||||
// Nested structs
|
||||
//
|
||||
|
||||
/// Shared storage object needed by threadblock-scoped GEMM
|
||||
class SharedStorage
|
||||
{
|
||||
public:
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// Shape of the A matrix operand in shared memory
|
||||
using ShapeA
|
||||
= MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow, Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
|
||||
|
||||
/// Shape of the B matrix operand in shared memory
|
||||
using ShapeB
|
||||
= MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow, Shape::kN + Policy::SmemPaddingB::kColumn>;
|
||||
|
||||
/// Shape of the shared memory buffer for the scales for the B matrix.
|
||||
using ShapeScale = MatrixShape<ScalebiasStages, ScaleElementsPerStage>;
|
||||
/// Shape of the shared memory buffer for the biases of the B matrix.
|
||||
using ShapeZero = MatrixShape<ScalebiasStages, BiasElementsPerStage>;
|
||||
|
||||
public:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Buffer for A operand
|
||||
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
|
||||
|
||||
/// Buffer for B operand
|
||||
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
|
||||
|
||||
/// Buffer to hold scales for threadblock
|
||||
AlignedBuffer<ElementScale, ShapeScale::kCount> operand_scale;
|
||||
|
||||
/// Buffer to hold scales for threadblock
|
||||
AlignedBuffer<ElementScale, ShapeZero::kCount> operand_zero;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Returns a layout object for the A matrix
|
||||
CUTLASS_DEVICE
|
||||
static typename Operator::LayoutA LayoutA()
|
||||
{
|
||||
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a layout object for the B matrix
|
||||
CUTLASS_HOST_DEVICE
|
||||
static typename Operator::LayoutB LayoutB()
|
||||
{
|
||||
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the A operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefA operand_A_ref()
|
||||
{
|
||||
return TensorRefA{operand_A.data(), LayoutA()};
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the B operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefB operand_B_ref()
|
||||
{
|
||||
return TensorRefB{operand_B.data(), LayoutB()};
|
||||
}
|
||||
};
|
||||
|
||||
protected:
|
||||
public:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A operand from shared memory
|
||||
typename Operator::IteratorA warp_tile_iterator_A_;
|
||||
/// Buffer for A operand
|
||||
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of B operand from shared memory
|
||||
typename Operator::IteratorB warp_tile_iterator_B_;
|
||||
/// Buffer for B operand
|
||||
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
/// Buffer to hold scales for threadblock
|
||||
AlignedBuffer<ElementScale, ShapeScale::kCount> operand_scale;
|
||||
|
||||
/// Buffer to hold scales for threadblock
|
||||
AlignedBuffer<ElementScale, ShapeZero::kCount> operand_zero;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Returns a layout object for the A matrix
|
||||
CUTLASS_DEVICE
|
||||
DqMmaBase(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
SharedStorage& shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx)
|
||||
: warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx)
|
||||
, warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx)
|
||||
{
|
||||
static typename Operator::LayoutA LayoutA() {
|
||||
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a layout object for the B matrix
|
||||
CUTLASS_HOST_DEVICE
|
||||
static typename Operator::LayoutB LayoutB() {
|
||||
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the A operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefA operand_A_ref() {
|
||||
return TensorRefA{operand_A.data(), LayoutA()};
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the B operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefB operand_B_ref() {
|
||||
return TensorRefB{operand_B.data(), LayoutB()};
|
||||
}
|
||||
};
|
||||
|
||||
protected:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A operand from shared memory
|
||||
typename Operator::IteratorA warp_tile_iterator_A_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of B operand from shared memory
|
||||
typename Operator::IteratorB warp_tile_iterator_B_;
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
DqMmaBase(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
SharedStorage& shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx)
|
||||
: warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
|
||||
warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
@@ -18,14 +18,15 @@
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
@@ -48,12 +49,9 @@
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -102,9 +100,9 @@ template <
|
||||
typename Enable = void>
|
||||
class DqMmaMultistage;
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h"
|
||||
|
||||
+628
-574
File diff suppressed because it is too large
Load Diff
+533
-480
File diff suppressed because it is too large
Load Diff
@@ -1,12 +1,12 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
@@ -18,14 +18,15 @@
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
@@ -53,27 +54,27 @@
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
@@ -98,9 +99,9 @@ template <
|
||||
typename Enable = void>
|
||||
class DqMmaPipelined;
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h"
|
||||
|
||||
+398
-348
@@ -1,12 +1,12 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
@@ -18,14 +18,15 @@
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
@@ -53,27 +54,27 @@
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
@@ -94,393 +95,442 @@ template <
|
||||
typename TransformBAfterLDS_,
|
||||
/// The quantization operator being used
|
||||
WeightOnlyQuantOp QuantOp_>
|
||||
class DqMmaPipelined<Shape_, IteratorA_, SmemIteratorA_, IteratorB_, SmemIteratorB_, IteratorScale_, SmemIteratorScale_,
|
||||
ElementC_, LayoutC_, Policy_, TransformBAfterLDG_, TransformBAfterLDS_, QuantOp_,
|
||||
std::enable_if_t<isFinegrained(QuantOp_)>>
|
||||
: public DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>
|
||||
{
|
||||
public:
|
||||
///< Base class
|
||||
using Base = DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>;
|
||||
class DqMmaPipelined<Shape_,
|
||||
IteratorA_,
|
||||
SmemIteratorA_,
|
||||
IteratorB_,
|
||||
SmemIteratorB_,
|
||||
IteratorScale_,
|
||||
SmemIteratorScale_,
|
||||
ElementC_,
|
||||
LayoutC_,
|
||||
Policy_,
|
||||
TransformBAfterLDG_,
|
||||
TransformBAfterLDS_,
|
||||
QuantOp_,
|
||||
std::enable_if_t<isFinegrained(QuantOp_)>>
|
||||
: public DqMmaBase<Shape_,
|
||||
Policy_,
|
||||
typename SmemIteratorScale_::Element,
|
||||
2,
|
||||
QuantOp_> {
|
||||
public:
|
||||
///< Base class
|
||||
using Base = DqMmaBase<Shape_,
|
||||
Policy_,
|
||||
typename SmemIteratorScale_::Element,
|
||||
2,
|
||||
QuantOp_>;
|
||||
|
||||
using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
|
||||
using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
|
||||
using ElementC = ElementC_; ///< Data type of accumulator matrix
|
||||
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
|
||||
using Policy = Policy_; ///< Policy describing tuning details
|
||||
using Shape =
|
||||
Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorA =
|
||||
IteratorA_; ///< Iterates over tiles of A operand in global memory
|
||||
using IteratorB =
|
||||
IteratorB_; ///< Iterates over tiles of B operand in global memory
|
||||
using ElementC = ElementC_; ///< Data type of accumulator matrix
|
||||
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
|
||||
using Policy = Policy_; ///< Policy describing tuning details
|
||||
|
||||
using IteratorScale = IteratorScale_;
|
||||
using ElementScale = typename IteratorScale::Element;
|
||||
using LayoutScale = typename IteratorScale::Layout;
|
||||
using IteratorScale = IteratorScale_;
|
||||
using ElementScale = typename IteratorScale::Element;
|
||||
using LayoutScale = typename IteratorScale::Layout;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
using SmemIteratorScale = SmemIteratorScale_;
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
using SmemIteratorScale = SmemIteratorScale_;
|
||||
|
||||
using TransformBAfterLDG = TransformBAfterLDG_;
|
||||
using TransformBAfterLDS = TransformBAfterLDS_;
|
||||
using TransformBAfterLDG = TransformBAfterLDG_;
|
||||
using TransformBAfterLDS = TransformBAfterLDS_;
|
||||
|
||||
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
||||
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of operand A loaded from global memory
|
||||
using FragmentA = typename IteratorA::Fragment;
|
||||
/// Fragment of operand A loaded from global memory
|
||||
using FragmentA = typename IteratorA::Fragment;
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB = typename IteratorB::Fragment;
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB = typename IteratorB::Fragment;
|
||||
|
||||
/// Fragment of operand Scale loaded from global memory;
|
||||
using FragmentScale = typename IteratorScale::Fragment;
|
||||
/// Fragment of operand Scale loaded from global memory;
|
||||
using FragmentScale = typename IteratorScale::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Obtain the arch tag from the warp-level operator
|
||||
using ArchTag = typename Policy::Operator::ArchTag;
|
||||
/// Obtain the arch tag from the warp-level operator
|
||||
using ArchTag = typename Policy::Operator::ArchTag;
|
||||
|
||||
using Dequantizer = warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB,
|
||||
typename SmemIteratorScale::Element, LayoutScale, 32, QuantOp>;
|
||||
using Dequantizer =
|
||||
warp::MmaTensorOpDequantizer<Operator,
|
||||
typename Base::WarpGemm,
|
||||
Operand::kB,
|
||||
typename SmemIteratorScale::Element,
|
||||
LayoutScale,
|
||||
32,
|
||||
QuantOp>;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = Operator::kTransformB;
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = Operator::kTransformB;
|
||||
|
||||
// staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline)
|
||||
static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2");
|
||||
// staticaly assert kStages for DqMmaPipelined is two (Double-buffered
|
||||
// pipeline)
|
||||
static_assert((Base::kStages == 2),
|
||||
"DqMmaPipelined requires kStages set to value 2");
|
||||
|
||||
static_assert(Base::SharedStorage::ShapeScale::kRow == Base::kStages, "");
|
||||
static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, "");
|
||||
static_assert(Base::SharedStorage::ShapeScale::kRow == Base::kStages, "");
|
||||
static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, "");
|
||||
|
||||
private:
|
||||
using WarpFragmentA = typename Operator::FragmentA;
|
||||
using WarpFragmentB = typename Operator::FragmentB;
|
||||
Dequantizer warp_dequantizer_;
|
||||
private:
|
||||
using WarpFragmentA = typename Operator::FragmentA;
|
||||
using WarpFragmentB = typename Operator::FragmentB;
|
||||
Dequantizer warp_dequantizer_;
|
||||
|
||||
using WarpFragmentScale = typename Dequantizer::FragmentScale;
|
||||
using WarpFragmentZero = typename Dequantizer::FragmentZero;
|
||||
using WarpFragmentScale = typename Dequantizer::FragmentScale;
|
||||
using WarpFragmentZero = typename Dequantizer::FragmentZero;
|
||||
|
||||
using ElementA = typename IteratorA::Element;
|
||||
using ElementB = typename IteratorB::Element;
|
||||
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
|
||||
using ElementA = typename IteratorA::Element;
|
||||
using ElementB = typename IteratorB::Element;
|
||||
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
|
||||
|
||||
static constexpr bool RequiresTileInterleave
|
||||
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
|
||||
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
|
||||
"Layout K must match threadblockK");
|
||||
static constexpr bool RequiresTileInterleave =
|
||||
layout::IsColumnMajorTileInterleave<
|
||||
typename LayoutDetailsForB::Layout>::value;
|
||||
static_assert(!RequiresTileInterleave ||
|
||||
(RequiresTileInterleave &&
|
||||
(Shape::kK == LayoutDetailsForB::ThreadblockK)),
|
||||
"Layout K must match threadblockK");
|
||||
|
||||
protected:
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
protected:
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory
|
||||
SmemIteratorScale smem_iterator_scale_;
|
||||
/// Iterator to write threadblock-scoped tile of scale and zero operand to
|
||||
/// shared memory
|
||||
SmemIteratorScale smem_iterator_scale_;
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
DqMmaPipelined(typename Base::SharedStorage&
|
||||
shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
int const group_size, ///< The group size for quantization
|
||||
int thread_idx, ///< ID within the threadblock
|
||||
int warp_idx, ///< ID of warp
|
||||
int lane_idx ///< ID of each thread within a warp
|
||||
)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx)
|
||||
, warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
|
||||
{shared_storage.operand_zero.data(), LayoutScale(Shape::kN)},
|
||||
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx)
|
||||
, smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx)
|
||||
, smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
|
||||
, smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(),
|
||||
shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size)
|
||||
{
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
DqMmaPipelined(typename Base::SharedStorage&
|
||||
shared_storage, ///< Shared storage needed for internal
|
||||
///< use by threadblock-scoped GEMM
|
||||
int const group_size, ///< The group size for quantization
|
||||
int thread_idx, ///< ID within the threadblock
|
||||
int warp_idx, ///< ID of warp
|
||||
int lane_idx ///< ID of each thread within a warp
|
||||
)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
warp_dequantizer_(
|
||||
{shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
|
||||
{shared_storage.operand_zero.data(), LayoutScale(Shape::kN)},
|
||||
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) /
|
||||
Base::WarpCount::kM,
|
||||
lane_idx),
|
||||
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
|
||||
smem_iterator_scale_(LayoutScale(Shape::kN),
|
||||
shared_storage.operand_scale.data(),
|
||||
shared_storage.operand_zero.data(),
|
||||
{Base::kStages, Shape::kN},
|
||||
thread_idx,
|
||||
group_size) {
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
|
||||
}
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
|
||||
CUTLASS_DEVICE
|
||||
void copy_scales_and_advance(IteratorScale& iterator_scale) {
|
||||
using TransformScale =
|
||||
NumericArrayConverter<typename SmemIteratorScale::Element,
|
||||
typename FragmentScale::Element,
|
||||
FragmentScale::kElements>;
|
||||
|
||||
FragmentScale tb_frag_scales;
|
||||
FragmentScale tb_frag_zeros;
|
||||
tb_frag_scales.clear();
|
||||
tb_frag_zeros.clear();
|
||||
|
||||
TransformScale transformScale;
|
||||
|
||||
using FragmentElement = typename FragmentScale::Element;
|
||||
|
||||
auto gmem_scale_ptr = iterator_scale.get_scale();
|
||||
auto gmem_zero_ptr = iterator_scale.get_zero();
|
||||
|
||||
arch::global_load<FragmentScale, sizeof(FragmentScale)>(
|
||||
tb_frag_scales, gmem_scale_ptr, iterator_scale.valid());
|
||||
|
||||
if (gmem_zero_ptr != nullptr) {
|
||||
arch::global_load<FragmentScale, sizeof(FragmentScale)>(
|
||||
tb_frag_zeros, gmem_zero_ptr, iterator_scale.valid());
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_scales_and_advance(IteratorScale& iterator_scale)
|
||||
{
|
||||
using TransformScale = NumericArrayConverter<typename SmemIteratorScale::Element,
|
||||
typename FragmentScale::Element, FragmentScale::kElements>;
|
||||
typename TransformScale::result_type tb_frag_scales_fp16 =
|
||||
transformScale(tb_frag_scales);
|
||||
typename TransformScale::result_type tb_frag_zeros_fp16;
|
||||
if (gmem_zero_ptr != nullptr)
|
||||
tb_frag_zeros_fp16 = transformScale(tb_frag_zeros);
|
||||
|
||||
FragmentScale tb_frag_scales;
|
||||
FragmentScale tb_frag_zeros;
|
||||
tb_frag_scales.clear();
|
||||
tb_frag_zeros.clear();
|
||||
auto frag_scale_ptr_fp16 =
|
||||
reinterpret_cast<typename SmemIteratorScale::Element*>(
|
||||
&tb_frag_scales_fp16);
|
||||
auto frag_zero_ptr_fp16 =
|
||||
reinterpret_cast<typename SmemIteratorScale::Element*>(
|
||||
&tb_frag_zeros_fp16);
|
||||
auto smem_scale_ptr = this->smem_iterator_scale_.get_scale();
|
||||
auto smem_zero_ptr = this->smem_iterator_scale_.get_zero();
|
||||
|
||||
TransformScale transformScale;
|
||||
if (iterator_scale.valid()) {
|
||||
auto smem_offset = cast_smem_ptr_to_uint(smem_scale_ptr);
|
||||
arch::shared_store<sizeof(FragmentScale)>(smem_offset,
|
||||
frag_scale_ptr_fp16);
|
||||
|
||||
using FragmentElement = typename FragmentScale::Element;
|
||||
|
||||
auto gmem_scale_ptr = iterator_scale.get_scale();
|
||||
auto gmem_zero_ptr = iterator_scale.get_zero();
|
||||
|
||||
arch::global_load<FragmentScale, sizeof(FragmentScale)>(tb_frag_scales, gmem_scale_ptr, iterator_scale.valid());
|
||||
|
||||
if (gmem_zero_ptr != nullptr)
|
||||
{
|
||||
arch::global_load<FragmentScale, sizeof(FragmentScale)>(
|
||||
tb_frag_zeros, gmem_zero_ptr, iterator_scale.valid());
|
||||
}
|
||||
|
||||
typename TransformScale::result_type tb_frag_scales_fp16 = transformScale(tb_frag_scales);
|
||||
typename TransformScale::result_type tb_frag_zeros_fp16;
|
||||
if (gmem_zero_ptr != nullptr)
|
||||
tb_frag_zeros_fp16 = transformScale(tb_frag_zeros);
|
||||
|
||||
auto frag_scale_ptr_fp16 = reinterpret_cast<typename SmemIteratorScale::Element*>(&tb_frag_scales_fp16);
|
||||
auto frag_zero_ptr_fp16 = reinterpret_cast<typename SmemIteratorScale::Element*>(&tb_frag_zeros_fp16);
|
||||
auto smem_scale_ptr = this->smem_iterator_scale_.get_scale();
|
||||
auto smem_zero_ptr = this->smem_iterator_scale_.get_zero();
|
||||
|
||||
if (iterator_scale.valid())
|
||||
{
|
||||
auto smem_offset = cast_smem_ptr_to_uint(smem_scale_ptr);
|
||||
arch::shared_store<sizeof(FragmentScale)>(smem_offset, frag_scale_ptr_fp16);
|
||||
|
||||
if (gmem_zero_ptr != nullptr)
|
||||
{
|
||||
smem_offset = cast_smem_ptr_to_uint(smem_zero_ptr);
|
||||
arch::shared_store<sizeof(FragmentScale)>(smem_offset, frag_zero_ptr_fp16);
|
||||
}
|
||||
}
|
||||
|
||||
if (iterator_scale.group_size_ == 64)
|
||||
{
|
||||
iterator_scale.add_tile_offset({1, 0});
|
||||
}
|
||||
else if (iterator_scale.group_size_ == 128)
|
||||
{
|
||||
if constexpr (Shape::kK == 128)
|
||||
{
|
||||
iterator_scale.add_tile_offset({1, 0});
|
||||
}
|
||||
else if constexpr (Shape::kK == 64)
|
||||
{
|
||||
if (iterator_scale.row_groupsize64_ & 0x1)
|
||||
{
|
||||
iterator_scale.add_tile_offset({1, 0});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128");
|
||||
}
|
||||
}
|
||||
|
||||
iterator_scale.row_groupsize64_++;
|
||||
|
||||
this->smem_iterator_scale_.add_tile_offset({1, 0});
|
||||
if (gmem_zero_ptr != nullptr) {
|
||||
smem_offset = cast_smem_ptr_to_uint(smem_zero_ptr);
|
||||
arch::shared_store<sizeof(FragmentScale)>(smem_offset,
|
||||
frag_zero_ptr_fp16);
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop
|
||||
FragmentC& accum, ///< destination accumulator tile
|
||||
IteratorA iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB iterator_B, ///< iterator over B operand in global memory
|
||||
IteratorScale iterator_scale, ///< iterator over scale operand in global memory
|
||||
FragmentC const& src_accum)
|
||||
{ ///< source accumulator tile
|
||||
if (iterator_scale.group_size_ == 64) {
|
||||
iterator_scale.add_tile_offset({1, 0});
|
||||
} else if (iterator_scale.group_size_ == 128) {
|
||||
if constexpr (Shape::kK == 128) {
|
||||
iterator_scale.add_tile_offset({1, 0});
|
||||
} else if constexpr (Shape::kK == 64) {
|
||||
if (iterator_scale.row_groupsize64_ & 0x1) {
|
||||
iterator_scale.add_tile_offset({1, 0});
|
||||
}
|
||||
} else {
|
||||
static_assert(Shape::kK == 0,
|
||||
"Unsupported k tile shape, can only be 64 or 128");
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
TransformBAfterLDG ldg_converter;
|
||||
TransformBAfterLDS lds_converter;
|
||||
iterator_scale.row_groupsize64_++;
|
||||
|
||||
using TransformA
|
||||
= NumericArrayConverter<typename WarpFragmentA::Element, typename FragmentA::Element, FragmentA::kElements>;
|
||||
this->smem_iterator_scale_.add_tile_offset({1, 0});
|
||||
}
|
||||
|
||||
// These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want
|
||||
// to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS.
|
||||
TransformA transformA;
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
int gemm_k_iterations, ///< number of iterations of the mainloop
|
||||
FragmentC& accum, ///< destination accumulator tile
|
||||
IteratorA iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB iterator_B, ///< iterator over B operand in global memory
|
||||
IteratorScale
|
||||
iterator_scale, ///< iterator over scale operand in global memory
|
||||
FragmentC const& src_accum) { ///< source accumulator tile
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
accum = src_accum;
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
TransformBAfterLDG ldg_converter;
|
||||
TransformBAfterLDS lds_converter;
|
||||
|
||||
FragmentA tb_frag_A;
|
||||
FragmentB tb_frag_B;
|
||||
using TransformA = NumericArrayConverter<typename WarpFragmentA::Element,
|
||||
typename FragmentA::Element,
|
||||
FragmentA::kElements>;
|
||||
|
||||
tb_frag_A.clear();
|
||||
tb_frag_B.clear();
|
||||
// These transforms are mainly to handle when we have bfloat activations and
|
||||
// weights in GMEM and want to issue HMMA on architectures older than
|
||||
// Ampere. We will convert to FP16 before STS.
|
||||
TransformA transformA;
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B.load(tb_frag_B);
|
||||
// Perform accumulation in the 'd' output operand
|
||||
accum = src_accum;
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B;
|
||||
FragmentA tb_frag_A;
|
||||
FragmentB tb_frag_B;
|
||||
|
||||
this->smem_iterator_A_.store(transformA(tb_frag_A));
|
||||
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
|
||||
tb_frag_A.clear();
|
||||
tb_frag_B.clear();
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
// The last kblock is loaded in the prolog
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B.load(tb_frag_B);
|
||||
|
||||
copy_scales_and_advance(iterator_scale);
|
||||
++iterator_A;
|
||||
++iterator_B;
|
||||
|
||||
__syncthreads();
|
||||
this->smem_iterator_A_.store(transformA(tb_frag_A));
|
||||
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math instructions
|
||||
WarpFragmentA warp_frag_A[2];
|
||||
WarpFragmentB warp_frag_B[2];
|
||||
WarpFragmentScale warp_frag_scales;
|
||||
WarpFragmentZero warp_frag_zero;
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
copy_scales_and_advance(iterator_scale);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
|
||||
__syncthreads();
|
||||
|
||||
warp_dequantizer_.load(warp_frag_scales, warp_frag_zero);
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpFragmentA warp_frag_A[2];
|
||||
WarpFragmentB warp_frag_B[2];
|
||||
WarpFragmentScale warp_frag_scales;
|
||||
WarpFragmentZero warp_frag_zero;
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
|
||||
|
||||
warp_dequantizer_.load(warp_frag_scales, warp_frag_zero);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
warp_dequantizer_.add_pointer_offset(Shape::kN);
|
||||
|
||||
Operator warp_mma;
|
||||
|
||||
int smem_write_stage_idx = 1;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 1);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 1);
|
||||
iterator_scale.clear_mask(gemm_k_iterations <= 1);
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER*
|
||||
// issuing shared memory loads (which have the tighest latency requirement).
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > 0; --gemm_k_iterations) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations;
|
||||
++warp_mma_k) {
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations - 1) {
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_A_.store(transformA(tb_frag_A));
|
||||
|
||||
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1) {
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0});
|
||||
} else {
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{0,
|
||||
-Base::kStages * Policy::kPartitionsK *
|
||||
Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK *
|
||||
Base::kWarpGemmIterationsForB,
|
||||
0});
|
||||
warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN);
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) %
|
||||
Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
warp_dequantizer_.add_pointer_offset(Shape::kN);
|
||||
|
||||
Operator warp_mma;
|
||||
|
||||
int smem_write_stage_idx = 1;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 1);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 1);
|
||||
iterator_scale.clear_mask(gemm_k_iterations <= 1);
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > 0; --gemm_k_iterations)
|
||||
{
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
|
||||
{
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
||||
// as the case may be.
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations - 1)
|
||||
{
|
||||
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_A_.store(transformA(tb_frag_A));
|
||||
|
||||
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1)
|
||||
{
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
|
||||
warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN);
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
||||
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
||||
// We are just about to finish computing on a fragment of B, so initiate the load for the next fragment.
|
||||
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
|
||||
{
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(
|
||||
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
}
|
||||
|
||||
if (warp_mma_k == 0)
|
||||
{
|
||||
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B.load(tb_frag_B);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B;
|
||||
|
||||
copy_scales_and_advance(iterator_scale);
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 2);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 2);
|
||||
iterator_scale.clear_mask(gemm_k_iterations <= 2);
|
||||
}
|
||||
|
||||
typename TransformBAfterLDS::result_type converted_frag_B
|
||||
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
|
||||
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zero);
|
||||
run_warp_mma(
|
||||
warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
|
||||
}
|
||||
|
||||
// Load the scales needed for the next tile iteration
|
||||
warp_dequantizer_.load(warp_frag_scales, warp_frag_zero);
|
||||
// Update internal pointer to the set of scales in shared memory
|
||||
warp_dequantizer_.add_pointer_offset(Shape::kN);
|
||||
int const warp_tileB_k_compute_offset =
|
||||
warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
||||
int const warp_tileB_k_load_offset =
|
||||
warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
||||
// We are just about to finish computing on a fragment of B, so initiate
|
||||
// the load for the next fragment.
|
||||
if (warp_tileB_k_compute_offset ==
|
||||
Base::kNumKIterationsPerWarpBLoad - 1) {
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(
|
||||
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
|
||||
this->warp_tile_iterator_B_.load(
|
||||
warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
}
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B.load(tb_frag_B);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B;
|
||||
|
||||
copy_scales_and_advance(iterator_scale);
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 2);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 2);
|
||||
iterator_scale.clear_mask(gemm_k_iterations <= 2);
|
||||
}
|
||||
|
||||
typename TransformBAfterLDS::result_type converted_frag_B =
|
||||
lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
|
||||
warp_dequantizer_.dequantize(
|
||||
converted_frag_B, warp_frag_scales, warp_frag_zero);
|
||||
run_warp_mma(warp_mma,
|
||||
accum,
|
||||
warp_frag_A[warp_mma_k % 2],
|
||||
converted_frag_B,
|
||||
accum,
|
||||
warp_tileB_k_compute_offset);
|
||||
}
|
||||
|
||||
// Load the scales needed for the next tile iteration
|
||||
warp_dequantizer_.load(warp_frag_scales, warp_frag_zero);
|
||||
// Update internal pointer to the set of scales in shared memory
|
||||
warp_dequantizer_.add_pointer_offset(Shape::kN);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
+356
-300
@@ -1,12 +1,12 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
@@ -18,14 +18,15 @@
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
@@ -53,27 +54,27 @@
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
@@ -94,306 +95,361 @@ template <
|
||||
typename TransformBAfterLDS_,
|
||||
/// The quantization operator being used
|
||||
WeightOnlyQuantOp QuantOp_>
|
||||
class DqMmaPipelined<Shape_, IteratorA_, SmemIteratorA_, IteratorB_, SmemIteratorB_, IteratorScale_, SmemIteratorScale_,
|
||||
ElementC_, LayoutC_, Policy_, TransformBAfterLDG_, TransformBAfterLDS_, QuantOp_,
|
||||
std::enable_if_t<!isFinegrained(QuantOp_)>>
|
||||
: public DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>
|
||||
{
|
||||
public:
|
||||
///< Base class
|
||||
using Base = DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>;
|
||||
class DqMmaPipelined<Shape_,
|
||||
IteratorA_,
|
||||
SmemIteratorA_,
|
||||
IteratorB_,
|
||||
SmemIteratorB_,
|
||||
IteratorScale_,
|
||||
SmemIteratorScale_,
|
||||
ElementC_,
|
||||
LayoutC_,
|
||||
Policy_,
|
||||
TransformBAfterLDG_,
|
||||
TransformBAfterLDS_,
|
||||
QuantOp_,
|
||||
std::enable_if_t<!isFinegrained(QuantOp_)>>
|
||||
: public DqMmaBase<Shape_,
|
||||
Policy_,
|
||||
typename SmemIteratorScale_::Element,
|
||||
2,
|
||||
QuantOp_> {
|
||||
public:
|
||||
///< Base class
|
||||
using Base = DqMmaBase<Shape_,
|
||||
Policy_,
|
||||
typename SmemIteratorScale_::Element,
|
||||
2,
|
||||
QuantOp_>;
|
||||
|
||||
using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
|
||||
using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
|
||||
using ElementC = ElementC_; ///< Data type of accumulator matrix
|
||||
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
|
||||
using Policy = Policy_; ///< Policy describing tuning details
|
||||
using Shape =
|
||||
Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorA =
|
||||
IteratorA_; ///< Iterates over tiles of A operand in global memory
|
||||
using IteratorB =
|
||||
IteratorB_; ///< Iterates over tiles of B operand in global memory
|
||||
using ElementC = ElementC_; ///< Data type of accumulator matrix
|
||||
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
|
||||
using Policy = Policy_; ///< Policy describing tuning details
|
||||
|
||||
using IteratorScale = IteratorScale_;
|
||||
using ElementScale = typename IteratorScale::Element;
|
||||
using LayoutScale = typename IteratorScale::Layout;
|
||||
using IteratorScale = IteratorScale_;
|
||||
using ElementScale = typename IteratorScale::Element;
|
||||
using LayoutScale = typename IteratorScale::Layout;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
using SmemIteratorScale = SmemIteratorScale_;
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
using SmemIteratorScale = SmemIteratorScale_;
|
||||
|
||||
using TransformBAfterLDG = TransformBAfterLDG_;
|
||||
using TransformBAfterLDS = TransformBAfterLDS_;
|
||||
using TransformBAfterLDG = TransformBAfterLDG_;
|
||||
using TransformBAfterLDS = TransformBAfterLDS_;
|
||||
|
||||
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
||||
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of operand A loaded from global memory
|
||||
using FragmentA = typename IteratorA::Fragment;
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB = typename IteratorB::Fragment;
|
||||
|
||||
/// Fragment of operand Scale loaded from global memory;
|
||||
using FragmentScale = typename IteratorScale::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Obtain the arch tag from the warp-level operator
|
||||
using ArchTag = typename Policy::Operator::ArchTag;
|
||||
|
||||
using Dequantizer = warp::MmaTensorOpDequantizer<
|
||||
Operator,
|
||||
typename Base::WarpGemm,
|
||||
Operand::kB,
|
||||
typename SmemIteratorScale::Fragment::Element,
|
||||
LayoutScale,
|
||||
32,
|
||||
QuantOp>;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = Operator::kTransformB;
|
||||
|
||||
// staticaly assert kStages for DqMmaPipelined is two (Double-buffered
|
||||
// pipeline)
|
||||
static_assert((Base::kStages == 2),
|
||||
"DqMmaPipelined requires kStages set to value 2");
|
||||
|
||||
private:
|
||||
using WarpFragmentA = typename Operator::FragmentA;
|
||||
using WarpFragmentB = typename Operator::FragmentB;
|
||||
Dequantizer warp_dequantizer_;
|
||||
|
||||
using ElementA = typename IteratorA::Element;
|
||||
using ElementB = typename IteratorB::Element;
|
||||
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
|
||||
|
||||
static constexpr bool RequiresTileInterleave =
|
||||
layout::IsColumnMajorTileInterleave<
|
||||
typename LayoutDetailsForB::Layout>::value;
|
||||
static_assert(!RequiresTileInterleave ||
|
||||
(RequiresTileInterleave &&
|
||||
(Shape::kK == LayoutDetailsForB::ThreadblockK)),
|
||||
"Layout K must match threadblockK");
|
||||
|
||||
protected:
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of scale operand to shared
|
||||
/// memory
|
||||
SmemIteratorScale smem_iterator_scale_;
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
DqMmaPipelined(
|
||||
typename Base::SharedStorage&
|
||||
shared_storage, ///< Shared storage needed for internal use by
|
||||
///< threadblock-scoped GEMM
|
||||
int const
|
||||
group_size, ///< Will not be used, just to adapt to finegrained
|
||||
///< modifications and make the compilation successful.
|
||||
///< Because DqMmaPipelined is only enabled for sm<80, so
|
||||
///< even if this argument is not added, it does not
|
||||
///< affect compilation for sm>=80.
|
||||
int thread_idx, ///< ID within the threadblock
|
||||
int warp_idx, ///< ID of warp
|
||||
int lane_idx ///< ID of each thread within a warp
|
||||
)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
warp_dequantizer_(
|
||||
{shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
|
||||
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) /
|
||||
Base::WarpCount::kM,
|
||||
lane_idx),
|
||||
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
|
||||
smem_iterator_scale_(LayoutScale(Shape::kN),
|
||||
shared_storage.operand_scale.data(),
|
||||
{1, Shape::kN},
|
||||
thread_idx) {
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
int gemm_k_iterations, ///< number of iterations of the mainloop
|
||||
FragmentC& accum, ///< destination accumulator tile
|
||||
IteratorA iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB iterator_B, ///< iterator over B operand in global memory
|
||||
IteratorScale
|
||||
iterator_scale, ///< iterator over scale operand in global memory
|
||||
FragmentC const& src_accum) { ///< source accumulator tile
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
// Prologue
|
||||
//
|
||||
TransformBAfterLDG ldg_converter;
|
||||
TransformBAfterLDS lds_converter;
|
||||
|
||||
using TransformA = NumericArrayConverter<typename WarpFragmentA::Element,
|
||||
typename FragmentA::Element,
|
||||
FragmentA::kElements>;
|
||||
|
||||
using TransformScale =
|
||||
NumericArrayConverter<typename SmemIteratorScale::Fragment::Element,
|
||||
typename FragmentScale::Element,
|
||||
FragmentScale::kElements>;
|
||||
|
||||
// These transforms are mainly to handle when we have bfloat activations and
|
||||
// weights in GMEM and want to issue HMMA on architectures older than
|
||||
// Ampere. We will convert to FP16 before STS.
|
||||
TransformA transformA;
|
||||
TransformScale transformScale;
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
accum = src_accum;
|
||||
|
||||
FragmentA tb_frag_A;
|
||||
FragmentB tb_frag_B;
|
||||
FragmentScale tb_frag_scales;
|
||||
|
||||
using WarpFragmentScale = typename Dequantizer::FragmentScale;
|
||||
WarpFragmentScale warp_frag_scales;
|
||||
|
||||
tb_frag_A.clear();
|
||||
tb_frag_B.clear();
|
||||
tb_frag_scales.clear();
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B.load(tb_frag_B);
|
||||
iterator_scale.load(tb_frag_scales);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B;
|
||||
|
||||
this->smem_iterator_A_.store(transformA(tb_frag_A));
|
||||
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
|
||||
this->smem_iterator_scale_.store(transformScale(tb_frag_scales));
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
warp_dequantizer_.load(warp_frag_scales);
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpFragmentA warp_frag_A[2];
|
||||
WarpFragmentB warp_frag_B[2];
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
Operator warp_mma;
|
||||
|
||||
int smem_write_stage_idx = 1;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 1);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 1);
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER*
|
||||
// issuing shared memory loads (which have the tighest latency requirement).
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
/// Fragment of operand A loaded from global memory
|
||||
using FragmentA = typename IteratorA::Fragment;
|
||||
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > 0; --gemm_k_iterations) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB = typename IteratorB::Fragment;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations;
|
||||
++warp_mma_k) {
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
|
||||
/// Fragment of operand Scale loaded from global memory;
|
||||
using FragmentScale = typename IteratorScale::Fragment;
|
||||
if (warp_mma_k == Base::kWarpGemmIterations - 1) {
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_A_.store(transformA(tb_frag_A));
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
__syncthreads();
|
||||
|
||||
/// Obtain the arch tag from the warp-level operator
|
||||
using ArchTag = typename Policy::Operator::ArchTag;
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
using Dequantizer = warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB,
|
||||
typename SmemIteratorScale::Fragment::Element, LayoutScale, 32, QuantOp>;
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1) {
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
} else {
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{0,
|
||||
-Base::kStages * Policy::kPartitionsK *
|
||||
Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK *
|
||||
Base::kWarpGemmIterationsForB,
|
||||
0});
|
||||
}
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = Operator::kTransformB;
|
||||
|
||||
// staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline)
|
||||
static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2");
|
||||
|
||||
private:
|
||||
using WarpFragmentA = typename Operator::FragmentA;
|
||||
using WarpFragmentB = typename Operator::FragmentB;
|
||||
Dequantizer warp_dequantizer_;
|
||||
|
||||
using ElementA = typename IteratorA::Element;
|
||||
using ElementB = typename IteratorB::Element;
|
||||
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
|
||||
|
||||
static constexpr bool RequiresTileInterleave
|
||||
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
|
||||
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
|
||||
"Layout K must match threadblockK");
|
||||
|
||||
protected:
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of scale operand to shared memory
|
||||
SmemIteratorScale smem_iterator_scale_;
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
DqMmaPipelined(typename Base::SharedStorage&
|
||||
shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
int const group_size, ///< Will not be used, just to adapt to finegrained modifications and make the compilation
|
||||
///< successful. Because DqMmaPipelined is only enabled for sm<80, so even if this
|
||||
///< argument is not added, it does not affect compilation for sm>=80.
|
||||
int thread_idx, ///< ID within the threadblock
|
||||
int warp_idx, ///< ID of warp
|
||||
int lane_idx ///< ID of each thread within a warp
|
||||
)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx)
|
||||
, warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
|
||||
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx)
|
||||
, smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx)
|
||||
, smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
|
||||
, smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx)
|
||||
{
|
||||
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop
|
||||
FragmentC& accum, ///< destination accumulator tile
|
||||
IteratorA iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB iterator_B, ///< iterator over B operand in global memory
|
||||
IteratorScale iterator_scale, ///< iterator over scale operand in global memory
|
||||
FragmentC const& src_accum)
|
||||
{ ///< source accumulator tile
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
TransformBAfterLDG ldg_converter;
|
||||
TransformBAfterLDS lds_converter;
|
||||
|
||||
using TransformA
|
||||
= NumericArrayConverter<typename WarpFragmentA::Element, typename FragmentA::Element, FragmentA::kElements>;
|
||||
|
||||
using TransformScale = NumericArrayConverter<typename SmemIteratorScale::Fragment::Element,
|
||||
typename FragmentScale::Element, FragmentScale::kElements>;
|
||||
|
||||
// These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want
|
||||
// to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS.
|
||||
TransformA transformA;
|
||||
TransformScale transformScale;
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
accum = src_accum;
|
||||
|
||||
FragmentA tb_frag_A;
|
||||
FragmentB tb_frag_B;
|
||||
FragmentScale tb_frag_scales;
|
||||
|
||||
using WarpFragmentScale = typename Dequantizer::FragmentScale;
|
||||
WarpFragmentScale warp_frag_scales;
|
||||
|
||||
tb_frag_A.clear();
|
||||
tb_frag_B.clear();
|
||||
tb_frag_scales.clear();
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B.load(tb_frag_B);
|
||||
iterator_scale.load(tb_frag_scales);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B;
|
||||
|
||||
this->smem_iterator_A_.store(transformA(tb_frag_A));
|
||||
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
|
||||
this->smem_iterator_scale_.store(transformScale(tb_frag_scales));
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
warp_dequantizer_.load(warp_frag_scales);
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math instructions
|
||||
WarpFragmentA warp_frag_A[2];
|
||||
WarpFragmentB warp_frag_B[2];
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
Operator warp_mma;
|
||||
|
||||
int smem_write_stage_idx = 1;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 1);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 1);
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > 0; --gemm_k_iterations)
|
||||
{
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
|
||||
{
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
||||
// as the case may be.
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations - 1)
|
||||
{
|
||||
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_A_.store(transformA(tb_frag_A));
|
||||
|
||||
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1)
|
||||
{
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
||||
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
||||
// We are just about to finish computing on a fragment of B, so initiate the load for the next fragment.
|
||||
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
|
||||
{
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(
|
||||
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
}
|
||||
|
||||
if (warp_mma_k == 0)
|
||||
{
|
||||
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B.load(tb_frag_B);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B;
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 2);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 2);
|
||||
}
|
||||
|
||||
typename TransformBAfterLDS::result_type converted_frag_B
|
||||
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
|
||||
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales);
|
||||
run_warp_mma(
|
||||
warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
|
||||
}
|
||||
smem_write_stage_idx ^= 1;
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) %
|
||||
Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
int const warp_tileB_k_compute_offset =
|
||||
warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
||||
int const warp_tileB_k_load_offset =
|
||||
warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
||||
// We are just about to finish computing on a fragment of B, so initiate
|
||||
// the load for the next fragment.
|
||||
if (warp_tileB_k_compute_offset ==
|
||||
Base::kNumKIterationsPerWarpBLoad - 1) {
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(
|
||||
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
|
||||
this->warp_tile_iterator_B_.load(
|
||||
warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
}
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B.load(tb_frag_B);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B;
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 2);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 2);
|
||||
}
|
||||
|
||||
typename TransformBAfterLDS::result_type converted_frag_B =
|
||||
lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
|
||||
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales);
|
||||
run_warp_mma(warp_mma,
|
||||
accum,
|
||||
warp_frag_A[warp_mma_k % 2],
|
||||
converted_frag_B,
|
||||
accum,
|
||||
warp_tileB_k_compute_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -66,7 +66,7 @@ template <
|
||||
/// Size of extra quantized params
|
||||
typename QuantParamsShape>
|
||||
class Wint2xMmaBase {
|
||||
public:
|
||||
public:
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
|
||||
@@ -85,9 +85,9 @@ public:
|
||||
using WarpGemm = typename Policy::Operator::Shape;
|
||||
|
||||
/// Shape describing the number of warps filling the CTA
|
||||
using WarpCount =
|
||||
GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN,
|
||||
Shape::kK / WarpGemm::kK>;
|
||||
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM,
|
||||
Shape::kN / WarpGemm::kN,
|
||||
Shape::kK / WarpGemm::kK>;
|
||||
|
||||
/// Number of warp-level GEMM operations
|
||||
static int const kWarpGemmIterations =
|
||||
@@ -95,7 +95,8 @@ public:
|
||||
|
||||
/// Number of warp-level GEMM operations per load for B
|
||||
static constexpr int kWarpGemmIterationsPerLoadForB =
|
||||
Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK;
|
||||
Operator::IteratorB::InstructionShape::kRow /
|
||||
Operator::InstructionShape::kK;
|
||||
static_assert(!(kWarpGemmIterations % kWarpGemmIterationsPerLoadForB), "");
|
||||
|
||||
static constexpr int kWarpLoadIterationsForB =
|
||||
@@ -125,7 +126,7 @@ public:
|
||||
|
||||
/// Shared storage object needed by threadblock-scoped GEMM
|
||||
class SharedStorage {
|
||||
public:
|
||||
public:
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
@@ -142,7 +143,7 @@ public:
|
||||
/// Shape of all quant params in shared memory
|
||||
using QuantParamsShapeB = QuantParamsShape;
|
||||
|
||||
public:
|
||||
public:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
@@ -156,7 +157,7 @@ public:
|
||||
/// Buffer for extra quant params of B operand
|
||||
AlignedBuffer<uint8_t, QuantParamsShapeB::kCount> operand_quant_params_B;
|
||||
|
||||
public:
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
@@ -186,7 +187,7 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
protected:
|
||||
protected:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
@@ -197,7 +198,7 @@ protected:
|
||||
/// Iterator to load a warp-scoped tile of B operand from shared memory
|
||||
typename Operator::IteratorB warp_tile_iterator_B_;
|
||||
|
||||
public:
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
Wint2xMmaBase(
|
||||
@@ -215,8 +216,8 @@ public:
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
@@ -18,14 +18,15 @@
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
@@ -91,11 +92,17 @@ template <
|
||||
typename QuantParamsAccessor_,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone>
|
||||
class Wint2xMmaMultistage :
|
||||
public Wint2xMmaBase<Shape_, Policy_, Stages, typename QuantParamsAccessor_::QuantParamsShape> {
|
||||
public:
|
||||
class Wint2xMmaMultistage
|
||||
: public Wint2xMmaBase<Shape_,
|
||||
Policy_,
|
||||
Stages,
|
||||
typename QuantParamsAccessor_::QuantParamsShape> {
|
||||
public:
|
||||
///< Base class
|
||||
using Base = Wint2xMmaBase<Shape_, Policy_, Stages, typename QuantParamsAccessor_::QuantParamsShape>;
|
||||
using Base = Wint2xMmaBase<Shape_,
|
||||
Policy_,
|
||||
Stages,
|
||||
typename QuantParamsAccessor_::QuantParamsShape>;
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
@@ -133,17 +140,19 @@ public:
|
||||
/// Minimum architecture is Sm80 to support cp.async
|
||||
using ArchTag = arch::Sm80;
|
||||
|
||||
//using LayoutScale = typename QuantParamsAccessor::IteratorSuperScale::Layout;
|
||||
// using LayoutScale = typename
|
||||
// QuantParamsAccessor::IteratorSuperScale::Layout;
|
||||
using LayoutScale = layout::RowMajor;
|
||||
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
|
||||
using WarpDequantizer =
|
||||
warp::MmaTensorOpWin2xDequantizer<Operator,
|
||||
typename Base::WarpGemm,
|
||||
Operand::kB,
|
||||
typename WarpTransformedFragmentB::Element,
|
||||
LayoutScale,
|
||||
QuantParamsAccessor::kGroupSize>;
|
||||
static_assert(sizeof(WarpDequantizer) > 0, "WarpDequantizer template instantiation failed");
|
||||
using WarpDequantizer = warp::MmaTensorOpWin2xDequantizer<
|
||||
Operator,
|
||||
typename Base::WarpGemm,
|
||||
Operand::kB,
|
||||
typename WarpTransformedFragmentB::Element,
|
||||
LayoutScale,
|
||||
QuantParamsAccessor::kGroupSize>;
|
||||
static_assert(sizeof(WarpDequantizer) > 0,
|
||||
"WarpDequantizer template instantiation failed");
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
@@ -153,7 +162,6 @@ public:
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail {
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand A
|
||||
static int const AsyncCopyIterationsPerStageA =
|
||||
IteratorA::ThreadMap::Iterations::kCount;
|
||||
@@ -167,24 +175,25 @@ public:
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand A
|
||||
static int const kAccessesPerGroupA =
|
||||
(AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
||||
(AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) /
|
||||
Base::kWarpGemmIterations;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB =
|
||||
(AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
||||
(AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) /
|
||||
Base::kWarpGemmIterations;
|
||||
|
||||
// Optional staged-accumulation (e.g., tf32x3 kernels) for improved numerical
|
||||
// accuracy, where each mainloop iteration first accumulates into a temporary
|
||||
// set of freshly-cleared accumulators, which are subsequently added to the
|
||||
// final accumulator set.
|
||||
static bool const kStagedAccumulation = arch::detail::UseStagedAccumulation<Operator>::value;
|
||||
// Optional staged-accumulation (e.g., tf32x3 kernels) for improved
|
||||
// numerical accuracy, where each mainloop iteration first accumulates into
|
||||
// a temporary set of freshly-cleared accumulators, which are subsequently
|
||||
// added to the final accumulator set.
|
||||
static bool const kStagedAccumulation =
|
||||
arch::detail::UseStagedAccumulation<Operator>::value;
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
// Structure encapsulating pipeline state live from one iteration to the next
|
||||
struct PipeState {
|
||||
|
||||
using WarpLoadedFragmentA = typename Operator::FragmentA;
|
||||
using WarpLoadedFragmentB = typename Operator::FragmentB;
|
||||
using WarpTransformedFragmentA = typename Operator::TransformedFragmentA;
|
||||
@@ -197,10 +206,12 @@ public:
|
||||
/// Temporary accumulator to facilitate staged-accumulation
|
||||
FragmentC tmp_accum_;
|
||||
|
||||
/// Pair of A fragments used to overlap shared memory loads and math instructions
|
||||
/// Pair of A fragments used to overlap shared memory loads and math
|
||||
/// instructions
|
||||
WarpTransformedFragmentA warp_frag_A_[2];
|
||||
|
||||
/// Pair of B fragments used to overlap shared memory loads and math instructions
|
||||
/// Pair of B fragments used to overlap shared memory loads and math
|
||||
/// instructions
|
||||
WarpLoadedFragmentB warp_loaded_frag_B_;
|
||||
WarpTransformedFragmentB warp_frag_B_[2];
|
||||
|
||||
@@ -218,12 +229,14 @@ public:
|
||||
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
|
||||
|
||||
static constexpr bool IsTileInterleaveLayout =
|
||||
layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
|
||||
static_assert(!IsTileInterleaveLayout || (IsTileInterleaveLayout && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
|
||||
"Layout K must match threadblockK");
|
||||
layout::IsColumnMajorTileInterleave<
|
||||
typename LayoutDetailsForB::Layout>::value;
|
||||
static_assert(!IsTileInterleaveLayout ||
|
||||
(IsTileInterleaveLayout &&
|
||||
(Shape::kK == LayoutDetailsForB::ThreadblockK)),
|
||||
"Layout K must match threadblockK");
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
@@ -249,8 +262,7 @@ public:
|
||||
/// Shared memory read stage index
|
||||
int smem_read_stage_idx_;
|
||||
|
||||
public:
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
Wint2xMmaMultistage(
|
||||
@@ -261,19 +273,24 @@ public:
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx
|
||||
) : Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
|
||||
quant_params_accessor_B_(shared_storage.operand_quant_params_B.data(), thread_idx, warp_idx, lane_idx),
|
||||
warp_dequantizer_(quant_params_accessor_B_.super_scale_ref(),
|
||||
quant_params_accessor_B_.local_scale_ref(),
|
||||
quant_params_accessor_B_.code_scale_ref(),
|
||||
quant_params_accessor_B_.code_zp_ref(),
|
||||
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx),
|
||||
smem_write_stage_idx_(0),
|
||||
smem_read_stage_idx_(0)
|
||||
{
|
||||
int lane_idx)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
|
||||
quant_params_accessor_B_(shared_storage.operand_quant_params_B.data(),
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx),
|
||||
warp_dequantizer_(
|
||||
quant_params_accessor_B_.super_scale_ref(),
|
||||
quant_params_accessor_B_.local_scale_ref(),
|
||||
quant_params_accessor_B_.code_scale_ref(),
|
||||
quant_params_accessor_B_.code_zp_ref(),
|
||||
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) /
|
||||
Base::WarpCount::kM,
|
||||
lane_idx),
|
||||
smem_write_stage_idx_(0),
|
||||
smem_read_stage_idx_(0) {
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
@@ -295,22 +312,26 @@ public:
|
||||
|
||||
/// Advance shared memory read-iterators to the next stage
|
||||
CUTLASS_DEVICE
|
||||
void advance_smem_read_stage()
|
||||
{
|
||||
void advance_smem_read_stage() {
|
||||
++smem_read_stage_idx_;
|
||||
|
||||
if (smem_read_stage_idx_ == Base::kStages) {
|
||||
// Wrap back around to the 'start' of the circular buffer in shared memory
|
||||
this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpLoadIterationsForB, 0});
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{0,
|
||||
-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK *
|
||||
Base::kWarpLoadIterationsForB,
|
||||
0});
|
||||
smem_read_stage_idx_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Advance global memory read-iterators and shared memory write-iterators to the stage
|
||||
/// Advance global memory read-iterators and shared memory write-iterators to
|
||||
/// the stage
|
||||
CUTLASS_DEVICE
|
||||
void advance_smem_write_stage(IteratorA &iterator_A, IteratorB &iterator_B)
|
||||
{
|
||||
void advance_smem_write_stage(IteratorA &iterator_A, IteratorB &iterator_B) {
|
||||
// Advance global iterators
|
||||
iterator_A.add_tile_offset({0, 1});
|
||||
iterator_B.add_tile_offset({1, 0});
|
||||
@@ -395,7 +416,9 @@ public:
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_B.get();
|
||||
bool is_valid = (threadIdx.x < IteratorB::ThreadMap::kThreads) ? iterator_B.valid() : false;
|
||||
bool is_valid = (threadIdx.x < IteratorB::ThreadMap::kThreads)
|
||||
? iterator_B.valid()
|
||||
: false;
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
@@ -429,10 +452,9 @@ public:
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_A.get();
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorA::Element>::value *
|
||||
IteratorA::ThreadMap::kElementsPerAccess /
|
||||
IteratorA::kAccessesPerVector / 8;
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
|
||||
IteratorA::ThreadMap::kElementsPerAccess /
|
||||
IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, iterator_A.get(), iterator_A.valid());
|
||||
@@ -464,10 +486,9 @@ public:
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_B.get();
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess /
|
||||
IteratorB::kAccessesPerVector / 8;
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess /
|
||||
IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, iterator_B.get(), iterator_B.valid());
|
||||
@@ -480,18 +501,22 @@ public:
|
||||
}
|
||||
|
||||
/// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching
|
||||
/// the global fragments needed by the first kStages-1 threadblock mainloop iterations
|
||||
/// the global fragments needed by the first kStages-1 threadblock mainloop
|
||||
/// iterations
|
||||
CUTLASS_DEVICE
|
||||
void prologue(
|
||||
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
|
||||
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
|
||||
QuantArguments &mma_quant_args, ///< iterators for extra quant params for B
|
||||
int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining
|
||||
void prologue(IteratorA &iterator_A, ///< [in|out] iterator over A operand in
|
||||
///< global memory
|
||||
IteratorB &iterator_B, ///< [in|out] iterator over B operand in
|
||||
///< global memory
|
||||
QuantArguments &
|
||||
mma_quant_args, ///< iterators for extra quant params for B
|
||||
int &gemm_k_iterations) ///< [in|out] number of threadblock
|
||||
///< mainloop iterations remaining
|
||||
{
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) {
|
||||
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations) {
|
||||
// Disable global fetching if done with global fetch iterations
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
@@ -502,11 +527,14 @@ public:
|
||||
// Async copy zipped B to shared memory.
|
||||
copy_tiles_and_advance_per_stage_B(iterator_B);
|
||||
|
||||
// Async copy other quantized params to shared memory, local_scale, code_scale, code_zp, super_scale.
|
||||
// Async copy other quantized params to shared memory, local_scale,
|
||||
// code_scale, code_zp, super_scale.
|
||||
if (stage == 0) {
|
||||
quant_params_accessor_B_.copy_tiles_and_advance_per_stage<true>(mma_quant_args, stage);
|
||||
quant_params_accessor_B_.copy_tiles_and_advance_per_stage<true>(
|
||||
mma_quant_args, stage);
|
||||
} else {
|
||||
quant_params_accessor_B_.copy_tiles_and_advance_per_stage<false>(mma_quant_args, stage);
|
||||
quant_params_accessor_B_.copy_tiles_and_advance_per_stage<false>(
|
||||
mma_quant_args, stage);
|
||||
}
|
||||
|
||||
// Move to the next write stage
|
||||
@@ -517,11 +545,12 @@ public:
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// Optionally clear the remaining stages of SMEM. This is a functional requirement for
|
||||
// some kernels so that all accumulator elements outside the GEMM footprint are zero.
|
||||
// Optionally clear the remaining stages of SMEM. This is a functional
|
||||
// requirement for some kernels so that all accumulator elements outside the
|
||||
// GEMM footprint are zero.
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) {
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared
|
||||
/// memory
|
||||
SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
|
||||
typename IteratorA::AccessType zero_A;
|
||||
|
||||
@@ -531,7 +560,6 @@ public:
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
|
||||
|
||||
typename IteratorA::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA::AccessType *>(
|
||||
last_smem_iterator_A.get());
|
||||
@@ -545,7 +573,8 @@ public:
|
||||
return;
|
||||
}
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared
|
||||
/// memory
|
||||
SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
|
||||
typename IteratorB::AccessType zero_B;
|
||||
|
||||
@@ -555,7 +584,6 @@ public:
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
|
||||
|
||||
typename IteratorB::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||
last_smem_iterator_B.get());
|
||||
@@ -569,9 +597,9 @@ public:
|
||||
|
||||
/// Wait until we have at least one completed global fetch stage
|
||||
CUTLASS_DEVICE
|
||||
void gmem_wait()
|
||||
{
|
||||
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed)
|
||||
void gmem_wait() {
|
||||
// Wait until we have at least one committed global fetch stage.
|
||||
// (#uncommitted = Base::kStages - 1 - #committed)
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
}
|
||||
@@ -579,25 +607,31 @@ public:
|
||||
/// Perform a threadblock mainloop iteration of matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void mac_loop_iter(
|
||||
PipeState &pipe_state, ///< [in|out] loop-carried pipeline state
|
||||
FragmentC &accum, ///< [in|out] destination accumulator tile
|
||||
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
|
||||
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
|
||||
QuantArguments &mma_quant_args, ///< iterators for extra quant params for B
|
||||
int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop iterations remaining
|
||||
int stage)
|
||||
{
|
||||
PipeState &pipe_state, ///< [in|out] loop-carried pipeline state
|
||||
FragmentC &accum, ///< [in|out] destination accumulator tile
|
||||
IteratorA
|
||||
&iterator_A, ///< [in|out] iterator over A operand in global memory
|
||||
IteratorB
|
||||
&iterator_B, ///< [in|out] iterator over B operand in global memory
|
||||
QuantArguments
|
||||
&mma_quant_args, ///< iterators for extra quant params for B
|
||||
int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop
|
||||
///< iterations remaining
|
||||
int stage) {
|
||||
const int mma_stage = stage - Base::kStages + 1;
|
||||
|
||||
// Unroll the warp-level MMA tiles of a threadblock's mainloop iteration
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
|
||||
|
||||
int warp_k_compute_offset_B = warp_mma_k % Base::kWarpGemmIterationsPerLoadForB;
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations;
|
||||
++warp_mma_k) {
|
||||
int warp_k_compute_offset_B =
|
||||
warp_mma_k % Base::kWarpGemmIterationsPerLoadForB;
|
||||
|
||||
if (warp_k_compute_offset_B == Base::kWarpGemmIterationsPerLoadForB - 1) {
|
||||
// Load the next warp-tile's B fragment from shared memory
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(((warp_mma_k + 1) % Base::kWarpGemmIterations) / Base::kWarpLoadIterationsForB);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(
|
||||
((warp_mma_k + 1) % Base::kWarpGemmIterations) /
|
||||
Base::kWarpLoadIterationsForB);
|
||||
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_);
|
||||
++this->warp_tile_iterator_B_;
|
||||
}
|
||||
@@ -608,28 +642,31 @@ public:
|
||||
}
|
||||
|
||||
// Load the next warp-tile's A fragment from shared memory
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_A_.load(pipe_state.warp_frag_A_[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) %
|
||||
Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_A_.load(
|
||||
pipe_state.warp_frag_A_[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
// dequantizes next warp-tile
|
||||
warp_dequantizer_.dequantize(pipe_state.warp_frag_local_scale_,
|
||||
pipe_state.warp_frag_code_scale_,
|
||||
pipe_state.warp_frag_code_zp_,
|
||||
pipe_state.warp_frag_super_scale_,
|
||||
pipe_state.warp_loaded_frag_B_,
|
||||
pipe_state.warp_frag_B_[(warp_mma_k + 1) % 2],
|
||||
((warp_mma_k == Base::kWarpGemmIterations - 1) ? (mma_stage + 1) : mma_stage) * Shape::kK,
|
||||
(warp_mma_k + 1) % Base::kWarpGemmIterationsPerLoadForB);
|
||||
warp_dequantizer_.dequantize(
|
||||
pipe_state.warp_frag_local_scale_,
|
||||
pipe_state.warp_frag_code_scale_,
|
||||
pipe_state.warp_frag_code_zp_,
|
||||
pipe_state.warp_frag_super_scale_,
|
||||
pipe_state.warp_loaded_frag_B_,
|
||||
pipe_state.warp_frag_B_[(warp_mma_k + 1) % 2],
|
||||
((warp_mma_k == Base::kWarpGemmIterations - 1) ? (mma_stage + 1)
|
||||
: mma_stage) *
|
||||
Shape::kK,
|
||||
(warp_mma_k + 1) % Base::kWarpGemmIterationsPerLoadForB);
|
||||
|
||||
// Execute the current warp-tile of MMA operations
|
||||
if constexpr (Detail::kStagedAccumulation) {
|
||||
warp_mma_(
|
||||
pipe_state.tmp_accum_,
|
||||
pipe_state.warp_frag_A_[warp_mma_k % 2],
|
||||
pipe_state.warp_frag_B_[warp_mma_k % 2],
|
||||
pipe_state.tmp_accum_
|
||||
);
|
||||
warp_mma_(pipe_state.tmp_accum_,
|
||||
pipe_state.warp_frag_A_[warp_mma_k % 2],
|
||||
pipe_state.warp_frag_B_[warp_mma_k % 2],
|
||||
pipe_state.tmp_accum_);
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
plus<FragmentC> plus_accum;
|
||||
@@ -637,11 +674,10 @@ public:
|
||||
pipe_state.tmp_accum_.clear();
|
||||
}
|
||||
} else {
|
||||
warp_mma_(
|
||||
accum,
|
||||
pipe_state.warp_frag_A_[warp_mma_k % 2],
|
||||
pipe_state.warp_frag_B_[warp_mma_k % 2],
|
||||
accum);
|
||||
warp_mma_(accum,
|
||||
pipe_state.warp_frag_A_[warp_mma_k % 2],
|
||||
pipe_state.warp_frag_B_[warp_mma_k % 2],
|
||||
accum);
|
||||
}
|
||||
|
||||
// Except for the last warp-tile, all warp-tiles issue their share of
|
||||
@@ -654,22 +690,28 @@ public:
|
||||
copy_tiles_and_advance_B(iterator_B, group_start_iteration_B);
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
quant_params_accessor_B_.copy_tiles_and_advance_per_stage<false>(mma_quant_args, stage);
|
||||
quant_params_accessor_B_.copy_tiles_and_advance_per_stage<false>(
|
||||
mma_quant_args, stage);
|
||||
}
|
||||
}
|
||||
|
||||
// The second-to-last warp-tile also:
|
||||
// - performs the last warp-tile's share of global->shared fragment copies
|
||||
// - performs the last warp-tile's share of global->shared fragment
|
||||
// copies
|
||||
// - moves to the next global fetch stage
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
|
||||
// Performs the last warp-tile's share of global->shared fragment copies
|
||||
if constexpr (Detail::AsyncCopyIterationsPerStageA >= Base::kWarpGemmIterations) {
|
||||
int group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
|
||||
if constexpr (Detail::AsyncCopyIterationsPerStageA >=
|
||||
Base::kWarpGemmIterations) {
|
||||
int group_start_iteration_A =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupA;
|
||||
copy_tiles_and_advance_A(iterator_A, group_start_iteration_A);
|
||||
}
|
||||
|
||||
if constexpr (Detail::AsyncCopyIterationsPerStageB >= Base::kWarpGemmIterations) {
|
||||
int group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB;
|
||||
if constexpr (Detail::AsyncCopyIterationsPerStageB >=
|
||||
Base::kWarpGemmIterations) {
|
||||
int group_start_iteration_B =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupB;
|
||||
copy_tiles_and_advance_B(iterator_B, group_start_iteration_B);
|
||||
}
|
||||
|
||||
@@ -691,7 +733,8 @@ public:
|
||||
--gemm_k_iterations;
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
quant_params_accessor_B_.clear_mask(mma_quant_args, gemm_k_iterations == 0);
|
||||
quant_params_accessor_B_.clear_mask(mma_quant_args,
|
||||
gemm_k_iterations == 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -700,12 +743,13 @@ public:
|
||||
/// multiply-accumulate. Assumes prologue has been initiated.
|
||||
CUTLASS_DEVICE
|
||||
void gemm_iters(
|
||||
int gemm_k_iterations, ///< number of threadblock mainloop iterations
|
||||
FragmentC &accum, ///< [in|out] accumulator tile
|
||||
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
|
||||
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
|
||||
QuantArguments &mma_quant_args)
|
||||
{
|
||||
int gemm_k_iterations, ///< number of threadblock mainloop iterations
|
||||
FragmentC &accum, ///< [in|out] accumulator tile
|
||||
IteratorA
|
||||
&iterator_A, ///< [in|out] iterator over A operand in global memory
|
||||
IteratorB
|
||||
&iterator_B, ///< [in|out] iterator over B operand in global memory
|
||||
QuantArguments &mma_quant_args) {
|
||||
PipeState pipe_state;
|
||||
|
||||
// Disable global fetching if done with global fetch iterations
|
||||
@@ -748,14 +792,13 @@ public:
|
||||
// Mainloop
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > (-Base::kStages + 1);) {
|
||||
mac_loop_iter(
|
||||
pipe_state,
|
||||
accum,
|
||||
iterator_A,
|
||||
iterator_B,
|
||||
mma_quant_args,
|
||||
gemm_k_iterations,
|
||||
stage);
|
||||
mac_loop_iter(pipe_state,
|
||||
accum,
|
||||
iterator_A,
|
||||
iterator_B,
|
||||
mma_quant_args,
|
||||
gemm_k_iterations,
|
||||
stage);
|
||||
stage += 1;
|
||||
}
|
||||
|
||||
@@ -764,7 +807,8 @@ public:
|
||||
accum = plus_accum(accum, pipe_state.tmp_accum_);
|
||||
}
|
||||
|
||||
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
|
||||
// Commit and drain all pending and predicated cp.async pnz from the GEMM
|
||||
// mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
@@ -772,15 +816,16 @@ public:
|
||||
|
||||
/// Prepares the class for another prologue.
|
||||
CUTLASS_DEVICE
|
||||
void wind_down()
|
||||
{
|
||||
// Catch-up the smem-read iterator to the smem-write iterator (so this class can be reused for another tile's prologue)
|
||||
void wind_down() {
|
||||
// Catch-up the smem-read iterator to the smem-write iterator (so this class can
|
||||
// be reused for another tile's prologue)
|
||||
|
||||
// First, increment remaining warp tiles to get to the next full stage. (Ideally we would
|
||||
// just decrement one tile, but not all iterators implement --() decrement.)
|
||||
#pragma unroll
|
||||
for (int warp_mma_k = 1; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
|
||||
{
|
||||
// First, increment remaining warp tiles to get to the next full stage. (Ideally
|
||||
// we would just decrement one tile, but not all iterators implement --()
|
||||
// decrement.)
|
||||
#pragma unroll
|
||||
for (int warp_mma_k = 1; warp_mma_k < Base::kWarpGemmIterations;
|
||||
++warp_mma_k) {
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(warp_mma_k);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(warp_mma_k);
|
||||
|
||||
@@ -789,22 +834,24 @@ public:
|
||||
}
|
||||
smem_read_stage_idx_++;
|
||||
|
||||
// Then wrap back two full stages (one for the tile advancing we just did, and one to catch the write iterators)
|
||||
static const int kStageIters = Policy::kPartitionsK * Base::kWarpGemmIterations;
|
||||
if (smem_read_stage_idx_ > 1)
|
||||
{
|
||||
// Then wrap back two full stages (one for the tile advancing we just did,
|
||||
// and one to catch the write iterators)
|
||||
static const int kStageIters =
|
||||
Policy::kPartitionsK * Base::kWarpGemmIterations;
|
||||
if (smem_read_stage_idx_ > 1) {
|
||||
this->warp_tile_iterator_A_.add_tile_offset({0, (-2 * kStageIters)});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({(-2 * kStageIters), 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
this->warp_tile_iterator_A_.add_tile_offset({0, ((Base::kStages - 2) * kStageIters)});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0});
|
||||
} else {
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{0, ((Base::kStages - 2) * kStageIters)});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{((Base::kStages - 2) * kStageIters), 0});
|
||||
}
|
||||
smem_read_stage_idx_ = smem_write_stage_idx_;
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate, pre-load B to shared memory.
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate, pre-load B to
|
||||
/// shared memory.
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
///< problem size of GEMM
|
||||
@@ -819,8 +866,8 @@ public:
|
||||
QuantArguments mma_quant_args,
|
||||
///< initial value of accumulator
|
||||
FragmentC const &src_accum) {
|
||||
|
||||
// Prologue (start fetching iterations of global fragments into shared memory)
|
||||
// Prologue (start fetching iterations of global fragments into shared
|
||||
// memory)
|
||||
prologue(iterator_A, iterator_B, mma_quant_args, gemm_k_iterations);
|
||||
|
||||
// Wait until we have at least one completed global fetch stage
|
||||
@@ -830,7 +877,8 @@ public:
|
||||
accum = src_accum;
|
||||
|
||||
// Perform the MAC-iterations
|
||||
gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, mma_quant_args);
|
||||
gemm_iters(
|
||||
gemm_k_iterations, accum, iterator_A, iterator_B, mma_quant_args);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -46,9 +46,10 @@ template <
|
||||
/// Group size for quantization
|
||||
int GroupSize_>
|
||||
class Wint2ParamsAccessor {
|
||||
public:
|
||||
static_assert(platform::is_same<T, half_t>::value || platform::is_same<T, bfloat16_t>::value,
|
||||
"T must be fp16 or bf16");
|
||||
public:
|
||||
static_assert(platform::is_same<T, half_t>::value ||
|
||||
platform::is_same<T, bfloat16_t>::value,
|
||||
"T must be fp16 or bf16");
|
||||
|
||||
using ElementType = T;
|
||||
using Shape = Shape_;
|
||||
@@ -72,7 +73,7 @@ public:
|
||||
using ElementLocalScale = typename IteratorLocalScale::Element;
|
||||
using LayoutLocalScale = typename IteratorLocalScale::Layout;
|
||||
static_assert(platform::is_same<ElementLocalScale, uint4b_t>::value,
|
||||
"local_scale's type must be uint4b_t.");
|
||||
"local_scale's type must be uint4b_t.");
|
||||
|
||||
using ElementCodeScaleZp = typename IteratorCodeScaleZp::Element;
|
||||
using LayoutCodeScaleZp = typename IteratorCodeScaleZp::Layout;
|
||||
@@ -80,24 +81,32 @@ public:
|
||||
/// 2 uint4b_t values are stored in a single uint8_t
|
||||
constexpr static int kStagesPerLocalScaleLoad = 2 * kGroupSize / Shape::kK;
|
||||
constexpr static int kLocalScaleRows =
|
||||
IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn * sizeof_bits<ElementLocalScale>::value / 8 / Shape::kN;
|
||||
IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn *
|
||||
sizeof_bits<ElementLocalScale>::value / 8 / Shape::kN;
|
||||
|
||||
using SmemElement = uint8_t;
|
||||
constexpr static int kSmemRows =
|
||||
kLocalScaleRows * kStages + sizeof(ElementSuperScale) + sizeof(ElementCodeScaleZp) * 2;
|
||||
constexpr static int kSmemRows = kLocalScaleRows * kStages +
|
||||
sizeof(ElementSuperScale) +
|
||||
sizeof(ElementCodeScaleZp) * 2;
|
||||
constexpr static int kSmemColumns = Shape::kN;
|
||||
|
||||
using QuantParamsShape = MatrixShape<kSmemRows, kSmemColumns>;
|
||||
|
||||
constexpr static int kSuperScaleSmemOffset = 0;
|
||||
constexpr static int kCodeScaleSmemOffset = kSmemColumns * sizeof(ElementSuperScale);
|
||||
constexpr static int kCodeZpSmemOffset = kCodeScaleSmemOffset + kSmemColumns * sizeof(ElementCodeScaleZp);
|
||||
constexpr static int kLocalScaleSmemOffset = kCodeZpSmemOffset + kSmemColumns * sizeof(ElementCodeScaleZp);
|
||||
constexpr static int kCodeScaleSmemOffset =
|
||||
kSmemColumns * sizeof(ElementSuperScale);
|
||||
constexpr static int kCodeZpSmemOffset =
|
||||
kCodeScaleSmemOffset + kSmemColumns * sizeof(ElementCodeScaleZp);
|
||||
constexpr static int kLocalScaleSmemOffset =
|
||||
kCodeZpSmemOffset + kSmemColumns * sizeof(ElementCodeScaleZp);
|
||||
|
||||
/// TensorRef type for loading element from a tensor
|
||||
using SuperTensorRef = cutlass::TensorRef<ElementSuperScale, LayoutSuperScale>;
|
||||
using LocalTensorRef = cutlass::TensorRef<ElementLocalScale, LayoutLocalScale>;
|
||||
using CodeTensorRef = cutlass::TensorRef<ElementCodeScaleZp, LayoutCodeScaleZp>;
|
||||
using SuperTensorRef =
|
||||
cutlass::TensorRef<ElementSuperScale, LayoutSuperScale>;
|
||||
using LocalTensorRef =
|
||||
cutlass::TensorRef<ElementLocalScale, LayoutLocalScale>;
|
||||
using CodeTensorRef =
|
||||
cutlass::TensorRef<ElementCodeScaleZp, LayoutCodeScaleZp>;
|
||||
|
||||
struct Arguments {
|
||||
IteratorSuperScale iterator_super_scale;
|
||||
@@ -113,14 +122,14 @@ public:
|
||||
IteratorCodeScaleZp iterator_code_scale,
|
||||
IteratorCodeScaleZp iterator_code_zp,
|
||||
int local_scale_pointer_offset)
|
||||
: iterator_super_scale(iterator_super_scale),
|
||||
iterator_local_scale(iterator_local_scale),
|
||||
iterator_code_scale(iterator_code_scale),
|
||||
iterator_code_zp(iterator_code_zp),
|
||||
local_scale_pointer_offset(local_scale_pointer_offset) {}
|
||||
: iterator_super_scale(iterator_super_scale),
|
||||
iterator_local_scale(iterator_local_scale),
|
||||
iterator_code_scale(iterator_code_scale),
|
||||
iterator_code_zp(iterator_code_zp),
|
||||
local_scale_pointer_offset(local_scale_pointer_offset) {}
|
||||
};
|
||||
|
||||
private:
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
@@ -128,13 +137,17 @@ private:
|
||||
/// Begin address of shared memory
|
||||
uint8_t* smem_pointer_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of super scale operand to shared memory
|
||||
/// Iterator to write threadblock-scoped tile of super scale operand to shared
|
||||
/// memory
|
||||
SmemIteratorSuperScale smem_iterator_super_scale_;
|
||||
/// Iterator to write threadblock-scoped tile of local scale operand to shared memory
|
||||
/// Iterator to write threadblock-scoped tile of local scale operand to shared
|
||||
/// memory
|
||||
SmemIteratorLocalScale smem_iterator_local_scale_;
|
||||
/// Iterator to write threadblock-scoped tile of code scale operand to shared memory
|
||||
/// Iterator to write threadblock-scoped tile of code scale operand to shared
|
||||
/// memory
|
||||
SmemIteratorCodeScaleZp smem_iterator_code_scale_;
|
||||
/// Iterator to write threadblock-scoped tile of code zp operand to shared memory
|
||||
/// Iterator to write threadblock-scoped tile of code zp operand to shared
|
||||
/// memory
|
||||
SmemIteratorCodeScaleZp smem_iterator_code_zp_;
|
||||
|
||||
/// Shared memory write stage index
|
||||
@@ -145,25 +158,29 @@ private:
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementSuperScale* get_super_scale_smem_ptr() {
|
||||
return reinterpret_cast<ElementSuperScale*>(smem_pointer_ + kSuperScaleSmemOffset);
|
||||
return reinterpret_cast<ElementSuperScale*>(smem_pointer_ +
|
||||
kSuperScaleSmemOffset);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementLocalScale* get_local_scale_smem_ptr() {
|
||||
return reinterpret_cast<ElementLocalScale*>(smem_pointer_ + kLocalScaleSmemOffset);
|
||||
return reinterpret_cast<ElementLocalScale*>(smem_pointer_ +
|
||||
kLocalScaleSmemOffset);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementCodeScaleZp* get_code_scale_smem_ptr() {
|
||||
return reinterpret_cast<ElementCodeScaleZp*>(smem_pointer_ + kCodeScaleSmemOffset);
|
||||
return reinterpret_cast<ElementCodeScaleZp*>(smem_pointer_ +
|
||||
kCodeScaleSmemOffset);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementCodeScaleZp* get_code_zp_smem_ptr() {
|
||||
return reinterpret_cast<ElementCodeScaleZp*>(smem_pointer_ + kCodeZpSmemOffset);
|
||||
return reinterpret_cast<ElementCodeScaleZp*>(smem_pointer_ +
|
||||
kCodeZpSmemOffset);
|
||||
}
|
||||
|
||||
public:
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
Wint2ParamsAccessor(
|
||||
@@ -175,55 +192,74 @@ public:
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx)
|
||||
: smem_pointer_(smem_pointer),
|
||||
smem_iterator_super_scale_(LayoutSuperScale(IteratorSuperScale::Shape::kColumn),
|
||||
get_super_scale_smem_ptr(), {1, IteratorSuperScale::Shape::kColumn}, thread_idx),
|
||||
smem_iterator_local_scale_(LayoutLocalScale(IteratorLocalScale::Shape::kColumn),
|
||||
get_local_scale_smem_ptr(), {1, IteratorLocalScale::Shape::kColumn}, thread_idx),
|
||||
smem_iterator_code_scale_(LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn),
|
||||
get_code_scale_smem_ptr(), {1, IteratorCodeScaleZp::Shape::kColumn}, thread_idx),
|
||||
smem_iterator_code_zp_(LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn),
|
||||
get_code_zp_smem_ptr(), {1, IteratorCodeScaleZp::Shape::kColumn}, thread_idx),
|
||||
smem_write_stage_idx_(0),
|
||||
smem_read_stage_idx_(0) {}
|
||||
: smem_pointer_(smem_pointer),
|
||||
smem_iterator_super_scale_(
|
||||
LayoutSuperScale(IteratorSuperScale::Shape::kColumn),
|
||||
get_super_scale_smem_ptr(),
|
||||
{1, IteratorSuperScale::Shape::kColumn},
|
||||
thread_idx),
|
||||
smem_iterator_local_scale_(
|
||||
LayoutLocalScale(IteratorLocalScale::Shape::kColumn),
|
||||
get_local_scale_smem_ptr(),
|
||||
{1, IteratorLocalScale::Shape::kColumn},
|
||||
thread_idx),
|
||||
smem_iterator_code_scale_(
|
||||
LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn),
|
||||
get_code_scale_smem_ptr(),
|
||||
{1, IteratorCodeScaleZp::Shape::kColumn},
|
||||
thread_idx),
|
||||
smem_iterator_code_zp_(
|
||||
LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn),
|
||||
get_code_zp_smem_ptr(),
|
||||
{1, IteratorCodeScaleZp::Shape::kColumn},
|
||||
thread_idx),
|
||||
smem_write_stage_idx_(0),
|
||||
smem_read_stage_idx_(0) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
SuperTensorRef super_scale_ref() {
|
||||
return {get_super_scale_smem_ptr(), LayoutSuperScale(IteratorSuperScale::Shape::kColumn)};
|
||||
return {get_super_scale_smem_ptr(),
|
||||
LayoutSuperScale(IteratorSuperScale::Shape::kColumn)};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
LocalTensorRef local_scale_ref() {
|
||||
return {get_local_scale_smem_ptr(), LayoutLocalScale(IteratorLocalScale::Shape::kColumn)};
|
||||
return {get_local_scale_smem_ptr(),
|
||||
LayoutLocalScale(IteratorLocalScale::Shape::kColumn)};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
CodeTensorRef code_scale_ref() {
|
||||
return {get_code_scale_smem_ptr(), LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn)};
|
||||
return {get_code_scale_smem_ptr(),
|
||||
LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn)};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
CodeTensorRef code_zp_ref() {
|
||||
return {get_code_zp_smem_ptr(), LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn)};
|
||||
return {get_code_zp_smem_ptr(),
|
||||
LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn)};
|
||||
}
|
||||
|
||||
template <bool IsFirstStage>
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_per_stage(Arguments &quant_args, int stage) {
|
||||
CUTLASS_DEVICE void copy_tiles_and_advance_per_stage(Arguments& quant_args,
|
||||
int stage) {
|
||||
if constexpr (IsFirstStage) {
|
||||
// Load channel-wise super_scale to shared memory, which only needs to be done once.
|
||||
// Load channel-wise super_scale to shared memory, which only needs to be
|
||||
// done once.
|
||||
typename IteratorSuperScale::Fragment tb_frag_super_scale;
|
||||
tb_frag_super_scale.clear();
|
||||
quant_args.iterator_super_scale.load(tb_frag_super_scale);
|
||||
this->smem_iterator_super_scale_.store(tb_frag_super_scale);
|
||||
|
||||
// Load channel-wise code_scale to shared memory, which only needs to be done once.
|
||||
// Load channel-wise code_scale to shared memory, which only needs to be
|
||||
// done once.
|
||||
typename IteratorCodeScaleZp::Fragment tb_frag_code_scale;
|
||||
tb_frag_code_scale.clear();
|
||||
quant_args.iterator_code_scale.load(tb_frag_code_scale);
|
||||
this->smem_iterator_code_scale_.store(tb_frag_code_scale);
|
||||
|
||||
// Load channel-wise code_zp to shared memory, which only needs to be done once.
|
||||
// Load channel-wise code_zp to shared memory, which only needs to be done
|
||||
// once.
|
||||
typename IteratorCodeScaleZp::Fragment tb_frag_code_zp;
|
||||
tb_frag_code_zp.clear();
|
||||
quant_args.iterator_code_zp.load(tb_frag_code_zp);
|
||||
@@ -231,20 +267,24 @@ public:
|
||||
}
|
||||
|
||||
if ((stage % kStagesPerLocalScaleLoad) == 0) {
|
||||
// Load group-wise local_scale to shared memory, which only needs to be done at each stage.
|
||||
// Since 2 uint4b_t values of local_scale are saved in a single uint8_t, local_scale needs to be loaded once every two stages.
|
||||
// Load group-wise local_scale to shared memory, which only needs to be
|
||||
// done at each stage. Since 2 uint4b_t values of local_scale are saved in
|
||||
// a single uint8_t, local_scale needs to be loaded once every two stages.
|
||||
using AccessType = typename IteratorLocalScale::AccessType;
|
||||
cutlass::arch::CacheOperation::Kind const kCacheOp = (sizeof_bits<AccessType>::value == 128)
|
||||
? cutlass::arch::CacheOperation::Global : cutlass::arch::CacheOperation::Always;
|
||||
cutlass::arch::CacheOperation::Kind const kCacheOp =
|
||||
(sizeof_bits<AccessType>::value == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
quant_args.iterator_local_scale.set_iteration_index(0);
|
||||
this->smem_iterator_local_scale_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for local_scale
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < IteratorLocalScale::ThreadMap::Iterations::kCount; ++j) {
|
||||
AccessType *dst_ptr =
|
||||
reinterpret_cast<AccessType *>(this->smem_iterator_local_scale_.get());
|
||||
for (int j = 0; j < IteratorLocalScale::ThreadMap::Iterations::kCount;
|
||||
++j) {
|
||||
AccessType* dst_ptr = reinterpret_cast<AccessType*>(
|
||||
this->smem_iterator_local_scale_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorLocalScale::kAccessesPerVector; ++v) {
|
||||
@@ -255,8 +295,8 @@ public:
|
||||
IteratorLocalScale::ThreadMap::kElementsPerAccess /
|
||||
IteratorLocalScale::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOp>(
|
||||
dst_ptr + v, gmem_ptr, quant_args.iterator_local_scale.valid());
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOp>(
|
||||
dst_ptr + v, gmem_ptr, quant_args.iterator_local_scale.valid());
|
||||
}
|
||||
++quant_args.iterator_local_scale;
|
||||
}
|
||||
@@ -265,13 +305,15 @@ public:
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void advance_smem_write_stage(Arguments &quant_args) {
|
||||
void advance_smem_write_stage(Arguments& quant_args) {
|
||||
if (smem_write_stage_idx_ % kStagesPerLocalScaleLoad == 0) {
|
||||
// Advance global iterators
|
||||
quant_args.iterator_local_scale.add_pointer_offset(quant_args.local_scale_pointer_offset);
|
||||
quant_args.iterator_local_scale.add_pointer_offset(
|
||||
quant_args.local_scale_pointer_offset);
|
||||
|
||||
// Advance shared iterators
|
||||
int smem_pointer_offset = IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn;
|
||||
int smem_pointer_offset =
|
||||
IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn;
|
||||
smem_iterator_local_scale_.add_pointer_offset(smem_pointer_offset);
|
||||
}
|
||||
|
||||
@@ -280,7 +322,8 @@ public:
|
||||
|
||||
if (smem_write_stage_idx_ == kStagesPerLocalScaleLoad * kStages) {
|
||||
// Wrap back around to the 'start' of the circular buffer in shared memory
|
||||
int pointer_offset = - kStages * IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn;
|
||||
int pointer_offset = -kStages * IteratorLocalScale::Shape::kRow *
|
||||
IteratorLocalScale::Shape::kColumn;
|
||||
smem_iterator_local_scale_.add_pointer_offset(pointer_offset);
|
||||
smem_write_stage_idx_ = 0;
|
||||
}
|
||||
@@ -298,14 +341,14 @@ public:
|
||||
|
||||
if (smem_read_stage_idx_ == kStagesPerLocalScaleLoad * kStages) {
|
||||
smem_read_stage_idx_ = 0;
|
||||
byte_offset = - (kStages - 1) * kLocalScaleRows * kSmemColumns;
|
||||
byte_offset = -(kStages - 1) * kLocalScaleRows * kSmemColumns;
|
||||
}
|
||||
|
||||
return byte_offset;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
int clear_mask(Arguments &quant_args, bool cond) {
|
||||
int clear_mask(Arguments& quant_args, bool cond) {
|
||||
quant_args.iterator_local_scale.clear_mask(cond);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -29,18 +29,27 @@ namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
template <typename T, int N>
|
||||
using UnzipArray = cutlass::AlignedArray<T, N, (N * cutlass::sizeof_bits<T>::value / 8)>;
|
||||
using UnzipArray =
|
||||
cutlass::AlignedArray<T, N, (N * cutlass::sizeof_bits<T>::value / 8)>;
|
||||
|
||||
template <typename T, WintQuantMethod QuantMethod, int TileRows,
|
||||
int TileColumns, int NumThreads = 128>
|
||||
template <typename T,
|
||||
WintQuantMethod QuantMethod,
|
||||
int TileRows,
|
||||
int TileColumns,
|
||||
int NumThreads = 128>
|
||||
struct UnzipAndDequantFunctor {
|
||||
__device__ void operator()(const T *in_ptr, const T *supper_scale_ptr,
|
||||
T *out_ptr, const int64_t in_stride) {}
|
||||
__device__ void operator()(const T *in_ptr,
|
||||
const T *supper_scale_ptr,
|
||||
T *out_ptr,
|
||||
const int64_t in_stride) {}
|
||||
};
|
||||
|
||||
template <typename T, int TileRows, int TileColumns, int NumThreads>
|
||||
struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt25, TileRows,
|
||||
TileColumns, NumThreads> {
|
||||
struct UnzipAndDequantFunctor<T,
|
||||
WintQuantMethod::kWeightOnlyInt25,
|
||||
TileRows,
|
||||
TileColumns,
|
||||
NumThreads> {
|
||||
using ZippedT = uint16_t;
|
||||
using ScaleComputeT = float;
|
||||
|
||||
@@ -52,7 +61,8 @@ struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt25, TileRows,
|
||||
static constexpr int32_t kLocalScaleMask = 0x1FFF;
|
||||
static constexpr int32_t kBZP = 4;
|
||||
|
||||
__device__ inline T Compute(int32_t zipped_value, int32_t shift_bit,
|
||||
__device__ inline T Compute(int32_t zipped_value,
|
||||
int32_t shift_bit,
|
||||
ScaleComputeT scale) {
|
||||
int32_t shifted_value = (zipped_value >> shift_bit) & kWeightMask;
|
||||
int32_t value = shifted_value - kBZP;
|
||||
@@ -61,8 +71,10 @@ struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt25, TileRows,
|
||||
return static_cast<T>(scaled_value);
|
||||
}
|
||||
|
||||
__device__ void operator()(const uint16_t *in_ptr, const T *super_scale_ptr,
|
||||
T *out_ptr, const int64_t in_stride) {
|
||||
__device__ void operator()(const uint16_t *in_ptr,
|
||||
const T *super_scale_ptr,
|
||||
T *out_ptr,
|
||||
const int64_t in_stride) {
|
||||
int32_t shift_bits[7] = {13, 11, 9, 6, 4, 2, 0};
|
||||
|
||||
int tid = threadIdx.x;
|
||||
@@ -111,8 +123,11 @@ struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt25, TileRows,
|
||||
};
|
||||
|
||||
template <typename T, int TileRows, int TileColumns, int NumThreads>
|
||||
struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt2, TileRows,
|
||||
TileColumns, NumThreads> {
|
||||
struct UnzipAndDequantFunctor<T,
|
||||
WintQuantMethod::kWeightOnlyInt2,
|
||||
TileRows,
|
||||
TileColumns,
|
||||
NumThreads> {
|
||||
using ZippedT = uint8_t;
|
||||
using ScaleComputeT = float;
|
||||
|
||||
@@ -129,9 +144,11 @@ struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt2, TileRows,
|
||||
// super_scale [N] T
|
||||
|
||||
// code_scale, code_zp and super_scale
|
||||
static constexpr int32_t kColumnWiseSmemBytes = (2 * sizeof(float) + sizeof(T)) * TileColumns;
|
||||
static constexpr int32_t kColumnWiseSmemBytes =
|
||||
(2 * sizeof(float) + sizeof(T)) * TileColumns;
|
||||
// zipped weights and local_scale
|
||||
static constexpr int32_t kZippedSmemBytes = (TileRows / 4 + (TileRows + 127) / 128) * TileColumns;
|
||||
static constexpr int32_t kZippedSmemBytes =
|
||||
(TileRows / 4 + (TileRows + 127) / 128) * TileColumns;
|
||||
|
||||
struct Arguments {
|
||||
uint8_t *weight_ptr;
|
||||
@@ -140,14 +157,20 @@ struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt2, TileRows,
|
||||
float *code_zp_ptr;
|
||||
T *super_scale_ptr;
|
||||
|
||||
__device__ Arguments() : weight_ptr(nullptr), local_scale_ptr(nullptr), code_scale_ptr(nullptr), code_zp_ptr(nullptr), super_scale_ptr(nullptr) {}
|
||||
__device__ Arguments()
|
||||
: weight_ptr(nullptr),
|
||||
local_scale_ptr(nullptr),
|
||||
code_scale_ptr(nullptr),
|
||||
code_zp_ptr(nullptr),
|
||||
super_scale_ptr(nullptr) {}
|
||||
|
||||
__device__ explicit Arguments(uint8_t *smem_ptr) {
|
||||
SetZippedPtrs(smem_ptr);
|
||||
SetColumnWisePtrs(smem_ptr + kZippedSmemBytes);
|
||||
}
|
||||
|
||||
__device__ Arguments(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr) {
|
||||
__device__ Arguments(uint8_t *zipped_smem_ptr,
|
||||
uint8_t *column_wise_smem_ptr) {
|
||||
SetZippedPtrs(zipped_smem_ptr);
|
||||
SetColumnWisePtrs(column_wise_smem_ptr);
|
||||
}
|
||||
@@ -159,15 +182,21 @@ struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt2, TileRows,
|
||||
|
||||
__device__ void SetColumnWisePtrs(uint8_t *column_wise_smem_ptr) {
|
||||
code_scale_ptr = reinterpret_cast<float *>(column_wise_smem_ptr);
|
||||
code_zp_ptr = reinterpret_cast<float *>(column_wise_smem_ptr + sizeof(float) * TileColumns);
|
||||
super_scale_ptr = reinterpret_cast<T *>(column_wise_smem_ptr + 2 * sizeof(float) * TileColumns);
|
||||
code_zp_ptr = reinterpret_cast<float *>(column_wise_smem_ptr +
|
||||
sizeof(float) * TileColumns);
|
||||
super_scale_ptr = reinterpret_cast<T *>(column_wise_smem_ptr +
|
||||
2 * sizeof(float) * TileColumns);
|
||||
}
|
||||
};
|
||||
|
||||
__device__ void Load(const uint8_t *g_weight_ptr, const uint8_t *g_local_scale_ptr,
|
||||
const float *g_code_scale_ptr, const float *g_code_zp_ptr,
|
||||
__device__ void Load(const uint8_t *g_weight_ptr,
|
||||
const uint8_t *g_local_scale_ptr,
|
||||
const float *g_code_scale_ptr,
|
||||
const float *g_code_zp_ptr,
|
||||
const T *g_super_scale_ptr,
|
||||
Arguments *args, const int64_t in_stride, bool need_preload) {
|
||||
Arguments *args,
|
||||
const int64_t in_stride,
|
||||
bool need_preload) {
|
||||
int tid = threadIdx.x;
|
||||
|
||||
#pragma unroll
|
||||
@@ -186,7 +215,8 @@ struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt2, TileRows,
|
||||
#pragma unroll
|
||||
for (int ls_row_id = 0; ls_row_id < TileRows / 128; ++ls_row_id) {
|
||||
int local_scale_offset = ls_row_id * in_stride + col;
|
||||
args->local_scale_ptr[ls_row_id * TileColumns + col] = g_local_scale_ptr[local_scale_offset];
|
||||
args->local_scale_ptr[ls_row_id * TileColumns + col] =
|
||||
g_local_scale_ptr[local_scale_offset];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
@@ -205,10 +235,12 @@ struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt2, TileRows,
|
||||
const float *g_code_scale_ptr,
|
||||
const float *g_code_zp_ptr,
|
||||
const T *g_super_scale_ptr,
|
||||
Arguments *args, const int64_t in_stride, bool need_preload) {
|
||||
Arguments *args,
|
||||
const int64_t in_stride,
|
||||
bool need_preload) {
|
||||
int tid = threadIdx.x;
|
||||
|
||||
constexpr int kBytesPerThread = 16; // 16B per thread
|
||||
constexpr int kBytesPerThread = 16; // 16B per thread
|
||||
|
||||
constexpr int weight_size = TileRows / 4 * TileColumns;
|
||||
constexpr int local_scale_size = (TileRows + 127) / 128 * TileColumns;
|
||||
@@ -216,87 +248,130 @@ struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt2, TileRows,
|
||||
constexpr int code_zp_size = sizeof(float) * TileColumns;
|
||||
constexpr int super_scale_size = sizeof(T) * TileColumns;
|
||||
|
||||
constexpr int total_size = weight_size + local_scale_size + code_scale_size + code_zp_size + super_scale_size;
|
||||
constexpr int total_size = weight_size + local_scale_size +
|
||||
code_scale_size + code_zp_size +
|
||||
super_scale_size;
|
||||
constexpr int total_tasks = total_size / kBytesPerThread;
|
||||
|
||||
constexpr int cur_num_threads = total_tasks / ((total_tasks + NumThreads - 1) / NumThreads);
|
||||
constexpr int cur_num_threads =
|
||||
total_tasks / ((total_tasks + NumThreads - 1) / NumThreads);
|
||||
|
||||
constexpr int weight_threads = weight_size * cur_num_threads / total_size;
|
||||
constexpr int local_scale_threads = local_scale_size * cur_num_threads / total_size;
|
||||
constexpr int code_scale_threads = code_scale_size * cur_num_threads / total_size;
|
||||
constexpr int local_scale_threads =
|
||||
local_scale_size * cur_num_threads / total_size;
|
||||
constexpr int code_scale_threads =
|
||||
code_scale_size * cur_num_threads / total_size;
|
||||
constexpr int code_zp_threads = code_zp_size * cur_num_threads / total_size;
|
||||
constexpr int super_scale_threads = super_scale_size * cur_num_threads / total_size;
|
||||
constexpr int super_scale_threads =
|
||||
super_scale_size * cur_num_threads / total_size;
|
||||
|
||||
static_assert(TileColumns % weight_threads == 0,
|
||||
"TileColumns must be divisible by weight_threads to ensure correct thread mapping.");
|
||||
"TileColumns must be divisible by weight_threads to ensure "
|
||||
"correct thread mapping.");
|
||||
|
||||
static_assert(TileColumns % local_scale_threads == 0,
|
||||
"TileColumns must be divisible by local_scale_threads to ensure correct thread mapping.");
|
||||
"TileColumns must be divisible by local_scale_threads to "
|
||||
"ensure correct thread mapping.");
|
||||
|
||||
if (tid < weight_threads) {
|
||||
constexpr int weight_per_thread_size = weight_size / weight_threads;
|
||||
constexpr int kIterations = (weight_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
|
||||
constexpr int kIterations =
|
||||
(weight_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kIterations; ++i) {
|
||||
int z_offset = (tid * weight_per_thread_size + i * kBytesPerThread);
|
||||
int g_offset = z_offset / TileColumns * in_stride + z_offset % TileColumns;
|
||||
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
|
||||
args->weight_ptr + z_offset, g_weight_ptr + g_offset, true);
|
||||
int z_offset = (tid * weight_per_thread_size + i * kBytesPerThread);
|
||||
int g_offset =
|
||||
z_offset / TileColumns * in_stride + z_offset % TileColumns;
|
||||
cutlass::arch::cp_async<kBytesPerThread,
|
||||
cutlass::arch::CacheOperation::Global>(
|
||||
args->weight_ptr + z_offset, g_weight_ptr + g_offset, true);
|
||||
}
|
||||
} else if (tid < weight_threads + local_scale_threads) {
|
||||
constexpr int start_thread_id = weight_threads;
|
||||
constexpr int local_scale_per_thread_size = local_scale_size / local_scale_threads;
|
||||
constexpr int kIterations = (local_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
|
||||
constexpr int local_scale_per_thread_size =
|
||||
local_scale_size / local_scale_threads;
|
||||
constexpr int kIterations =
|
||||
(local_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kIterations; ++i) {
|
||||
int z_offset = (tid - start_thread_id) * local_scale_per_thread_size + i * kBytesPerThread;
|
||||
int g_offset = z_offset / TileColumns * in_stride + z_offset % TileColumns;
|
||||
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
|
||||
args->local_scale_ptr + z_offset, g_local_scale_ptr + g_offset, true);
|
||||
int z_offset = (tid - start_thread_id) * local_scale_per_thread_size +
|
||||
i * kBytesPerThread;
|
||||
int g_offset =
|
||||
z_offset / TileColumns * in_stride + z_offset % TileColumns;
|
||||
cutlass::arch::cp_async<kBytesPerThread,
|
||||
cutlass::arch::CacheOperation::Global>(
|
||||
args->local_scale_ptr + z_offset,
|
||||
g_local_scale_ptr + g_offset,
|
||||
true);
|
||||
}
|
||||
} else if (need_preload) {
|
||||
if (tid < weight_threads + local_scale_threads + code_scale_threads) {
|
||||
constexpr int start_thread_id = weight_threads + local_scale_threads;
|
||||
constexpr int code_scale_per_thread_size = code_scale_size / code_scale_threads;
|
||||
constexpr int kIterations = (code_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
|
||||
constexpr int code_scale_per_thread_size =
|
||||
code_scale_size / code_scale_threads;
|
||||
constexpr int kIterations =
|
||||
(code_scale_per_thread_size + kBytesPerThread - 1) /
|
||||
kBytesPerThread;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kIterations; ++i) {
|
||||
int offset = ((tid - start_thread_id) * code_scale_per_thread_size + i * kBytesPerThread) / sizeof(float);
|
||||
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
|
||||
int offset = ((tid - start_thread_id) * code_scale_per_thread_size +
|
||||
i * kBytesPerThread) /
|
||||
sizeof(float);
|
||||
cutlass::arch::cp_async<kBytesPerThread,
|
||||
cutlass::arch::CacheOperation::Global>(
|
||||
args->code_scale_ptr + offset, g_code_scale_ptr + offset, true);
|
||||
}
|
||||
} else if (tid < weight_threads + local_scale_threads + code_scale_threads + code_zp_threads) {
|
||||
constexpr int start_thread_id = weight_threads + local_scale_threads + code_scale_threads;
|
||||
} else if (tid < weight_threads + local_scale_threads +
|
||||
code_scale_threads + code_zp_threads) {
|
||||
constexpr int start_thread_id =
|
||||
weight_threads + local_scale_threads + code_scale_threads;
|
||||
constexpr int code_zp_per_thread_size = code_zp_size / code_zp_threads;
|
||||
constexpr int kIterations = (code_zp_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
|
||||
constexpr int kIterations =
|
||||
(code_zp_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kIterations; ++i) {
|
||||
int offset = ((tid - start_thread_id) * code_zp_per_thread_size + i * kBytesPerThread) / sizeof(float);
|
||||
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
|
||||
int offset = ((tid - start_thread_id) * code_zp_per_thread_size +
|
||||
i * kBytesPerThread) /
|
||||
sizeof(float);
|
||||
cutlass::arch::cp_async<kBytesPerThread,
|
||||
cutlass::arch::CacheOperation::Global>(
|
||||
args->code_zp_ptr + offset, g_code_zp_ptr + offset, true);
|
||||
}
|
||||
} else if (tid < weight_threads + local_scale_threads + code_scale_threads + code_zp_threads + super_scale_threads) {
|
||||
} else if (tid < weight_threads + local_scale_threads +
|
||||
code_scale_threads + code_zp_threads +
|
||||
super_scale_threads) {
|
||||
if (g_super_scale_ptr) {
|
||||
constexpr int start_thread_id = weight_threads + local_scale_threads + code_scale_threads + code_zp_threads;
|
||||
constexpr int super_scale_per_thread_size = super_scale_size / super_scale_threads;
|
||||
constexpr int kIterations = (super_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
|
||||
constexpr int start_thread_id = weight_threads + local_scale_threads +
|
||||
code_scale_threads + code_zp_threads;
|
||||
constexpr int super_scale_per_thread_size =
|
||||
super_scale_size / super_scale_threads;
|
||||
constexpr int kIterations =
|
||||
(super_scale_per_thread_size + kBytesPerThread - 1) /
|
||||
kBytesPerThread;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kIterations; ++i) {
|
||||
int offset = ((tid - start_thread_id) * super_scale_per_thread_size + i * kBytesPerThread) / sizeof(T);
|
||||
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
|
||||
args->super_scale_ptr + offset, g_super_scale_ptr + offset, true);
|
||||
int offset =
|
||||
((tid - start_thread_id) * super_scale_per_thread_size +
|
||||
i * kBytesPerThread) /
|
||||
sizeof(T);
|
||||
cutlass::arch::cp_async<kBytesPerThread,
|
||||
cutlass::arch::CacheOperation::Global>(
|
||||
args->super_scale_ptr + offset,
|
||||
g_super_scale_ptr + offset,
|
||||
true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void Compute(const Arguments &args, T *out_ptr,
|
||||
__device__ void Compute(const Arguments &args,
|
||||
T *out_ptr,
|
||||
const int64_t block_start_row) {
|
||||
int32_t shift_bits[4] = {9, 6, 3, 0};
|
||||
|
||||
@@ -333,9 +408,9 @@ struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt2, TileRows,
|
||||
|
||||
#pragma unroll
|
||||
for (int zipped_row = 0; zipped_row < 16; ++zipped_row) {
|
||||
int32_t decode_value =
|
||||
static_cast<int32_t>(floor(zipped_value[zipped_row] * code_scale + code_zp +
|
||||
static_cast<ScaleComputeT>(0.5)));
|
||||
int32_t decode_value = static_cast<int32_t>(
|
||||
floor(zipped_value[zipped_row] * code_scale + code_zp +
|
||||
static_cast<ScaleComputeT>(0.5)));
|
||||
|
||||
int row = group_id * 64 + zipped_row * 4;
|
||||
|
||||
@@ -355,14 +430,17 @@ struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt2, TileRows,
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
__device__ void ComputeVectorized(const Arguments &args, T *out_ptr,
|
||||
__device__ void ComputeVectorized(const Arguments &args,
|
||||
T *out_ptr,
|
||||
const int64_t block_start_row) {
|
||||
constexpr int kNumWeightsPerThread = TileRows * TileColumns / (4 * NumThreads);
|
||||
constexpr int kNumWeightsPerThread =
|
||||
TileRows * TileColumns / (4 * NumThreads);
|
||||
constexpr int N = (kNumWeightsPerThread >= 32) ? 4 : 2;
|
||||
constexpr int RowStride = NumThreads * N / TileColumns;
|
||||
constexpr int kNumIters = kNumWeightsPerThread / N;
|
||||
|
||||
static_assert(N * NumThreads >= TileColumns, "N * NumThreads should be no less than TileColumns.");
|
||||
static_assert(N * NumThreads >= TileColumns,
|
||||
"N * NumThreads should be no less than TileColumns.");
|
||||
|
||||
constexpr ScaleComputeT decode_value_zp = static_cast<ScaleComputeT>(0.5);
|
||||
|
||||
@@ -373,19 +451,22 @@ struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt2, TileRows,
|
||||
static_assert(TileRows <= 128, "TileRows is expected to no more than 128.");
|
||||
|
||||
UnzipArray<uint8_t, N> local_scales =
|
||||
*reinterpret_cast<const UnzipArray<uint8_t, N> *>(args.local_scale_ptr + begin_col_id);
|
||||
*reinterpret_cast<const UnzipArray<uint8_t, N> *>(args.local_scale_ptr +
|
||||
begin_col_id);
|
||||
|
||||
UnzipArray<uint8_t, N> zipped_values[2];
|
||||
int zipped_offset = begin_row_id * TileColumns + begin_col_id;
|
||||
zipped_values[0] =
|
||||
*reinterpret_cast<const UnzipArray<uint8_t, N> *>(args.weight_ptr + zipped_offset);
|
||||
zipped_values[0] = *reinterpret_cast<const UnzipArray<uint8_t, N> *>(
|
||||
args.weight_ptr + zipped_offset);
|
||||
|
||||
UnzipArray<T, N> super_scales =
|
||||
*reinterpret_cast<const UnzipArray<T, N> *>(args.super_scale_ptr + begin_col_id);
|
||||
UnzipArray<T, N> super_scales = *reinterpret_cast<const UnzipArray<T, N> *>(
|
||||
args.super_scale_ptr + begin_col_id);
|
||||
UnzipArray<float, N> code_scales =
|
||||
*reinterpret_cast<const UnzipArray<float, N> *>(args.code_scale_ptr + begin_col_id);
|
||||
*reinterpret_cast<const UnzipArray<float, N> *>(args.code_scale_ptr +
|
||||
begin_col_id);
|
||||
UnzipArray<float, N> code_zps =
|
||||
*reinterpret_cast<const UnzipArray<float, N> *>(args.code_zp_ptr + begin_col_id);
|
||||
*reinterpret_cast<const UnzipArray<float, N> *>(args.code_zp_ptr +
|
||||
begin_col_id);
|
||||
|
||||
// special for TileRows = 64
|
||||
int local_scale_shift = (((block_start_row / 64) + 1) & 1) * 4;
|
||||
@@ -394,9 +475,10 @@ struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt2, TileRows,
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; ++i) {
|
||||
int32_t shifted_local_scale =
|
||||
(static_cast<int32_t>(local_scales[i]) >> local_scale_shift) & kLocalScaleMask;
|
||||
scales[i] =
|
||||
static_cast<ScaleComputeT>(shifted_local_scale) * static_cast<ScaleComputeT>(super_scales[i]);
|
||||
(static_cast<int32_t>(local_scales[i]) >> local_scale_shift) &
|
||||
kLocalScaleMask;
|
||||
scales[i] = static_cast<ScaleComputeT>(shifted_local_scale) *
|
||||
static_cast<ScaleComputeT>(super_scales[i]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
@@ -405,26 +487,33 @@ struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt2, TileRows,
|
||||
int row = zipped_row * 4;
|
||||
|
||||
if (iter_id < kNumIters - 1) {
|
||||
int zipped_offset = (zipped_row + RowStride) * TileColumns + begin_col_id;
|
||||
int zipped_offset =
|
||||
(zipped_row + RowStride) * TileColumns + begin_col_id;
|
||||
zipped_values[(iter_id + 1) & 1] =
|
||||
*reinterpret_cast<const UnzipArray<uint8_t, N> *>(args.weight_ptr + zipped_offset);
|
||||
*reinterpret_cast<const UnzipArray<uint8_t, N> *>(args.weight_ptr +
|
||||
zipped_offset);
|
||||
}
|
||||
|
||||
UnzipArray<T, N> outs[4];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; ++i) {
|
||||
int32_t decode_value =
|
||||
static_cast<int32_t>(floor(static_cast<ScaleComputeT>(zipped_values[iter_id & 1][i]) * code_scales[i]
|
||||
+ code_zps[i] + decode_value_zp));
|
||||
int32_t decode_value = static_cast<int32_t>(
|
||||
floor(static_cast<ScaleComputeT>(zipped_values[iter_id & 1][i]) *
|
||||
code_scales[i] +
|
||||
code_zps[i] + decode_value_zp));
|
||||
|
||||
ScaleComputeT value_3 = static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
|
||||
ScaleComputeT value_3 =
|
||||
static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
|
||||
decode_value >>= 3;
|
||||
ScaleComputeT value_2 = static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
|
||||
ScaleComputeT value_2 =
|
||||
static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
|
||||
decode_value >>= 3;
|
||||
ScaleComputeT value_1 = static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
|
||||
ScaleComputeT value_1 =
|
||||
static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
|
||||
decode_value >>= 3;
|
||||
ScaleComputeT value_0 = static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
|
||||
ScaleComputeT value_0 =
|
||||
static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
|
||||
outs[0][i] = static_cast<T>(scales[i] * value_0);
|
||||
outs[1][i] = static_cast<T>(scales[i] * value_1);
|
||||
outs[2][i] = static_cast<T>(scales[i] * value_2);
|
||||
|
||||
Reference in New Issue
Block a user