mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 09:44:10 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -14,9 +14,11 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from fastdeploy.model_executor.ops.triton_ops.triton_utils_v2 import paddle_use_triton_v2
|
||||
|
||||
from fastdeploy.model_executor.ops.triton_ops.triton_utils_v2 import (
|
||||
paddle_use_triton_v2,
|
||||
)
|
||||
|
||||
|
||||
@paddle_use_triton_v2()
|
||||
@@ -30,7 +32,6 @@ def fused_moe_kernel_paddle(
|
||||
sorted_token_ids_ptr,
|
||||
expert_ids_ptr,
|
||||
num_tokens_post_padded_ptr,
|
||||
|
||||
# Matrix dimensions
|
||||
max_possible_num_post_padded,
|
||||
num_valid_tokens,
|
||||
@@ -109,16 +110,13 @@ def fused_moe_kernel_paddle(
|
||||
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
|
||||
offs_k[None, :] * stride_ak)
|
||||
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak)
|
||||
|
||||
off_experts = tl.load(expert_ids_ptr + pid_m)
|
||||
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
|
||||
offs_bn[None, :] * stride_bn)
|
||||
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
|
||||
if use_int8_w8a16:
|
||||
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[
|
||||
None, :] * stride_bsn
|
||||
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
|
||||
b_scale = tl.load(b_scale_ptrs)
|
||||
|
||||
if use_fp8_w8a8:
|
||||
@@ -140,19 +138,14 @@ def fused_moe_kernel_paddle(
|
||||
mask=token_mask[:, None],
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(b_ptrs,
|
||||
cache_modifier=".cv",
|
||||
eviction_policy='evict_first')
|
||||
b = tl.load(b_ptrs, cache_modifier=".cv", eviction_policy="evict_first")
|
||||
else:
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=token_mask[:, None] &
|
||||
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(b_ptrs,
|
||||
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
|
||||
other=0.0)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
|
||||
# We accumulate along the K dimension.
|
||||
if use_int8_w8a16:
|
||||
@@ -161,13 +154,14 @@ def fused_moe_kernel_paddle(
|
||||
if group_k > 0 and group_n > 0:
|
||||
k_start = k * BLOCK_SIZE_K
|
||||
offs_ks = k_start // group_k
|
||||
a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask,
|
||||
mask=token_mask,
|
||||
other=0.0)
|
||||
a_scale = tl.load(
|
||||
a_scale_ptrs + offs_ks * stride_ask,
|
||||
mask=token_mask,
|
||||
other=0.0,
|
||||
)
|
||||
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
|
||||
|
||||
accumulator += tl.dot(a, b) * a_scale[:,
|
||||
None] * b_scale[None, :]
|
||||
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
|
||||
else:
|
||||
accumulator = tl.dot(a, b, acc=accumulator)
|
||||
else:
|
||||
@@ -177,9 +171,7 @@ def fused_moe_kernel_paddle(
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
if MUL_ROUTED_WEIGHT:
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token,
|
||||
mask=token_mask,
|
||||
other=0)
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
||||
accumulator = accumulator * moe_weight[:, None]
|
||||
if use_int8_w8a16:
|
||||
accumulator = (accumulator * b_scale).to(compute_type)
|
||||
@@ -192,8 +184,7 @@ def fused_moe_kernel_paddle(
|
||||
accumulator = accumulator.to(compute_type)
|
||||
# Write back the block of the output
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
|
||||
None, :]
|
||||
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
||||
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
Reference in New Issue
Block a user