mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
7a6c28781b
* optimize attn_mask_offset and optimize mtp usage * delete useless branch * fix kernel format * fix kernel runner
692 lines
24 KiB
Python
692 lines
24 KiB
Python
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
Unit tests for verify_draft_tokens kernel.
|
|
|
|
Verification strategies:
|
|
- TOPP (0): Verify draft token is in top-p candidate set
|
|
- GREEDY (1): Verify draft token matches target model's argmax
|
|
- TARGET_MATCH (2): Verify draft token matches target model's sampled token
|
|
"""
|
|
|
|
import random
|
|
import unittest
|
|
from typing import Any, Dict
|
|
|
|
import numpy as np
|
|
import paddle
|
|
|
|
from fastdeploy.model_executor.ops.gpu import verify_draft_tokens
|
|
from fastdeploy.spec_decode import VerifyStrategy
|
|
|
|
CUDA_PLACE = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() else paddle.CPUPlace()
|
|
CPU_PLACE = paddle.CPUPlace()
|
|
|
|
|
|
# ============================================================
|
|
# Helpers: tensor creation / kernel invocation / comparison
|
|
# ============================================================
|
|
|
|
|
|
def to_paddle_inputs(inputs: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Convert numpy input dict to paddle tensors on GPU."""
|
|
paddle_inputs = {}
|
|
for k, v in inputs.items():
|
|
if isinstance(v, (int, bool, float, str)):
|
|
paddle_inputs[k] = v
|
|
elif v is not None:
|
|
paddle_inputs[k] = paddle.to_tensor(v, place=CUDA_PLACE)
|
|
else:
|
|
paddle_inputs[k] = None
|
|
return paddle_inputs
|
|
|
|
|
|
def run_kernel(paddle_inputs: Dict[str, Any], inputs: Dict[str, Any]):
|
|
"""Call verify_draft_tokens kernel."""
|
|
verify_draft_tokens(
|
|
paddle_inputs["step_output_ids"],
|
|
paddle_inputs["step_output_len"],
|
|
paddle_inputs["step_input_ids"],
|
|
paddle_inputs["target_tokens"],
|
|
paddle_inputs["candidate_ids"],
|
|
paddle_inputs["candidate_scores"],
|
|
paddle_inputs["candidate_lens"],
|
|
paddle_inputs["topp"],
|
|
paddle_inputs["stop_flags"],
|
|
paddle_inputs["seq_lens_encoder"],
|
|
paddle_inputs["seq_lens_this_time"],
|
|
paddle_inputs["end_tokens"],
|
|
paddle_inputs["is_block_step"],
|
|
paddle_inputs["cu_seqlens_q_output"],
|
|
paddle_inputs["reasoning_status"],
|
|
paddle_inputs["max_dec_len"],
|
|
paddle_inputs["step_idx"],
|
|
inputs["max_seq_len"],
|
|
inputs["verify_window"],
|
|
inputs["verify_strategy"],
|
|
inputs["reject_all"],
|
|
inputs["accept_all"],
|
|
)
|
|
|
|
|
|
def run_ref(inputs: Dict[str, Any]):
|
|
"""Run reference implementation on deep-copied inputs, return (output_ids, output_len)."""
|
|
ref = {k: v.copy() if isinstance(v, np.ndarray) else v for k, v in inputs.items()}
|
|
return verify_draft_tokens_ref(
|
|
ref["step_output_ids"],
|
|
ref["step_output_len"],
|
|
ref["step_input_ids"],
|
|
ref["target_tokens"],
|
|
ref["candidate_ids"],
|
|
ref["candidate_scores"],
|
|
ref["candidate_lens"],
|
|
ref["topp"],
|
|
ref["stop_flags"],
|
|
ref["seq_lens_encoder"],
|
|
ref["seq_lens_this_time"],
|
|
ref["end_tokens"],
|
|
ref["is_block_step"],
|
|
ref["cu_seqlens_q_output"],
|
|
ref["reasoning_status"],
|
|
ref["max_dec_len"],
|
|
ref["step_idx"],
|
|
ref["max_seq_len"],
|
|
ref["verify_window"],
|
|
ref["verify_strategy"],
|
|
ref["reject_all"],
|
|
ref["accept_all"],
|
|
)
|
|
|
|
|
|
def compare_results(
|
|
paddle_inputs: Dict[str, Any],
|
|
step_output_ids_ref: np.ndarray,
|
|
step_output_len_ref: np.ndarray,
|
|
inputs: Dict[str, Any],
|
|
label: str = "unknown",
|
|
):
|
|
"""Compare GPU kernel output vs reference."""
|
|
gpu_ids = paddle_inputs["step_output_ids"].numpy()
|
|
gpu_len = paddle_inputs["step_output_len"].numpy()
|
|
|
|
np.testing.assert_array_equal(
|
|
gpu_len,
|
|
step_output_len_ref,
|
|
err_msg=f"step_output_len mismatch ({label})",
|
|
)
|
|
|
|
if inputs["verify_strategy"] == 0: # TOPP — Phase 2 is stochastic
|
|
real_bsz = inputs["seq_lens_this_time"].shape[0]
|
|
for bid in range(real_bsz):
|
|
ref_len = int(step_output_len_ref[bid])
|
|
if ref_len > 1:
|
|
np.testing.assert_array_equal(
|
|
gpu_ids[bid, : ref_len - 1],
|
|
step_output_ids_ref[bid, : ref_len - 1],
|
|
err_msg=f"step_output_ids (accepted) mismatch at bid={bid} ({label})",
|
|
)
|
|
else:
|
|
np.testing.assert_array_equal(
|
|
gpu_ids,
|
|
step_output_ids_ref,
|
|
err_msg=f"step_output_ids mismatch ({label})",
|
|
)
|
|
|
|
|
|
# ============================================================
|
|
# Reference helpers
|
|
# ============================================================
|
|
|
|
|
|
def topp_sampling_kernel(candidate_ids, candidate_scores, curand_value, candidate_len, topp, tid=0):
|
|
rand_top_p = curand_value * topp
|
|
sum_scores = 0.0
|
|
for i in range(candidate_len):
|
|
sum_scores += candidate_scores[i]
|
|
if rand_top_p <= sum_scores:
|
|
return int(candidate_ids[i])
|
|
return int(candidate_ids[0])
|
|
|
|
|
|
def is_in_end(token, end_tokens, end_length):
|
|
return token in end_tokens[:end_length]
|
|
|
|
|
|
def is_in(candidate_list, token, length):
|
|
return token in candidate_list[:length]
|
|
|
|
|
|
class _VerifyContext:
|
|
"""Python mirror of the CUDA VerifyContext struct for reference testing."""
|
|
|
|
def __init__(
|
|
self,
|
|
bid,
|
|
max_step_tokens,
|
|
end_length,
|
|
end_tokens,
|
|
max_dec_len,
|
|
step_input_ids_now,
|
|
step_output_ids_flat,
|
|
cur_step_idx,
|
|
):
|
|
self.bid = bid
|
|
self.max_step_tokens = max_step_tokens
|
|
self.end_length = end_length
|
|
self.end_tokens = end_tokens
|
|
self.max_dec_len = max_dec_len
|
|
self.step_input_ids_now = step_input_ids_now
|
|
self.step_output_ids_flat = step_output_ids_flat
|
|
self.cur_step_idx = cur_step_idx
|
|
self.output_len_now = 0
|
|
self.stopped = False
|
|
|
|
def emit_token(self, pos, token):
|
|
"""Emit a token to output. Returns True if sequence should stop."""
|
|
self.cur_step_idx += 1
|
|
eos = is_in_end(token, self.end_tokens, self.end_length)
|
|
max_hit = self.cur_step_idx >= int(self.max_dec_len[self.bid])
|
|
if (eos or max_hit) and not eos:
|
|
token = int(self.end_tokens[0])
|
|
self.step_output_ids_flat[self.bid * self.max_step_tokens + pos] = token
|
|
self.output_len_now += 1
|
|
if eos or max_hit:
|
|
self.stopped = True
|
|
return True
|
|
return False
|
|
|
|
def emit_final_token(self, pos, token):
|
|
"""Emit the Phase 2 final token. Increments output_len_now."""
|
|
self.cur_step_idx += 1
|
|
eos = is_in_end(token, self.end_tokens, self.end_length)
|
|
max_hit = self.cur_step_idx >= int(self.max_dec_len[self.bid])
|
|
if (eos or max_hit) and not eos:
|
|
token = int(self.end_tokens[0])
|
|
self.step_output_ids_flat[self.bid * self.max_step_tokens + pos] = token
|
|
self.output_len_now += 1
|
|
|
|
|
|
# NOTE: try_verify_window_fallback was removed from the CUDA kernel.
|
|
# TOPP strategy now rejects immediately when draft token is not in candidate set.
|
|
|
|
|
|
def verify_draft_tokens_ref(
|
|
step_output_ids,
|
|
step_output_len,
|
|
step_input_ids,
|
|
target_tokens,
|
|
candidate_ids,
|
|
candidate_scores,
|
|
candidate_lens,
|
|
topp,
|
|
stop_flags,
|
|
seq_lens_encoder,
|
|
seq_lens_this_time,
|
|
end_tokens,
|
|
is_block_step,
|
|
cu_seqlens_q_output,
|
|
reasoning_status,
|
|
max_dec_len,
|
|
step_idx,
|
|
max_seq_len,
|
|
verify_window,
|
|
verify_strategy,
|
|
reject_all,
|
|
accept_all,
|
|
):
|
|
"""Reference implementation of verify_draft_tokens in Python."""
|
|
real_bsz = seq_lens_this_time.shape[0]
|
|
max_step_tokens = step_input_ids.shape[1]
|
|
end_length = end_tokens.shape[0]
|
|
max_candidate_len = candidate_ids.shape[1] if candidate_ids is not None else 1
|
|
|
|
dev_curand_states = [random.Random(0).random() for _ in range(max_step_tokens)]
|
|
|
|
step_output_ids_flat = step_output_ids.reshape(-1)
|
|
# Kernel initializes step_output_ids to -1 for all slots
|
|
step_output_ids_flat[:] = -1
|
|
step_input_ids_flat = step_input_ids.reshape(-1)
|
|
candidate_ids_flat = candidate_ids.reshape(-1) if candidate_ids is not None else None
|
|
candidate_scores_flat = candidate_scores.reshape(-1) if candidate_scores is not None else None
|
|
|
|
for bid in range(real_bsz):
|
|
start_token_id = cu_seqlens_q_output[bid]
|
|
|
|
if is_block_step[bid] or stop_flags[bid]:
|
|
step_output_len[bid] = 0
|
|
continue
|
|
|
|
step_input_ids_now = step_input_ids_flat[bid * max_step_tokens :]
|
|
target_tokens_now = target_tokens[start_token_id:] if target_tokens is not None else None
|
|
candidate_ids_now = (
|
|
candidate_ids_flat[start_token_id * max_candidate_len :] if candidate_ids_flat is not None else None
|
|
)
|
|
candidate_lens_now = candidate_lens[start_token_id:] if candidate_lens is not None else None
|
|
candidate_scores_now = (
|
|
candidate_scores_flat[start_token_id * max_candidate_len :] if candidate_scores_flat is not None else None
|
|
)
|
|
|
|
ctx = _VerifyContext(
|
|
bid,
|
|
max_step_tokens,
|
|
end_length,
|
|
end_tokens,
|
|
max_dec_len,
|
|
step_input_ids_now,
|
|
step_output_ids_flat,
|
|
int(step_idx[bid]),
|
|
)
|
|
|
|
# Phase 1: Verify
|
|
i = 0
|
|
while i < seq_lens_this_time[bid] - 1:
|
|
if reject_all or seq_lens_encoder[bid] != 0 or reasoning_status[bid] == 1:
|
|
break
|
|
if accept_all:
|
|
if ctx.emit_token(i, step_input_ids_now[i + 1]):
|
|
break
|
|
i += 1
|
|
continue
|
|
|
|
accepted = False
|
|
if verify_strategy == 0: # TOPP
|
|
actual_cand_len = min(candidate_lens_now[i], max_candidate_len)
|
|
accepted = is_in(
|
|
candidate_ids_now[i * max_candidate_len : (i + 1) * max_candidate_len],
|
|
step_input_ids_now[i + 1],
|
|
actual_cand_len,
|
|
)
|
|
elif verify_strategy in (1, 2): # GREEDY / TARGET_MATCH
|
|
accepted = target_tokens_now[i] == step_input_ids_now[i + 1]
|
|
|
|
if accepted:
|
|
if ctx.emit_token(i, step_input_ids_now[i + 1]):
|
|
break
|
|
else:
|
|
break
|
|
i += 1
|
|
|
|
# Phase 2: Sample for rejected/last position
|
|
if not ctx.stopped:
|
|
if verify_strategy == 0:
|
|
if candidate_lens_now is not None and len(candidate_lens_now) > i:
|
|
actual_cand_len = min(candidate_lens_now[i], max_candidate_len)
|
|
accept_token = topp_sampling_kernel(
|
|
candidate_ids_now[i * max_candidate_len : (i + 1) * max_candidate_len],
|
|
candidate_scores_now[i * max_candidate_len : (i + 1) * max_candidate_len],
|
|
dev_curand_states[i],
|
|
actual_cand_len,
|
|
topp[bid],
|
|
)
|
|
else:
|
|
accept_token = int(step_input_ids_now[0])
|
|
elif verify_strategy in (1, 2):
|
|
accept_token = (
|
|
int(target_tokens_now[i])
|
|
if target_tokens_now is not None and len(target_tokens_now) > i
|
|
else int(step_input_ids_now[0])
|
|
)
|
|
else:
|
|
accept_token = (
|
|
int(candidate_ids_now[i * max_candidate_len])
|
|
if candidate_ids_now is not None
|
|
else int(step_input_ids_now[0])
|
|
)
|
|
ctx.emit_final_token(i, accept_token)
|
|
|
|
step_output_len[bid] = ctx.output_len_now
|
|
|
|
return step_output_ids, step_output_len
|
|
|
|
|
|
# ============================================================
|
|
# Input generation
|
|
# ============================================================
|
|
|
|
|
|
def gen_verify_draft_tokens_inputs(
|
|
real_bsz: int = 32,
|
|
max_draft_tokens: int = 16,
|
|
max_seq_len: int = 256,
|
|
max_candidate_len: int = 8,
|
|
verify_window: int = 2,
|
|
end_length: int = 4,
|
|
verify_strategy: int = 1,
|
|
reject_all: bool = False,
|
|
accept_all: bool = False,
|
|
match_ratio: float = 0.0,
|
|
seed: int = 2025,
|
|
) -> Dict[str, Any]:
|
|
"""Generate test inputs for verify_draft_tokens kernel.
|
|
|
|
Args:
|
|
match_ratio: Fraction of draft token positions where target/candidates
|
|
are forced to match step_input_ids, so the acceptance path is exercised.
|
|
0.0 = fully random (mostly rejects), 1.0 = all positions match.
|
|
"""
|
|
rng = np.random.default_rng(seed)
|
|
|
|
seq_lens_encoder = np.zeros(real_bsz, dtype=np.int32)
|
|
seq_lens_this_time = rng.integers(1, max_draft_tokens + 1, size=real_bsz, dtype=np.int32)
|
|
step_input_ids = rng.integers(0, 1000, size=(real_bsz, max_draft_tokens), dtype=np.int64)
|
|
|
|
sum_seq = int(np.sum(seq_lens_this_time))
|
|
|
|
if verify_strategy in (1, 2): # GREEDY / TARGET_MATCH
|
|
target_tokens = rng.integers(0, 1000, size=(sum_seq,), dtype=np.int64)
|
|
candidate_ids = None
|
|
candidate_scores = None
|
|
candidate_lens = None
|
|
else: # TOPP
|
|
target_tokens = None
|
|
candidate_ids = rng.integers(0, 1000, size=(sum_seq, max_candidate_len), dtype=np.int64)
|
|
candidate_scores = rng.random(size=(sum_seq, max_candidate_len)).astype(np.float32)
|
|
candidate_scores = candidate_scores / candidate_scores.sum(axis=1, keepdims=True)
|
|
candidate_lens = rng.integers(1, max_candidate_len + 1, size=sum_seq, dtype=np.int32)
|
|
|
|
end_tokens = rng.integers(1, 1000, size=end_length, dtype=np.int64)
|
|
is_block_step = rng.integers(0, 2, size=real_bsz, dtype=bool)
|
|
|
|
cu_seqlens_q_output = np.zeros(real_bsz + 1, dtype=np.int32)
|
|
for i in range(real_bsz):
|
|
cu_seqlens_q_output[i + 1] = cu_seqlens_q_output[i] + seq_lens_this_time[i]
|
|
cu_seqlens_q_output = cu_seqlens_q_output[:real_bsz].astype(np.int32)
|
|
|
|
topp = rng.uniform(0.8, 1.0, size=real_bsz).astype(np.float32)
|
|
reasoning_status = np.zeros(real_bsz, dtype=np.int32)
|
|
step_output_ids = np.zeros((real_bsz, max_draft_tokens), dtype=np.int64)
|
|
step_output_len = np.zeros(real_bsz, dtype=np.int32)
|
|
stop_flags = np.zeros(real_bsz, dtype=bool)
|
|
|
|
# Force match_ratio fraction of positions so acceptance path is tested
|
|
if match_ratio > 0.0:
|
|
offset = 0
|
|
for bid in range(real_bsz):
|
|
slt = int(seq_lens_this_time[bid])
|
|
n_match = max(1, int((slt - 1) * match_ratio)) # slt-1 verify positions
|
|
for pos in range(min(n_match, slt - 1)):
|
|
draft_token = int(step_input_ids[bid, pos + 1])
|
|
# Ensure draft_token is not an end_token (would cause early stop)
|
|
while draft_token in end_tokens[:end_length]:
|
|
draft_token = (draft_token + 1) % 1000
|
|
step_input_ids[bid, pos + 1] = draft_token
|
|
if verify_strategy in (1, 2) and target_tokens is not None:
|
|
target_tokens[offset + pos] = draft_token
|
|
elif verify_strategy == 0 and candidate_ids is not None:
|
|
candidate_ids[offset + pos, 0] = draft_token
|
|
candidate_lens[offset + pos] = max(candidate_lens[offset + pos], 1)
|
|
offset += slt
|
|
|
|
return {
|
|
"step_output_ids": step_output_ids,
|
|
"step_output_len": step_output_len,
|
|
"step_input_ids": step_input_ids,
|
|
"target_tokens": target_tokens,
|
|
"candidate_ids": candidate_ids,
|
|
"candidate_scores": candidate_scores,
|
|
"candidate_lens": candidate_lens,
|
|
"topp": topp,
|
|
"stop_flags": stop_flags,
|
|
"seq_lens_encoder": seq_lens_encoder,
|
|
"seq_lens_this_time": seq_lens_this_time,
|
|
"end_tokens": end_tokens,
|
|
"is_block_step": is_block_step,
|
|
"cu_seqlens_q_output": cu_seqlens_q_output,
|
|
"reasoning_status": reasoning_status,
|
|
"max_dec_len": rng.integers(50, 200, size=real_bsz, dtype=np.int64),
|
|
"step_idx": rng.integers(0, 30, size=real_bsz, dtype=np.int64),
|
|
"max_seq_len": max_seq_len,
|
|
"verify_window": verify_window,
|
|
"verify_strategy": verify_strategy,
|
|
"reject_all": reject_all,
|
|
"accept_all": accept_all,
|
|
}
|
|
|
|
|
|
# ============================================================
|
|
# Test configs
|
|
# ============================================================
|
|
|
|
TEST_CONFIGS = [
|
|
# --- strategy coverage (random, mostly rejects) ---
|
|
{
|
|
"name": "greedy_small_batch",
|
|
"real_bsz": 1,
|
|
"max_draft_tokens": 9,
|
|
"max_seq_len": 11,
|
|
"max_candidate_len": 4,
|
|
"verify_window": 2,
|
|
"end_length": 5,
|
|
"verify_strategy": VerifyStrategy.GREEDY.value,
|
|
"seed": 42,
|
|
},
|
|
{
|
|
"name": "greedy_medium_batch",
|
|
"real_bsz": 33,
|
|
"max_draft_tokens": 5,
|
|
"max_seq_len": 10111,
|
|
"max_candidate_len": 5,
|
|
"verify_window": 2,
|
|
"end_length": 6,
|
|
"verify_strategy": VerifyStrategy.GREEDY.value,
|
|
"seed": 42,
|
|
},
|
|
{
|
|
"name": "topp_small_batch",
|
|
"real_bsz": 6,
|
|
"max_draft_tokens": 4,
|
|
"max_seq_len": 10001,
|
|
"max_candidate_len": 6,
|
|
"verify_window": 2,
|
|
"end_length": 7,
|
|
"verify_strategy": VerifyStrategy.TOPP.value,
|
|
"seed": 42,
|
|
},
|
|
{
|
|
"name": "target_match_medium",
|
|
"real_bsz": 7,
|
|
"max_draft_tokens": 3,
|
|
"max_seq_len": 777,
|
|
"max_candidate_len": 7,
|
|
"verify_window": 2,
|
|
"end_length": 5,
|
|
"verify_strategy": VerifyStrategy.TARGET_MATCH.value,
|
|
"seed": 42,
|
|
},
|
|
{
|
|
"name": "greedy_large_batch",
|
|
"real_bsz": 55,
|
|
"max_draft_tokens": 5,
|
|
"max_seq_len": 31,
|
|
"max_candidate_len": 9,
|
|
"verify_window": 2,
|
|
"end_length": 3,
|
|
"verify_strategy": VerifyStrategy.GREEDY.value,
|
|
"seed": 42,
|
|
},
|
|
# --- partial acceptance (match_ratio forces draft tokens to match target/candidates) ---
|
|
{
|
|
"name": "greedy_half_accept",
|
|
"real_bsz": 8,
|
|
"max_draft_tokens": 8,
|
|
"max_seq_len": 256,
|
|
"max_candidate_len": 4,
|
|
"verify_window": 2,
|
|
"end_length": 3,
|
|
"verify_strategy": VerifyStrategy.GREEDY.value,
|
|
"seed": 42,
|
|
"match_ratio": 0.5,
|
|
},
|
|
{
|
|
"name": "greedy_full_accept",
|
|
"real_bsz": 8,
|
|
"max_draft_tokens": 8,
|
|
"max_seq_len": 256,
|
|
"max_candidate_len": 4,
|
|
"verify_window": 2,
|
|
"end_length": 3,
|
|
"verify_strategy": VerifyStrategy.GREEDY.value,
|
|
"seed": 42,
|
|
"match_ratio": 1.0,
|
|
},
|
|
{
|
|
"name": "topp_half_accept",
|
|
"real_bsz": 8,
|
|
"max_draft_tokens": 8,
|
|
"max_seq_len": 256,
|
|
"max_candidate_len": 6,
|
|
"verify_window": 2,
|
|
"end_length": 3,
|
|
"verify_strategy": VerifyStrategy.TOPP.value,
|
|
"seed": 42,
|
|
"match_ratio": 0.5,
|
|
},
|
|
{
|
|
"name": "topp_full_accept",
|
|
"real_bsz": 8,
|
|
"max_draft_tokens": 8,
|
|
"max_seq_len": 256,
|
|
"max_candidate_len": 6,
|
|
"verify_window": 2,
|
|
"end_length": 3,
|
|
"verify_strategy": VerifyStrategy.TOPP.value,
|
|
"seed": 42,
|
|
"match_ratio": 1.0,
|
|
},
|
|
{
|
|
"name": "target_match_accept",
|
|
"real_bsz": 8,
|
|
"max_draft_tokens": 6,
|
|
"max_seq_len": 256,
|
|
"max_candidate_len": 4,
|
|
"verify_window": 2,
|
|
"end_length": 3,
|
|
"verify_strategy": VerifyStrategy.TARGET_MATCH.value,
|
|
"seed": 42,
|
|
"match_ratio": 0.7,
|
|
},
|
|
# --- reject_all / accept_all (kernel-level flags) ---
|
|
{
|
|
"name": "reject_all",
|
|
"real_bsz": 8,
|
|
"max_draft_tokens": 5,
|
|
"max_seq_len": 100,
|
|
"max_candidate_len": 5,
|
|
"verify_window": 2,
|
|
"end_length": 3,
|
|
"verify_strategy": VerifyStrategy.GREEDY.value,
|
|
"seed": 42,
|
|
"reject_all": True,
|
|
},
|
|
{
|
|
"name": "accept_all",
|
|
"real_bsz": 8,
|
|
"max_draft_tokens": 5,
|
|
"max_seq_len": 100,
|
|
"max_candidate_len": 5,
|
|
"verify_window": 2,
|
|
"end_length": 3,
|
|
"verify_strategy": VerifyStrategy.TOPP.value,
|
|
"seed": 42,
|
|
"accept_all": True,
|
|
},
|
|
# --- edge cases ---
|
|
{
|
|
"name": "empty_batch",
|
|
"real_bsz": 1,
|
|
"max_draft_tokens": 1,
|
|
"max_seq_len": 10,
|
|
"max_candidate_len": 2,
|
|
"verify_window": 1,
|
|
"end_length": 4,
|
|
"verify_strategy": VerifyStrategy.GREEDY.value,
|
|
"seed": 42,
|
|
},
|
|
]
|
|
|
|
|
|
# ============================================================
|
|
# Test suite
|
|
# ============================================================
|
|
|
|
|
|
class TestVerifyDraftTokens(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
if not paddle.is_compiled_with_cuda():
|
|
self.skipTest("Requires CUDA")
|
|
|
|
# ------ shared run + check helper ------
|
|
|
|
def _run_and_compare(self, inputs: Dict[str, Any], label: str = ""):
|
|
"""Convert→run kernel→run ref→compare."""
|
|
paddle_inputs = to_paddle_inputs(inputs)
|
|
run_kernel(paddle_inputs, inputs)
|
|
ids_ref, len_ref = run_ref(inputs)
|
|
compare_results(paddle_inputs, ids_ref, len_ref, inputs, label)
|
|
return paddle_inputs
|
|
|
|
# ------ test cases ------
|
|
|
|
def test_verify_configs(self):
|
|
"""Test all configs in TEST_CONFIGS (strategies, reject/accept, edge cases)."""
|
|
for cfg in TEST_CONFIGS:
|
|
with self.subTest(name=cfg["name"]):
|
|
test_cfg = {k: v for k, v in cfg.items() if k != "name"}
|
|
inputs = gen_verify_draft_tokens_inputs(**test_cfg)
|
|
self._run_and_compare(inputs, label=cfg["name"])
|
|
|
|
def test_eos_handling(self):
|
|
"""Test EOS token in draft triggers early stop."""
|
|
inputs = gen_verify_draft_tokens_inputs(
|
|
real_bsz=4, max_draft_tokens=5, verify_strategy=VerifyStrategy.GREEDY.value, seed=42
|
|
)
|
|
inputs["step_input_ids"][0, 2] = inputs["end_tokens"][0]
|
|
self._run_and_compare(inputs, label="eos_handling")
|
|
|
|
def test_max_dec_len_truncation(self):
|
|
"""Test max_dec_len causes token replacement with end_tokens[0]."""
|
|
inputs = gen_verify_draft_tokens_inputs(
|
|
real_bsz=4, max_draft_tokens=5, verify_strategy=VerifyStrategy.GREEDY.value, seed=42, match_ratio=1.0
|
|
)
|
|
# Set step_idx close to max_dec_len so it triggers during verification
|
|
inputs["step_idx"][:] = [48, 10, 10, 10]
|
|
inputs["max_dec_len"][:] = [50, 200, 200, 200]
|
|
inputs["is_block_step"][:] = False
|
|
inputs["stop_flags"][:] = False
|
|
# Ensure no accidental EOS in draft tokens
|
|
for bid in range(4):
|
|
for j in range(5):
|
|
while inputs["step_input_ids"][bid, j] in inputs["end_tokens"]:
|
|
inputs["step_input_ids"][bid, j] = (inputs["step_input_ids"][bid, j] + 1) % 1000
|
|
self._run_and_compare(inputs, label="max_dec_len_truncation")
|
|
|
|
def test_verify_strategy_enum(self):
|
|
self.assertEqual(VerifyStrategy.TOPP.value, 0)
|
|
self.assertEqual(VerifyStrategy.GREEDY.value, 1)
|
|
self.assertEqual(VerifyStrategy.TARGET_MATCH.value, 2)
|
|
|
|
def test_verify_strategy_from_string(self):
|
|
self.assertEqual(VerifyStrategy.from_string("topp"), VerifyStrategy.TOPP)
|
|
self.assertEqual(VerifyStrategy.from_string("TOPP"), VerifyStrategy.TOPP)
|
|
self.assertEqual(VerifyStrategy.from_string("greedy"), VerifyStrategy.GREEDY)
|
|
self.assertEqual(VerifyStrategy.from_string("target_match"), VerifyStrategy.TARGET_MATCH)
|
|
with self.assertRaises(ValueError):
|
|
VerifyStrategy.from_string("invalid")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|