/* * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /** * @file epilogue_helpers.h * * This file includes types for the epilogues. The empty structs exist so we can * signal to template code the type of epilogue we want to run, and let the * underlying code specify the details such as element types, accumulator type * and elements per vector access. * */ #pragma once #include "cutlass/epilogue/thread/linear_combination.h" #include "cutlass/epilogue/thread/linear_combination_generic.h" #include "cutlass/epilogue/thread/linear_combination_relu.h" #include "cutlass/epilogue/thread/linear_combination_silu.h" #include "cutlass_extensions/epilogue/thread/fused_activations.h" // #include "cutlass/epilogue/fusion/operations.hpp" namespace cutlass_extensions { struct EpilogueOpBiasSilu {}; struct EpilogueOpBiasReLU {}; struct EpilogueOpBiasFtGelu {}; struct EpilogueOpBias {}; struct EpilogueOpDefaultSilu {}; struct EpilogueOpDefaultReLU {}; struct EpilogueOpDefaultFtGelu {}; struct EpilogueOpDefault {}; template struct Epilogue { static_assert(sizeof(ElementType) == 0, "Unrecognized Epilogue Tag"); }; constexpr auto BiasScaleMode = cutlass::epilogue::thread::ScaleType::NoBetaScaling; template struct Epilogue { using Op = cutlass::epilogue::thread::LinearCombinationSilu; }; template struct Epilogue { using Op = cutlass::epilogue::thread::LinearCombinationRelu; }; template struct Epilogue { using Op = cutlass::epilogue::thread::LinearCombinationGeneric< cutlass::epilogue::thread::GELU_taylor, ElementType, ElementsPerVectorAccess, ElementAccumulator, ElementAccumulator, BiasScaleMode, cutlass::FloatRoundStyle::round_to_nearest, true>; }; template struct Epilogue { using Op = cutlass::epilogue::thread::LinearCombination; }; constexpr auto DefaultScaleMode = cutlass::epilogue::thread::ScaleType::Default; template struct Epilogue { using Op = cutlass::epilogue::thread::LinearCombinationSilu; }; template struct Epilogue { using Op = cutlass::epilogue::thread::LinearCombinationRelu; }; template struct Epilogue { using Op = cutlass::epilogue::thread::LinearCombinationGeneric< cutlass::epilogue::thread::GELU_taylor, ElementType, ElementsPerVectorAccess, ElementAccumulator, ElementAccumulator, DefaultScaleMode, cutlass::FloatRoundStyle::round_to_nearest, true>; }; template struct Epilogue { using Op = cutlass::epilogue::thread::LinearCombination; }; } // namespace cutlass_extensions