mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
refactor rl get_name_mappings_to_training (#2847)
Deploy GitHub Pages / deploy (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
* refactor rl get_name_mappings_to_training * fix tp>1 * change variable name(ffn1->up_gate_proj/ffn2->down_proj) * change variable name(linear_weight->weight/linear_bias->bias) * add rl names mapping for vl * fix ernie 0.3B error * fix develop code * fix
This commit is contained in:
@@ -15,9 +15,10 @@
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
from paddle.nn.quant import weight_only_linear
|
||||
from paddle.incubate.nn.functional import swiglu
|
||||
from paddle.nn.quant import weight_only_linear
|
||||
|
||||
|
||||
def group_gemm(
|
||||
@@ -71,31 +72,31 @@ def group_gemm(
|
||||
def iluvatar_moe_expert_ffn(
|
||||
permute_input: paddle.Tensor,
|
||||
tokens_expert_prefix_sum: paddle.Tensor,
|
||||
ffn1_weight: paddle.Tensor,
|
||||
ffn2_weight: paddle.Tensor,
|
||||
ffn1_bias: Optional[paddle.Tensor],
|
||||
ffn1_scale: Optional[paddle.Tensor],
|
||||
ffn2_scale: Optional[paddle.Tensor],
|
||||
ffn2_in_scale: Optional[paddle.Tensor],
|
||||
up_gate_proj_weight: paddle.Tensor,
|
||||
down_proj_weight: paddle.Tensor,
|
||||
up_gate_proj_bias: Optional[paddle.Tensor],
|
||||
up_gate_proj_scale: Optional[paddle.Tensor],
|
||||
down_proj_scale: Optional[paddle.Tensor],
|
||||
down_proj_in_scale: Optional[paddle.Tensor],
|
||||
expert_idx_per_token: Optional[paddle.Tensor],
|
||||
quant_method: str,
|
||||
used_in_ep_low_latency: bool,
|
||||
):
|
||||
assert ffn1_bias is None
|
||||
assert ffn1_scale is not None
|
||||
assert ffn2_scale is not None
|
||||
assert ffn2_in_scale is None
|
||||
assert up_gate_proj_bias is None
|
||||
assert up_gate_proj_scale is not None
|
||||
assert down_proj_scale is not None
|
||||
assert down_proj_in_scale is None
|
||||
assert expert_idx_per_token is None
|
||||
assert quant_method in ("weight_only_int8")
|
||||
assert not used_in_ep_low_latency
|
||||
tokens_expert_prefix_sum_cpu = tokens_expert_prefix_sum.to("cpu")
|
||||
ffn1_output = paddle.empty([permute_input.shape[0], ffn1_weight.shape[1]],
|
||||
up_gate_proj_output = paddle.empty([permute_input.shape[0], up_gate_proj_weight.shape[1]],
|
||||
dtype=permute_input.dtype)
|
||||
group_gemm(permute_input, tokens_expert_prefix_sum_cpu, ffn1_weight,
|
||||
ffn1_scale, ffn1_output)
|
||||
act_out = swiglu(ffn1_output)
|
||||
output = paddle.empty([act_out.shape[0], ffn2_weight.shape[1]],
|
||||
group_gemm(permute_input, tokens_expert_prefix_sum_cpu, up_gate_proj_weight,
|
||||
up_gate_proj_scale, up_gate_proj_output)
|
||||
act_out = swiglu(up_gate_proj_output)
|
||||
output = paddle.empty([act_out.shape[0], down_proj_weight.shape[1]],
|
||||
dtype=act_out.dtype)
|
||||
group_gemm(act_out, tokens_expert_prefix_sum_cpu, ffn2_weight, ffn2_scale,
|
||||
group_gemm(act_out, tokens_expert_prefix_sum_cpu, down_proj_weight, down_proj_scale,
|
||||
output)
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user