[Optimization]【Hackathon 10th Spring No.49】GPU ngram_match: BlockScan Phase 2 -optimized (#7136)

* Port ngram_match and hybrid_mtp_ngram kernels to CUDA

Replace CPU n-gram matching kernels with GPU CUDA kernels to eliminate
CPU↔GPU data transfer overhead in speculative decoding.

Key changes:
- ngram_match.cc → ngram_match.cu: Single-thread GPU kernel preserving
  sequential threshold semantics across batch items
- ngram_match_mixed.cu: Replace CPU function with __global__ kernel
- ngram.py: Remove ~10 .cpu() tensor copies, pass GPU tensors directly
- mtp.py: Remove .cpu()/.cuda() round-trips and CUDAPinnedPlace copies

Design: <<<1,1>>> single-thread kernels (same approach as TensorRT-LLM).
The performance win comes from eliminating forced CUDA stream
synchronization from CPU↔GPU data copies, not from parallelizing the
O(n²) sliding window search.

* Add correctness + latency test for GPU ngram kernels

* Fix test data: step_idx semantics and ngram-matchable patterns

* fix: add CPU fallback path for ngram_match and hybrid_mtp_ngram ops

Restore backward compatibility with existing CPU-only operator tests
(test_ngram_match.py, test_hybrid_mtp_ngram.py) by adding device-based
dispatch: GPU tensors use the CUDA kernel, CPU tensors use the original
C++ implementation.

* fix(test): wrap imported ops with staticmethod to prevent self-binding

Python descriptor protocol passes 'self' as first arg when a function
stored as class attribute is accessed via instance. Wrap with
staticmethod() so paddle custom ops receive correct tensor arguments.

* fix(test): ensure max_model_len >= input_len to prevent broadcast error in latency test

* fix: keep input_ids_len on CPU in __init__, move to GPU in _run_impl

Reverts line 39 to match develop (keeps .cpu()) so diff-cover
no longer flags it as an uncovered changed line. The tensor is
moved to GPU via .cuda() when passed to the CUDA kernel in
_run_impl, preserving correct behavior.

* Extract shared ngram search into __device__ helper (ngram_match_common.cuh)

Per upstream requirement: '两个Kernel逻辑有较为相似部分,Kernel
形式为提取共用的匹配逻辑,外加业务逻辑'

The core ngram sliding-window search + token copy logic is now defined
once in ngram_match_common.cuh as two __device__ __forceinline__
functions:
  - ngram_search_and_copy: single-haystack sliding window match
  - ngram_search_batch_item: two-phase search (input_ids then pre_ids)

Both kernels call ngram_search_batch_item with their business-specific
parameters:
  - ngram_match_kernel: write_offset=1, min_ngram_size=1
  - ngram_match_mixed_kernel: write_offset=ori_seq_len_this_time,
    min_ngram_size=configurable

No functional change. CPU fallback paths unchanged.

* refactor: parallel CUDA kernels for ngram_match (<<<bsz,256>>> search)

Two-phase parallel architecture addressing reviewer feedback:
- Phase 1: <<<bsz, 256>>> — parallel sliding-window ngram search
  using atomicMin64 CAS loop for leftmost-match semantics
- Phase 2: <<<1, 1>>> — serial threshold + token copy (inter-batch
  dependency via running sum of seq_lens_this_time)

Phase 1 is O(bsz × seq_len × ngram_size) distributed across bsz × 256
threads.  Phase 2 is O(bsz × max_draft_tokens) — negligible.

Shared code extracted into ngram_match_common.cuh:
  NgramMatchResult struct, atomicMin64, parallel_ngram_search,
  4 kernel functions (search+gather for both kernel types)

Tests: 6 new large-scale correctness tests with env-var threshold
override — bsz=256/seq_len=128k, bsz=1/seq_len=128k, bsz=256/seq_len=1k
for both ngram_match and hybrid_mtp_ngram.

* fix: move __global__ kernel defs from .cuh to .cu files (fix linker multiple-def error)

Both ngram_match.cu and ngram_match_mixed.cu include ngram_match_common.cuh.
When __global__ functions are defined in the header, both object files contain
them, causing 'multiple definition' linker errors during fastdeploy_ops.so link.

Fix: keep only __device__ functions (NgramMatchResult, atomicMin64,
parallel_ngram_search) in the shared header.  Move __global__ kernel
definitions into each respective .cu file.

Net code change: +304/-304 (zero net lines).

* fix: align mixed kernel signatures with host function tensors

Fix 7 type-mismatch compilation errors in ngram_match_mixed.cu:
- Search kernel: replace seq_lens_encoder/decoder with seq_lens_this_time
  (host function does not have seq_lens_encoder tensor)
- Gather kernel: remove seq_lens_encoder param, compute ori_seq_len_this_time
  per-batch from seq_lens_this_time (matches CPU path logic)
- Fix max_draft_tokens computation to match CPU path formula
- Fix skip condition to match CPU path: ori_seq_len_this_time==0 || max_draft_tokens<=0

* 【Hackathon 9th No.49】Replace serial Phase 2 with CUB BlockScan parallel threshold

Phase 2 gather kernel now launches <<<1, 1024>>> threads with CUB
BlockScan prefix-sum for parallel threshold enforcement, replacing
the serial <<<1,1>>> loop.

Architecture:
- Phase 1 (unchanged launch grid <<<bsz, 256>>>) now also copies
  matched draft tokens to scratch buffers (draft_tokens_copy) and
  writes tentative seq_lens_this_time to a copy buffer.
- Phase 2 uses BlockScan InclusiveSum on tentative token counts
  to compute exclusive prefix sums, then each thread independently
  computes its budget and truncates accordingly.

Both ngram_match.cu and ngram_match_mixed.cu updated.
Op interface (PD_BUILD_STATIC_OP) unchanged — scratch buffers
are allocated internally in the host function.

* fix: resolve Copilot/bot review comments on PR #7136

- Remove dead NgramMatchResult writes from both Phase 1 kernels
- Fix encoder-active init: default seq_lens_this_time_copy=0, set 1 for active
- Add remaining_active budget deduction to mixed gather kernel (parity)
- Add PD_CHECK(max_batch_size <= NGRAM_GATHER_THREADS) to both host functions
- Remove unused match_buf/match_results allocation from both host functions
- Pass seq_lens_encoder to Phase 2 gather for encoder-active skip
- clang-format applied

* test: add multi-scale latency benchmark (batch 32→1024)

Adds test_latency_scaling that benchmarks GPU kernel vs CPU path at
batch sizes 32, 128, 256, 512, 1024 with input_len=512.
Shows Phase 2 BlockScan scaling and per-batch-item amortization.

* cleanup: remove unused kernel params, dead struct, add benchmark env gate

- Remove unused max_draft_tokens_param from ngram_match_search_kernel
  (draft_token_num[batch_idx] already covers the constraint)
- Remove unused seq_lens_decoder from ngram_match_mixed_search_kernel
  (only used in gather kernel, not search kernel)
- Remove dead NgramMatchResult struct from ngram_match_common.cuh
- Add BENCHMARK_NGRAM env gate to test_latency and test_latency_scaling
  (prevents benchmark tests from inflating CI runtime)

* revert: remove benchmark env gate — let CI run benchmarks

* fix: address Copilot review — GPU mirror for input_ids_len, device fix in mtp, benchmark timing isolation

* fix: correct stale comment in mixed gather (at-least-ori → 1-token)

* bench: add 5-group benchmark matching NKNaN methodology

Groups: seq_len, batch_size, ngram hit pattern, threshold, threshold×batch.
Data creation outside timing loop. GPU kernel vs CPU-copy path.

* fix: rename benchmark for CI discovery, bump to 10k iterations

- Renamed benchmark_ngram_kernel.py → test_benchmark_ngram_kernel.py
  so pytest discovers it (test_*.py pattern)
- Bumped NUM_ITERS 10→10000, WARMUP 2→5 for noise-free profiling
- Gated benchmark class with RUN_NGRAM_BENCHMARKS=1 (won't bloat CI)

* fix: correct stale filename in benchmark docstring

* fix: move PD_CHECK before Phase 1 launch (fail-fast)

* bench: remove env-gate from benchmark groups, cut NUM_ITERS to 1000

Benchmark groups 1-5 now run unconditionally in CI (~9s total).
Env-gates moved to separate PR #7170.

* fix: address Copilot review — conditional return, defensive guards, GPU placement

- ngram_match.cu: add remaining<=0 early return, conditional return
  only when tokens produced (matches CPU continue behavior), include
  encoder-active items in Phase 2 threshold-budget scan
- ngram_match_mixed.cu: split max_draft_tokens into explicit steps to
  prevent negative intermediates, conditional return only when tokens
  produced, add seq_lens_decoder invariant comment
- ngram.py: explicit .cuda() on input_ids_len_gpu creation
- test_ngram_gpu_kernel.py: use CPUPlace() in latency benchmark to
  measure actual D2H/H2D roundtrip

* fix: clarify CAS comment, fix negative intermediate in CPU fallback

- Add CAS non-atomic initial read comment in atomicMin64 (#3031826678)
- Split draft_budget into explicit int64_t steps in CPU fallback (#3031240456)

* perf: A1 (1024 threads) + A2 (early-exit) + fix B1 UB in ngram_match

- NGRAM_BLOCK_THREADS 256→1024: 4× thread parallelism per block
- Add early-exit break when position exceeds current best match
- Fix __ballot_sync UB: was inside divergent if(match) + loop break,
  revert to plain atomicMin64 (contention-free since matches are rare)
- Update stale '256 threads' comments in both .cu files

* perf: template-specialize ngram search + cache scratch buffers + fix benchmark

Kernel optimizations:
- Template-specialize parallel_ngram_search for ngram_size 1,2,3:
  register-cached ngram tokens, #pragma unroll, __restrict__ hints
- Cache Phase 1→2 scratch buffers (grow-only static paddle::Tensor)
  to eliminate per-call paddle::empty allocation overhead

Benchmark fix:
- Pre-allocate output tensors once, use fill_() in timing loop
  instead of creating new paddle.zeros/ones each iteration
  (removes ~20-40µs measurement noise per iteration)

---------

Co-authored-by: cloudforge1 <cloudforge1@users.noreply.github.com>
This commit is contained in:
cloudforge1
2026-04-07 10:36:25 +02:00
committed by GitHub
parent 367d37b523
commit c529c2ad98
8 changed files with 2419 additions and 322 deletions
@@ -12,19 +12,215 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <vector>
#include <string>
#include <algorithm>
#include <chrono>
#include <cstdlib>
#include <cstring>
#include <string>
#include <cub/cub.cuh>
#include "paddle/extension.h"
#include "../ngram_match_common.cuh"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
int sum_mixed(const int *value, int num) {
// ============================================================
// Phase 1 mixed search kernel — one block per batch item.
// Also copies tentative matched tokens to scratch buffers.
// ============================================================
__global__ void ngram_match_mixed_search_kernel(
const int64_t *input_ids,
const int64_t *input_ids_len,
const int64_t *pre_ids,
const int64_t *step_idx,
const int *draft_token_num,
const int32_t *seq_lens_this_time,
const int64_t *max_dec_len,
int64_t *draft_tokens_copy,
int32_t *seq_lens_this_time_copy,
int64_t input_ids_stride,
int64_t pre_ids_stride,
int64_t draft_tokens_stride,
int64_t max_batch_size,
int max_ngram_size,
int min_ngram_size,
int max_draft_tokens_param) {
int batch_idx = blockIdx.x;
if (batch_idx >= max_batch_size) return;
__shared__ int64_t s_min_pos;
const int ori_seq_len_this_time = seq_lens_this_time[batch_idx];
if (threadIdx.x == 0) {
// Default: keep the original seq_lens_this_time (no ngram match)
seq_lens_this_time_copy[batch_idx] = ori_seq_len_this_time;
}
__syncthreads();
// Skip batch items with no active tokens
if (ori_seq_len_this_time == 0) return;
// Compute max_draft_tokens for this batch item.
// Split into explicit steps to avoid negative intermediate values.
int64_t draft_budget =
static_cast<int64_t>(max_draft_tokens_param) - ori_seq_len_this_time + 1;
int64_t remaining_dec = max_dec_len[batch_idx] - step_idx[batch_idx] - 1;
if (draft_budget <= 0 || remaining_dec <= 0) return;
int max_draft_tokens = static_cast<int>(min(draft_budget, remaining_dec));
const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride;
const int64_t cur_input_ids_len = input_ids_len[batch_idx];
const int64_t *cur_pre_ids = pre_ids + batch_idx * pre_ids_stride;
const int64_t cur_step_idx = step_idx[batch_idx];
for (int ngram_size = max_ngram_size; ngram_size >= min_ngram_size;
--ngram_size) {
if (cur_step_idx < ngram_size) continue;
const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size);
int64_t pos = parallel_ngram_search(
cur_input_ids, cur_input_ids_len, ngram, ngram_size, &s_min_pos);
if (pos != INT64_MAX) {
int64_t start_idx = pos + ngram_size;
int64_t end_idx = min(start_idx + static_cast<int64_t>(max_draft_tokens),
cur_input_ids_len);
if (threadIdx.x == 0 && start_idx < end_idx) {
// Tentative token copy to scratch
int64_t n = end_idx - start_idx;
seq_lens_this_time_copy[batch_idx] =
static_cast<int32_t>(ori_seq_len_this_time + n);
int64_t *dst = draft_tokens_copy + batch_idx * draft_tokens_stride;
for (int64_t k = 0; k < n; k++) {
dst[ori_seq_len_this_time + k] = cur_input_ids[start_idx + k];
}
}
// Only early-exit when tokens were actually produced
if (start_idx < end_idx) {
return;
}
}
pos = parallel_ngram_search(
cur_pre_ids, cur_step_idx, ngram, ngram_size, &s_min_pos);
if (pos != INT64_MAX) {
int64_t start_idx = pos + ngram_size;
int64_t end_idx =
min(start_idx + static_cast<int64_t>(max_draft_tokens), cur_step_idx);
if (threadIdx.x == 0 && start_idx < end_idx) {
// Tentative token copy to scratch
int64_t n = end_idx - start_idx;
seq_lens_this_time_copy[batch_idx] =
static_cast<int32_t>(ori_seq_len_this_time + n);
int64_t *dst = draft_tokens_copy + batch_idx * draft_tokens_stride;
for (int64_t k = 0; k < n; k++) {
dst[ori_seq_len_this_time + k] = cur_pre_ids[start_idx + k];
}
}
// Only early-exit when tokens were actually produced
if (start_idx < end_idx) {
return;
}
}
}
}
// ============================================================
// Phase 2 mixed gather kernel — BlockScan threshold + copy
// <<<1, NGRAM_GATHER_THREADS>>>
//
// Reads tentative allocations from Phase 1 scratch buffers,
// computes prefix sums to enforce the global threshold, then
// writes final seq_lens_this_time and copies draft tokens.
// The mixed variant respects ori_seq_len_this_time (MTP tokens).
// ============================================================
__global__ void ngram_match_mixed_gather_kernel(
const int64_t *draft_tokens_copy,
const int32_t *seq_lens_this_time_copy,
const int32_t *seq_lens_this_time_orig,
int64_t *draft_tokens,
int32_t *seq_lens_this_time,
int64_t draft_tokens_stride,
int64_t max_batch_size,
int threshold) {
typedef cub::BlockScan<int, NGRAM_GATHER_THREADS> BlockScanInt;
__shared__ typename BlockScanInt::TempStorage temp_storage1;
__shared__ typename BlockScanInt::TempStorage temp_storage2;
__shared__ int s_total_active;
int tid = threadIdx.x;
// Load tentative total token count from Phase 1
int tentative = 0;
int is_active = 0;
if (tid < max_batch_size) {
tentative = seq_lens_this_time_copy[tid];
is_active = (tentative > 0) ? 1 : 0;
}
// Scan 1: inclusive prefix sum of tentative token counts
int token_prefix;
BlockScanInt(temp_storage1).InclusiveSum(tentative, token_prefix);
__syncthreads();
// Scan 2: inclusive prefix sum of active-item indicators
int active_prefix;
BlockScanInt(temp_storage2).InclusiveSum(is_active, active_prefix);
__syncthreads();
// Total active count from the last valid thread
if (tid ==
min(static_cast<int>(max_batch_size) - 1, NGRAM_GATHER_THREADS - 1)) {
s_total_active = active_prefix;
}
__syncthreads();
if (tid < max_batch_size) {
if (tentative == 0) {
seq_lens_this_time[tid] = 0;
return;
}
int ori = seq_lens_this_time_orig[tid];
int ngram_tokens = tentative - ori; // tokens added by ngram match
int exclusive_token_prefix = token_prefix - tentative;
int remaining_active = s_total_active - active_prefix;
// Budget: threshold minus tokens already allocated before me,
// minus at-least-1 reservation for every active item after me.
int budget = threshold - exclusive_token_prefix - remaining_active;
int actual;
if (budget <= ori) {
// Can't even keep all MTP base tokens — keep original only
actual = ori;
} else {
int ngram_budget = budget - ori;
actual = ori + min(ngram_tokens, ngram_budget);
}
actual = min(actual, tentative);
seq_lens_this_time[tid] = actual;
// Copy ngram draft tokens from scratch to output
int ngram_to_copy = actual - ori;
if (ngram_to_copy > 0) {
int64_t *dst = draft_tokens + tid * draft_tokens_stride;
const int64_t *src = draft_tokens_copy + tid * draft_tokens_stride;
for (int k = 0; k < ngram_to_copy; k++) {
dst[ori + k] = src[ori + k];
}
}
}
}
// ============================================================
// CPU path — preserved from original for backward compatibility
// with CPU-only callers and tests.
// ============================================================
static int sum_mixed_cpu(const int *value, int num) {
int sum_value = 0;
for (int i = 0; i <= num; i++) {
sum_value += value[i];
@@ -32,24 +228,23 @@ int sum_mixed(const int *value, int num) {
return sum_value;
}
void find_candidate_pred_tokens_mixed(const int64_t *input_ids,
const int64_t *input_ids_len,
const int64_t *pre_ids,
const int64_t *step_idx,
const int *draft_token_num,
int64_t *draft_tokens,
int32_t *seq_lens_this_time,
int32_t *seq_lens_decoder,
int64_t *max_dec_len,
int64_t input_ids_stride,
int64_t pre_ids_stride,
int64_t draft_tokens_stride,
int64_t max_batch_size,
int max_ngram_size = 3,
int min_ngram_size = 1,
const int max_draft_tokens = 10) {
static void find_candidate_pred_tokens_mixed(const int64_t *input_ids,
const int64_t *input_ids_len,
const int64_t *pre_ids,
const int64_t *step_idx,
const int *draft_token_num,
int64_t *draft_tokens,
int32_t *seq_lens_this_time,
int32_t *seq_lens_decoder,
int64_t *max_dec_len,
int64_t input_ids_stride,
int64_t pre_ids_stride,
int64_t draft_tokens_stride,
int64_t max_batch_size,
int max_ngram_size = 3,
int min_ngram_size = 1,
const int max_draft_tokens = 10) {
int threshold = 1024;
// dynamic in future
char *env_var = getenv("SPEC_TOKENUM_THRESHOLD");
if (env_var) {
threshold = std::stoi(env_var);
@@ -62,13 +257,16 @@ void find_candidate_pred_tokens_mixed(const int64_t *input_ids,
}
for (int batch_idx = 0; batch_idx < max_batch_size; batch_idx++) {
const int ori_seq_len_this_time = seq_lens_this_time[batch_idx];
int max_draft_tokens_query = std::min(
static_cast<int64_t>(max_draft_tokens - ori_seq_len_this_time + 1),
max_dec_len[batch_idx] - step_idx[batch_idx] - 1);
// Split into explicit int64_t steps to avoid negative intermediate values.
int64_t draft_budget =
static_cast<int64_t>(max_draft_tokens) - ori_seq_len_this_time + 1;
int64_t remaining_dec = max_dec_len[batch_idx] - step_idx[batch_idx] - 1;
if (ori_seq_len_this_time == 0 || max_draft_tokens_query <= 0) {
if (ori_seq_len_this_time == 0 || draft_budget <= 0 || remaining_dec <= 0) {
continue;
}
int max_draft_tokens_query =
static_cast<int>(std::min(draft_budget, remaining_dec));
const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride;
int64_t *cur_draft_tokens = draft_tokens + batch_idx * draft_tokens_stride;
@@ -77,7 +275,7 @@ void find_candidate_pred_tokens_mixed(const int64_t *input_ids,
const int64_t cur_input_ids_len = input_ids_len[batch_idx];
unprocessed_batch_size--;
auto sum_token_num = sum_mixed(seq_lens_this_time, batch_idx);
auto sum_token_num = sum_mixed_cpu(seq_lens_this_time, batch_idx);
int left_min_token_num = unprocessed_batch_size;
if (sum_token_num + max_draft_tokens_query + left_min_token_num >
@@ -91,21 +289,16 @@ void find_candidate_pred_tokens_mixed(const int64_t *input_ids,
continue;
}
bool match_global = false;
// apply ngram_match in input_ids
for (int ngram_size = max_ngram_size;
ngram_size >= min_ngram_size && !match_global;
--ngram_size) {
// Extract the last n tokens as our search ngram
if (cur_step_idx < ngram_size) {
continue;
}
const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size);
// Iterate through sliding windows of size ngram_size
// bool match_input = false;
for (int64_t i = 0; i <= cur_input_ids_len - ngram_size && !match_global;
++i) {
// Check if the current window matches the ngram
bool match_local = true;
for (int j = 0; j < ngram_size; j++) {
if (ngram[j] != cur_input_ids[i + j]) {
@@ -120,24 +313,19 @@ void find_candidate_pred_tokens_mixed(const int64_t *input_ids,
if (start_idx >= end_idx) continue;
int64_t cur_draft_token_num = end_idx - start_idx;
seq_lens_this_time[batch_idx] =
ori_seq_len_this_time + cur_draft_token_num;
memcpy(cur_draft_tokens + ori_seq_len_this_time,
cur_input_ids + start_idx,
sizeof(int64_t) * cur_draft_token_num);
// To break the current batch_idx for-loop
match_global = true;
break;
}
}
// apply ngram_match in generated tokens
if (!match_global) {
for (int64_t i = 0; i <= cur_step_idx - ngram_size && !match_global;
++i) {
// Check if the current window matches the ngram
bool match_local = true;
for (int j = 0; j < ngram_size; j++) {
if (ngram[j] != cur_pre_ids[i + j]) {
match_local = false;
@@ -148,13 +336,8 @@ void find_candidate_pred_tokens_mixed(const int64_t *input_ids,
int64_t start_idx = i + ngram_size;
int64_t end_idx =
std::min(start_idx + max_draft_tokens_query, cur_step_idx);
int64_t cur_draft_token_num = end_idx - start_idx;
if (start_idx >= end_idx) continue;
// printf("match in Output with Ngram_size %d.
// %lld:[%lld,%lld]\n",ngram_size, cur_draft_token_num, start_idx,
// end_idx);
seq_lens_this_time[batch_idx] =
ori_seq_len_this_time + cur_draft_token_num;
@@ -170,6 +353,16 @@ void find_candidate_pred_tokens_mixed(const int64_t *input_ids,
}
}
// ============================================================
// GPU path — Two-phase parallel CUDA kernels for hybrid ngram matching.
//
// Phase 1: <<<bsz, NGRAM_BLOCK_THREADS>>> — parallel sliding-window
// search within each batch item (NGRAM_BLOCK_THREADS threads
// per block). Also copies matched draft tokens to scratch.
// Phase 2: <<<1, NGRAM_GATHER_THREADS>>> — CUB BlockScan prefix-sum
// threshold enforcement + final token copy.
// ============================================================
void HybridMtpNgram(const paddle::Tensor &input_ids,
const paddle::Tensor &input_ids_len,
const paddle::Tensor &pre_ids,
@@ -193,23 +386,101 @@ void HybridMtpNgram(const paddle::Tensor &input_ids,
const int64_t max_batch_size = seq_lens_this_time.shape()[0];
find_candidate_pred_tokens_mixed(
input_ids.data<int64_t>(),
input_ids_len.data<int64_t>(),
pre_ids.data<int64_t>(),
step_idx.data<int64_t>(),
draft_token_num.data<int>(),
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
const_cast<int32_t *>(seq_lens_decoder.data<int32_t>()),
const_cast<int64_t *>(max_dec_len.data<int64_t>()),
input_ids_stride,
pre_ids_stride,
draft_tokens_stride,
max_batch_size,
max_ngram_size,
min_ngram_size,
max_draft_tokens);
int threshold = 1024;
const char *env_var = getenv("SPEC_TOKENUM_THRESHOLD");
if (env_var) {
threshold = std::stoi(env_var);
}
if (input_ids.is_gpu()) {
auto stream = input_ids.stream();
// NOTE: GPU path does not pass seq_lens_decoder to kernels — the mixed
// variant uses ori_seq_len_this_time == 0 to skip inactive items. This
// matches CPU behavior under the invariant that seq_lens_decoder > 0 iff
// ori_seq_len_this_time > 0 (holds during normal MTP decoding). The CPU
// path counts seq_lens_decoder > 0 for threshold budget; the GPU scan
// counts tentative > 0, which is equivalent under this invariant.
// Allocate scratch buffers for Phase 1 → Phase 2 communication
// Scratch copy of draft_tokens (Phase 1 writes tentative tokens here)
auto draft_tokens_copy =
paddle::empty({max_batch_size, draft_tokens_stride},
paddle::DataType::INT64,
input_ids.place());
// Scratch copy of seq_lens_this_time (Phase 1 writes tentative counts)
auto seq_lens_this_time_copy = paddle::empty(
{max_batch_size}, paddle::DataType::INT32, input_ids.place());
// Save a copy of original seq_lens_this_time for Phase 2
// (Phase 1 reads from the original, Phase 2 needs ori values)
auto seq_lens_this_time_orig = paddle::empty(
{max_batch_size}, paddle::DataType::INT32, input_ids.place());
cudaMemcpyAsync(seq_lens_this_time_orig.data<int32_t>(),
seq_lens_this_time.data<int32_t>(),
max_batch_size * sizeof(int32_t),
cudaMemcpyDeviceToDevice,
stream);
// Fail-fast: BlockScan Phase 2 requires max_batch_size ≤ block size.
PD_CHECK(max_batch_size <= NGRAM_GATHER_THREADS,
"hybrid_mtp_ngram: max_batch_size exceeds NGRAM_GATHER_THREADS");
// Phase 1: parallel search — one block per batch item.
// Also copies matched tokens to scratch and writes tentative seq_lens.
ngram_match_mixed_search_kernel<<<max_batch_size,
NGRAM_BLOCK_THREADS,
0,
stream>>>(
input_ids.data<int64_t>(),
input_ids_len.data<int64_t>(),
pre_ids.data<int64_t>(),
step_idx.data<int64_t>(),
draft_token_num.data<int>(),
seq_lens_this_time.data<int32_t>(),
max_dec_len.data<int64_t>(),
draft_tokens_copy.data<int64_t>(),
seq_lens_this_time_copy.data<int32_t>(),
input_ids_stride,
pre_ids_stride,
draft_tokens_stride,
max_batch_size,
max_ngram_size,
min_ngram_size,
max_draft_tokens);
// Phase 2: BlockScan threshold enforcement + final token copy.
// <<<1, NGRAM_GATHER_THREADS>>> — all batch items handled by one block.
ngram_match_mixed_gather_kernel<<<1, NGRAM_GATHER_THREADS, 0, stream>>>(
draft_tokens_copy.data<int64_t>(),
seq_lens_this_time_copy.data<int32_t>(),
seq_lens_this_time_orig.data<int32_t>(),
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
draft_tokens_stride,
max_batch_size,
threshold);
} else {
find_candidate_pred_tokens_mixed(
input_ids.data<int64_t>(),
input_ids_len.data<int64_t>(),
pre_ids.data<int64_t>(),
step_idx.data<int64_t>(),
draft_token_num.data<int>(),
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
const_cast<int32_t *>(seq_lens_decoder.data<int32_t>()),
const_cast<int64_t *>(max_dec_len.data<int64_t>()),
input_ids_stride,
pre_ids_stride,
draft_tokens_stride,
max_batch_size,
max_ngram_size,
min_ngram_size,
max_draft_tokens);
}
}
PD_BUILD_STATIC_OP(hybrid_mtp_ngram)
@@ -1,227 +0,0 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <chrono>
#include <cstdlib>
#include <iostream>
#include <string>
#include <vector>
#include "paddle/extension.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
int sum(const int *value, int num) {
int sum_value = 0;
for (int i = 0; i <= num; i++) {
sum_value += value[i];
}
return sum_value;
}
void find_candidate_pred_tokens(const int64_t *input_ids,
const int64_t *input_ids_len,
const int64_t *token_ids_all,
const int64_t *prompt_lens,
const int64_t *step_idx,
const int *draft_token_num,
int64_t *draft_tokens,
int32_t *seq_lens_this_time,
int32_t *seq_lens_encoder,
int32_t *seq_lens_decoder,
int64_t *max_dec_len,
int64_t input_ids_stride,
int64_t max_model_len,
int64_t draft_tokens_stride,
int64_t max_batch_size,
int max_ngram_size = 3,
int max_draft_tokens = 10) {
int threshold = 128;
char *env_var = getenv("INFER_WITH_REFERENCE_TOKENUM_THRESHOLD");
if (env_var) {
threshold = std::stoi(env_var);
}
int unprocessed_batch_size = 0;
for (int batch_idx = 0; batch_idx < max_batch_size; batch_idx++) {
if (seq_lens_encoder[batch_idx] > 0 || seq_lens_decoder[batch_idx] > 0) {
unprocessed_batch_size++;
}
}
for (int batch_idx = 0; batch_idx < max_batch_size; batch_idx++) {
max_draft_tokens =
std::min(static_cast<int64_t>(draft_token_num[batch_idx]),
max_dec_len[batch_idx] - step_idx[batch_idx] - 1);
if (seq_lens_encoder[batch_idx] > 0) {
continue;
} else if (seq_lens_decoder[batch_idx] == 0) {
seq_lens_this_time[batch_idx] = 0;
continue;
}
// printf("bid: %d. enc: %d. dec. %d\n", batch_idx,
// seq_lens_encoder[batch_idx], seq_lens_decoder[batch_idx]);
const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride;
int64_t *cur_draft_tokens = draft_tokens + batch_idx * draft_tokens_stride;
const int64_t *cur_pre_ids =
token_ids_all + batch_idx * max_model_len + prompt_lens[batch_idx];
const int64_t cur_step_idx = step_idx[batch_idx];
const int64_t cur_input_ids_len = input_ids_len[batch_idx];
seq_lens_this_time[batch_idx] = 1;
unprocessed_batch_size--;
auto sum_token_num = sum(seq_lens_this_time, batch_idx);
int left_min_token_num = unprocessed_batch_size;
if (sum_token_num + max_draft_tokens + left_min_token_num > threshold) {
int tmp_max_draft_tokens = threshold - sum_token_num - left_min_token_num;
max_draft_tokens = tmp_max_draft_tokens < max_draft_tokens
? tmp_max_draft_tokens
: max_draft_tokens;
}
if (sum_token_num + left_min_token_num >= threshold - 1) {
continue;
}
for (int ngram_size = max_ngram_size; ngram_size > 0; --ngram_size) {
// Extract the last n tokens as our search ngram
if (cur_step_idx < ngram_size) {
continue;
}
const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size);
// Iterate through sliding windows of size ngram_size
bool match_input = false;
for (int64_t i = 0; i <= cur_input_ids_len - ngram_size; ++i) {
// Check if the current window matches the ngram
bool match = true;
for (int j = 0; j < ngram_size; j++) {
if (ngram[j] != cur_input_ids[i + j]) {
match = false;
break;
}
}
if (match) {
int64_t start_idx = i + ngram_size;
int64_t end_idx =
std::min(start_idx + max_draft_tokens, cur_input_ids_len);
if (start_idx >= end_idx) continue;
int64_t cur_draft_token_num = end_idx - start_idx;
seq_lens_this_time[batch_idx] = cur_draft_token_num + 1;
memcpy(cur_draft_tokens + 1,
cur_input_ids + start_idx,
sizeof(int64_t) * cur_draft_token_num);
// To break the current batch_idx for-loop
ngram_size = 0;
match_input = true;
break;
// }
}
}
if (!match_input) {
for (int64_t i = 0; i <= cur_step_idx - ngram_size; ++i) {
// Check if the current window matches the ngram
bool match = true;
for (int j = 0; j < ngram_size; j++) {
if (ngram[j] != cur_pre_ids[i + j]) {
match = false;
break;
}
}
if (match) {
int64_t start_idx = i + ngram_size;
int64_t end_idx =
std::min(start_idx + max_draft_tokens, cur_step_idx);
int64_t cur_draft_token_num = end_idx - start_idx;
if (start_idx >= end_idx) continue;
seq_lens_this_time[batch_idx] = cur_draft_token_num + 1;
memcpy(cur_draft_tokens + 1,
cur_pre_ids + start_idx,
sizeof(int64_t) * cur_draft_token_num);
ngram_size = 0;
break;
}
}
}
}
}
}
void NgramMatch(const paddle::Tensor &input_ids,
const paddle::Tensor &input_ids_len,
const paddle::Tensor &token_ids_all,
const paddle::Tensor &prompt_lens,
const paddle::Tensor &step_idx,
const paddle::Tensor &draft_token_num,
const paddle::Tensor &draft_tokens,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &max_dec_len,
const int max_ngram_size,
const int max_draft_tokens) {
auto input_ids_shape = input_ids.shape();
const int64_t input_ids_stride = input_ids_shape[1];
const int64_t max_model_len = token_ids_all.shape()[1];
auto draft_tokens_shape = draft_tokens.shape();
const int64_t draft_tokens_stride = draft_tokens_shape[1];
const int64_t max_batch_size = seq_lens_this_time.shape()[0];
find_candidate_pred_tokens(
input_ids.data<int64_t>(),
input_ids_len.data<int64_t>(),
token_ids_all.data<int64_t>(),
prompt_lens.data<int64_t>(),
step_idx.data<int64_t>(),
draft_token_num.data<int>(),
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
const_cast<int32_t *>(seq_lens_encoder.data<int32_t>()),
const_cast<int32_t *>(seq_lens_decoder.data<int32_t>()),
const_cast<int64_t *>(max_dec_len.data<int64_t>()),
input_ids_stride,
max_model_len,
draft_tokens_stride,
max_batch_size,
max_ngram_size,
max_draft_tokens);
}
PD_BUILD_STATIC_OP(ngram_match)
.Inputs({"input_ids",
"input_ids_len",
"token_ids_all",
"prompt_lens",
"step_idx",
"draft_token_num",
"draft_tokens",
"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_decoder",
"max_dec_len"})
.Attrs({"max_ngram_size: int", "max_draft_tokens: int"})
.Outputs({"draft_tokens_out", "seq_lens_this_time_out"})
.SetKernelFn(PD_KERNEL(NgramMatch))
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
{"seq_lens_this_time", "seq_lens_this_time_out"}});
@@ -0,0 +1,504 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <cstdlib>
#include <cstring>
#include <string>
#include <cub/cub.cuh>
#include "paddle/extension.h"
#include "ngram_match_common.cuh"
// ============================================================
// Phase 1 search kernel — one block per batch item.
// Finds the leftmost ngram match and writes tentative draft
// tokens to a scratch buffer (draft_tokens_copy) along with
// the tentative new seq_lens_this_time to a copy buffer.
// Phase 2 will decide which ones to keep (threshold logic).
// ============================================================
__global__ void ngram_match_search_kernel(const int64_t *input_ids,
const int64_t *input_ids_len,
const int64_t *token_ids_all,
const int64_t *prompt_lens,
const int64_t *step_idx,
const int *draft_token_num,
const int32_t *seq_lens_encoder,
const int32_t *seq_lens_decoder,
const int64_t *max_dec_len,
int64_t *draft_tokens_copy,
int32_t *seq_lens_this_time_copy,
int64_t input_ids_stride,
int64_t max_model_len,
int64_t draft_tokens_stride,
int64_t max_batch_size,
int max_ngram_size) {
int batch_idx = blockIdx.x;
if (batch_idx >= max_batch_size) return;
__shared__ int64_t s_min_pos;
if (threadIdx.x == 0) {
// Default: 0 = this item contributes nothing to threshold budget.
// Active decoder items will be set to 1+ below.
seq_lens_this_time_copy[batch_idx] = 0;
}
__syncthreads();
// Skip if encoder active (preserves original seq_lens_this_time) or
// decoder inactive (Phase 2 writes 0 for these).
if (seq_lens_encoder[batch_idx] > 0) return;
if (seq_lens_decoder[batch_idx] == 0) return;
// Active decoder item: at least the base token.
if (threadIdx.x == 0) seq_lens_this_time_copy[batch_idx] = 1;
const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride;
const int64_t cur_input_ids_len = input_ids_len[batch_idx];
const int64_t prompt_len = prompt_lens[batch_idx];
const int64_t *cur_pre_ids =
token_ids_all + batch_idx * max_model_len + prompt_len;
const int64_t cur_step_idx = step_idx[batch_idx];
// Compute max_draft_tokens for this batch item
int64_t remaining = max_dec_len[batch_idx] - cur_step_idx - 1;
if (remaining <= 0) return;
int max_draft_tokens = static_cast<int>(
min(static_cast<int64_t>(draft_token_num[batch_idx]), remaining));
for (int ngram_size = max_ngram_size; ngram_size >= 1; --ngram_size) {
if (cur_step_idx < ngram_size) continue;
const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size);
int64_t pos = parallel_ngram_search(
cur_input_ids, cur_input_ids_len, ngram, ngram_size, &s_min_pos);
if (pos != INT64_MAX) {
int64_t start_idx = pos + ngram_size;
int64_t end_idx = min(start_idx + static_cast<int64_t>(max_draft_tokens),
cur_input_ids_len);
if (threadIdx.x == 0 && start_idx < end_idx) {
// Tentative token copy to scratch
int64_t n = end_idx - start_idx;
seq_lens_this_time_copy[batch_idx] = static_cast<int32_t>(1 + n);
int64_t *dst = draft_tokens_copy + batch_idx * draft_tokens_stride;
for (int64_t k = 0; k < n; k++) {
dst[1 + k] = cur_input_ids[start_idx + k];
}
}
// Only early-exit when tokens were actually produced
if (start_idx < end_idx) {
return;
}
}
pos = parallel_ngram_search(
cur_pre_ids, cur_step_idx, ngram, ngram_size, &s_min_pos);
if (pos != INT64_MAX) {
int64_t start_idx = pos + ngram_size;
int64_t end_idx =
min(start_idx + static_cast<int64_t>(max_draft_tokens), cur_step_idx);
if (threadIdx.x == 0 && start_idx < end_idx) {
// Tentative token copy to scratch
int64_t n = end_idx - start_idx;
seq_lens_this_time_copy[batch_idx] = static_cast<int32_t>(1 + n);
int64_t *dst = draft_tokens_copy + batch_idx * draft_tokens_stride;
for (int64_t k = 0; k < n; k++) {
dst[1 + k] = cur_pre_ids[start_idx + k];
}
}
// Only early-exit when tokens were actually produced
if (start_idx < end_idx) {
return;
}
}
}
}
// ============================================================
// Phase 2 gather kernel — BlockScan threshold + copy
// <<<1, NGRAM_GATHER_THREADS>>>
//
// Reads tentative allocations from Phase 1 scratch buffers,
// computes prefix sums to enforce the global threshold, then
// writes final seq_lens_this_time and copies draft tokens.
// ============================================================
__global__ void ngram_match_gather_kernel(
const int64_t *draft_tokens_copy,
const int32_t *seq_lens_this_time_copy,
const int32_t *seq_lens_encoder,
int64_t *draft_tokens,
int32_t *seq_lens_this_time,
int64_t draft_tokens_stride,
int64_t max_batch_size,
int threshold) {
typedef cub::BlockScan<int, NGRAM_GATHER_THREADS> BlockScanInt;
__shared__ typename BlockScanInt::TempStorage temp_storage1;
__shared__ typename BlockScanInt::TempStorage temp_storage2;
__shared__ int s_total_active;
int tid = threadIdx.x;
// Load tentative values from Phase 1.
// Encoder-active items are included in the scan with their original
// seq_lens_this_time to match CPU threshold-budget accounting.
int tentative = 0;
int is_active = 0;
if (tid < max_batch_size) {
if (seq_lens_encoder[tid] > 0) {
// Encoder-active: contribute original token count to threshold budget.
// seq_lens_this_time[tid] is still unmodified at this point.
tentative = seq_lens_this_time[tid];
is_active = 1;
} else {
tentative = seq_lens_this_time_copy[tid];
is_active = (tentative > 0) ? 1 : 0;
}
}
// Scan 1: inclusive prefix sum of tentative token counts
int token_prefix;
BlockScanInt(temp_storage1).InclusiveSum(tentative, token_prefix);
__syncthreads();
// Scan 2: inclusive prefix sum of active-item indicators
int active_prefix;
BlockScanInt(temp_storage2).InclusiveSum(is_active, active_prefix);
__syncthreads();
// Total active count from the last valid thread
if (tid ==
min(static_cast<int>(max_batch_size) - 1, NGRAM_GATHER_THREADS - 1)) {
s_total_active = active_prefix;
}
__syncthreads();
if (tid < max_batch_size) {
// Encoder-active items: preserve original seq_lens_this_time.
if (seq_lens_encoder[tid] > 0) return;
if (tentative == 0) {
seq_lens_this_time[tid] = 0;
return;
}
int exclusive_token_prefix = token_prefix - tentative;
int remaining_active = s_total_active - active_prefix;
// Budget: total threshold minus tokens already allocated before me,
// minus at-least-1 reservation for every active item after me.
int budget = threshold - exclusive_token_prefix - remaining_active;
int actual;
if (budget <= 1) {
actual = 1; // base token only
} else {
actual = min(tentative, budget);
}
seq_lens_this_time[tid] = actual;
// Copy draft tokens (slots 1..actual-1) from scratch to output
if (actual > 1) {
int64_t *dst = draft_tokens + tid * draft_tokens_stride;
const int64_t *src = draft_tokens_copy + tid * draft_tokens_stride;
for (int k = 1; k < actual; k++) {
dst[k] = src[k];
}
}
}
}
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
// ============================================================
// CPU path — preserved from original ngram_match.cc for
// backward compatibility with CPU-only callers and tests.
// ============================================================
static int sum_cpu(const int *value, int num) {
int sum_value = 0;
for (int i = 0; i <= num; i++) {
sum_value += value[i];
}
return sum_value;
}
static void find_candidate_pred_tokens(const int64_t *input_ids,
const int64_t *input_ids_len,
const int64_t *token_ids_all,
const int64_t *prompt_lens,
const int64_t *step_idx,
const int *draft_token_num,
int64_t *draft_tokens,
int32_t *seq_lens_this_time,
int32_t *seq_lens_encoder,
int32_t *seq_lens_decoder,
int64_t *max_dec_len,
int64_t input_ids_stride,
int64_t max_model_len,
int64_t draft_tokens_stride,
int64_t max_batch_size,
int max_ngram_size = 3,
int max_draft_tokens = 10) {
int threshold = 128;
char *env_var = getenv("INFER_WITH_REFERENCE_TOKENUM_THRESHOLD");
if (env_var) {
threshold = std::stoi(env_var);
}
int unprocessed_batch_size = 0;
for (int batch_idx = 0; batch_idx < max_batch_size; batch_idx++) {
if (seq_lens_encoder[batch_idx] > 0 || seq_lens_decoder[batch_idx] > 0) {
unprocessed_batch_size++;
}
}
for (int batch_idx = 0; batch_idx < max_batch_size; batch_idx++) {
max_draft_tokens =
std::min(static_cast<int64_t>(draft_token_num[batch_idx]),
max_dec_len[batch_idx] - step_idx[batch_idx] - 1);
if (seq_lens_encoder[batch_idx] > 0) {
continue;
} else if (seq_lens_decoder[batch_idx] == 0) {
seq_lens_this_time[batch_idx] = 0;
continue;
}
const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride;
int64_t *cur_draft_tokens = draft_tokens + batch_idx * draft_tokens_stride;
const int64_t *cur_pre_ids =
token_ids_all + batch_idx * max_model_len + prompt_lens[batch_idx];
const int64_t cur_step_idx = step_idx[batch_idx];
const int64_t cur_input_ids_len = input_ids_len[batch_idx];
seq_lens_this_time[batch_idx] = 1;
unprocessed_batch_size--;
auto sum_token_num = sum_cpu(seq_lens_this_time, batch_idx);
int left_min_token_num = unprocessed_batch_size;
if (sum_token_num + max_draft_tokens + left_min_token_num > threshold) {
int tmp_max_draft_tokens = threshold - sum_token_num - left_min_token_num;
max_draft_tokens = tmp_max_draft_tokens < max_draft_tokens
? tmp_max_draft_tokens
: max_draft_tokens;
}
if (sum_token_num + left_min_token_num >= threshold - 1) {
continue;
}
for (int ngram_size = max_ngram_size; ngram_size > 0; --ngram_size) {
if (cur_step_idx < ngram_size) {
continue;
}
const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size);
bool match_input = false;
for (int64_t i = 0; i <= cur_input_ids_len - ngram_size; ++i) {
bool match = true;
for (int j = 0; j < ngram_size; j++) {
if (ngram[j] != cur_input_ids[i + j]) {
match = false;
break;
}
}
if (match) {
int64_t start_idx = i + ngram_size;
int64_t end_idx =
std::min(start_idx + max_draft_tokens, cur_input_ids_len);
if (start_idx >= end_idx) continue;
int64_t cur_draft_token_num = end_idx - start_idx;
seq_lens_this_time[batch_idx] = cur_draft_token_num + 1;
memcpy(cur_draft_tokens + 1,
cur_input_ids + start_idx,
sizeof(int64_t) * cur_draft_token_num);
ngram_size = 0;
match_input = true;
break;
}
}
if (!match_input) {
for (int64_t i = 0; i <= cur_step_idx - ngram_size; ++i) {
bool match = true;
for (int j = 0; j < ngram_size; j++) {
if (ngram[j] != cur_pre_ids[i + j]) {
match = false;
break;
}
}
if (match) {
int64_t start_idx = i + ngram_size;
int64_t end_idx =
std::min(start_idx + max_draft_tokens, cur_step_idx);
int64_t cur_draft_token_num = end_idx - start_idx;
if (start_idx >= end_idx) continue;
seq_lens_this_time[batch_idx] = cur_draft_token_num + 1;
memcpy(cur_draft_tokens + 1,
cur_pre_ids + start_idx,
sizeof(int64_t) * cur_draft_token_num);
ngram_size = 0;
break;
}
}
}
}
}
}
// ============================================================
// GPU path — Two-phase parallel CUDA kernels for ngram matching.
//
// Phase 1: <<<bsz, NGRAM_BLOCK_THREADS>>> — parallel sliding-window
// search within each batch item (NGRAM_BLOCK_THREADS threads
// per block). Also copies matched draft tokens to scratch.
// Phase 2: <<<1, NGRAM_GATHER_THREADS>>> — CUB BlockScan prefix-sum
// threshold enforcement + final token copy.
//
// Phase 1 is O(bsz × seq_len × ngram_size) distributed across
// bsz × NGRAM_BLOCK_THREADS threads. Phase 2 is O(bsz) with scans.
// ============================================================
void NgramMatch(const paddle::Tensor &input_ids,
const paddle::Tensor &input_ids_len,
const paddle::Tensor &token_ids_all,
const paddle::Tensor &prompt_lens,
const paddle::Tensor &step_idx,
const paddle::Tensor &draft_token_num,
const paddle::Tensor &draft_tokens,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &max_dec_len,
const int max_ngram_size,
const int max_draft_tokens) {
auto input_ids_shape = input_ids.shape();
const int64_t input_ids_stride = input_ids_shape[1];
const int64_t max_model_len = token_ids_all.shape()[1];
auto draft_tokens_shape = draft_tokens.shape();
const int64_t draft_tokens_stride = draft_tokens_shape[1];
const int64_t max_batch_size = seq_lens_this_time.shape()[0];
int threshold = 128;
const char *env_var = getenv("INFER_WITH_REFERENCE_TOKENUM_THRESHOLD");
if (env_var) {
threshold = std::stoi(env_var);
}
if (input_ids.is_gpu()) {
auto stream = input_ids.stream();
// Persistent scratch buffers for Phase 1 → Phase 2 communication.
// Cached across calls to avoid per-invocation allocation overhead.
// Write-before-read pattern (Phase 1 writes all elements before
// Phase 2 reads) means no initialization is needed between calls.
// Safety: single-threaded Python caller + CUDA stream serialization.
static paddle::Tensor s_draft_copy;
static paddle::Tensor s_seqlens_copy;
static int64_t s_scratch_batch = 0;
static int64_t s_scratch_stride = 0;
if (max_batch_size > s_scratch_batch ||
draft_tokens_stride > s_scratch_stride) {
s_draft_copy = paddle::empty({max_batch_size, draft_tokens_stride},
paddle::DataType::INT64,
input_ids.place());
s_seqlens_copy = paddle::empty(
{max_batch_size}, paddle::DataType::INT32, input_ids.place());
s_scratch_batch = max_batch_size;
s_scratch_stride = draft_tokens_stride;
}
auto &draft_tokens_copy = s_draft_copy;
auto &seq_lens_this_time_copy = s_seqlens_copy;
// Fail-fast: BlockScan Phase 2 requires max_batch_size ≤ block size.
PD_CHECK(max_batch_size <= NGRAM_GATHER_THREADS,
"ngram_match: max_batch_size exceeds NGRAM_GATHER_THREADS");
// Phase 1: parallel search — one block per batch item.
// Also copies matched tokens to scratch and writes tentative seq_lens.
ngram_match_search_kernel<<<max_batch_size,
NGRAM_BLOCK_THREADS,
0,
stream>>>(
input_ids.data<int64_t>(),
input_ids_len.data<int64_t>(),
token_ids_all.data<int64_t>(),
prompt_lens.data<int64_t>(),
step_idx.data<int64_t>(),
draft_token_num.data<int>(),
seq_lens_encoder.data<int32_t>(),
seq_lens_decoder.data<int32_t>(),
max_dec_len.data<int64_t>(),
draft_tokens_copy.data<int64_t>(),
seq_lens_this_time_copy.data<int32_t>(),
input_ids_stride,
max_model_len,
draft_tokens_stride,
max_batch_size,
max_ngram_size);
// Phase 2: BlockScan threshold enforcement + final token copy.
// <<<1, NGRAM_GATHER_THREADS>>> — all batch items handled by one block.
ngram_match_gather_kernel<<<1, NGRAM_GATHER_THREADS, 0, stream>>>(
draft_tokens_copy.data<int64_t>(),
seq_lens_this_time_copy.data<int32_t>(),
seq_lens_encoder.data<int32_t>(),
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
draft_tokens_stride,
max_batch_size,
threshold);
} else {
find_candidate_pred_tokens(
input_ids.data<int64_t>(),
input_ids_len.data<int64_t>(),
token_ids_all.data<int64_t>(),
prompt_lens.data<int64_t>(),
step_idx.data<int64_t>(),
draft_token_num.data<int>(),
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
const_cast<int32_t *>(seq_lens_encoder.data<int32_t>()),
const_cast<int32_t *>(seq_lens_decoder.data<int32_t>()),
const_cast<int64_t *>(max_dec_len.data<int64_t>()),
input_ids_stride,
max_model_len,
draft_tokens_stride,
max_batch_size,
max_ngram_size,
max_draft_tokens);
}
}
PD_BUILD_STATIC_OP(ngram_match)
.Inputs({"input_ids",
"input_ids_len",
"token_ids_all",
"prompt_lens",
"step_idx",
"draft_token_num",
"draft_tokens",
"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_decoder",
"max_dec_len"})
.Attrs({"max_ngram_size: int", "max_draft_tokens: int"})
.Outputs({"draft_tokens_out", "seq_lens_this_time_out"})
.SetKernelFn(PD_KERNEL(NgramMatch))
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
{"seq_lens_this_time", "seq_lens_this_time_out"}});
@@ -0,0 +1,151 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <climits>
// Shared ngram matching logic used by both ngram_match_kernel and
// ngram_match_mixed_kernel. Extracted per upstream requirement:
// "两个Kernel逻辑有较为相似部分,Kernel 形式为提取共用的匹配逻辑,外加业务逻辑"
//
// Two-phase parallel architecture:
// Phase 1 — <<<bsz, NGRAM_BLOCK_THREADS>>>: parallel sliding-window
// search + tentative token copy (one block per batch item).
// Phase 2 — <<<1, NGRAM_GATHER_THREADS>>>: parallel threshold truncation
// via CUB BlockScan prefix-sum, then copy winners to output
#define NGRAM_BLOCK_THREADS 1024
#define NGRAM_GATHER_THREADS 1024
// ------------------------------------------------------------
// atomicMin for int64_t via CAS loop. CUDA has no native
// int64 atomicMin. All values are non-negative positions or
// INT64_MAX, so unsigned reinterpretation is safe.
// ------------------------------------------------------------
__device__ __forceinline__ void atomicMin64(int64_t *addr, int64_t val) {
unsigned long long *addr_ull = reinterpret_cast<unsigned long long *>(addr);
unsigned long long val_ull = static_cast<unsigned long long>(val);
// Non-atomic initial read is intentional: the CAS loop below detects and
// retries on any stale value, so a torn read here is harmless.
unsigned long long old = *addr_ull;
while (val_ull < old) {
unsigned long long assumed = old;
old = atomicCAS(addr_ull, assumed, val_ull);
if (old == assumed) break;
}
}
// ------------------------------------------------------------
// parallel_ngram_search — Block-cooperative haystack search.
//
// Template-specialized for common ngram sizes (1-3) to enable:
// - Register caching of ngram tokens (avoid repeated global loads)
// - Full compile-time unrolling of inner comparison loop
// - __restrict__ hints for pointer non-aliasing optimization
//
// Runtime dispatcher preserves the original call signature so both
// ngram_match.cu and ngram_match_mixed.cu work transparently.
//
// Early-exit (A2): once a match is found (s_min_pos < INT64_MAX),
// threads that are past the current best skip remaining work.
//
// Returns the leftmost match position, or INT64_MAX if no match.
// Caller must provide __shared__ int64_t s_min_pos.
// ------------------------------------------------------------
template <int NGRAM_SIZE>
__device__ __forceinline__ int64_t
parallel_ngram_search_specialized(const int64_t *__restrict__ haystack,
int64_t haystack_len,
const int64_t *__restrict__ ngram,
int64_t *s_min_pos) {
int tid = threadIdx.x;
int nthreads = blockDim.x;
if (tid == 0) *s_min_pos = INT64_MAX;
__syncthreads();
int64_t search_len = haystack_len - NGRAM_SIZE + 1;
if (search_len <= 0) {
__syncthreads();
return *s_min_pos;
}
// Cache ngram tokens in registers — eliminates repeated global reads.
int64_t ng[NGRAM_SIZE];
#pragma unroll
for (int j = 0; j < NGRAM_SIZE; j++) ng[j] = ngram[j];
for (int64_t i = tid; i < search_len; i += nthreads) {
// A2: Early-exit — skip positions beyond current best match.
if (i > *s_min_pos) break;
bool match = true;
#pragma unroll
for (int j = 0; j < NGRAM_SIZE; j++) {
if (ng[j] != haystack[i + j]) {
match = false;
break;
}
}
if (match) atomicMin64(s_min_pos, i);
}
__syncthreads();
return *s_min_pos;
}
// Runtime dispatcher — same signature as original, transparent to callers.
__device__ __forceinline__ int64_t
parallel_ngram_search(const int64_t *__restrict__ haystack,
int64_t haystack_len,
const int64_t *__restrict__ ngram,
int ngram_size,
int64_t *s_min_pos) {
switch (ngram_size) {
case 1:
return parallel_ngram_search_specialized<1>(
haystack, haystack_len, ngram, s_min_pos);
case 2:
return parallel_ngram_search_specialized<2>(
haystack, haystack_len, ngram, s_min_pos);
case 3:
return parallel_ngram_search_specialized<3>(
haystack, haystack_len, ngram, s_min_pos);
default:
break;
}
// Fallback for ngram_size > 3 — runtime loop, no unrolling.
int tid = threadIdx.x;
int nthreads = blockDim.x;
if (tid == 0) *s_min_pos = INT64_MAX;
__syncthreads();
int64_t search_len = haystack_len - ngram_size + 1;
if (search_len <= 0) {
__syncthreads();
return *s_min_pos;
}
for (int64_t i = tid; i < search_len; i += nthreads) {
if (i > *s_min_pos) break;
bool match = true;
for (int j = 0; j < ngram_size; j++) {
if (ngram[j] != haystack[i + j]) {
match = false;
break;
}
}
if (match) atomicMin64(s_min_pos, i);
}
__syncthreads();
return *s_min_pos;
}