[MetaxGPU] Support FastDeploy on metax gpu (#3241)

* [MetaxGPU] Support FastDeploy on metax gpu

* Update metax_worker.py

1. change worker log;
2. remove custom allreduce, adapt it later;
3. remove cuda graph;

* Update __init__.py

1. remove metax's key work comment

* Update __init__.py

1. remove metax's key word comment;
2. add fused_moe_kernel_paddle import

---------

Co-authored-by: yongqiangma <xing.wo@163.com>
This commit is contained in:
Kane2011
2025-08-13 11:11:54 +08:00
committed by GitHub
parent ed6bff215a
commit b4fef2cf29
29 changed files with 3224 additions and 11 deletions
@@ -0,0 +1,19 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .triton_moe_kernels import fused_moe_kernel_paddle
__all__ = [
"fused_moe_kernel_paddle",
]
@@ -0,0 +1,276 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import paddle
from paddle import nn
import fastdeploy
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
from fastdeploy.utils import ceil_div
from .triton_moe_kernels import fused_moe_kernel_paddle
class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
"""
Use Triton Group Gemm to compute Fused MoE.
"""
def __init__(self, quant_config=None):
"""
Triton Group Gemm to compute Fused MoE.
"""
self.quant_config = quant_config
self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"]
self.added_scale_attrs = [
"up_gate_proj_weight_scale",
"down_proj_weight_scale",
]
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
"""process_prequanted_weights"""
pass
def create_weights(self, layer: nn.Layer, state_dict):
"""
Triton MoE create weight process.
"""
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
assert len(up_gate_proj_weights) == layer.num_local_experts
assert len(down_proj_weights) == layer.num_local_experts
if layer.quant_method.quant_config:
algo = layer.quant_method.quant_config.name()
assert up_gate_proj_weights[0].shape == [
layer.hidden_size,
layer.moe_intermediate_size * 2,
]
assert down_proj_weights[0].shape == [
layer.moe_intermediate_size,
layer.hidden_size,
]
up_gate_proj_tensor = paddle.stack(up_gate_proj_weights, axis=0)
down_proj_tensor = paddle.stack(down_proj_weights, axis=0)
if algo == "wint8":
max_bound = 127
elif algo == "wint4":
max_bound = 7
for idx, weight_tensor in enumerate([up_gate_proj_tensor, down_proj_tensor]):
weight_name = self.added_weight_attrs[idx]
scale_name = self.added_scale_attrs[idx]
quanted_weight_scale = weight_tensor.abs().max(axis=1)
if self.quant_config is not None:
quanted_weight = weight_tensor / quanted_weight_scale[:, None, :] * max_bound
quanted_weight = paddle.round(quanted_weight).astype("int8")
quanted_weight_scale = quanted_weight_scale / max_bound
setattr(
layer,
weight_name,
layer.create_parameter(
shape=quanted_weight.shape,
dtype=quanted_weight.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).set_value(quanted_weight)
setattr(
layer,
scale_name,
layer.create_parameter(
shape=quanted_weight_scale.shape,
dtype=quanted_weight_scale.dtype,
),
)
getattr(layer, scale_name).set_value(quanted_weight_scale)
else:
setattr(
layer,
weight_name,
layer.create_parameter(
shape=quanted_weight.shape,
dtype=quanted_weight.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).set_value(quanted_weight)
setattr(
layer,
scale_name,
layer.create_parameter(
shape=quanted_weight_scale.shape,
dtype=quanted_weight_scale.dtype,
),
)
getattr(layer, scale_name).set_value(quanted_weight_scale)
def apply(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate_out: paddle.Tensor,
) -> paddle.Tensor:
"""
Triton compute Fused MoE.
"""
token_num = x.shape[0]
top_k = layer.top_k
num_local_experts = layer.num_local_experts
top_k = layer.top_k
moe_intermediate_size = layer.moe_intermediate_size
hidden_size = layer.hidden_size
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,
layer.gate_correction_bias,
top_k,
True, # apply_norm_weight,
False,
)
up_gate_proj_out = paddle.empty(
[token_num * top_k, moe_intermediate_size * 2],
dtype=x.dtype,
)
if self.quant_config is not None:
config = {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
}
else:
config = {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
}
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
topk_ids, num_local_experts, config["BLOCK_SIZE_M"]
)
max_possible_num_post_padded = sorted_token_ids.shape[0]
grid = (
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"])
* ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]),
)
fused_moe_kernel_paddle[grid](
x,
layer.up_gate_proj_weight,
up_gate_proj_out,
None,
layer.up_gate_proj_weight_scale,
None,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
max_possible_num_post_padded,
token_num * top_k,
N=moe_intermediate_size * 2,
K=hidden_size,
stride_am=x.strides[0],
stride_ak=x.strides[1],
stride_be=layer.up_gate_proj_weight.strides[0],
stride_bk=layer.up_gate_proj_weight.strides[1],
stride_bn=layer.up_gate_proj_weight.strides[2],
stride_cm=up_gate_proj_out.strides[0],
stride_cn=up_gate_proj_out.strides[1],
#
stride_asm=-1,
stride_ask=-1,
stride_bse=layer.up_gate_proj_weight_scale.strides[0],
stride_bsk=-1,
stride_bsn=layer.up_gate_proj_weight_scale.strides[1],
group_n=-1,
group_k=-1,
# Meta-parameters
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
GROUP_SIZE_M=config["GROUP_SIZE_M"],
MUL_ROUTED_WEIGHT=False,
top_k=top_k,
compute_type_enum=1,
use_fp8_w8a8=False,
use_int8_w8a16=True,
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
)
down_proj_input = paddle.incubate.nn.functional.swiglu(up_gate_proj_out)
down_proj_out = paddle.empty(
(token_num * top_k, hidden_size),
dtype=x.dtype,
)
grid = (
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"])
* ceil_div(hidden_size, config["BLOCK_SIZE_N"]),
)
fused_moe_kernel_paddle[grid](
down_proj_input,
layer.down_proj_weight,
down_proj_out,
None,
layer.down_proj_weight_scale,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
max_possible_num_post_padded,
token_num * top_k,
N=hidden_size,
K=moe_intermediate_size,
stride_am=down_proj_input.strides[0],
stride_ak=down_proj_input.strides[1],
stride_be=layer.down_proj_weight.strides[0],
stride_bk=layer.down_proj_weight.strides[1],
stride_bn=layer.down_proj_weight.strides[2],
stride_cm=down_proj_out.strides[0],
stride_cn=down_proj_out.strides[1],
stride_asm=-1,
stride_ask=-1,
stride_bse=layer.down_proj_weight_scale.strides[0],
stride_bsk=-1,
stride_bsn=layer.down_proj_weight_scale.strides[1],
group_n=-1,
group_k=-1,
# Meta-parameters
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
GROUP_SIZE_M=config["GROUP_SIZE_M"],
MUL_ROUTED_WEIGHT=True,
top_k=1,
compute_type_enum=1,
use_fp8_w8a8=False,
use_int8_w8a16=True,
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
)
down_proj_out.reshape_([token_num, top_k, hidden_size])
out = down_proj_out.sum(axis=1)
return out
@@ -0,0 +1,187 @@
"""
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import triton
import triton.language as tl
@triton.jit
def fused_moe_kernel_paddle(
a_ptr,
b_ptr,
c_ptr,
a_scale_ptr,
b_scale_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
max_possible_num_post_padded,
num_valid_tokens,
N,
K,
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_asm,
stride_ask,
stride_bse,
stride_bsk,
stride_bsn,
# Block size for block-wise fp8 quantization
group_n: tl.constexpr,
group_k: tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type_enum: tl.constexpr,
use_fp8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr,
even_Ks: tl.constexpr,
):
"""
Key Parameters:
- A: The input tensor representing tokens with shape (*, K), where '*' can
be any shape representing batches and K is the feature dimension of
each token.
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
the number of experts, K is the input feature dimension, and N is
the output feature dimension.
- C: The output cache tensor with shape (M, topk, N), where M is the
total number of tokens post padding, topk is the number of times
each token is repeated, and N is the output feature dimension.
- sorted_token_ids: A tensor containing the sorted indices of tokens,
repeated topk times and arranged by the expert index they are
assigned to.
- expert_ids: A tensor containing the indices of the expert for each
block. It determines which expert matrix from B should be used for
each block in A.
This kernel performs the multiplication of a token by its corresponding
expert matrix as determined by `expert_ids`. The sorting of
`sorted_token_ids` by expert index and padding ensures divisibility by
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
multiplication across different blocks processed by the same expert.
"""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(max_possible_num_post_padded, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
assert compute_type_enum == 1
compute_type = tl.bfloat16
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak)
off_experts = tl.load(expert_ids_ptr + pid_m)
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
if use_int8_w8a16:
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
b_scale = tl.load(b_scale_ptrs)
if use_fp8_w8a8:
if group_k > 0 and group_n > 0:
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
offs_bsn = offs_bn // group_n
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
else:
# (Zkk): every expert has one activation scale and weight scale.
a_scale = tl.load(a_scale_ptr + off_experts)
b_scale = tl.load(b_scale_ptr + off_experts)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
if even_Ks:
a = tl.load(
a_ptrs,
mask=token_mask[:, None],
other=0.0,
)
b = tl.load(b_ptrs, cache_modifier=".ca", eviction_policy="evict_first")
else:
a = tl.load(
a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0,
)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
if use_int8_w8a16:
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
elif use_fp8_w8a8:
if group_k > 0 and group_n > 0:
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_scale = tl.load(
a_scale_ptrs + offs_ks * stride_ask,
mask=token_mask,
other=0.0,
)
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
else:
accumulator = tl.dot(a, b, acc=accumulator)
else:
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator = accumulator * moe_weight[:, None]
if use_int8_w8a16:
accumulator = (accumulator * b_scale).to(compute_type)
elif use_fp8_w8a8:
if group_k > 0 and group_n > 0:
accumulator = accumulator.to(compute_type)
else:
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
else:
accumulator = accumulator.to(compute_type)
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)