mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 17:11:21 +08:00
@@ -1,12 +1,12 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
@@ -18,18 +18,20 @@
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Functor performing linear combination with a maximum operation used by epilogues.
|
||||
\brief Functor performing linear combination with a maximum operation used by
|
||||
epilogues.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
@@ -46,60 +48,53 @@
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace epilogue
|
||||
{
|
||||
namespace thread
|
||||
{
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace thread {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
__forceinline__ __device__ float copysignf_pos(float a, float b)
|
||||
{
|
||||
float r;
|
||||
r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000));
|
||||
return r;
|
||||
__forceinline__ __device__ float copysignf_pos(float a, float b) {
|
||||
float r;
|
||||
r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000));
|
||||
return r;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ float tanh_opt(float x)
|
||||
{
|
||||
__forceinline__ __device__ float tanh_opt(float x) {
|
||||
#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750)
|
||||
float const exp_val = -1.f * fabs(2 * x);
|
||||
return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x);
|
||||
float const exp_val = -1.f * fabs(2 * x);
|
||||
return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x);
|
||||
#else
|
||||
return fast_tanh(x);
|
||||
return fast_tanh(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <>
|
||||
struct GELU_taylor<float>
|
||||
{
|
||||
static bool const kIsHeavy = true;
|
||||
struct GELU_taylor<float> {
|
||||
static bool const kIsHeavy = true;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
float operator()(float const& z) const
|
||||
{
|
||||
CUTLASS_DEVICE
|
||||
float operator()(float const& z) const {
|
||||
float k0 = float(0.7978845608028654);
|
||||
float k1 = float(0.044715);
|
||||
|
||||
float k0 = float(0.7978845608028654);
|
||||
float k1 = float(0.044715);
|
||||
return float(
|
||||
cutlass::constants::half<float>() * z *
|
||||
(cutlass::constants::one<float>() +
|
||||
tanh_opt(k0 * z * (cutlass::constants::one<float>() + k1 * z * z))));
|
||||
}
|
||||
|
||||
return float(cutlass::constants::half<float>() * z
|
||||
* (cutlass::constants::one<float>() + tanh_opt(k0 * z * (cutlass::constants::one<float>() + k1 * z * z))));
|
||||
}
|
||||
using Params = LinearCombinationGenericParams<float>;
|
||||
|
||||
using Params = LinearCombinationGenericParams<float>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
float operator()(float const& scalar, Params const& params_) const
|
||||
{
|
||||
return this->operator()(scalar);
|
||||
}
|
||||
CUTLASS_DEVICE
|
||||
float operator()(float const& scalar, Params const& params_) const {
|
||||
return this->operator()(scalar);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace thread
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
} // namespace thread
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
Reference in New Issue
Block a user