Files
FastDeploy/custom_ops/xpu_ops/test/test_verify_draft_tokens.py
T
cmcamdy 13b9fe7299 [XPU] add verify draft tokens (#6947)
* [XPU] add verify draft tokens

* fix test

* fix code style

* use sync cpy

* fix code style

* fix kernel check

* fix ramdom seed

* fix test

* fix check

* fix eos set

* fix verify

* fix verify
2026-04-15 10:18:33 +08:00

1040 lines
40 KiB
Python

# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Unit tests for verify_draft_tokens kernel.
Verification strategies:
- TOPP (0): Verify draft token is in top-p candidate set
- GREEDY (1): Verify draft token matches target model's argmax
- TARGET_MATCH (2): Verify draft token matches target model's sampled token
"""
import random
import unittest
from typing import Any, Dict
import numpy as np
import paddle
from fastdeploy.model_executor.ops.xpu import verify_draft_tokens
from fastdeploy.spec_decode import VerifyStrategy
CPU_PLACE = paddle.CPUPlace()
DEVICE_PLACE = paddle.XPUPlace(0) if paddle.is_compiled_with_xpu() else paddle.CPUPlace()
# ============================================================
# Helpers: tensor creation / kernel invocation / comparison
# ============================================================
def to_paddle_inputs(inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Convert numpy input dict to paddle tensors on GPU."""
paddle_inputs = {}
for k, v in inputs.items():
if isinstance(v, (int, bool, float, str)):
paddle_inputs[k] = v
elif v is not None:
paddle_inputs[k] = paddle.to_tensor(v, place=DEVICE_PLACE)
else:
paddle_inputs[k] = None
return paddle_inputs
def run_kernel(paddle_inputs: Dict[str, Any], inputs: Dict[str, Any]):
"""Call verify_draft_tokens kernel."""
verify_draft_tokens(
paddle_inputs["step_output_ids"],
paddle_inputs["step_output_len"],
paddle_inputs["step_input_ids"],
paddle_inputs["target_tokens"],
paddle_inputs["candidate_ids"],
paddle_inputs["candidate_scores"],
paddle_inputs["candidate_lens"],
paddle_inputs["topp"],
paddle_inputs["stop_flags"],
paddle_inputs["seq_lens_encoder"],
paddle_inputs["seq_lens_this_time"],
paddle_inputs["end_tokens"],
paddle_inputs["is_block_step"],
paddle_inputs["cu_seqlens_q_output"],
paddle_inputs["reasoning_status"],
paddle_inputs["max_dec_len"],
paddle_inputs["step_idx"],
inputs["max_seq_len"],
inputs["verify_window"],
inputs["verify_strategy"],
inputs["reject_all"],
inputs["accept_all"],
)
def run_ref(inputs: Dict[str, Any]):
"""Run reference implementation on deep-copied inputs, return (output_ids, output_len)."""
ref = {k: v.copy() if isinstance(v, np.ndarray) else v for k, v in inputs.items()}
return verify_draft_tokens_ref(
ref["step_output_ids"],
ref["step_output_len"],
ref["step_input_ids"],
ref["target_tokens"],
ref["candidate_ids"],
ref["candidate_scores"],
ref["candidate_lens"],
ref["topp"],
ref["stop_flags"],
ref["seq_lens_encoder"],
ref["seq_lens_this_time"],
ref["end_tokens"],
ref["is_block_step"],
ref["cu_seqlens_q_output"],
ref["reasoning_status"],
ref["max_dec_len"],
ref["step_idx"],
ref["max_seq_len"],
ref["verify_window"],
ref["verify_strategy"],
ref["reject_all"],
ref["accept_all"],
)
def compare_results(
paddle_inputs: Dict[str, Any],
step_output_ids_ref: np.ndarray,
step_output_len_ref: np.ndarray,
inputs: Dict[str, Any],
label: str = "unknown",
):
"""Compare GPU kernel output vs reference."""
gpu_ids = paddle_inputs["step_output_ids"].numpy()
gpu_len = paddle_inputs["step_output_len"].numpy()
np.testing.assert_array_equal(
gpu_len,
step_output_len_ref,
err_msg=f"step_output_len mismatch ({label})",
)
if inputs["verify_strategy"] == 0: # TOPP — Phase 2 is stochastic
real_bsz = inputs["seq_lens_this_time"].shape[0]
for bid in range(real_bsz):
ref_len = int(step_output_len_ref[bid])
if ref_len > 1:
print(gpu_ids[bid, : ref_len - 1], step_output_ids_ref[bid, : ref_len - 1])
np.testing.assert_array_equal(
gpu_ids[bid, : ref_len - 1],
step_output_ids_ref[bid, : ref_len - 1],
err_msg=f"step_output_ids (accepted) mismatch at bid={bid} ({label})",
)
else:
np.testing.assert_array_equal(
gpu_ids,
step_output_ids_ref,
err_msg=f"step_output_ids mismatch ({label})",
)
# ============================================================
# Reference helpers
# ============================================================
def topp_sampling_kernel(candidate_ids, candidate_scores, curand_value, candidate_len, topp, tid=0):
rand_top_p = curand_value * topp
sum_scores = 0.0
for i in range(candidate_len):
sum_scores += candidate_scores[i]
if rand_top_p <= sum_scores:
return int(candidate_ids[i])
return int(candidate_ids[0])
def is_in_end(token, end_tokens, end_length):
return token in end_tokens[:end_length]
def is_in(candidate_list, token, length):
return token in candidate_list[:length]
class _VerifyContext:
"""Python mirror of the CUDA VerifyContext struct for reference testing."""
def __init__(
self,
bid,
max_step_tokens,
end_length,
end_tokens,
max_dec_len,
step_input_ids_now,
step_output_ids_flat,
cur_step_idx,
):
self.bid = bid
self.max_step_tokens = max_step_tokens
self.end_length = end_length
self.end_tokens = end_tokens
self.max_dec_len = max_dec_len
self.step_input_ids_now = step_input_ids_now
self.step_output_ids_flat = step_output_ids_flat
self.cur_step_idx = cur_step_idx
self.output_len_now = 0
self.stopped = False
def emit_token(self, pos, token):
"""Emit a token to output. Returns True if sequence should stop."""
self.cur_step_idx += 1
eos = is_in_end(token, self.end_tokens, self.end_length)
max_hit = self.cur_step_idx >= int(self.max_dec_len[self.bid])
if (eos or max_hit) and not eos:
token = int(self.end_tokens[0])
self.step_output_ids_flat[self.bid * self.max_step_tokens + pos] = token
self.output_len_now += 1
if eos or max_hit:
self.stopped = True
return True
return False
def emit_final_token(self, pos, token):
"""Emit the Phase 2 final token. Increments output_len_now."""
self.cur_step_idx += 1
eos = is_in_end(token, self.end_tokens, self.end_length)
max_hit = self.cur_step_idx >= int(self.max_dec_len[self.bid])
if (eos or max_hit) and not eos:
token = int(self.end_tokens[0])
self.step_output_ids_flat[self.bid * self.max_step_tokens + pos] = token
self.output_len_now += 1
def verify_draft_tokens_ref(
step_output_ids,
step_output_len,
step_input_ids,
target_tokens,
candidate_ids,
candidate_scores,
candidate_lens,
topp,
stop_flags,
seq_lens_encoder,
seq_lens_this_time,
end_tokens,
is_block_step,
cu_seqlens_q_output,
reasoning_status,
max_dec_len,
step_idx,
max_seq_len,
verify_window,
verify_strategy,
reject_all,
accept_all,
):
"""Reference implementation of verify_draft_tokens in Python."""
real_bsz = seq_lens_this_time.shape[0]
max_step_tokens = step_input_ids.shape[1]
end_length = end_tokens.shape[0]
max_candidate_len = candidate_ids.shape[1] if candidate_ids is not None else 1
dev_curand_states = [random.Random(0).random() for _ in range(max_step_tokens)]
step_output_ids_flat = step_output_ids.reshape(-1)
step_input_ids_flat = step_input_ids.reshape(-1)
candidate_ids_flat = candidate_ids.reshape(-1) if candidate_ids is not None else None
candidate_scores_flat = candidate_scores.reshape(-1) if candidate_scores is not None else None
for bid in range(real_bsz):
start_token_id = cu_seqlens_q_output[bid]
if is_block_step[bid] or stop_flags[bid]:
step_output_len[bid] = 0
continue
step_input_ids_now = step_input_ids_flat[bid * max_step_tokens :]
target_tokens_now = target_tokens[start_token_id:] if target_tokens is not None else None
candidate_ids_now = (
candidate_ids_flat[start_token_id * max_candidate_len :] if candidate_ids_flat is not None else None
)
candidate_lens_now = candidate_lens[start_token_id:] if candidate_lens is not None else None
candidate_scores_now = (
candidate_scores_flat[start_token_id * max_candidate_len :] if candidate_scores_flat is not None else None
)
ctx = _VerifyContext(
bid,
max_step_tokens,
end_length,
end_tokens,
max_dec_len,
step_input_ids_now,
step_output_ids_flat,
int(step_idx[bid]),
)
# Phase 1: Verify
i = 0
while i < seq_lens_this_time[bid] - 1:
if reject_all or seq_lens_encoder[bid] != 0 or reasoning_status[bid] == 1:
break
if accept_all:
if ctx.emit_token(i, step_input_ids_now[i + 1]):
break
i += 1
continue
accepted = False
if verify_strategy == 0: # TOPP
actual_cand_len = min(candidate_lens_now[i], max_candidate_len)
accepted = is_in(
candidate_ids_now[i * max_candidate_len : (i + 1) * max_candidate_len],
step_input_ids_now[i + 1],
actual_cand_len,
)
if not accepted:
# verify_window fallback
ii = i
if (
max_candidate_len >= 2
and candidate_ids_now[ii * max_candidate_len + 1] == step_input_ids_now[ii + 1]
):
j, ii = 0, ii + 1
while j < verify_window and ii < seq_lens_this_time[bid] - 1:
if candidate_ids_now[ii * max_candidate_len] != step_input_ids_now[ii + 1]:
break
j += 1
ii += 1
if j >= verify_window:
for k in range(i, ii):
if ctx.emit_token(k, step_input_ids_now[k + 1]):
i = k
break
if ctx.stopped:
break
i = ii
continue
break
elif verify_strategy in (1, 2): # GREEDY / TARGET_MATCH
accepted = target_tokens_now[i] == step_input_ids_now[i + 1]
if accepted:
if ctx.emit_token(i, step_input_ids_now[i + 1]):
break
else:
break
i += 1
# Phase 2: Sample for rejected/last position
if not ctx.stopped:
if verify_strategy == 0:
if candidate_lens_now is not None and len(candidate_lens_now) > i:
actual_cand_len = min(candidate_lens_now[i], max_candidate_len)
accept_token = topp_sampling_kernel(
candidate_ids_now[i * max_candidate_len : (i + 1) * max_candidate_len],
candidate_scores_now[i * max_candidate_len : (i + 1) * max_candidate_len],
dev_curand_states[i],
actual_cand_len,
topp[bid],
)
else:
accept_token = int(step_input_ids_now[0])
elif verify_strategy in (1, 2):
accept_token = (
int(target_tokens_now[i])
if target_tokens_now is not None and len(target_tokens_now) > i
else int(step_input_ids_now[0])
)
else:
accept_token = (
int(candidate_ids_now[i * max_candidate_len])
if candidate_ids_now is not None
else int(step_input_ids_now[0])
)
ctx.emit_final_token(i, accept_token)
step_output_len[bid] = ctx.output_len_now
return step_output_ids, step_output_len
# ============================================================
# Input generation
# ============================================================
def gen_verify_draft_tokens_inputs(
real_bsz: int = 32,
max_draft_tokens: int = 16,
max_seq_len: int = 256,
max_candidate_len: int = 8,
verify_window: int = 2,
end_length: int = 4,
verify_strategy: int = 1,
reject_all: bool = False,
accept_all: bool = False,
match_ratio: float = 0.0,
seed: int = 2025,
) -> Dict[str, Any]:
"""Generate test inputs for verify_draft_tokens kernel.
Args:
match_ratio: Fraction of draft token positions where target/candidates
are forced to match step_input_ids, so the acceptance path is exercised.
0.0 = fully random (mostly rejects), 1.0 = all positions match.
"""
rng = np.random.default_rng(seed)
seq_lens_encoder = np.zeros(real_bsz, dtype=np.int32)
seq_lens_this_time = rng.integers(1, max_draft_tokens + 1, size=real_bsz, dtype=np.int32)
step_input_ids = rng.integers(0, 1000, size=(real_bsz, max_draft_tokens), dtype=np.int64)
sum_seq = int(np.sum(seq_lens_this_time))
if verify_strategy in (1, 2): # GREEDY / TARGET_MATCH
target_tokens = rng.integers(0, 1000, size=(sum_seq,), dtype=np.int64)
candidate_ids = None
candidate_scores = None
candidate_lens = None
else: # TOPP
target_tokens = None
candidate_ids = rng.integers(0, 1000, size=(sum_seq, max_candidate_len), dtype=np.int64)
candidate_scores = rng.random(size=(sum_seq, max_candidate_len)).astype(np.float32)
candidate_scores = candidate_scores / candidate_scores.sum(axis=1, keepdims=True)
candidate_lens = rng.integers(1, max_candidate_len + 1, size=sum_seq, dtype=np.int32)
end_tokens = rng.integers(1, 1000, size=end_length, dtype=np.int64)
is_block_step = rng.integers(0, 2, size=real_bsz, dtype=bool)
cu_seqlens_q_output = np.zeros(real_bsz + 1, dtype=np.int32)
for i in range(real_bsz):
cu_seqlens_q_output[i + 1] = cu_seqlens_q_output[i] + seq_lens_this_time[i]
cu_seqlens_q_output = cu_seqlens_q_output[:real_bsz].astype(np.int32)
topp = rng.uniform(0.8, 1.0, size=real_bsz).astype(np.float32)
reasoning_status = np.zeros(real_bsz, dtype=np.int32)
step_output_ids = np.zeros((real_bsz, max_draft_tokens), dtype=np.int64)
step_output_len = np.zeros(real_bsz, dtype=np.int32)
stop_flags = np.zeros(real_bsz, dtype=bool)
# Force match_ratio fraction of positions so acceptance path is tested
if match_ratio > 0.0:
offset = 0
for bid in range(real_bsz):
slt = int(seq_lens_this_time[bid])
n_match = max(1, int((slt - 1) * match_ratio)) # slt-1 verify positions
for pos in range(min(n_match, slt - 1)):
draft_token = int(step_input_ids[bid, pos + 1])
# Ensure draft_token is not an end_token (would cause early stop)
while draft_token in end_tokens[:end_length]:
draft_token = (draft_token + 1) % 1000
step_input_ids[bid, pos + 1] = draft_token
if verify_strategy in (1, 2) and target_tokens is not None:
target_tokens[offset + pos] = draft_token
elif verify_strategy == 0 and candidate_ids is not None:
candidate_ids[offset + pos, 0] = draft_token
candidate_lens[offset + pos] = max(candidate_lens[offset + pos], 1)
offset += slt
return {
"step_output_ids": step_output_ids,
"step_output_len": step_output_len,
"step_input_ids": step_input_ids,
"target_tokens": target_tokens,
"candidate_ids": candidate_ids,
"candidate_scores": candidate_scores,
"candidate_lens": candidate_lens,
"topp": topp,
"stop_flags": stop_flags,
"seq_lens_encoder": seq_lens_encoder,
"seq_lens_this_time": seq_lens_this_time,
"end_tokens": end_tokens,
"is_block_step": is_block_step,
"cu_seqlens_q_output": cu_seqlens_q_output,
"reasoning_status": reasoning_status,
"max_dec_len": rng.integers(50, 200, size=real_bsz, dtype=np.int64),
"step_idx": rng.integers(0, 30, size=real_bsz, dtype=np.int64),
"max_seq_len": max_seq_len,
"verify_window": verify_window,
"verify_strategy": verify_strategy,
"reject_all": reject_all,
"accept_all": accept_all,
}
# ============================================================
# Test configs
# ============================================================
TEST_CONFIGS = [
# --- strategy coverage (random, mostly rejects) ---
{
"name": "greedy_small_batch",
"real_bsz": 1,
"max_draft_tokens": 9,
"max_seq_len": 11,
"max_candidate_len": 4,
"verify_window": 2,
"end_length": 5,
"verify_strategy": VerifyStrategy.GREEDY.value,
"seed": 42,
},
{
"name": "greedy_medium_batch",
"real_bsz": 33,
"max_draft_tokens": 5,
"max_seq_len": 10111,
"max_candidate_len": 5,
"verify_window": 2,
"end_length": 6,
"verify_strategy": VerifyStrategy.GREEDY.value,
"seed": 42,
},
{
"name": "topp_small_batch",
"real_bsz": 6,
"max_draft_tokens": 4,
"max_seq_len": 10001,
"max_candidate_len": 6,
"verify_window": 2,
"end_length": 7,
"verify_strategy": VerifyStrategy.TOPP.value,
"seed": 42,
},
{
"name": "target_match_medium",
"real_bsz": 7,
"max_draft_tokens": 3,
"max_seq_len": 777,
"max_candidate_len": 7,
"verify_window": 2,
"end_length": 5,
"verify_strategy": VerifyStrategy.TARGET_MATCH.value,
"seed": 42,
},
{
"name": "greedy_large_batch",
"real_bsz": 55,
"max_draft_tokens": 5,
"max_seq_len": 31,
"max_candidate_len": 9,
"verify_window": 2,
"end_length": 3,
"verify_strategy": VerifyStrategy.GREEDY.value,
"seed": 42,
},
# --- partial acceptance (match_ratio forces draft tokens to match target/candidates) ---
{
"name": "greedy_half_accept",
"real_bsz": 8,
"max_draft_tokens": 8,
"max_seq_len": 256,
"max_candidate_len": 4,
"verify_window": 2,
"end_length": 3,
"verify_strategy": VerifyStrategy.GREEDY.value,
"seed": 42,
"match_ratio": 0.5,
},
{
"name": "greedy_full_accept",
"real_bsz": 8,
"max_draft_tokens": 8,
"max_seq_len": 256,
"max_candidate_len": 4,
"verify_window": 2,
"end_length": 3,
"verify_strategy": VerifyStrategy.GREEDY.value,
"seed": 42,
"match_ratio": 1.0,
},
{
"name": "topp_half_accept",
"real_bsz": 8,
"max_draft_tokens": 8,
"max_seq_len": 256,
"max_candidate_len": 6,
"verify_window": 2,
"end_length": 3,
"verify_strategy": VerifyStrategy.TOPP.value,
"seed": 42,
"match_ratio": 0.5,
},
{
"name": "topp_full_accept",
"real_bsz": 8,
"max_draft_tokens": 8,
"max_seq_len": 256,
"max_candidate_len": 6,
"verify_window": 2,
"end_length": 3,
"verify_strategy": VerifyStrategy.TOPP.value,
"seed": 42,
"match_ratio": 1.0,
},
{
"name": "target_match_accept",
"real_bsz": 8,
"max_draft_tokens": 6,
"max_seq_len": 256,
"max_candidate_len": 4,
"verify_window": 2,
"end_length": 3,
"verify_strategy": VerifyStrategy.TARGET_MATCH.value,
"seed": 42,
"match_ratio": 0.7,
},
# --- reject_all / accept_all (kernel-level flags) ---
{
"name": "reject_all",
"real_bsz": 8,
"max_draft_tokens": 5,
"max_seq_len": 100,
"max_candidate_len": 5,
"verify_window": 2,
"end_length": 3,
"verify_strategy": VerifyStrategy.GREEDY.value,
"seed": 42,
"reject_all": True,
},
{
"name": "accept_all",
"real_bsz": 8,
"max_draft_tokens": 5,
"max_seq_len": 100,
"max_candidate_len": 5,
"verify_window": 2,
"end_length": 3,
"verify_strategy": VerifyStrategy.TOPP.value,
"seed": 42,
"accept_all": True,
},
{
"name": "reject_all_topp",
"real_bsz": 8,
"max_draft_tokens": 5,
"max_seq_len": 100,
"max_candidate_len": 5,
"verify_window": 2,
"end_length": 3,
"verify_strategy": VerifyStrategy.TOPP.value,
"seed": 42,
"reject_all": True,
},
{
"name": "reject_all_target_match",
"real_bsz": 8,
"max_draft_tokens": 5,
"max_seq_len": 100,
"max_candidate_len": 5,
"verify_window": 2,
"end_length": 3,
"verify_strategy": VerifyStrategy.TARGET_MATCH.value,
"seed": 42,
"reject_all": True,
},
{
"name": "accept_all_greedy",
"real_bsz": 8,
"max_draft_tokens": 5,
"max_seq_len": 100,
"max_candidate_len": 5,
"verify_window": 2,
"end_length": 3,
"verify_strategy": VerifyStrategy.GREEDY.value,
"seed": 42,
"accept_all": True,
},
{
"name": "accept_all_target_match",
"real_bsz": 8,
"max_draft_tokens": 5,
"max_seq_len": 100,
"max_candidate_len": 5,
"verify_window": 2,
"end_length": 3,
"verify_strategy": VerifyStrategy.TARGET_MATCH.value,
"seed": 42,
"accept_all": True,
},
# --- edge cases ---
{
"name": "empty_batch",
"real_bsz": 1,
"max_draft_tokens": 1,
"max_seq_len": 10,
"max_candidate_len": 2,
"verify_window": 1,
"end_length": 4,
"verify_strategy": VerifyStrategy.GREEDY.value,
"seed": 42,
},
]
# ============================================================
# Test suite
# ============================================================
class TestVerifyDraftTokens(unittest.TestCase):
def setUp(self):
pass
# if not paddle.is_compiled_with_cuda():
# self.skipTest("Requires CUDA")
# ------ shared run + check helper ------
def _run_and_compare(self, inputs: Dict[str, Any], label: str = ""):
"""Convert→run kernel→run ref→compare."""
paddle_inputs = to_paddle_inputs(inputs)
# print("paddle_inputs: ", paddle_inputs)
run_kernel(paddle_inputs, inputs)
ids_ref, len_ref = run_ref(inputs)
compare_results(paddle_inputs, ids_ref, len_ref, inputs, label)
return paddle_inputs
# ------ test cases ------
def test_verify_configs(self):
"""Test all configs in TEST_CONFIGS (strategies, reject/accept, edge cases)."""
for cfg in TEST_CONFIGS:
with self.subTest(name=cfg["name"]):
test_cfg = {k: v for k, v in cfg.items() if k != "name"}
inputs = gen_verify_draft_tokens_inputs(**test_cfg)
self._run_and_compare(inputs, label=cfg["name"])
def test_eos_handling(self):
"""Test EOS token in draft triggers early stop."""
inputs = gen_verify_draft_tokens_inputs(
real_bsz=4, max_draft_tokens=5, verify_strategy=VerifyStrategy.GREEDY.value, seed=42
)
inputs["step_input_ids"][0, 2] = inputs["end_tokens"][0]
self._run_and_compare(inputs, label="eos_handling")
def test_max_dec_len_truncation(self):
"""Test max_dec_len causes token replacement with end_tokens[0]."""
inputs = gen_verify_draft_tokens_inputs(
real_bsz=4, max_draft_tokens=5, verify_strategy=VerifyStrategy.GREEDY.value, seed=42, match_ratio=1.0
)
# Set step_idx close to max_dec_len so it triggers during verification
inputs["step_idx"][:] = [48, 10, 10, 10]
inputs["max_dec_len"][:] = [50, 200, 200, 200]
inputs["is_block_step"][:] = False
inputs["stop_flags"][:] = False
# Ensure no accidental EOS in draft tokens
for bid in range(4):
for j in range(5):
while inputs["step_input_ids"][bid, j] in inputs["end_tokens"]:
inputs["step_input_ids"][bid, j] = (inputs["step_input_ids"][bid, j] + 1) % 1000
self._run_and_compare(inputs, label="max_dec_len_truncation")
def test_verify_strategy_enum(self):
self.assertEqual(VerifyStrategy.TOPP.value, 0)
self.assertEqual(VerifyStrategy.GREEDY.value, 1)
self.assertEqual(VerifyStrategy.TARGET_MATCH.value, 2)
def test_verify_strategy_from_string(self):
self.assertEqual(VerifyStrategy.from_string("topp"), VerifyStrategy.TOPP)
self.assertEqual(VerifyStrategy.from_string("TOPP"), VerifyStrategy.TOPP)
self.assertEqual(VerifyStrategy.from_string("greedy"), VerifyStrategy.GREEDY)
self.assertEqual(VerifyStrategy.from_string("target_match"), VerifyStrategy.TARGET_MATCH)
with self.assertRaises(ValueError):
VerifyStrategy.from_string("invalid")
def test_topp_verify_window_fallback(self):
"""Test TOPP verify_window fallback: top-2 match + consecutive top-1 matches."""
real_bsz, max_draft_tokens, max_candidate_len, verify_window = 1, 8, 4, 2
inputs = gen_verify_draft_tokens_inputs(
real_bsz=real_bsz,
max_draft_tokens=max_draft_tokens,
verify_strategy=VerifyStrategy.TOPP.value,
max_candidate_len=max_candidate_len,
verify_window=verify_window,
seed=42,
)
# Rebuild arrays for full seq_lens_this_time
new_slt = max_draft_tokens + 1
inputs["seq_lens_this_time"] = np.array([new_slt], dtype=np.int32)
inputs["cu_seqlens_q_output"] = np.array([0], dtype=np.int32)
rng = np.random.default_rng(42)
sum_seq = new_slt
inputs["candidate_ids"] = rng.integers(0, 1000, size=(sum_seq, max_candidate_len), dtype=np.int64)
inputs["candidate_scores"] = rng.random(size=(sum_seq, max_candidate_len)).astype(np.float32)
inputs["candidate_scores"] /= inputs["candidate_scores"].sum(axis=1, keepdims=True)
inputs["candidate_lens"] = rng.integers(1, max_candidate_len + 1, size=sum_seq, dtype=np.int32)
# Draft tokens
draft_tokens = [100, 200, 300, 400, 500, 600, 700]
for i, token in enumerate(draft_tokens):
inputs["step_input_ids"][0, i + 1] = token
# Position 0: draft NOT in candidates, but top-2 matches draft
inputs["candidate_ids"][0] = [999, 100, 998, 997]
# Positions 1,2: top-1 matches next draft tokens
inputs["candidate_ids"][1] = [200, 888, 777, 666]
inputs["candidate_ids"][2] = [300, 555, 444, 333]
inputs["candidate_lens"][:3] = 4
inputs["is_block_step"] = np.zeros(real_bsz, dtype=bool)
self._run_and_compare(inputs, label="verify_window_fallback")
def test_topp_verify_window_no_fallback(self):
"""Test TOPP when verify_window fallback does NOT trigger."""
inputs = gen_verify_draft_tokens_inputs(
real_bsz=1,
max_draft_tokens=5,
verify_strategy=VerifyStrategy.TOPP.value,
max_candidate_len=4,
verify_window=2,
seed=42,
)
inputs["step_input_ids"][0, 1:] = [999, 998, 997, 996]
inputs["candidate_ids"][:] = 0
inputs["candidate_ids"][0] = [1, 2, 3, 4]
inputs["candidate_lens"][0] = 4
inputs["seq_lens_this_time"][0] = 5
self._run_and_compare(inputs, label="verify_window_no_fallback")
def test_stop_flags_skip(self):
"""Test that sequences with stop_flags=True are skipped (output_len=0)."""
inputs = gen_verify_draft_tokens_inputs(
real_bsz=4, max_draft_tokens=5, verify_strategy=VerifyStrategy.GREEDY.value, seed=42, match_ratio=1.0
)
inputs["is_block_step"][:] = False
inputs["stop_flags"][:] = [True, False, True, False]
self._run_and_compare(inputs, label="stop_flags_skip")
# Double-check stopped sequences produce output_len=0
paddle_inputs = to_paddle_inputs(inputs)
run_kernel(paddle_inputs, inputs)
gpu_len = paddle_inputs["step_output_len"].numpy()
self.assertEqual(gpu_len[0], 0, "stopped seq bid=0 should have output_len=0")
self.assertEqual(gpu_len[2], 0, "stopped seq bid=2 should have output_len=0")
def test_prefill_skip(self):
"""Test that prefill requests (seq_lens_encoder != 0) skip Phase 1, only output 1 token."""
inputs = gen_verify_draft_tokens_inputs(
real_bsz=4, max_draft_tokens=6, verify_strategy=VerifyStrategy.GREEDY.value, seed=42, match_ratio=1.0
)
inputs["is_block_step"][:] = False
inputs["stop_flags"][:] = False
# Set bid 0 and 2 as prefill requests
inputs["seq_lens_encoder"][0] = 10
inputs["seq_lens_encoder"][2] = 5
self._run_and_compare(inputs, label="prefill_skip")
def test_reasoning_status_skip(self):
"""Test that reasoning_status=1 skips Phase 1, only outputs 1 token."""
inputs = gen_verify_draft_tokens_inputs(
real_bsz=4, max_draft_tokens=6, verify_strategy=VerifyStrategy.GREEDY.value, seed=42, match_ratio=1.0
)
inputs["is_block_step"][:] = False
inputs["stop_flags"][:] = False
# Set bid 1 and 3 as reasoning mode
inputs["reasoning_status"][1] = 1
inputs["reasoning_status"][3] = 1
self._run_and_compare(inputs, label="reasoning_status_skip")
def test_reject_all_and_accept_all_priority(self):
"""Test that reject_all takes priority over accept_all when both are True."""
inputs = gen_verify_draft_tokens_inputs(
real_bsz=4,
max_draft_tokens=5,
verify_strategy=VerifyStrategy.GREEDY.value,
seed=42,
match_ratio=1.0,
reject_all=True,
accept_all=True,
)
inputs["is_block_step"][:] = False
inputs["stop_flags"][:] = False
self._run_and_compare(inputs, label="reject_all_and_accept_all")
# All sequences should produce exactly 1 token (Phase 2 only)
paddle_inputs = to_paddle_inputs(inputs)
run_kernel(paddle_inputs, inputs)
gpu_len = paddle_inputs["step_output_len"].numpy()
for bid in range(4):
self.assertEqual(gpu_len[bid], 1, f"reject_all should produce exactly 1 token at bid={bid}")
def test_mixed_batch_heterogeneous(self):
"""Test a batch with mixed states: normal, stopped, prefill, reasoning, block_step."""
inputs = gen_verify_draft_tokens_inputs(
real_bsz=6, max_draft_tokens=6, verify_strategy=VerifyStrategy.GREEDY.value, seed=42, match_ratio=0.8
)
# bid 0: normal decode
inputs["is_block_step"][0] = False
inputs["stop_flags"][0] = False
inputs["seq_lens_encoder"][0] = 0
inputs["reasoning_status"][0] = 0
# bid 1: stopped
inputs["is_block_step"][1] = False
inputs["stop_flags"][1] = True
inputs["seq_lens_encoder"][1] = 0
inputs["reasoning_status"][1] = 0
# bid 2: prefill
inputs["is_block_step"][2] = False
inputs["stop_flags"][2] = False
inputs["seq_lens_encoder"][2] = 8
inputs["reasoning_status"][2] = 0
# bid 3: reasoning mode
inputs["is_block_step"][3] = False
inputs["stop_flags"][3] = False
inputs["seq_lens_encoder"][3] = 0
inputs["reasoning_status"][3] = 1
# bid 4: block step
inputs["is_block_step"][4] = True
inputs["stop_flags"][4] = False
inputs["seq_lens_encoder"][4] = 0
inputs["reasoning_status"][4] = 0
# bid 5: normal decode
inputs["is_block_step"][5] = False
inputs["stop_flags"][5] = False
inputs["seq_lens_encoder"][5] = 0
inputs["reasoning_status"][5] = 0
self._run_and_compare(inputs, label="mixed_batch_heterogeneous")
def test_single_token_sequence(self):
"""Test seq_lens_this_time=1: Phase 1 is skipped entirely, only Phase 2 outputs 1 token."""
for strategy in [VerifyStrategy.GREEDY.value, VerifyStrategy.TOPP.value, VerifyStrategy.TARGET_MATCH.value]:
with self.subTest(strategy=strategy):
inputs = gen_verify_draft_tokens_inputs(
real_bsz=4, max_draft_tokens=8, verify_strategy=strategy, seed=42
)
inputs["seq_lens_this_time"][:] = 1
# Recompute cu_seqlens_q_output for all-1 seq_lens
inputs["cu_seqlens_q_output"] = np.array([0, 1, 2, 3], dtype=np.int32)
# Regenerate target/candidate arrays for new sum_seq=4
sum_seq = 4
rng = np.random.default_rng(42)
if strategy in (1, 2):
inputs["target_tokens"] = rng.integers(0, 1000, size=(sum_seq,), dtype=np.int64)
else:
max_candidate_len = 8
inputs["candidate_ids"] = rng.integers(0, 1000, size=(sum_seq, max_candidate_len), dtype=np.int64)
inputs["candidate_scores"] = rng.random(size=(sum_seq, max_candidate_len)).astype(np.float32)
inputs["candidate_scores"] /= inputs["candidate_scores"].sum(axis=1, keepdims=True)
inputs["candidate_lens"] = rng.integers(1, max_candidate_len + 1, size=sum_seq, dtype=np.int32)
inputs["is_block_step"][:] = False
inputs["stop_flags"][:] = False
self._run_and_compare(inputs, label=f"single_token_strategy_{strategy}")
def test_max_dec_len_exact_boundary(self):
"""Test step_idx == max_dec_len - 1: first emit triggers max_len_hit immediately."""
inputs = gen_verify_draft_tokens_inputs(
real_bsz=4, max_draft_tokens=6, verify_strategy=VerifyStrategy.GREEDY.value, seed=42, match_ratio=1.0
)
inputs["is_block_step"][:] = False
inputs["stop_flags"][:] = False
# Set step_idx = max_dec_len - 1, so first emit_token increments past max_dec_len
inputs["max_dec_len"][:] = 50
inputs["step_idx"][:] = 49
# Ensure no accidental EOS in draft tokens
for bid in range(4):
for j in range(6):
while inputs["step_input_ids"][bid, j] in inputs["end_tokens"]:
inputs["step_input_ids"][bid, j] = (inputs["step_input_ids"][bid, j] + 1) % 1000
self._run_and_compare(inputs, label="max_dec_len_exact_boundary")
# All sequences should produce exactly 1 token (first emit triggers stop)
paddle_inputs = to_paddle_inputs(inputs)
run_kernel(paddle_inputs, inputs)
gpu_len = paddle_inputs["step_output_len"].numpy()
for bid in range(4):
self.assertEqual(gpu_len[bid], 1, f"max_dec_len boundary should produce 1 token at bid={bid}")
def test_eos_during_verify_window_bulk_accept(self):
"""Test EOS token in the middle of verify_window bulk-accept range stops correctly."""
real_bsz, max_draft_tokens, max_candidate_len, verify_window = 1, 10, 4, 2
inputs = gen_verify_draft_tokens_inputs(
real_bsz=real_bsz,
max_draft_tokens=max_draft_tokens,
verify_strategy=VerifyStrategy.TOPP.value,
max_candidate_len=max_candidate_len,
verify_window=verify_window,
seed=42,
)
new_slt = max_draft_tokens
inputs["seq_lens_this_time"] = np.array([new_slt], dtype=np.int32)
inputs["cu_seqlens_q_output"] = np.array([0], dtype=np.int32)
rng = np.random.default_rng(42)
sum_seq = new_slt
inputs["candidate_ids"] = rng.integers(0, 1000, size=(sum_seq, max_candidate_len), dtype=np.int64)
inputs["candidate_scores"] = rng.random(size=(sum_seq, max_candidate_len)).astype(np.float32)
inputs["candidate_scores"] /= inputs["candidate_scores"].sum(axis=1, keepdims=True)
inputs["candidate_lens"] = np.full(sum_seq, max_candidate_len, dtype=np.int32)
inputs["is_block_step"] = np.zeros(real_bsz, dtype=bool)
inputs["stop_flags"] = np.zeros(real_bsz, dtype=bool)
inputs["max_dec_len"][:] = 200
eos_token = int(inputs["end_tokens"][0])
# Draft tokens: 100, 200, EOS, 400, 500, ...
draft_tokens = [100, 200, eos_token, 400, 500, 600, 700, 800, 900]
for i, token in enumerate(draft_tokens):
inputs["step_input_ids"][0, i + 1] = token
# Position 0: draft NOT in top-1, but top-2 matches draft -> verify_window triggers
inputs["candidate_ids"][0] = [999, 100, 998, 997]
# Position 1: top-1 matches next draft
inputs["candidate_ids"][1] = [200, 888, 777, 666]
# Position 2: top-1 matches next draft (which is EOS)
inputs["candidate_ids"][2] = [eos_token, 555, 444, 333]
# Position 3 onwards: top-1 matches (shouldn't be reached due to EOS)
inputs["candidate_ids"][3] = [400, 222, 111, 100]
self._run_and_compare(inputs, label="eos_during_verify_window")
def test_topp_max_candidate_len_1(self):
"""Test TOPP with max_candidate_len=1: verify_window fallback cannot trigger."""
inputs = gen_verify_draft_tokens_inputs(
real_bsz=4,
max_draft_tokens=6,
verify_strategy=VerifyStrategy.TOPP.value,
max_candidate_len=1,
verify_window=2,
seed=42,
match_ratio=0.5,
)
inputs["is_block_step"][:] = False
inputs["stop_flags"][:] = False
self._run_and_compare(inputs, label="topp_max_candidate_len_1")
def test_phase2_eos_token(self):
"""Test Phase 2 target token is an EOS token."""
inputs = gen_verify_draft_tokens_inputs(
real_bsz=4, max_draft_tokens=5, verify_strategy=VerifyStrategy.GREEDY.value, seed=42
)
inputs["is_block_step"][:] = False
inputs["stop_flags"][:] = False
# Make all draft tokens NOT match target (all reject at position 0)
inputs["step_input_ids"][:, 1:] = 999
if inputs["target_tokens"] is not None:
inputs["target_tokens"][:] = 888
# Now set the Phase 2 token (target_tokens at position 0 for each bid) to EOS
eos_token = int(inputs["end_tokens"][0])
offset = 0
for bid in range(4):
inputs["target_tokens"][offset] = eos_token
offset += int(inputs["seq_lens_this_time"][bid])
self._run_and_compare(inputs, label="phase2_eos_token")
if __name__ == "__main__":
unittest.main()