mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] support qkv&gate linear fusion (#6455)
* [Feature] support qkv&gate linear fusion * add test
This commit is contained in:
@@ -1106,3 +1106,223 @@ class KVBatchLinear(nn.Layer):
|
||||
return self.forward_v_b(x)
|
||||
else:
|
||||
raise ValueError(f"proj_type must be 'k' or 'v', got {proj_type}")
|
||||
|
||||
|
||||
class QKVGateParallelLinear(ColumnParallelLinear):
|
||||
"""
|
||||
QKVGateParallelLinear
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config,
|
||||
prefix,
|
||||
with_bias=False,
|
||||
num_heads: Optional[int] = None,
|
||||
kv_num_heads: Optional[int] = None,
|
||||
hidden_size: Optional[int] = None,
|
||||
head_dim: Optional[int] = None,
|
||||
skip_quant: bool = False,
|
||||
weight_dtype: str = "",
|
||||
):
|
||||
self.prefix = prefix
|
||||
|
||||
self.qkv_weight_key = f"{prefix}.weight".replace("qkvg", "qkv")
|
||||
self.gate_weight_key = f"{prefix}.weight".replace("qkvg_proj", "gate")
|
||||
self.qkv_bias_key = f"{prefix}.bias".replace("qkvg", "qkv")
|
||||
self.gate_bias_key = f"{prefix}.bias".replace("qkvg_proj", "gate")
|
||||
|
||||
self.num_heads = fd_config.model_config.num_attention_heads if num_heads is None else num_heads
|
||||
self.kv_num_heads = fd_config.model_config.num_key_value_heads if kv_num_heads is None else kv_num_heads
|
||||
self.hidden_size = fd_config.model_config.hidden_size if hidden_size is None else hidden_size
|
||||
self.head_dim = fd_config.model_config.head_dim if head_dim is None else head_dim
|
||||
self.tp_size = fd_config.parallel_config.tensor_parallel_size
|
||||
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||
self.num_heads_per_rank = divide(self.num_heads, self.tp_size)
|
||||
|
||||
if self.kv_num_heads < self.tp_size and self.tp_size % self.kv_num_heads == 0:
|
||||
self.kv_num_heads_per_rank = 1
|
||||
self.num_kv_head_replicas = divide(self.tp_size, self.kv_num_heads)
|
||||
output_size = (2 * self.num_heads + 2 * self.tp_size) * self.head_dim
|
||||
else:
|
||||
self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.tp_size)
|
||||
self.num_kv_head_replicas = 1
|
||||
output_size = (2 * self.num_heads + 2 * self.kv_num_heads) * self.head_dim
|
||||
input_size = self.hidden_size
|
||||
super().__init__(
|
||||
fd_config=fd_config,
|
||||
prefix=prefix,
|
||||
input_size=input_size,
|
||||
output_size=output_size,
|
||||
with_bias=with_bias,
|
||||
skip_quant=skip_quant,
|
||||
weight_dtype=weight_dtype,
|
||||
)
|
||||
|
||||
def _get_shard_size_mapping(self, loaded_shard_id: str, head_dim: int):
|
||||
shard_size_mapping = {
|
||||
"q": self.num_heads_per_rank * head_dim,
|
||||
"k": self.kv_num_heads_per_rank * head_dim,
|
||||
"v": self.kv_num_heads_per_rank * head_dim,
|
||||
}
|
||||
return shard_size_mapping.get(loaded_shard_id)
|
||||
|
||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||
assert loaded_shard_id in [
|
||||
"qkv",
|
||||
"gate",
|
||||
], f"loaded_shard_id must be one of ['qkv', 'gate'], but got {loaded_shard_id}"
|
||||
|
||||
if loaded_shard_id == "qkv":
|
||||
self.qkv_weight_loader(param, loaded_weight, None)
|
||||
else:
|
||||
self.gate_weight_loader(param, loaded_weight)
|
||||
|
||||
def qkv_weight_loader(self, param, loaded_weight, loaded_shard_id):
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
assert output_dim is not None
|
||||
dim = -1 if output_dim else 0
|
||||
|
||||
# q_head + gate_head + kv_head
|
||||
head_dim = param.shape[dim] // (2 * self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank)
|
||||
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||
if loaded_shard_id is None:
|
||||
if weight_need_transpose:
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
# Avoid redundant transpose of fused weights when weight_loader is called iteratively
|
||||
param.weight_need_transpose = False
|
||||
# Loaded weight is already fused on disk
|
||||
shard_offsets = [
|
||||
# (shard_id, shard_offset, shard_size)
|
||||
("q", 0, self.num_heads * head_dim),
|
||||
("k", self.num_heads * head_dim, self.kv_num_heads * head_dim),
|
||||
("v", (self.num_heads + self.kv_num_heads) * head_dim, self.kv_num_heads * head_dim),
|
||||
]
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
loaded_weight_shard = slice_fn(
|
||||
loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size
|
||||
)
|
||||
self.qkv_weight_loader(param, loaded_weight_shard, shard_id)
|
||||
else:
|
||||
# split q k v
|
||||
assert loaded_shard_id in ["q", "k", "v"]
|
||||
if weight_need_transpose:
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
# Tensor parallelism splits the weight along the output_dim
|
||||
if self.tp_size > 1 and output_dim is not None:
|
||||
block_size = self._get_shard_size_mapping(loaded_shard_id, head_dim)
|
||||
shard_id = self.local_rank if loaded_shard_id == "q" else self.local_rank // self.num_kv_head_replicas
|
||||
shard_offset = shard_id * block_size
|
||||
shard_size = block_size
|
||||
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size)
|
||||
|
||||
if not param._is_initialized():
|
||||
param.initialize()
|
||||
|
||||
if loaded_shard_id == "q":
|
||||
param_shard_offset = 0
|
||||
param_shard_size = self.num_heads_per_rank * head_dim
|
||||
elif loaded_shard_id == "k":
|
||||
param_shard_offset = self.num_heads_per_rank * head_dim
|
||||
param_shard_size = self.kv_num_heads_per_rank * head_dim
|
||||
else:
|
||||
# loaded_shard_id == "v"
|
||||
param_shard_offset = (self.num_heads_per_rank + self.kv_num_heads_per_rank) * head_dim
|
||||
param_shard_size = self.kv_num_heads_per_rank * head_dim
|
||||
if hasattr(param, "tensor_track"):
|
||||
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
|
||||
|
||||
param = slice_fn(param, output_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size)
|
||||
assert param.shape == loaded_weight.shape, (
|
||||
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
|
||||
)
|
||||
# Ensure loaded weight dtype matches model param dtype
|
||||
if loaded_weight.dtype != param.dtype:
|
||||
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
|
||||
loaded_weight = loaded_weight.view(param.dtype)
|
||||
else:
|
||||
loaded_weight = loaded_weight.cast(param.dtype)
|
||||
h2d_copy(param, loaded_weight)
|
||||
|
||||
def gate_weight_loader(self, param, loaded_weight):
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
assert output_dim is not None
|
||||
dim = -1 if output_dim else 0
|
||||
# q_head + gate_head + kv_head
|
||||
head_dim = param.shape[dim] // (2 * self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank)
|
||||
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||
|
||||
if weight_need_transpose:
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
|
||||
# Tensor parallelism splits the weight along the output_dim
|
||||
if self.tp_size > 1 and output_dim is not None:
|
||||
block_size = self.num_heads_per_rank * head_dim
|
||||
shard_offset = self.local_rank * block_size
|
||||
shard_size = block_size
|
||||
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size)
|
||||
|
||||
if not param._is_initialized():
|
||||
param.initialize()
|
||||
|
||||
param_shard_offset = (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank) * head_dim
|
||||
param_shard_size = self.num_heads_per_rank * head_dim
|
||||
|
||||
if hasattr(param, "tensor_track"):
|
||||
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
|
||||
|
||||
param = slice_fn(param, output_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size)
|
||||
assert param.shape == loaded_weight.shape, (
|
||||
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
|
||||
)
|
||||
# Ensure loaded weight dtype matches model param dtype
|
||||
if loaded_weight.dtype != param.dtype:
|
||||
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
|
||||
loaded_weight = loaded_weight.view(param.dtype)
|
||||
else:
|
||||
loaded_weight = loaded_weight.cast(param.dtype)
|
||||
h2d_copy(param, loaded_weight)
|
||||
|
||||
def load_weight(self, state_dict: dict):
|
||||
"""
|
||||
Load the weight from the state dictionary.
|
||||
|
||||
Args:
|
||||
state_dict (dict): A dictionary containing the weights
|
||||
"""
|
||||
qkv_weight_tensor = get_tensor(state_dict.pop(self.qkv_weight_key))
|
||||
gate_weight_tensor = get_tensor(state_dict.pop(self.gate_weight_key))
|
||||
qkvg_weight_tensor = paddle.concat([qkv_weight_tensor, gate_weight_tensor], axis=-1)
|
||||
|
||||
self.quant_method.process_loaded_weights(self, qkvg_weight_tensor)
|
||||
|
||||
def load_state_dict(self, state_dict: dict):
|
||||
"""
|
||||
Load the checkpoint state dictionary into the layer.
|
||||
|
||||
Args:
|
||||
state_dict (dict): A dictionary containing the checkpoint weights and biases.
|
||||
"""
|
||||
# weight
|
||||
assert (
|
||||
self.qkv_weight_key in state_dict.keys() and self.gate_weight_key in state_dict.keys()
|
||||
), f"{self.qkv_weight_key} or {self.gate_weight_key} not found in state_dict"
|
||||
|
||||
if self.is_quantized:
|
||||
self.load_prequant_weight(state_dict)
|
||||
else:
|
||||
self.load_weight(state_dict)
|
||||
|
||||
# bias
|
||||
if self.with_bias:
|
||||
assert (
|
||||
self.qkv_bias_key in state_dict.keys() and self.gate_bias_key in state_dict.keys()
|
||||
), f"{self.qkv_bias_key} or {self.gate_bias_key} not found in state_dict"
|
||||
qkv_bias_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.qkv_bias_key)))
|
||||
gate_bias_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.gate_bias_key)))
|
||||
bias_tensor = paddle.concat([qkv_bias_tensor, gate_bias_tensor], axis=-1)
|
||||
|
||||
self.bias.set_value(bias_tensor)
|
||||
|
||||
Reference in New Issue
Block a user