refactor rl get_name_mappings_to_training (#2847)
Deploy GitHub Pages / deploy (push) Has been cancelled

* refactor rl get_name_mappings_to_training

* fix tp>1

* change variable name(ffn1->up_gate_proj/ffn2->down_proj)

* change variable name(linear_weight->weight/linear_bias->bias)

* add rl names mapping for vl

* fix ernie 0.3B error

* fix develop code

* fix
This commit is contained in:
Yuanle Liu
2025-07-15 22:31:42 +08:00
committed by GitHub
parent e7bcbbab52
commit 61b3997b85
47 changed files with 1591 additions and 1629 deletions
@@ -15,9 +15,10 @@
"""
from typing import Optional
import paddle
from paddle.nn.quant import weight_only_linear
from paddle.incubate.nn.functional import swiglu
from paddle.nn.quant import weight_only_linear
def group_gemm(
@@ -71,31 +72,31 @@ def group_gemm(
def iluvatar_moe_expert_ffn(
permute_input: paddle.Tensor,
tokens_expert_prefix_sum: paddle.Tensor,
ffn1_weight: paddle.Tensor,
ffn2_weight: paddle.Tensor,
ffn1_bias: Optional[paddle.Tensor],
ffn1_scale: Optional[paddle.Tensor],
ffn2_scale: Optional[paddle.Tensor],
ffn2_in_scale: Optional[paddle.Tensor],
up_gate_proj_weight: paddle.Tensor,
down_proj_weight: paddle.Tensor,
up_gate_proj_bias: Optional[paddle.Tensor],
up_gate_proj_scale: Optional[paddle.Tensor],
down_proj_scale: Optional[paddle.Tensor],
down_proj_in_scale: Optional[paddle.Tensor],
expert_idx_per_token: Optional[paddle.Tensor],
quant_method: str,
used_in_ep_low_latency: bool,
):
assert ffn1_bias is None
assert ffn1_scale is not None
assert ffn2_scale is not None
assert ffn2_in_scale is None
assert up_gate_proj_bias is None
assert up_gate_proj_scale is not None
assert down_proj_scale is not None
assert down_proj_in_scale is None
assert expert_idx_per_token is None
assert quant_method in ("weight_only_int8")
assert not used_in_ep_low_latency
tokens_expert_prefix_sum_cpu = tokens_expert_prefix_sum.to("cpu")
ffn1_output = paddle.empty([permute_input.shape[0], ffn1_weight.shape[1]],
up_gate_proj_output = paddle.empty([permute_input.shape[0], up_gate_proj_weight.shape[1]],
dtype=permute_input.dtype)
group_gemm(permute_input, tokens_expert_prefix_sum_cpu, ffn1_weight,
ffn1_scale, ffn1_output)
act_out = swiglu(ffn1_output)
output = paddle.empty([act_out.shape[0], ffn2_weight.shape[1]],
group_gemm(permute_input, tokens_expert_prefix_sum_cpu, up_gate_proj_weight,
up_gate_proj_scale, up_gate_proj_output)
act_out = swiglu(up_gate_proj_output)
output = paddle.empty([act_out.shape[0], down_proj_weight.shape[1]],
dtype=act_out.dtype)
group_gemm(act_out, tokens_expert_prefix_sum_cpu, ffn2_weight, ffn2_scale,
group_gemm(act_out, tokens_expert_prefix_sum_cpu, down_proj_weight, down_proj_scale,
output)
return output