[Feature] support qkv&gate linear fusion (#6455)

* [Feature] support qkv&gate linear fusion

* add test
This commit is contained in:
Longzhi Wang
2026-02-24 15:20:29 +08:00
committed by GitHub
parent 38c3e02470
commit 22566168c3
3 changed files with 692 additions and 0 deletions
+220
View File
@@ -1106,3 +1106,223 @@ class KVBatchLinear(nn.Layer):
return self.forward_v_b(x)
else:
raise ValueError(f"proj_type must be 'k' or 'v', got {proj_type}")
class QKVGateParallelLinear(ColumnParallelLinear):
"""
QKVGateParallelLinear
"""
def __init__(
self,
fd_config,
prefix,
with_bias=False,
num_heads: Optional[int] = None,
kv_num_heads: Optional[int] = None,
hidden_size: Optional[int] = None,
head_dim: Optional[int] = None,
skip_quant: bool = False,
weight_dtype: str = "",
):
self.prefix = prefix
self.qkv_weight_key = f"{prefix}.weight".replace("qkvg", "qkv")
self.gate_weight_key = f"{prefix}.weight".replace("qkvg_proj", "gate")
self.qkv_bias_key = f"{prefix}.bias".replace("qkvg", "qkv")
self.gate_bias_key = f"{prefix}.bias".replace("qkvg_proj", "gate")
self.num_heads = fd_config.model_config.num_attention_heads if num_heads is None else num_heads
self.kv_num_heads = fd_config.model_config.num_key_value_heads if kv_num_heads is None else kv_num_heads
self.hidden_size = fd_config.model_config.hidden_size if hidden_size is None else hidden_size
self.head_dim = fd_config.model_config.head_dim if head_dim is None else head_dim
self.tp_size = fd_config.parallel_config.tensor_parallel_size
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
self.num_heads_per_rank = divide(self.num_heads, self.tp_size)
if self.kv_num_heads < self.tp_size and self.tp_size % self.kv_num_heads == 0:
self.kv_num_heads_per_rank = 1
self.num_kv_head_replicas = divide(self.tp_size, self.kv_num_heads)
output_size = (2 * self.num_heads + 2 * self.tp_size) * self.head_dim
else:
self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.tp_size)
self.num_kv_head_replicas = 1
output_size = (2 * self.num_heads + 2 * self.kv_num_heads) * self.head_dim
input_size = self.hidden_size
super().__init__(
fd_config=fd_config,
prefix=prefix,
input_size=input_size,
output_size=output_size,
with_bias=with_bias,
skip_quant=skip_quant,
weight_dtype=weight_dtype,
)
def _get_shard_size_mapping(self, loaded_shard_id: str, head_dim: int):
shard_size_mapping = {
"q": self.num_heads_per_rank * head_dim,
"k": self.kv_num_heads_per_rank * head_dim,
"v": self.kv_num_heads_per_rank * head_dim,
}
return shard_size_mapping.get(loaded_shard_id)
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
assert loaded_shard_id in [
"qkv",
"gate",
], f"loaded_shard_id must be one of ['qkv', 'gate'], but got {loaded_shard_id}"
if loaded_shard_id == "qkv":
self.qkv_weight_loader(param, loaded_weight, None)
else:
self.gate_weight_loader(param, loaded_weight)
def qkv_weight_loader(self, param, loaded_weight, loaded_shard_id):
output_dim = getattr(param, "output_dim", None)
assert output_dim is not None
dim = -1 if output_dim else 0
# q_head + gate_head + kv_head
head_dim = param.shape[dim] // (2 * self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank)
weight_need_transpose = getattr(param, "weight_need_transpose", False)
if loaded_shard_id is None:
if weight_need_transpose:
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
# Avoid redundant transpose of fused weights when weight_loader is called iteratively
param.weight_need_transpose = False
# Loaded weight is already fused on disk
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("q", 0, self.num_heads * head_dim),
("k", self.num_heads * head_dim, self.kv_num_heads * head_dim),
("v", (self.num_heads + self.kv_num_heads) * head_dim, self.kv_num_heads * head_dim),
]
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = slice_fn(
loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size
)
self.qkv_weight_loader(param, loaded_weight_shard, shard_id)
else:
# split q k v
assert loaded_shard_id in ["q", "k", "v"]
if weight_need_transpose:
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
# Tensor parallelism splits the weight along the output_dim
if self.tp_size > 1 and output_dim is not None:
block_size = self._get_shard_size_mapping(loaded_shard_id, head_dim)
shard_id = self.local_rank if loaded_shard_id == "q" else self.local_rank // self.num_kv_head_replicas
shard_offset = shard_id * block_size
shard_size = block_size
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size)
if not param._is_initialized():
param.initialize()
if loaded_shard_id == "q":
param_shard_offset = 0
param_shard_size = self.num_heads_per_rank * head_dim
elif loaded_shard_id == "k":
param_shard_offset = self.num_heads_per_rank * head_dim
param_shard_size = self.kv_num_heads_per_rank * head_dim
else:
# loaded_shard_id == "v"
param_shard_offset = (self.num_heads_per_rank + self.kv_num_heads_per_rank) * head_dim
param_shard_size = self.kv_num_heads_per_rank * head_dim
if hasattr(param, "tensor_track"):
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
param = slice_fn(param, output_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
# Ensure loaded weight dtype matches model param dtype
if loaded_weight.dtype != param.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.view(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
h2d_copy(param, loaded_weight)
def gate_weight_loader(self, param, loaded_weight):
output_dim = getattr(param, "output_dim", None)
assert output_dim is not None
dim = -1 if output_dim else 0
# q_head + gate_head + kv_head
head_dim = param.shape[dim] // (2 * self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank)
weight_need_transpose = getattr(param, "weight_need_transpose", False)
if weight_need_transpose:
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
# Tensor parallelism splits the weight along the output_dim
if self.tp_size > 1 and output_dim is not None:
block_size = self.num_heads_per_rank * head_dim
shard_offset = self.local_rank * block_size
shard_size = block_size
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size)
if not param._is_initialized():
param.initialize()
param_shard_offset = (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank) * head_dim
param_shard_size = self.num_heads_per_rank * head_dim
if hasattr(param, "tensor_track"):
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
param = slice_fn(param, output_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size)
assert param.shape == loaded_weight.shape, (
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
# Ensure loaded weight dtype matches model param dtype
if loaded_weight.dtype != param.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.view(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
h2d_copy(param, loaded_weight)
def load_weight(self, state_dict: dict):
"""
Load the weight from the state dictionary.
Args:
state_dict (dict): A dictionary containing the weights
"""
qkv_weight_tensor = get_tensor(state_dict.pop(self.qkv_weight_key))
gate_weight_tensor = get_tensor(state_dict.pop(self.gate_weight_key))
qkvg_weight_tensor = paddle.concat([qkv_weight_tensor, gate_weight_tensor], axis=-1)
self.quant_method.process_loaded_weights(self, qkvg_weight_tensor)
def load_state_dict(self, state_dict: dict):
"""
Load the checkpoint state dictionary into the layer.
Args:
state_dict (dict): A dictionary containing the checkpoint weights and biases.
"""
# weight
assert (
self.qkv_weight_key in state_dict.keys() and self.gate_weight_key in state_dict.keys()
), f"{self.qkv_weight_key} or {self.gate_weight_key} not found in state_dict"
if self.is_quantized:
self.load_prequant_weight(state_dict)
else:
self.load_weight(state_dict)
# bias
if self.with_bias:
assert (
self.qkv_bias_key in state_dict.keys() and self.gate_bias_key in state_dict.keys()
), f"{self.qkv_bias_key} or {self.gate_bias_key} not found in state_dict"
qkv_bias_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.qkv_bias_key)))
gate_bias_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.gate_bias_key)))
bias_tensor = paddle.concat([qkv_bias_tensor, gate_bias_tensor], axis=-1)
self.bias.set_value(bias_tensor)