Files
FastDeploy/tests/quantization/test_quantization_init.py
T
gongweibao a6351dea0b [BugFix][Optimization] Replace silent failures with catchable exceptions and informative error messages (#6533)
* init

* init

* fix format

* add

* add files

* add ut

* fix some

* add ut

* add more

* add

* fix pre-commit

* fix pre-commit

* fix cover

* skip long seq

* add

* add

* fix

* remove not need

* fix set attr

* fix comments

* fix comments

* fix failed tests

---------

Co-authored-by: gongweibao <gognweibao@baidu.com>
2026-03-16 21:32:43 +08:00

61 lines
2.1 KiB
Python

# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for quantization module initialization and parse_quant_config.
"""
import unittest
from fastdeploy.model_executor.layers.quantization import (
_compute_hadamard_block_size,
get_quantization_config,
)
class TestComputeHadamardBlockSize(unittest.TestCase):
"""Tests for _compute_hadamard_block_size function."""
def test_basic_case(self):
"""Test basic computation."""
result = _compute_hadamard_block_size(4096, 2)
self.assertGreater(result, 0)
self.assertTrue(result & (result - 1) == 0) # Power of 2
def test_not_divisible_raises(self):
"""Test that non-divisible moe_intermediate_size raises ValueError."""
with self.assertRaises(ValueError) as ctx:
_compute_hadamard_block_size(4095, 2)
self.assertIn("must be divisible", str(ctx.exception))
class TestGetQuantizationConfig(unittest.TestCase):
"""Tests for get_quantization_config function."""
def test_valid_quantization_method(self):
"""Test getting config for valid quantization method."""
for method in ["wint4", "wint8", "block_wise_fp8", "w4afp8"]:
config_cls = get_quantization_config(method)
self.assertIsNotNone(config_cls)
def test_invalid_quantization_method_raises(self):
"""Test that invalid method raises ValueError."""
with self.assertRaises(ValueError) as ctx:
get_quantization_config("invalid_method")
self.assertIn("Invalid quantization method", str(ctx.exception))
if __name__ == "__main__":
unittest.main()