diff --git a/custom_ops/xpu_ops/src/ops/limit_thinking_content_length_v1.cc b/custom_ops/xpu_ops/src/ops/limit_thinking_content_length_v1.cc index 3a4a6eefcf..fb1a85179a 100644 --- a/custom_ops/xpu_ops/src/ops/limit_thinking_content_length_v1.cc +++ b/custom_ops/xpu_ops/src/ops/limit_thinking_content_length_v1.cc @@ -25,27 +25,38 @@ void LimitThinkingContentLengthV1(const paddle::Tensor& next_tokens, const paddle::Tensor& max_think_lens, const paddle::Tensor& step_idx, const paddle::Tensor& limit_think_status, + const paddle::Tensor& stop_flags, + const paddle::Tensor& eos_token_ids, const int64_t think_end_id) { phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); auto xpu_ctx = static_cast(dev_ctx); const int batch_size = next_tokens.shape()[0]; + const int eos_token_id_len = eos_token_ids.shape()[0]; int r = baidu::xpu::api::plugin::limit_thinking_content_length_kernel_v1( xpu_ctx->x_context(), const_cast(next_tokens.data()), max_think_lens.data(), step_idx.data(), + eos_token_ids.data(), const_cast(limit_think_status.data()), + const_cast(stop_flags.data()), think_end_id, - batch_size); + batch_size, + eos_token_id_len); PD_CHECK(r == 0, "baidu::xpu::api::plugin::limit_thinking_content_length_kernel_v1 " "failed."); } PD_BUILD_STATIC_OP(limit_thinking_content_length_v1) - .Inputs({"next_tokens", "max_think_lens", "step_idx", "limit_think_status"}) + .Inputs({"next_tokens", + "max_think_lens", + "step_idx", + "limit_think_status", + "stop_flags", + "eos_token_ids"}) .Attrs({"think_end_id: int64_t"}) .Outputs({"next_tokens_out"}) .SetInplaceMap({{"next_tokens", "next_tokens_out"}}) diff --git a/custom_ops/xpu_ops/src/ops/limit_thinking_content_length_v2.cc b/custom_ops/xpu_ops/src/ops/limit_thinking_content_length_v2.cc index fdf60b1dae..18be6033bf 100644 --- a/custom_ops/xpu_ops/src/ops/limit_thinking_content_length_v2.cc +++ b/custom_ops/xpu_ops/src/ops/limit_thinking_content_length_v2.cc @@ -25,6 +25,7 @@ void LimitThinkingContentLengthV2(const paddle::Tensor& next_tokens, const paddle::Tensor& max_think_lens, const paddle::Tensor& step_idx, const paddle::Tensor& limit_think_status, + const paddle::Tensor& stop_flags, const int64_t think_end_id, const int64_t line_break_id) { phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); @@ -38,6 +39,7 @@ void LimitThinkingContentLengthV2(const paddle::Tensor& next_tokens, max_think_lens.data(), step_idx.data(), const_cast(limit_think_status.data()), + stop_flags.data(), think_end_id, line_break_id, batch_size); @@ -47,7 +49,11 @@ void LimitThinkingContentLengthV2(const paddle::Tensor& next_tokens, } PD_BUILD_STATIC_OP(limit_thinking_content_length_v2) - .Inputs({"next_tokens", "max_think_lens", "step_idx", "limit_think_status"}) + .Inputs({"next_tokens", + "max_think_lens", + "step_idx", + "limit_think_status", + "stop_flags"}) .Attrs({"think_end_id: int64_t", "line_break_id: int64_t"}) .Outputs({"next_tokens_out"}) .SetInplaceMap({{"next_tokens", "next_tokens_out"}}) diff --git a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h index 971493d568..4e393b8685 100644 --- a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h +++ b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h @@ -220,9 +220,12 @@ DLL_EXPORT int limit_thinking_content_length_kernel_v1( int64_t* next_tokens, const int* max_think_lens, const int64_t* step_idx, + const int64_t* eos_token_ids, int* limit_think_status, + bool* stop_flags, const int64_t think_end_id, - const int bs); + const int bs, + const int eos_token_id_len); DLL_EXPORT int limit_thinking_content_length_kernel_v2( api::Context* ctx, @@ -230,6 +233,7 @@ DLL_EXPORT int limit_thinking_content_length_kernel_v2( const int* max_think_lens, const int64_t* step_idx, int* limit_think_status, + const bool* stop_flags, const int64_t think_end_id, const int64_t line_break_id, const int bs); diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/limit_thinking_content_length_v1.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/limit_thinking_content_length_v1.xpu index bcd61879e2..247a226006 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/limit_thinking_content_length_v1.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/limit_thinking_content_length_v1.xpu @@ -10,43 +10,68 @@ namespace xpu3 { namespace plugin { +template +static inline __device__ bool is_in_end(const T id, + const T* end_ids, + const int length) { + for (int i = 0; i < length; i++) { + if (id == end_ids[i]) { + return true; + } + } + return false; +} + __global__ void limit_thinking_content_length_kernel_v1( int64_t* next_tokens, const int* max_think_lens, const int64_t* step_idx, + const int64_t* eos_token_ids, int* limit_think_status, + bool* stop_flags, const int64_t think_end_id, - const int bs) { + const int bs, + const int eos_token_id_len) { int cid = core_id(); int ncores = core_num(); int clusterid = cluster_id(); int nclusters = cluster_num(); if (clusterid != 0) return; + __simd__ __local__ int64_t eos_token_ids_lm[256]; for (int i = cid; i < bs; i += ncores) { int max_think_len_lm; int limit_think_status_lm; int64_t next_token_lm; int64_t step_idx_lm; + bool stop_flags_lm; GM2LM_ASYNC(next_tokens + i, &next_token_lm, sizeof(int64_t)); GM2LM_ASYNC(step_idx + i, &step_idx_lm, sizeof(int64_t)); GM2LM_ASYNC(max_think_lens + i, &max_think_len_lm, sizeof(int)); + GM2LM_ASYNC(stop_flags + i, &stop_flags_lm, sizeof(bool)); + GM2LM_ASYNC( + eos_token_ids, eos_token_ids_lm, sizeof(int64_t) * eos_token_id_len); GM2LM(limit_think_status + i, &limit_think_status_lm, sizeof(int)); // 如果该序列未启用思考功能,则直接返回,默认值为 -1,表示不限制思考长度 if (max_think_len_lm < 0) continue; // 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行. - if (limit_think_status_lm == 2) continue; + if (limit_think_status_lm == 2 && stop_flags_lm) continue; // ======================= 思考阶段控制 ======================= // 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束 if (limit_think_status_lm < 1) { // 当开启思考长度控制时,检查是否超时 - if (step_idx_lm >= max_think_len_lm) { + if ((step_idx_lm >= max_think_len_lm) || + is_in_end(next_token_lm, eos_token_ids_lm, eos_token_id_len)) { // 强制将当前token替换为结束思考的token next_token_lm = think_end_id; // 将状态推进到 1, 表示 "正在结束思考" limit_think_status_lm = 1; + if (stop_flags_lm) { + stop_flags_lm = false; + LM2GM(&stop_flags_lm, stop_flags + i, sizeof(bool)); + } } } diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/limit_thinking_content_length_v2.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/limit_thinking_content_length_v2.xpu index 239f8dd929..367dd71cec 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/limit_thinking_content_length_v2.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/limit_thinking_content_length_v2.xpu @@ -15,6 +15,7 @@ __global__ void limit_thinking_content_length_kernel_v2( const int* max_think_lens, const int64_t* step_idx, int* limit_think_status, + const bool* stop_flags, const int64_t think_end_id, const int64_t line_break_id, const int bs) { @@ -29,15 +30,17 @@ __global__ void limit_thinking_content_length_kernel_v2( int limit_think_status_lm; int64_t next_token_lm; int64_t step_idx_lm; + bool stop_flags_lm; GM2LM_ASYNC(next_tokens + i, &next_token_lm, sizeof(int64_t)); GM2LM_ASYNC(step_idx + i, &step_idx_lm, sizeof(int64_t)); + GM2LM_ASYNC(stop_flags + i, &stop_flags_lm, sizeof(bool)); GM2LM_ASYNC(max_think_lens + i, &max_think_len_lm, sizeof(int)); GM2LM(limit_think_status + i, &limit_think_status_lm, sizeof(int)); // 如果该序列未启用思考功能,则直接返回,默认值为 -1,表示不限制思考长度 if (max_think_len_lm < 0) continue; // 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行. - if (limit_think_status_lm == 3) continue; + if (limit_think_status_lm == 3 && stop_flags_lm) continue; // ======================= 思考阶段控制 ======================= // 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束 diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/limit_thinking_content_length_v1.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/limit_thinking_content_length_v1.cpp index 57eb02efb1..a34a2972fc 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/limit_thinking_content_length_v1.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/limit_thinking_content_length_v1.cpp @@ -24,10 +24,12 @@ __attribute__((global)) void limit_thinking_content_length_kernel_v1( int64_t* next_tokens, const int* max_think_lens, const int64_t* step_idx, + const int64_t* eos_token_ids, int* limit_think_status, + bool* stop_flags, const int64_t think_end_id, - const int bs); - + const int bs, + const int eos_token_id_len); } // namespace plugin } // namespace xpu3 @@ -36,13 +38,58 @@ namespace xpu { namespace api { namespace plugin { +static int cpu_wrapper(Context* ctx, + int64_t* next_tokens, + const int* max_think_lens, + const int64_t* step_idx, + const int64_t* eos_token_ids, + int* limit_think_status, + bool* stop_flags, + const int64_t think_end_id, + const int bs, + const int eos_token_id_len) { + auto is_in_end = [](int64_t token_id, const int64_t* end_ids, int length) { + for (int i = 0; i < length; i++) { + if (token_id == end_ids[i]) { + return true; + } + } + return false; + }; + for (int bid = 0; bid < bs; bid++) { + const int max_think_len = max_think_lens[bid]; + if (max_think_len < 0) continue; + int current_limit_think_status = limit_think_status[bid]; + if (limit_think_status[bid] == 2 && stop_flags[bid]) continue; + int64_t next_token = next_tokens[bid]; + const int64_t step = step_idx[bid]; + if (current_limit_think_status < 1) { + if (step >= max_think_len || + is_in_end(next_token, eos_token_ids, eos_token_id_len)) { + next_token = think_end_id; + current_limit_think_status = 1; + } + } + if (current_limit_think_status < 2) { + if (next_token == think_end_id) { + current_limit_think_status = 2; + } + } + next_tokens[bid] = next_token; + limit_think_status[bid] = current_limit_think_status; + } + return api::SUCCESS; +} static int xpu3_wrapper(Context* ctx, int64_t* next_tokens, const int* max_think_lens, const int64_t* step_idx, + const int64_t* eos_token_ids, int* limit_think_status, + bool* stop_flags, const int64_t think_end_id, - const int bs) { + const int bs, + const int eos_token_id_len) { using XPU_INT64 = typename XPUIndexType::type; auto limit_thinking_content_length_kernel_v1 = xpu3::plugin::limit_thinking_content_length_kernel_v1; @@ -50,9 +97,12 @@ static int xpu3_wrapper(Context* ctx, reinterpret_cast(next_tokens), max_think_lens, reinterpret_cast(step_idx), + reinterpret_cast(eos_token_ids), limit_think_status, + stop_flags, think_end_id, - bs); + bs, + eos_token_id_len); return api::SUCCESS; } @@ -60,31 +110,46 @@ int limit_thinking_content_length_kernel_v1(Context* ctx, int64_t* next_tokens, const int* max_think_lens, const int64_t* step_idx, + const int64_t* eos_token_ids, int* limit_think_status, + bool* stop_flags, const int64_t think_end_id, - const int bs) { + const int bs, + const int eos_token_id_len) { WRAPPER_CHECK_CTX(ctx); WRAPPER_DUMP_FUNCTION_T1(ctx, "limit_thinking_content_length_kernel_v1", int); WRAPPER_DUMP_PARAM5(ctx, next_tokens, max_think_lens, step_idx, - limit_think_status, - think_end_id); - WRAPPER_DUMP_PARAM1(ctx, bs); + eos_token_ids, + limit_think_status); + WRAPPER_DUMP_PARAM4(ctx, stop_flags, think_end_id, bs, eos_token_id_len); WRAPPER_DUMP(ctx); if (ctx->dev().type() == api::kCPU) { - assert(false); + return cpu_wrapper(ctx, + next_tokens, + max_think_lens, + step_idx, + eos_token_ids, + limit_think_status, + stop_flags, + think_end_id, + bs, + eos_token_id_len); } if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { return xpu3_wrapper(ctx, next_tokens, max_think_lens, step_idx, + eos_token_ids, limit_think_status, + stop_flags, think_end_id, - bs); + bs, + eos_token_id_len); } WRAPPER_UNIMPLEMENTED(ctx); } diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/limit_thinking_content_length_v2.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/limit_thinking_content_length_v2.cpp index 242e0ec241..70b69e3803 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/limit_thinking_content_length_v2.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/limit_thinking_content_length_v2.cpp @@ -25,6 +25,7 @@ __attribute__((global)) void limit_thinking_content_length_kernel_v2( const int* max_think_lens, const int64_t* step_idx, int* limit_think_status, + const bool* stop_flags, const int64_t think_end_id, const int64_t line_break_id, const int bs); @@ -37,11 +38,60 @@ namespace xpu { namespace api { namespace plugin { +static int cpu_wrapper(Context* ctx, + int64_t* next_tokens, + const int* max_think_lens, + const int64_t* step_idx, + int* limit_think_status, + const bool* stop_flags, + const int64_t think_end_id, + const int64_t line_break_id, + const int bs) { + for (int bid = 0; bid < bs; bid++) { + const int max_think_len = max_think_lens[bid]; + if (max_think_len < 0) continue; + int current_limit_think_status = limit_think_status[bid]; + if (current_limit_think_status == 3 && stop_flags[bid]) { + continue; + } + + int64_t next_token = next_tokens[bid]; + const int64_t step = step_idx[bid]; + + if (current_limit_think_status <= 1) { + if (step == max_think_len) { + next_token = line_break_id; + current_limit_think_status = 1; + } else if (step == max_think_len + 1) { + next_token = think_end_id; + current_limit_think_status = 1; + } else if (step == max_think_len + 2) { + next_token = line_break_id; + current_limit_think_status = 1; + } else if (step == max_think_len + 3) { + next_token = line_break_id; + current_limit_think_status = 2; + } + } + if (current_limit_think_status == 0) { + if (next_token == think_end_id) { + current_limit_think_status = 3; + } + } + if (current_limit_think_status == 2) { + current_limit_think_status = 3; + } + next_tokens[bid] = next_token; + limit_think_status[bid] = current_limit_think_status; + } + return api::SUCCESS; +} static int xpu3_wrapper(Context* ctx, int64_t* next_tokens, const int* max_think_lens, const int64_t* step_idx, int* limit_think_status, + const bool* stop_flags, const int64_t think_end_id, const int64_t line_break_id, const int bs) { @@ -53,6 +103,7 @@ static int xpu3_wrapper(Context* ctx, max_think_lens, reinterpret_cast(step_idx), limit_think_status, + stop_flags, think_end_id, line_break_id, bs); @@ -64,6 +115,7 @@ int limit_thinking_content_length_kernel_v2(Context* ctx, const int* max_think_lens, const int64_t* step_idx, int* limit_think_status, + const bool* stop_flags, const int64_t think_end_id, const int64_t line_break_id, const int bs) { @@ -74,11 +126,19 @@ int limit_thinking_content_length_kernel_v2(Context* ctx, max_think_lens, step_idx, limit_think_status, - think_end_id); - WRAPPER_DUMP_PARAM2(ctx, line_break_id, bs); + stop_flags); + WRAPPER_DUMP_PARAM3(ctx, think_end_id, line_break_id, bs); WRAPPER_DUMP(ctx); if (ctx->dev().type() == api::kCPU) { - assert(false); + return cpu_wrapper(ctx, + next_tokens, + max_think_lens, + step_idx, + limit_think_status, + stop_flags, + think_end_id, + line_break_id, + bs); } if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { return xpu3_wrapper(ctx, @@ -86,6 +146,7 @@ int limit_thinking_content_length_kernel_v2(Context* ctx, max_think_lens, step_idx, limit_think_status, + stop_flags, think_end_id, line_break_id, bs); diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index 58d7ff8a45..03011433c9 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -202,6 +202,8 @@ def xpu_post_process( max_think_lens = share_inputs["max_think_lens"] step_idx = share_inputs["step_idx"] limit_think_status = share_inputs["limit_think_status"] + stop_flags = share_inputs["stop_flags"] + eos_token_ids = share_inputs["eos_token_id"] if limit_strategy == "": # for ernie-45-vl limit_thinking_content_length_v1( @@ -209,6 +211,8 @@ def xpu_post_process( max_think_lens, step_idx, limit_think_status, + stop_flags, + eos_token_ids, # 处理由于模型效果问题导致思考过程中输出eos token的问题 think_end_id, ) elif limit_strategy == "\n\n\n": @@ -219,6 +223,7 @@ def xpu_post_process( max_think_lens, step_idx, limit_think_status, + stop_flags, think_end_id, line_break_id, )