mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
419bf79951
* [Feature] support qkv&gate linear fusion * add test
471 lines
17 KiB
Python
471 lines
17 KiB
Python
"""
|
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import shutil
|
|
import unittest
|
|
|
|
import paddle
|
|
|
|
from fastdeploy.config import (
|
|
CacheConfig,
|
|
FDConfig,
|
|
GraphOptimizationConfig,
|
|
LoadConfig,
|
|
ModelConfig,
|
|
ParallelConfig,
|
|
)
|
|
from fastdeploy.model_executor.layers.linear import QKVGateParallelLinear
|
|
from fastdeploy.scheduler import SchedulerConfig
|
|
|
|
paddle.set_default_dtype("bfloat16")
|
|
|
|
|
|
class DummyPlatform:
|
|
"""Mock platform class for testing."""
|
|
|
|
def __init__(self, cuda=True):
|
|
self._cuda = cuda
|
|
|
|
def is_cuda(self):
|
|
return self._cuda
|
|
|
|
def is_xpu(self):
|
|
return False
|
|
|
|
def is_iluvatar(self):
|
|
return False
|
|
|
|
def is_dcu(self):
|
|
return False
|
|
|
|
def is_maca(self):
|
|
return False
|
|
|
|
def is_gcu(self):
|
|
return False
|
|
|
|
def is_intel_hpu(self):
|
|
return False
|
|
|
|
|
|
class TestQKVGateParallelLinear(unittest.TestCase):
|
|
"""Unit tests for QKVGateParallelLinear layer."""
|
|
|
|
def setUp(self) -> None:
|
|
"""Set up test environment."""
|
|
self.tmp_dir = "./tmp_qkvg_test"
|
|
self.model_name_or_path = None
|
|
self.model_config = self.build_model_config()
|
|
|
|
def tearDown(self) -> None:
|
|
"""Clean up test environment."""
|
|
if self.model_name_or_path and os.path.exists(self.model_name_or_path):
|
|
shutil.rmtree(self.model_name_or_path)
|
|
if os.path.exists(self.tmp_dir):
|
|
shutil.rmtree(self.tmp_dir)
|
|
|
|
def build_model_config(self) -> ModelConfig:
|
|
"""Build a mock model config for testing."""
|
|
config_path = self.build_config_json()
|
|
return ModelConfig(
|
|
{
|
|
"model": config_path,
|
|
"max_model_len": 2048,
|
|
}
|
|
)
|
|
|
|
def build_config_json(self) -> str:
|
|
"""Build a temporary config JSON file."""
|
|
config_dict = {
|
|
"architectures": ["TestModelForCausalLM"],
|
|
"hidden_size": 1024,
|
|
"num_attention_heads": 16,
|
|
"num_key_value_heads": 4,
|
|
"num_hidden_layers": 12,
|
|
"dtype": "bfloat16",
|
|
}
|
|
|
|
os.makedirs(self.tmp_dir, exist_ok=True)
|
|
config_file = os.path.join(self.tmp_dir, "config.json")
|
|
with open(config_file, "w") as f:
|
|
json.dump(config_dict, f)
|
|
self.model_name_or_path = os.path.join(os.getcwd(), self.tmp_dir)
|
|
return self.model_name_or_path
|
|
|
|
def create_fd_config(self, tp_size: int = 1, tp_rank: int = 0) -> FDConfig:
|
|
"""Create a FDConfig for testing."""
|
|
return FDConfig(
|
|
model_config=self.model_config,
|
|
parallel_config=ParallelConfig(
|
|
{
|
|
"tensor_parallel_size": tp_size,
|
|
"tensor_parallel_rank": tp_rank,
|
|
"expert_parallel_size": 1,
|
|
"data_parallel_size": 1,
|
|
}
|
|
),
|
|
load_config=LoadConfig({}),
|
|
graph_opt_config=GraphOptimizationConfig({}),
|
|
scheduler_config=SchedulerConfig({}),
|
|
cache_config=CacheConfig({}),
|
|
)
|
|
|
|
def test_initialization_basic(self):
|
|
"""Test basic initialization of QKVGateParallelLinear."""
|
|
fd_config = self.create_fd_config()
|
|
layer = QKVGateParallelLinear(
|
|
fd_config=fd_config,
|
|
prefix="test.qkvg_proj",
|
|
with_bias=False,
|
|
)
|
|
|
|
# Check basic attributes
|
|
self.assertEqual(layer.prefix, "test.qkvg_proj")
|
|
self.assertEqual(layer.num_heads, 16)
|
|
self.assertEqual(layer.kv_num_heads, 4)
|
|
self.assertEqual(layer.hidden_size, 1024)
|
|
self.assertEqual(layer.head_dim, 64) # 1024 / 16
|
|
self.assertEqual(layer.tp_size, 1)
|
|
self.assertEqual(layer.local_rank, 0)
|
|
|
|
# Check output_size calculation: (2*num_heads + 2*kv_num_heads) * head_dim
|
|
# (2*16 + 2*4) * 64 = 40 * 64 = 2560
|
|
expected_output_size = (2 * 16 + 2 * 4) * 64
|
|
self.assertEqual(layer.output_size, expected_output_size)
|
|
|
|
def test_initialization_with_custom_params(self):
|
|
"""Test initialization with custom parameters."""
|
|
fd_config = self.create_fd_config()
|
|
layer = QKVGateParallelLinear(
|
|
fd_config=fd_config,
|
|
prefix="test.qkvg_proj",
|
|
with_bias=True,
|
|
num_heads=32,
|
|
kv_num_heads=8,
|
|
hidden_size=2048,
|
|
head_dim=64,
|
|
)
|
|
|
|
self.assertEqual(layer.num_heads, 32)
|
|
self.assertEqual(layer.kv_num_heads, 8)
|
|
# Note: ColumnParallelLinear.__init__ overwrites hidden_size with fd_config.model_config.hidden_size
|
|
# So hidden_size will be 1024 from model_config, not 2048 from the parameter
|
|
self.assertEqual(layer.hidden_size, 1024)
|
|
self.assertEqual(layer.head_dim, 64)
|
|
self.assertTrue(layer.with_bias)
|
|
|
|
# Note: output_size is divided by tp_size in ColumnParallelLinear parent
|
|
# For tp_size=1, output_size stays the same
|
|
# Check output_size: (2*32 + 2*8) * 64 = 80 * 64 = 5120
|
|
self.assertEqual(layer.output_size, 5120)
|
|
|
|
# Test with head_dim not provided - uses model_config.head_dim (64)
|
|
layer2 = QKVGateParallelLinear(
|
|
fd_config=fd_config,
|
|
prefix="test.qkvg_proj",
|
|
with_bias=True,
|
|
num_heads=32,
|
|
kv_num_heads=8,
|
|
hidden_size=2048,
|
|
)
|
|
# When head_dim is not provided, it uses fd_config.model_config.head_dim
|
|
self.assertEqual(layer2.head_dim, 64)
|
|
|
|
def test_initialization_with_kv_heads_less_than_tp_size(self):
|
|
"""Test initialization when kv_num_heads < tp_size."""
|
|
fd_config = self.create_fd_config(tp_size=8, tp_rank=0)
|
|
layer = QKVGateParallelLinear(
|
|
fd_config=fd_config,
|
|
prefix="test.qkvg_proj",
|
|
num_heads=16,
|
|
kv_num_heads=4,
|
|
)
|
|
|
|
self.assertEqual(layer.tp_size, 8)
|
|
self.assertEqual(layer.local_rank, 0)
|
|
|
|
# When kv_num_heads < tp_size and tp_size % kv_num_heads == 0
|
|
# kv_num_heads_per_rank = 1, num_kv_head_replicas = tp_size / kv_num_heads = 2
|
|
self.assertEqual(layer.kv_num_heads_per_rank, 1)
|
|
self.assertEqual(layer.num_kv_head_replicas, 2)
|
|
|
|
# output_size calculation:
|
|
# Full output: (2 * num_heads + 2 * tp_size) * head_dim
|
|
# = (2 * 16 + 2 * 8) * 64 = (32 + 16) * 64 = 48 * 64 = 3072
|
|
# Then ColumnParallelLinear divides by tp_size: 3072 / 8 = 384
|
|
expected_output_size = 384
|
|
self.assertEqual(layer.output_size, expected_output_size)
|
|
|
|
def test_initialization_with_tensor_parallel(self):
|
|
"""Test initialization with tensor parallelism."""
|
|
fd_config = self.create_fd_config(tp_size=4, tp_rank=2)
|
|
layer = QKVGateParallelLinear(
|
|
fd_config=fd_config,
|
|
prefix="test.qkvg_proj",
|
|
)
|
|
|
|
self.assertEqual(layer.tp_size, 4)
|
|
self.assertEqual(layer.local_rank, 2)
|
|
self.assertEqual(layer.num_heads_per_rank, 4) # 16 / 4
|
|
self.assertEqual(layer.kv_num_heads_per_rank, 1) # 4 / 4
|
|
|
|
def test_weight_keys(self):
|
|
"""Test that weight keys are correctly generated."""
|
|
fd_config = self.create_fd_config()
|
|
layer = QKVGateParallelLinear(
|
|
fd_config=fd_config,
|
|
prefix="model.layers.0.self_attn.qkvg_proj",
|
|
with_bias=True,
|
|
)
|
|
|
|
# Check qkv weight key
|
|
self.assertEqual(layer.qkv_weight_key, "model.layers.0.self_attn.qkv_proj.weight")
|
|
# Check gate weight key
|
|
self.assertEqual(layer.gate_weight_key, "model.layers.0.self_attn.gate.weight")
|
|
# Check bias keys
|
|
self.assertEqual(layer.qkv_bias_key, "model.layers.0.self_attn.qkv_proj.bias")
|
|
self.assertEqual(layer.gate_bias_key, "model.layers.0.self_attn.gate.bias")
|
|
|
|
def test_get_shard_size_mapping(self):
|
|
"""Test _get_shard_size_mapping method."""
|
|
fd_config = self.create_fd_config(tp_size=4, tp_rank=0)
|
|
layer = QKVGateParallelLinear(
|
|
fd_config=fd_config,
|
|
prefix="test.qkvg_proj",
|
|
)
|
|
|
|
head_dim = 64
|
|
# num_heads_per_rank = 16 / 4 = 4
|
|
# kv_num_heads_per_rank = 4 / 4 = 1
|
|
q_size = layer._get_shard_size_mapping("q", head_dim)
|
|
k_size = layer._get_shard_size_mapping("k", head_dim)
|
|
v_size = layer._get_shard_size_mapping("v", head_dim)
|
|
|
|
self.assertEqual(q_size, 4 * 64) # num_heads_per_rank * head_dim
|
|
self.assertEqual(k_size, 1 * 64) # kv_num_heads_per_rank * head_dim
|
|
self.assertEqual(v_size, 1 * 64) # kv_num_heads_per_rank * head_dim
|
|
|
|
# Test unknown shard_id
|
|
unknown_size = layer._get_shard_size_mapping("unknown", head_dim)
|
|
self.assertIsNone(unknown_size)
|
|
|
|
def test_weight_loader_valid_shard_id(self):
|
|
"""Test weight_loader with valid shard IDs."""
|
|
fd_config = self.create_fd_config()
|
|
layer = QKVGateParallelLinear(
|
|
fd_config=fd_config,
|
|
prefix="test.qkvg_proj",
|
|
)
|
|
|
|
# Test valid shard_id "qkv"
|
|
param_mock = paddle.zeros([layer.output_size, layer.input_size], dtype="float16")
|
|
weight_mock = paddle.zeros([layer.output_size // 2, layer.input_size], dtype="float16")
|
|
|
|
# Should not raise an assertion error
|
|
try:
|
|
layer.weight_loader(param_mock, weight_mock, "qkv")
|
|
except AssertionError as e:
|
|
if "loaded_shard_id must be one of" in str(e):
|
|
self.fail("weight_loader should accept 'qkv' as a valid shard_id")
|
|
|
|
# Test valid shard_id "gate"
|
|
try:
|
|
layer.weight_loader(param_mock, weight_mock, "gate")
|
|
except AssertionError as e:
|
|
if "loaded_shard_id must be one of" in str(e):
|
|
self.fail("weight_loader should accept 'gate' as a valid shard_id")
|
|
|
|
def test_weight_loader_invalid_shard_id(self):
|
|
"""Test weight_loader with invalid shard IDs."""
|
|
fd_config = self.create_fd_config()
|
|
layer = QKVGateParallelLinear(
|
|
fd_config=fd_config,
|
|
prefix="test.qkvg_proj",
|
|
)
|
|
|
|
param_mock = paddle.zeros([layer.output_size, layer.input_size], dtype="float16")
|
|
weight_mock = paddle.zeros([layer.output_size, layer.input_size], dtype="float16")
|
|
|
|
# Should raise an assertion error for invalid shard_id
|
|
with self.assertRaises(AssertionError) as context:
|
|
layer.weight_loader(param_mock, weight_mock, "invalid")
|
|
|
|
self.assertIn("loaded_shard_id must be one of", str(context.exception))
|
|
|
|
def test_load_state_dict_success(self):
|
|
"""Test loading state_dict with valid qkv and gate weights."""
|
|
fd_config = self.create_fd_config()
|
|
layer = QKVGateParallelLinear(
|
|
fd_config=fd_config,
|
|
prefix="test.qkvg_proj",
|
|
with_bias=False,
|
|
)
|
|
|
|
# Calculate the qkv and gate weight sizes
|
|
# Total output_size = (2*num_heads + 2*kv_num_heads) * head_dim
|
|
# qkv part = (num_heads + 2*kv_num_heads) * head_dim
|
|
# gate part = num_heads * head_dim
|
|
qkv_weight_size = (layer.num_heads + 2 * layer.kv_num_heads) * layer.head_dim // layer.tp_size
|
|
gate_weight_size = layer.num_heads * layer.head_dim // layer.tp_size
|
|
|
|
# Create mock weights
|
|
qkv_weight = paddle.randn([layer.input_size, qkv_weight_size], dtype="float16")
|
|
gate_weight = paddle.randn([layer.input_size, gate_weight_size], dtype="float16")
|
|
|
|
state_dict = {
|
|
"test.qkv_proj.weight": qkv_weight.numpy(),
|
|
"test.gate.weight": gate_weight.numpy(),
|
|
}
|
|
|
|
# This should not raise an error (weight keys may not match due to fallback)
|
|
# We're testing the logic flow
|
|
try:
|
|
layer.load_state_dict(state_dict)
|
|
except KeyError:
|
|
# If this happens, the weight keys don't match expected format
|
|
# This is OK for testing - we're verifying the logic flow
|
|
pass
|
|
|
|
def test_load_weight_and_split_correctness(self):
|
|
"""Test that QKVGateParallelLinear correctly loads and splits qkv and gate weights."""
|
|
fd_config = self.create_fd_config()
|
|
layer = QKVGateParallelLinear(
|
|
fd_config=fd_config,
|
|
prefix="test.qkvg_proj",
|
|
with_bias=False,
|
|
)
|
|
|
|
# Calculate the qkv and gate weight sizes
|
|
# qkv part = (num_heads + 2*kv_num_heads) * head_dim
|
|
# gate part = num_heads * head_dim
|
|
qkv_weight_size = (layer.num_heads + 2 * layer.kv_num_heads) * layer.head_dim // layer.tp_size
|
|
gate_weight_size = layer.num_heads * layer.head_dim // layer.tp_size
|
|
|
|
# Create mock weights with known values for verification
|
|
import numpy as np
|
|
|
|
input_size = layer.input_size
|
|
# Use the same dtype as layer.weight to avoid precision issues during cast
|
|
weight_dtype = layer.weight.dtype
|
|
qkv_weight = paddle.randn([input_size, qkv_weight_size], dtype=weight_dtype)
|
|
gate_weight = paddle.randn([input_size, gate_weight_size], dtype=weight_dtype)
|
|
|
|
state_dict = {
|
|
"test.qkv_proj.weight": qkv_weight.numpy(),
|
|
"test.gate.weight": gate_weight.numpy(),
|
|
}
|
|
|
|
# Load the weights
|
|
layer.load_state_dict(state_dict)
|
|
|
|
# Verify the combined weight shape
|
|
self.assertEqual(layer.weight.shape, [input_size, qkv_weight_size + gate_weight_size])
|
|
|
|
# Split the loaded weight back into qkv and gate parts
|
|
loaded_qkv_weight = layer.weight[:, :qkv_weight_size]
|
|
loaded_gate_weight = layer.weight[:, qkv_weight_size:]
|
|
|
|
# Compare the original and loaded weights
|
|
# Use numpy.allclose for floating point comparison
|
|
self.assertTrue(
|
|
np.allclose(qkv_weight.numpy(), loaded_qkv_weight.numpy()),
|
|
"Loaded qkv weights do not match original qkv weights",
|
|
)
|
|
self.assertTrue(
|
|
np.allclose(gate_weight.numpy(), loaded_gate_weight.numpy()),
|
|
"Loaded gate weights do not match original gate weights",
|
|
)
|
|
|
|
def test_load_state_dict_missing_weights(self):
|
|
"""Test load_state_dict with missing required weights."""
|
|
fd_config = self.create_fd_config()
|
|
layer = QKVGateParallelLinear(
|
|
fd_config=fd_config,
|
|
prefix="test.qkvg_proj",
|
|
)
|
|
|
|
state_dict = {} # Empty state dict
|
|
|
|
with self.assertRaises(AssertionError) as context:
|
|
layer.load_state_dict(state_dict)
|
|
|
|
self.assertIn("not found in state_dict", str(context.exception))
|
|
|
|
def test_load_state_dict_missing_bias(self):
|
|
"""Test load_state_dict with with_bias=True but missing bias."""
|
|
fd_config = self.create_fd_config()
|
|
layer = QKVGateParallelLinear(
|
|
fd_config=fd_config,
|
|
prefix="test.qkvg_proj",
|
|
with_bias=True,
|
|
)
|
|
|
|
# Create weights but no bias
|
|
qkv_weight = paddle.randn([layer.input_size, 1024], dtype="float16")
|
|
gate_weight = paddle.randn([layer.input_size, 512], dtype="float16")
|
|
|
|
state_dict = {
|
|
"test.qkv_proj.weight": qkv_weight.numpy(),
|
|
"test.gate.weight": gate_weight.numpy(),
|
|
}
|
|
|
|
# This should raise an error about missing bias
|
|
with self.assertRaises((AssertionError, KeyError)):
|
|
layer.load_state_dict(state_dict)
|
|
|
|
def test_weight_shapes_after_init(self):
|
|
"""Test that weight parameter has correct shape after initialization."""
|
|
fd_config = self.create_fd_config()
|
|
layer = QKVGateParallelLinear(
|
|
fd_config=fd_config,
|
|
prefix="test.qkvg_proj",
|
|
)
|
|
|
|
# Check that weight parameter exists and has correct shape
|
|
self.assertIsNotNone(layer.weight)
|
|
|
|
# Shape should be [input_size, output_size]
|
|
# input_size = hidden_size = 1024
|
|
# output_size = (2*16 + 2*4) * 64 = 2560
|
|
expected_shape = [layer.input_size, layer.output_size]
|
|
self.assertEqual(list(layer.weight.shape), expected_shape)
|
|
|
|
def test_different_prefix_formats(self):
|
|
"""Test weight key generation with different prefix formats."""
|
|
test_cases = [
|
|
("model.layers.0.qkvg_proj", "model.layers.0.qkv_proj.weight"),
|
|
("layers.0.qkvg_proj", "layers.0.qkv_proj.weight"),
|
|
("simple.qkvg_proj", "simple.qkv_proj.weight"),
|
|
]
|
|
|
|
for prefix, expected_qkv_key in test_cases:
|
|
with self.subTest(prefix=prefix):
|
|
fd_config = self.create_fd_config()
|
|
layer = QKVGateParallelLinear(
|
|
fd_config=fd_config,
|
|
prefix=prefix,
|
|
)
|
|
self.assertEqual(layer.qkv_weight_key, expected_qkv_key)
|
|
# Gate key replaces "qkvg_proj" with "gate"
|
|
expected_gate_key = expected_qkv_key.replace("qkv_proj.weight", "gate.weight")
|
|
self.assertEqual(layer.gate_weight_key, expected_gate_key)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|