mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 01:29:57 +08:00
refactor rl get_name_mappings_to_training (#2847)
Deploy GitHub Pages / deploy (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
* refactor rl get_name_mappings_to_training * fix tp>1 * change variable name(ffn1->up_gate_proj/ffn2->down_proj) * change variable name(linear_weight->weight/linear_bias->bias) * add rl names mapping for vl * fix ernie 0.3B error * fix develop code * fix
This commit is contained in:
@@ -161,7 +161,7 @@ __global__ void combine_prmt_back_kernel(
|
||||
expanded_permuted_rows + expanded_permuted_row * cols; // prmt后的位置对应的值
|
||||
Load<T, VEC_SIZE>(expanded_permuted_rows_row_ptr + tid * VEC_SIZE, &load_vec);
|
||||
const int expert_idx = expert_for_source_row[k_offset]; // 当前位置对应的专家
|
||||
const T* bias_ptr = bias ? bias + expert_idx * cols : nullptr; // 当前专家对应的ffn2的bias
|
||||
const T* bias_ptr = bias ? bias + expert_idx * cols : nullptr; // 当前专家对应的down_proj的bias
|
||||
if (bias_ptr) {
|
||||
Load<T, VEC_SIZE>(bias_ptr + tid * VEC_SIZE, &bias_vec);
|
||||
#pragma unroll
|
||||
@@ -188,7 +188,7 @@ void MoeCombineKernel(const paddle::Tensor& ffn_out,
|
||||
const paddle::Tensor& expert_scales_float,
|
||||
const paddle::Tensor& permute_indices_per_token,
|
||||
const paddle::Tensor& top_k_indices,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_bias,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_bias,
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor,
|
||||
const int num_rows,
|
||||
@@ -206,7 +206,7 @@ void MoeCombineKernel(const paddle::Tensor& ffn_out,
|
||||
combine_prmt_back_kernel<<<gridx, threads, 0, stream>>>(
|
||||
ffn_out.data<data_t>(),
|
||||
output->data<data_t>(),
|
||||
ffn2_bias ? ffn2_bias->data<data_t>() : nullptr,
|
||||
down_proj_bias ? down_proj_bias->data<data_t>() : nullptr,
|
||||
expert_scales_float.data<float>(),
|
||||
permute_indices_per_token.data<int32_t>(),
|
||||
top_k_indices.data<int>(),
|
||||
@@ -223,7 +223,7 @@ std::vector<paddle::Tensor> EPMoeExpertCombine(
|
||||
const paddle::Tensor& expert_scales_float, // dst_weights
|
||||
const paddle::Tensor& permute_indices_per_token, // permute_indices_per_token
|
||||
const paddle::Tensor& top_k_indices, // dst_indices
|
||||
const paddle::optional<paddle::Tensor>& ffn2_bias,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_bias,
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor) {
|
||||
|
||||
@@ -242,7 +242,7 @@ std::vector<paddle::Tensor> EPMoeExpertCombine(
|
||||
expert_scales_float,
|
||||
permute_indices_per_token,
|
||||
top_k_indices,
|
||||
ffn2_bias,
|
||||
down_proj_bias,
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
num_rows,
|
||||
@@ -255,7 +255,7 @@ std::vector<paddle::Tensor> EPMoeExpertCombine(
|
||||
expert_scales_float,
|
||||
permute_indices_per_token,
|
||||
top_k_indices,
|
||||
ffn2_bias,
|
||||
down_proj_bias,
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
num_rows,
|
||||
@@ -274,7 +274,7 @@ __global__ void permute_x_kernel(const T *src_x,
|
||||
const int64_t *topk_idx,
|
||||
const float *topk_weights,
|
||||
const int *token_nums_per_expert,
|
||||
const float *ffn1_in_scale,
|
||||
const float *up_gate_proj_in_scale,
|
||||
const int moe_topk,
|
||||
const int num_rows,
|
||||
const int token_nums_this_rank,
|
||||
@@ -327,9 +327,9 @@ __global__ void permute_x_kernel(const T *src_x,
|
||||
// cp x
|
||||
for (int v_id = tid; v_id < hidden_size_int4; v_id += blockDim.x) {
|
||||
Load<T, vec_size>(&src_x[s_token_idx * hidden_size + v_id * vec_size], &src_vec);
|
||||
if (ffn1_in_scale) {
|
||||
if (up_gate_proj_in_scale) {
|
||||
for (int i = 0; i < vec_size; i++) {
|
||||
float quant_value = max_bound * ffn1_in_scale[expert_now] * static_cast<float>(src_vec[i]);
|
||||
float quant_value = max_bound * up_gate_proj_in_scale[expert_now] * static_cast<float>(src_vec[i]);
|
||||
if (RoundType == 0) {
|
||||
res_vec[i] = static_cast<OutT>(ClipFunc<float>(rint(quant_value), min_bound, max_bound));
|
||||
} else {
|
||||
@@ -353,7 +353,7 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
|
||||
const paddle::Tensor& topk_ids,
|
||||
const paddle::Tensor& topk_weights,
|
||||
const paddle::Tensor& token_nums_per_expert,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_in_scale,
|
||||
const std::string& moe_quant_type,
|
||||
const int moe_topk,
|
||||
const int num_rows,
|
||||
@@ -383,7 +383,7 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
ffn1_in_scale ? ffn1_in_scale.get().data<float>() : nullptr,
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
@@ -404,7 +404,7 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
ffn1_in_scale ? ffn1_in_scale.get().data<float>() : nullptr,
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
@@ -427,7 +427,7 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
ffn1_in_scale ? ffn1_in_scale.get().data<float>() : nullptr,
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
@@ -448,7 +448,7 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
ffn1_in_scale ? ffn1_in_scale.get().data<float>() : nullptr,
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
@@ -472,7 +472,7 @@ std::vector<paddle::Tensor> EPMoeExpertDispatch(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& topk_ids,
|
||||
const paddle::Tensor& topk_weights,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_in_scale,
|
||||
const std::vector<int>& token_nums_per_expert,
|
||||
const int token_nums_this_rank,
|
||||
const std::string& moe_quant_type) {
|
||||
@@ -516,7 +516,7 @@ std::vector<paddle::Tensor> EPMoeExpertDispatch(
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
num_experts_per_rank_tensor,
|
||||
ffn1_in_scale,
|
||||
up_gate_proj_in_scale,
|
||||
moe_quant_type,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
@@ -536,7 +536,7 @@ std::vector<paddle::Tensor> EPMoeExpertDispatch(
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
num_experts_per_rank_tensor,
|
||||
ffn1_in_scale,
|
||||
up_gate_proj_in_scale,
|
||||
moe_quant_type,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
@@ -568,7 +568,7 @@ std::vector<std::vector<int64_t>> EPMoeExpertDispatchInferShape(
|
||||
const std::vector<int64_t>& input_shape,
|
||||
const std::vector<int64_t>& topk_ids_shape,
|
||||
const std::vector<int64_t>& topk_weights_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn1_in_scale_dtype,
|
||||
const paddle::optional<std::vector<int64_t>>& up_gate_proj_in_scale_dtype,
|
||||
const std::vector<int>& token_nums_per_expert,
|
||||
const int token_nums_this_rank) {
|
||||
int token_rows = -1;
|
||||
@@ -610,7 +610,7 @@ std::vector<paddle::DataType> EPMoeExpertDispatchInferDtype(
|
||||
|
||||
PD_BUILD_STATIC_OP(ep_moe_expert_dispatch)
|
||||
.Inputs({"input", "topk_ids", "topk_weights",
|
||||
paddle::Optional("ffn1_in_scale")})
|
||||
paddle::Optional("up_gate_proj_in_scale")})
|
||||
.Outputs({"permute_input",
|
||||
"permute_indices_per_token",
|
||||
"token_nums_per_expert_cumsum",
|
||||
|
||||
Reference in New Issue
Block a user