[CI]【Hackathon 9th Sprint No.46】NO.46 功能模块 fastdeploy/model_executor/guided_decoding/xgrammar_backend.py 单测补充 (#5042)

* test

* rename ut

* remove test max_rollback_tokens

* update

* 精简代码

* fix: torch use mock

---------

Co-authored-by: CSWYF3634076 <wangyafeng@baidu.com>
This commit is contained in:
essos
2025-12-30 17:05:26 +08:00
committed by GitHub
parent 0bcf924e10
commit b03a4f3e3d
@@ -0,0 +1,423 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.util
import json
import os
import sys
import unittest
from unittest.mock import Mock, patch
sys.modules["torch"] = Mock()
sys.modules["xgrammar"] = Mock()
sys.modules["fastdeploy"] = Mock()
sys.modules["fastdeploy.config"] = Mock()
sys.modules["fastdeploy.engine"] = Mock()
sys.modules["fastdeploy.engine.request"] = Mock()
sys.modules["fastdeploy.model_executor"] = Mock()
sys.modules["fastdeploy.model_executor.guided_decoding"] = Mock()
sys.modules["fastdeploy.utils"] = Mock()
sys.modules["fastdeploy.platforms"] = Mock()
class MockGrammarMatcher:
def __init__(self, compiled_grammar, **kwargs):
self.fill_next_token_bitmask = Mock()
self.reset = Mock()
self.accept_token = Mock(return_value=True)
self.is_terminated = Mock(return_value=False)
def mock_allocate_token_bitmask(batch_size, vocab_size):
mock_tensor = Mock()
mock_tensor.shape = (batch_size, vocab_size)
mock_tensor.dtype = "int32"
return mock_tensor
sys.modules["xgrammar"].CompiledGrammar = type("CompiledGrammar", (), {})
sys.modules["xgrammar"].Grammar = type(
"Grammar",
(),
{
"from_json_schema": staticmethod(Mock),
"from_ebnf": staticmethod(Mock),
"from_structural_tag": staticmethod(Mock),
},
)
sys.modules["xgrammar"].GrammarMatcher = MockGrammarMatcher
sys.modules["xgrammar"].GrammarCompiler = Mock
sys.modules["xgrammar"].TokenizerInfo = Mock
sys.modules["xgrammar"].StructuralTagItem = Mock
sys.modules["xgrammar"].allocate_token_bitmask = mock_allocate_token_bitmask
sys.modules["xgrammar"].apply_token_bitmask_inplace = lambda logits, bitmask, indices=None: None
class MockLogitsProcessorBase:
def __init__(self, enable_reasoning=False):
self.enable_reasoning = enable_reasoning
class MockBackendBase:
def __init__(self, fd_config=None):
self.hf_tokenizer = Mock()
class MockBaseChecker:
pass
sys.modules["fastdeploy.model_executor.guided_decoding"].LogitsProcessorBase = MockLogitsProcessorBase
sys.modules["fastdeploy.model_executor.guided_decoding"].BackendBase = MockBackendBase
sys.modules["fastdeploy.model_executor.guided_decoding"].BaseChecker = MockBaseChecker
class MockGrammarCompiler:
def __init__(self, tokenizer_info, **kwargs):
self.tokenizer_info = tokenizer_info
def compile_json_schema(self, schema_str, **kwargs):
return Mock()
def compile_regex(self, pattern_str):
return Mock()
def compile_grammar(self, grammar_str):
return Mock()
def compile_structural_tag(self, tags, triggers):
return Mock()
sys.modules["xgrammar"].GrammarCompiler = MockGrammarCompiler
sys.modules["xgrammar"].TokenizerInfo = type("TokenizerInfo", (), {"from_huggingface": staticmethod(Mock)})
sys.modules["xgrammar"].StructuralTagItem = lambda begin, schema, end: Mock(begin=begin, schema=schema, end=end)
sys.modules["fastdeploy.utils"].llm_logger = type("llm_logger", (), {"error": Mock, "info": Mock, "warning": Mock})()
# Import the target module
spec = importlib.util.spec_from_file_location(
"xgrammar_backend",
os.path.join(os.path.dirname(__file__), "../../../fastdeploy/model_executor/guided_decoding/xgrammar_backend.py"),
)
xgrammar_backend = importlib.util.module_from_spec(spec)
sys.modules["fastdeploy.model_executor.guided_decoding.xgrammar_backend"] = xgrammar_backend
spec.loader.exec_module(xgrammar_backend)
XGrammarProcessor = xgrammar_backend.XGrammarProcessor
XGrammarBackend = xgrammar_backend.XGrammarBackend
XGrammarChecker = xgrammar_backend.XGrammarChecker
class TestXGrammarProcessor(unittest.TestCase):
"""Test cases for XGrammarProcessor class."""
def setUp(self):
self.mock_compiled_grammar = Mock()
self.processor = XGrammarProcessor(
compiled_grammar=self.mock_compiled_grammar,
terminate_without_stop_token=False,
override_stop_tokens=None,
vocab_size=1000,
batch_size=4,
enable_thinking=False,
)
def test_processor_initialization(self):
"""Test processor initialization (covers lines 75-88)."""
# Default parameters
processor = XGrammarProcessor(compiled_grammar=self.mock_compiled_grammar)
self.assertIsNone(processor.vocab_size)
self.assertIsNone(processor.batch_size)
self.assertFalse(processor.terminate_without_stop_token)
self.assertIsNone(processor.override_stop_tokens)
self.assertFalse(processor.enable_reasoning)
# Custom parameters
processor = XGrammarProcessor(
compiled_grammar=self.mock_compiled_grammar,
terminate_without_stop_token=True,
override_stop_tokens=[1, 2, 3],
vocab_size=2000,
batch_size=8,
enable_thinking=True,
)
self.assertTrue(processor.terminate_without_stop_token)
self.assertEqual(processor.override_stop_tokens, [1, 2, 3])
self.assertEqual(processor.vocab_size, 2000)
self.assertEqual(processor.batch_size, 8)
self.assertTrue(processor.enable_reasoning)
def test_allocate_token_bitmask(self):
"""Test token bitmask allocation (covers line 97)."""
if self.processor.batch_size is None or self.processor.vocab_size is None:
return
bitmask = self.processor.allocate_token_bitmask()
self.assertIsInstance(bitmask, Mock)
self.assertEqual(bitmask.shape, (self.processor.batch_size, self.processor.vocab_size))
def test_fill_token_bitmask(self):
"""Test filling token bitmask (covers line 110)."""
mock_bitmask = Mock()
self.processor.fill_token_bitmask(mock_bitmask, 2)
self.processor.matcher.fill_next_token_bitmask.assert_called_once_with(mock_bitmask, 2)
def test_reset(self):
"""Test resetting processor state (covers line 119)."""
self.processor.reset()
self.processor.matcher.reset.assert_called_once()
def test_accept_token_when_already_terminated(self):
"""Test accept_token when already terminated (covers lines 130-132)."""
self.processor.is_terminated = True
result = self.processor.accept_token(42)
self.assertFalse(result)
self.processor.matcher.accept_token.assert_not_called()
def test_accept_token_sets_terminated_flag(self):
"""Test accept_token sets terminated flag (covers lines 136-138)."""
self.processor.matcher.accept_token.return_value = True
self.processor.matcher.is_terminated.side_effect = [False, True]
result = self.processor.accept_token(42)
self.assertTrue(result)
self.assertTrue(self.processor.is_terminated)
def test_copy(self):
"""Test creating a copy of the processor (covers line 147)."""
copied = self.processor.copy()
self.assertIsNot(copied, self.processor)
self.assertEqual(copied.compiled_grammar, self.processor.compiled_grammar)
self.assertEqual(copied.terminate_without_stop_token, self.processor.terminate_without_stop_token)
self.assertEqual(copied.override_stop_tokens, self.processor.override_stop_tokens)
self.assertEqual(copied.vocab_size, self.processor.vocab_size)
self.assertEqual(copied.batch_size, self.processor.batch_size)
class TestXGrammarBackend(unittest.TestCase):
"""Test cases for XGrammarBackend class."""
def setUp(self):
self.mock_fd_config = Mock()
self.mock_fd_config.model_config.vocab_size = 1000
self.mock_fd_config.scheduler_config.max_num_seqs = 4
self.mock_fd_config.structured_outputs_config.disable_any_whitespace = False
with patch("xgrammar.TokenizerInfo.from_huggingface"), patch("xgrammar.GrammarCompiler"):
self.backend = XGrammarBackend(self.mock_fd_config)
def _get_sample_structural_tag(self):
return json.dumps(
{
"structures": [{"begin": "<tool>", "schema": {"type": "string"}, "end": "</tool>"}],
"triggers": ["<tool>"],
}
)
def test_backend_initialization(self):
"""Test backend initialization (covers lines 175-196)."""
test_configs = [
(1000, 4, False, True),
(2000, 8, True, False),
]
for vocab_size, max_num_seqs, disable_any_white, expected_any_white in test_configs:
mock_fd_config = Mock()
mock_fd_config.model_config.vocab_size = vocab_size
mock_fd_config.scheduler_config.max_num_seqs = max_num_seqs
mock_fd_config.structured_outputs_config.disable_any_whitespace = disable_any_white
with patch("xgrammar.TokenizerInfo.from_huggingface"), patch("xgrammar.GrammarCompiler"):
backend = XGrammarBackend(mock_fd_config)
self.assertEqual(backend.vocab_size, vocab_size)
self.assertEqual(backend.batch_size, max_num_seqs)
self.assertEqual(backend.any_whitespace, expected_any_white)
def test_backend_initialization_tokenizer_failure(self):
"""Test backend initialization with tokenizer failure (covers lines 175-196)."""
mock_fd_config = Mock()
mock_fd_config.model_config.vocab_size = 1000
mock_fd_config.scheduler_config.max_num_seqs = 4
mock_fd_config.structured_outputs_config.disable_any_whitespace = False
with patch("xgrammar.TokenizerInfo.from_huggingface", side_effect=Exception("Tokenizer error")):
with self.assertRaises(Exception) as context:
XGrammarBackend(mock_fd_config)
self.assertIn("Failed to load XGrammar tokenizer", str(context.exception))
def test_create_processor(self):
"""Test creating a processor instance (covers line 217)."""
mock_compiled_grammar = Mock()
processor = self.backend._create_processor(
compiled_grammar=mock_compiled_grammar,
terminate_without_stop_token=True,
override_stop_tokens=[1, 2, 3],
enable_thinking=True,
)
self.assertIsInstance(processor, XGrammarProcessor)
self.assertEqual(processor.compiled_grammar, mock_compiled_grammar)
self.assertTrue(processor.terminate_without_stop_token)
self.assertEqual(processor.override_stop_tokens, [1, 2, 3])
self.assertEqual(processor.vocab_size, self.backend.vocab_size)
self.assertEqual(processor.batch_size, self.backend.batch_size)
self.assertTrue(processor.enable_reasoning)
def test_structural_tag_processor_parsing_and_tag_construction(self):
"""Test structural tag parsing (covers lines 290-305)."""
complex_tag = json.dumps(
{
"structures": [
{"begin": "<tool>", "schema": {"type": "string"}, "end": "</tool>"},
{"begin": "<function>", "schema": {"type": "object"}, "end": "</function>"},
],
"triggers": ["<tool>", "<function>"],
}
)
with patch.object(self.backend.grammar_compiler, "compile_structural_tag") as mock_compile:
mock_compile.return_value = Mock()
processor = self.backend._structural_tag_processor(complex_tag)
self.assertIsNotNone(processor)
mock_compile.assert_called_once()
def test_processor_error_logging(self):
"""Test processors log errors (covers lines 237-241, 255-259, 273-277)."""
from fastdeploy.utils import llm_logger
test_cases = [
("json", "compile_json_schema", "invalid schema", "_json_processor"),
("regex", "compile_regex", "[invalid", "_regex_processor"),
("grammar", "compile_grammar", "invalid", "_grammar_processor"),
]
for name, compile_method, test_input, processor_method in test_cases:
with self.subTest(name=name):
with patch.object(self.backend.grammar_compiler, compile_method, side_effect=Exception("Failed")):
with patch.object(llm_logger, "error") as mock_logger:
processor = getattr(self.backend, processor_method)(test_input)
self.assertIsNone(processor)
mock_logger.assert_called_once()
class TestXGrammarChecker(unittest.TestCase):
"""Test cases for XGrammarChecker class."""
def setUp(self):
self.checker = XGrammarChecker(disable_any_whitespace=False)
def _create_mock_request(self):
request = Mock()
request.guided_json = None
request.guided_grammar = None
request.guided_json_object = None
request.guided_choice = None
request.structural_tag = None
return request
def _get_sample_structural_tag(self):
return json.dumps(
{
"structures": [{"begin": "<tool>", "schema": {"type": "string"}, "end": "</tool>"}],
"triggers": ["<tool>"],
}
)
def _test_schema_format_error(self, field_name, field_value, patch_target, error_substring):
request = self._create_mock_request()
setattr(request, field_name, field_value)
with patch(patch_target, side_effect=RuntimeError("test error")):
_, error = self.checker.schema_format(request)
self.assertIsNotNone(error)
self.assertIn(error_substring, error)
def test_unsupported_json_schema(self):
"""Test detection of unsupported features in JSON schema (covers lines 339-375)."""
test_schemas = [
{"type": "object", "properties": {"age": {"type": "number", "multipleOf": 2}}},
{"type": "array", "uniqueItems": True},
{"type": "array", "contains": {"type": "number"}},
{"type": "array", "minContains": 1},
{"type": "array", "maxContains": 5},
{"type": "string", "format": "date-time"},
{"type": "object", "minProperties": 1},
{"type": "object", "maxProperties": 10},
{"type": "object", "propertyNames": {"type": "string"}},
{"type": "object", "patternProperties": {"^S_": {"type": "string"}}},
{
"type": "object",
"properties": {
"items": {
"type": "array",
"items": {"type": "object", "properties": {"value": {"type": "number", "multipleOf": 0.1}}},
}
},
},
]
for schema in test_schemas:
self.assertTrue(self.checker._unsupported_json_schema(schema), f"Failed for: {schema}")
def test_schema_format_guided_json_success(self):
"""Test successful formatting of guided_json (covers lines 388, 391-393)."""
request = self._create_mock_request()
original_json = {"type": "object", "properties": {"name": {"type": "string"}}}
request.guided_json = original_json
with patch("xgrammar.Grammar.from_json_schema"):
result_request, error = self.checker.schema_format(request)
self.assertIsNone(error)
expected_json = json.dumps(original_json)
self.assertEqual(result_request.guided_json, expected_json)
def test_schema_format_guided_json_invalid_format(self):
"""Test formatting of invalid guided_json (covers lines 391-393)."""
self._test_schema_format_error(
"guided_json", "invalid json", "xgrammar.Grammar.from_json_schema", "Invalid JSON format"
)
def test_schema_format_guided_json_unsupported_schema(self):
"""Test formatting of guided_json with unsupported schema (covers lines 396-397)."""
request = self._create_mock_request()
request.guided_json = {"type": "number", "multipleOf": 2}
with patch("xgrammar.Grammar.from_json_schema"):
with patch.object(self.checker, "_unsupported_json_schema", return_value=True):
_, error = self.checker.schema_format(request)
self.assertIsNotNone(error)
self.assertIn("unsupported JSON schema", error)
def test_schema_format_guided_grammar_invalid_format(self):
"""Test formatting of invalid guided_grammar (covers lines 406-408)."""
self._test_schema_format_error(
"guided_grammar", "invalid grammar", "xgrammar.Grammar.from_ebnf", "Invalid grammar format"
)
def test_schema_format_guided_json_object(self):
"""Test formatting of guided_json_object (covers lines 412-413)."""
request = self._create_mock_request()
request.guided_json_object = True
result_request, error = self.checker.schema_format(request)
self.assertIsNone(error)
self.assertEqual(result_request.guided_json, '{"type": "object"}')
def test_schema_format_structural_tag_success(self):
"""Test successful formatting of structural_tag (covers lines 439-440)."""
request = self._create_mock_request()
request.structural_tag = self._get_sample_structural_tag()
with patch("xgrammar.Grammar.from_structural_tag"):
_, error = self.checker.schema_format(request)
self.assertIsNone(error)
def test_schema_format_structural_tag_invalid_grammar(self):
"""Test formatting of structural_tag with invalid grammar (covers lines 438-440)."""
self._test_schema_format_error(
"structural_tag",
self._get_sample_structural_tag(),
"xgrammar.Grammar.from_structural_tag",
"Invalid structural_tag format",
)
if __name__ == "__main__":
unittest.main(verbosity=2)