mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-22 16:07:51 +08:00
[BugFix][Models] Unify PaddleFormers fused QKV TP loading and stabilize fallback TP path (#6555)
* [BugFix][Models] avoid custom all-reduce in PaddleFormers fallback TP path and tighten TP-aware layout matching * [BugFix][Models] unify PaddleFormers fused QKV TP loading and align fallback tests
This commit is contained in:
@@ -39,11 +39,12 @@ from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
||||
from fastdeploy.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
||||
from fastdeploy.model_executor.utils import WeightsMapper
|
||||
from fastdeploy.model_executor.utils import WeightsMapper, slice_fn
|
||||
|
||||
|
||||
class PaddleFormersRMSNormWrapper(nn.Layer):
|
||||
@@ -67,6 +68,188 @@ class PaddleFormersRMSNormWrapper(nn.Layer):
|
||||
return out
|
||||
|
||||
|
||||
class PaddleFormersQKVParallelLinear(QKVParallelLinear):
|
||||
"""PF-specific QKV loader that packs local shards in PF interleaved order."""
|
||||
|
||||
def __init__(self, fd_config, prefix: str, with_bias: bool = False):
|
||||
super().__init__(fd_config=fd_config, prefix=prefix, with_bias=with_bias)
|
||||
self._pending_local_shards: dict[int, dict[str, paddle.Tensor]] = {}
|
||||
self._model_format = str(getattr(fd_config.model_config, "model_format", "") or "").lower()
|
||||
|
||||
@staticmethod
|
||||
def _to_tensor(t: paddle.Tensor | object) -> paddle.Tensor:
|
||||
return t if isinstance(t, paddle.Tensor) else paddle.to_tensor(t)
|
||||
|
||||
def _extract_local_shard(self, param: paddle.Tensor, loaded_weight: paddle.Tensor, loaded_shard_id: str):
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
if output_dim is None:
|
||||
raise ValueError("Missing output_dim for QKV parameter.")
|
||||
|
||||
dim = -1 if output_dim else 0
|
||||
denom = self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank
|
||||
head_dim = int(param.shape[dim]) // int(denom)
|
||||
|
||||
weight = self._to_tensor(loaded_weight)
|
||||
if getattr(param, "weight_need_transpose", False):
|
||||
if weight.ndim != 2:
|
||||
raise ValueError(f"Expected 2D tensor for transpose, got shape={list(weight.shape)}")
|
||||
weight = weight.transpose([1, 0])
|
||||
|
||||
if self.tp_size > 1 and output_dim is not None and not self.fd_config.load_config.is_pre_sharded:
|
||||
block_size = self._get_shard_size_mapping(loaded_shard_id, head_dim)
|
||||
shard_id = self.local_rank if loaded_shard_id == "q" else self.local_rank // self.num_kv_head_replicas
|
||||
shard_offset = shard_id * block_size
|
||||
weight = slice_fn(weight, output_dim, start=shard_offset, end=shard_offset + block_size)
|
||||
|
||||
return weight
|
||||
|
||||
@staticmethod
|
||||
def _to_hidden_major(weight: paddle.Tensor, expected_out: int, name: str) -> paddle.Tensor:
|
||||
if weight.ndim != 2:
|
||||
raise ValueError(f"Expected 2D {name} shard, got shape={list(weight.shape)}")
|
||||
|
||||
s0, s1 = int(weight.shape[0]), int(weight.shape[1])
|
||||
if s1 == expected_out:
|
||||
return weight
|
||||
if s0 == expected_out:
|
||||
return weight.transpose([1, 0])
|
||||
raise ValueError(
|
||||
f"Cannot normalize {name} shard shape={list(weight.shape)} to hidden-major with expected_out={expected_out}."
|
||||
)
|
||||
|
||||
def _pack_pf_interleaved_local(
|
||||
self,
|
||||
q_local: paddle.Tensor,
|
||||
k_local: paddle.Tensor,
|
||||
v_local: paddle.Tensor,
|
||||
output_dim: bool,
|
||||
):
|
||||
kv_local = int(self.kv_num_heads_per_rank)
|
||||
if kv_local <= 0:
|
||||
raise ValueError("Invalid kv_num_heads_per_rank, must be > 0.")
|
||||
if self.num_heads_per_rank % kv_local != 0:
|
||||
raise ValueError(
|
||||
f"num_heads_per_rank={self.num_heads_per_rank} is not divisible by kv_num_heads_per_rank={kv_local}"
|
||||
)
|
||||
q_groups_local = self.num_heads_per_rank // kv_local
|
||||
|
||||
if q_local.ndim == 1:
|
||||
q = q_local.reshape([kv_local, q_groups_local, self.head_dim])
|
||||
k = k_local.reshape([kv_local, 1, self.head_dim])
|
||||
v = v_local.reshape([kv_local, 1, self.head_dim])
|
||||
return paddle.concat([q, k, v], axis=1).reshape([-1])
|
||||
|
||||
q_out = kv_local * q_groups_local * self.head_dim
|
||||
kv_out = kv_local * self.head_dim
|
||||
|
||||
q_hm = self._to_hidden_major(q_local, q_out, "q")
|
||||
k_hm = self._to_hidden_major(k_local, kv_out, "k")
|
||||
v_hm = self._to_hidden_major(v_local, kv_out, "v")
|
||||
|
||||
hidden_size = int(q_hm.shape[0])
|
||||
if int(k_hm.shape[0]) != hidden_size or int(v_hm.shape[0]) != hidden_size:
|
||||
raise ValueError(
|
||||
"Q/K/V hidden dimension mismatch after normalization: "
|
||||
f"q={list(q_hm.shape)}, k={list(k_hm.shape)}, v={list(v_hm.shape)}"
|
||||
)
|
||||
|
||||
q = q_hm.reshape([hidden_size, kv_local, q_groups_local, self.head_dim])
|
||||
k = k_hm.reshape([hidden_size, kv_local, 1, self.head_dim])
|
||||
v = v_hm.reshape([hidden_size, kv_local, 1, self.head_dim])
|
||||
packed_hidden_major = paddle.concat([q, k, v], axis=2).reshape([hidden_size, -1])
|
||||
|
||||
if output_dim:
|
||||
return packed_hidden_major
|
||||
return packed_hidden_major.transpose([1, 0])
|
||||
|
||||
def _split_pf_fused_qkv(self, loaded_weight: paddle.Tensor, is_bias: bool):
|
||||
if self._model_format != "paddle":
|
||||
raise ValueError(
|
||||
"Direct qkv_proj loading is only supported for model_format='paddle'. "
|
||||
"Use split q_proj/k_proj/v_proj weights for other formats."
|
||||
)
|
||||
|
||||
weight = self._to_tensor(loaded_weight)
|
||||
if is_bias:
|
||||
if weight.ndim != 1:
|
||||
raise ValueError(f"Unexpected fused qkv bias dims: {list(weight.shape)}, expected 1D.")
|
||||
width = int(weight.shape[0])
|
||||
else:
|
||||
if weight.ndim != 2:
|
||||
raise ValueError(f"Unexpected fused qkv weight dims: {list(weight.shape)}, expected 2D.")
|
||||
width = int(weight.shape[1])
|
||||
|
||||
global_width = int((self.num_heads + 2 * self.kv_num_heads) * self.head_dim)
|
||||
local_width = int((self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank) * self.head_dim)
|
||||
|
||||
if width == global_width:
|
||||
num_heads, num_kv_heads = self.num_heads, self.kv_num_heads
|
||||
elif width == local_width:
|
||||
num_heads, num_kv_heads = self.num_heads_per_rank, self.kv_num_heads_per_rank
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot validate fused qkv_proj width={width}. "
|
||||
f"Expect global={global_width} or local={local_width} for PF interleaved layout."
|
||||
)
|
||||
|
||||
if num_heads % num_kv_heads != 0:
|
||||
raise ValueError(f"Invalid head config: num_heads={num_heads}, num_kv_heads={num_kv_heads}")
|
||||
q_groups = num_heads // num_kv_heads
|
||||
|
||||
if is_bias:
|
||||
fused = weight.reshape([num_kv_heads, q_groups + 2, self.head_dim])
|
||||
q = fused[:, :q_groups, :].reshape([-1])
|
||||
k = fused[:, q_groups : q_groups + 1, :].reshape([-1])
|
||||
v = fused[:, q_groups + 1 :, :].reshape([-1])
|
||||
return q, k, v
|
||||
|
||||
hidden_size = int(weight.shape[0])
|
||||
fused = weight.reshape([hidden_size, num_kv_heads, q_groups + 2, self.head_dim])
|
||||
q = fused[:, :, :q_groups, :].reshape([hidden_size, -1])
|
||||
k = fused[:, :, q_groups : q_groups + 1, :].reshape([hidden_size, -1])
|
||||
v = fused[:, :, q_groups + 1 :, :].reshape([hidden_size, -1])
|
||||
return q, k, v
|
||||
|
||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: str | None = None):
|
||||
if loaded_shard_id is None:
|
||||
is_bias = len(param.shape) == 1
|
||||
q_shard, k_shard, v_shard = self._split_pf_fused_qkv(loaded_weight, is_bias=is_bias)
|
||||
self.weight_loader(param, q_shard, "q")
|
||||
self.weight_loader(param, k_shard, "k")
|
||||
self.weight_loader(param, v_shard, "v")
|
||||
return
|
||||
|
||||
if loaded_shard_id not in {"q", "k", "v"}:
|
||||
super().weight_loader(param, loaded_weight, loaded_shard_id)
|
||||
return
|
||||
|
||||
local_shard = self._extract_local_shard(param, loaded_weight, loaded_shard_id)
|
||||
key = id(param)
|
||||
pending = self._pending_local_shards.setdefault(key, {})
|
||||
pending[loaded_shard_id] = local_shard
|
||||
|
||||
if len(pending) < 3:
|
||||
setattr(param, "_pf_qkv_pending", True)
|
||||
return
|
||||
|
||||
packed = self._pack_pf_interleaved_local(
|
||||
pending["q"],
|
||||
pending["k"],
|
||||
pending["v"],
|
||||
output_dim=bool(getattr(param, "output_dim", True)),
|
||||
)
|
||||
if not param._is_initialized():
|
||||
param.initialize()
|
||||
if packed.dtype != param.dtype:
|
||||
packed = packed.cast(param.dtype)
|
||||
if list(param.shape) != list(packed.shape):
|
||||
raise ValueError(f"Packed qkv shape mismatch: packed={list(packed.shape)} param={list(param.shape)}")
|
||||
|
||||
param.set_value(packed)
|
||||
del self._pending_local_shards[key]
|
||||
setattr(param, "_pf_qkv_pending", False)
|
||||
|
||||
|
||||
def getattr_iter(obj, names, default=None):
|
||||
for name in names:
|
||||
if hasattr(obj, name):
|
||||
@@ -109,7 +292,11 @@ def fastdeploy_append_attention_forward(
|
||||
if scaling is not None:
|
||||
self_attn.scale = float(scaling)
|
||||
|
||||
# 统一获取 heads 信息
|
||||
tp_size = 1
|
||||
if hasattr(self_attn, "fd_config") and hasattr(self_attn.fd_config, "parallel_config"):
|
||||
tp_size = int(getattr(self_attn.fd_config.parallel_config, "tensor_parallel_size", 1) or 1)
|
||||
|
||||
# Resolve head-related metadata.
|
||||
num_heads = (
|
||||
getattr(module, "num_heads", None)
|
||||
or getattr(config, "num_attention_heads", None)
|
||||
@@ -125,7 +312,7 @@ def fastdeploy_append_attention_forward(
|
||||
num_heads = int(num_heads) if num_heads is not None else None
|
||||
num_kv_heads = int(num_kv_heads) if num_kv_heads is not None else None
|
||||
|
||||
# 仅支持 3D(HSD/SHD) 或 4D(BHSD/BSHD, 且 B=1) 输入
|
||||
# Support only 3D (HSD/SHD) or 4D (BHSD/BSHD with B=1) inputs.
|
||||
def squeeze_to_3d(t: paddle.Tensor, name: str) -> paddle.Tensor:
|
||||
if t.ndim == 4:
|
||||
if int(t.shape[0]) != 1:
|
||||
@@ -144,9 +331,11 @@ def fastdeploy_append_attention_forward(
|
||||
return False
|
||||
if actual_heads == expected_heads:
|
||||
return True
|
||||
return actual_heads > 0 and expected_heads % actual_heads == 0
|
||||
if tp_size > 1 and expected_heads % tp_size == 0:
|
||||
expected_heads //= tp_size
|
||||
return actual_heads == expected_heads
|
||||
|
||||
# 使用 Q/K 共同判断布局;歧义时默认 hsd(兼容 Paddle 常见路径)
|
||||
# Determine layout from Q/K/V head axes; keep default behavior on ambiguity.
|
||||
is_hsd = (
|
||||
heads_match(int(q.shape[0]), num_heads)
|
||||
and heads_match(int(k.shape[0]), num_kv_heads)
|
||||
@@ -172,7 +361,7 @@ def fastdeploy_append_attention_forward(
|
||||
f"heads={num_heads}/{num_kv_heads}"
|
||||
)
|
||||
|
||||
# Q/K/V flatten 后序列长度必须一致
|
||||
# Sequence lengths must match after flattening Q/K/V.
|
||||
q_seq, k_seq, v_seq = int(q_flat.shape[0]), int(k_flat.shape[0]), int(v_flat.shape[0])
|
||||
if not (q_seq == k_seq == v_seq):
|
||||
raise ValueError(
|
||||
@@ -180,7 +369,7 @@ def fastdeploy_append_attention_forward(
|
||||
f"raw query={list(query.shape)}, key={list(key.shape)}, value={list(value.shape)}."
|
||||
)
|
||||
|
||||
# 若 forward_meta 带了 ids_remove_padding,则强校验 Q 序列长度
|
||||
# If forward_meta provides ids_remove_padding, strictly validate Q sequence length.
|
||||
ids_remove_padding = getattr(forward_meta, "ids_remove_padding", None)
|
||||
if ids_remove_padding is not None:
|
||||
expected_seq = int(ids_remove_padding.shape[0])
|
||||
@@ -246,16 +435,12 @@ class PaddleFormersModelBase(nn.Layer):
|
||||
supported_fused_qkv_models = ["qwen3", "qwen2"]
|
||||
|
||||
tp_size = fd_config.parallel_config.tensor_parallel_size
|
||||
if tp_size > 1:
|
||||
self._use_fused_qkv = False
|
||||
logger.info(f"Fusion disabled for TP={tp_size} due to shape incompatibility")
|
||||
self._use_fused_qkv = model_type in supported_fused_qkv_models
|
||||
if self._use_fused_qkv:
|
||||
self.paddleformers_config.fuse_attention_qkv = True
|
||||
logger.info(f"Enabled fuse_attention_qkv for model_type={model_type}, tp={tp_size}")
|
||||
else:
|
||||
self._use_fused_qkv = model_type in supported_fused_qkv_models
|
||||
if self._use_fused_qkv:
|
||||
self.paddleformers_config.fuse_attention_qkv = True
|
||||
logger.info(f"Enabled fuse_attention_qkv for model_type={model_type}")
|
||||
else:
|
||||
logger.debug(f"QKV fusion not enabled for model_type={model_type}")
|
||||
logger.debug(f"QKV fusion not enabled for model_type={model_type}")
|
||||
|
||||
# PaddleFormers fused optimize option
|
||||
self._use_fused_ffn = model_type in supported_fused_qkv_models
|
||||
@@ -283,6 +468,8 @@ class PaddleFormersModelBase(nn.Layer):
|
||||
|
||||
# Linear and Norm replace for FD optimized versions and TP support
|
||||
self.recursive_replace()
|
||||
# Patch PF attention head counts to TP-local values for fused qkv reshape
|
||||
self._localize_pf_attention_heads()
|
||||
# Attention instances for FD Attention backend
|
||||
self.attention_instances = self.create_attention_instances()
|
||||
self.paddleformers_config.attention_instances = self.attention_instances
|
||||
@@ -378,15 +565,12 @@ class PaddleFormersModelBase(nn.Layer):
|
||||
with_bias = hasattr(child_module, "bias") and child_module.bias is not None
|
||||
|
||||
if style == "colwise":
|
||||
# For qkv_proj when fused QKV is enabled:
|
||||
# Use ColumnParallelLinear (not QKVParallelLinear) because we fuse weights
|
||||
# into PaddleFormers' per-KV-head interleaved format in load_weights()
|
||||
# qkv_proj uses PF-specific TP-aware loader to support
|
||||
# unified split-QKV loading across TP1/TP>1.
|
||||
if "qkv_proj" in qual_name and self._use_fused_qkv:
|
||||
new_module = ColumnParallelLinear(
|
||||
new_module = PaddleFormersQKVParallelLinear(
|
||||
self.fd_config,
|
||||
prefix=qual_name,
|
||||
input_size=in_features,
|
||||
output_size=out_features,
|
||||
with_bias=with_bias,
|
||||
)
|
||||
# For up_gate_proj when fused FFN is enabled:
|
||||
@@ -448,6 +632,44 @@ class PaddleFormersModelBase(nn.Layer):
|
||||
|
||||
_recursive_replace(self.model, prefix="model")
|
||||
|
||||
def _localize_pf_attention_heads(self):
|
||||
"""Patch PF attention modules' head counts to TP-local values.
|
||||
|
||||
PF Attention.__init__ reads global head counts from config and stores
|
||||
them as instance attrs (num_heads, num_key_value_heads, etc.).
|
||||
Since we cannot set config.tensor_model_parallel_size > 1 (it would
|
||||
trigger PF's own TP linears, conflicting with recursive_replace),
|
||||
we patch the instance attrs directly after model creation.
|
||||
|
||||
Only needed when fused qkv is enabled, because the PF forward path
|
||||
reshapes qkv_proj output using these head counts.
|
||||
"""
|
||||
tp_size = self.fd_config.parallel_config.tensor_parallel_size
|
||||
if tp_size <= 1 or not self._use_fused_qkv:
|
||||
return
|
||||
|
||||
g_heads = int(self.text_config.num_attention_heads)
|
||||
g_kv = int(getattr(self.text_config, "num_key_value_heads", g_heads))
|
||||
local_heads = g_heads // tp_size
|
||||
local_kv = max(1, g_kv // tp_size)
|
||||
local_groups = local_heads // local_kv
|
||||
|
||||
patched = 0
|
||||
for name, module in self.model.named_sublayers():
|
||||
# PF attention modules store head counts as instance attrs used in forward reshape
|
||||
if not hasattr(module, "num_key_value_groups"):
|
||||
continue
|
||||
module.num_heads = local_heads
|
||||
module.num_key_value_heads = local_kv
|
||||
module.num_key_value_groups = local_groups
|
||||
patched += 1
|
||||
|
||||
if patched:
|
||||
logger.info(
|
||||
f"Localized {patched} PF attention modules: "
|
||||
f"heads {g_heads}->{local_heads}, kv {g_kv}->{local_kv}, tp={tp_size}"
|
||||
)
|
||||
|
||||
def _get_tp_plan(self) -> dict[str, str]:
|
||||
"""Get TP plan for linear layer replacement.
|
||||
|
||||
@@ -658,19 +880,19 @@ class PaddleFormersModelBase(nn.Layer):
|
||||
process_fn = process_weights_after_loading(sublayers_dict, self.fd_config)
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
# === 前缀别名处理 ===
|
||||
# === Checkpoint prefix alias handling ===
|
||||
model_type = str(getattr(self.paddleformers_config, "model_type", "") or "").lower()
|
||||
ckpt_prefix_aliases = {model_type, model_type.replace("-", "_"), model_type.replace("_", "")} - {""}
|
||||
ckpt_alias_markers = (".layers.", ".embed_tokens.", ".lm_head.", ".norm.", ".final_layernorm.", ".rotary_emb.")
|
||||
|
||||
def resolve_param_name(weight_name: str) -> str | None:
|
||||
# 动态收集前缀别名
|
||||
# Collect prefix aliases dynamically.
|
||||
if "." in weight_name:
|
||||
prefix = weight_name.split(".", 1)[0]
|
||||
if prefix not in {"model", "lm_head"} and any(m in weight_name for m in ckpt_alias_markers):
|
||||
ckpt_prefix_aliases.add(prefix)
|
||||
|
||||
# 生成候选名称
|
||||
# Generate candidate parameter names.
|
||||
candidates = [weight_name]
|
||||
candidates.append(weight_name[6:] if weight_name.startswith("model.") else "model." + weight_name)
|
||||
if "." in weight_name:
|
||||
@@ -680,7 +902,7 @@ class PaddleFormersModelBase(nn.Layer):
|
||||
|
||||
return next((c for c in candidates if c in params_dict), None)
|
||||
|
||||
# === 权重映射配置 ===
|
||||
# === Stacked parameter mapping config ===
|
||||
stacked_params_mapping = [
|
||||
("embed_tokens.embeddings", "embed_tokens", None),
|
||||
("lm_head.linear", "lm_head", None),
|
||||
@@ -688,116 +910,71 @@ class PaddleFormersModelBase(nn.Layer):
|
||||
if self._use_fused_ffn:
|
||||
stacked_params_mapping += [("up_gate_proj", "gate_proj", "gate"), ("up_gate_proj", "up_proj", "up")]
|
||||
|
||||
# === QKV 融合相关 ===
|
||||
mc = self.fd_config.model_config
|
||||
model_format = str(getattr(mc, "model_format", "") or "").lower()
|
||||
qkv_buffer, qkv_bias_buffer = {}, {}
|
||||
# === QKV loading helpers ===
|
||||
qkv_split_layers: set[str] = set()
|
||||
qkv_direct_pending: dict[tuple[str, bool], tuple[str, paddle.Tensor]] = {}
|
||||
|
||||
def parse_qkv_name(name: str) -> tuple[str, str, str] | None:
|
||||
for proj, ptype in [("q_proj", "q"), ("k_proj", "k"), ("v_proj", "v")]:
|
||||
if proj in name:
|
||||
layer_key = name.replace(f".{proj}.weight", "").replace(f".{proj}.bias", "")
|
||||
return layer_key, ptype, name.replace(proj, "qkv_proj")
|
||||
def parse_qkv_shard_name(name: str) -> tuple[str, str, str] | None:
|
||||
shard_suffixes = (
|
||||
(".q_proj.weight", "q"),
|
||||
(".k_proj.weight", "k"),
|
||||
(".v_proj.weight", "v"),
|
||||
(".q_proj.bias", "q"),
|
||||
(".k_proj.bias", "k"),
|
||||
(".v_proj.bias", "v"),
|
||||
)
|
||||
for suffix, shard_id in shard_suffixes:
|
||||
if name.endswith(suffix):
|
||||
layer_key = name.replace(suffix, "")
|
||||
qkv_param_name = name.replace(".q_proj.", ".qkv_proj.")
|
||||
qkv_param_name = qkv_param_name.replace(".k_proj.", ".qkv_proj.")
|
||||
qkv_param_name = qkv_param_name.replace(".v_proj.", ".qkv_proj.")
|
||||
return layer_key, shard_id, qkv_param_name
|
||||
return None
|
||||
|
||||
def fuse_qkv(q, k, v, is_bias: bool) -> paddle.Tensor:
|
||||
num_heads, num_kv_heads, head_dim = mc.num_attention_heads, mc.num_key_value_heads, mc.head_dim
|
||||
hidden_size, num_kv_groups = mc.hidden_size, num_heads // num_kv_heads
|
||||
q_out, kv_out = num_heads * head_dim, num_kv_heads * head_dim
|
||||
def parse_direct_qkv_name(name: str) -> tuple[str, bool] | None:
|
||||
if name.endswith(".qkv_proj.weight"):
|
||||
return name.replace(".qkv_proj.weight", ""), False
|
||||
if name.endswith(".qkv_proj.bias"):
|
||||
return name.replace(".qkv_proj.bias", ""), True
|
||||
return None
|
||||
|
||||
if is_bias:
|
||||
# 校验 bias 维度和形状
|
||||
if q.ndim != 1 or k.ndim != 1 or v.ndim != 1:
|
||||
raise ValueError(f"Unexpected qkv bias dims: q={q.shape}, k={k.shape}, v={v.shape}; expected 1D.")
|
||||
if q.shape[0] != q_out or k.shape[0] != kv_out or v.shape[0] != kv_out:
|
||||
raise ValueError(
|
||||
f"Unexpected qkv bias shapes: q={q.shape}, k={k.shape}, v={v.shape}; "
|
||||
f"expected q=[{q_out}], k/v=[{kv_out}]."
|
||||
)
|
||||
return paddle.concat(
|
||||
[
|
||||
q.reshape([num_kv_heads, num_kv_groups, head_dim]),
|
||||
k.reshape([num_kv_heads, 1, head_dim]),
|
||||
v.reshape([num_kv_heads, 1, head_dim]),
|
||||
],
|
||||
axis=1,
|
||||
).reshape([-1])
|
||||
|
||||
# 校验 weight 形状和 model_format
|
||||
q_shape, k_shape, v_shape = [int(x) for x in q.shape], [int(x) for x in k.shape], [int(x) for x in v.shape]
|
||||
torch_layout = (
|
||||
q_shape == [q_out, hidden_size]
|
||||
and k_shape == [kv_out, hidden_size]
|
||||
and v_shape == [kv_out, hidden_size]
|
||||
)
|
||||
paddle_layout = (
|
||||
q_shape == [hidden_size, q_out]
|
||||
and k_shape == [hidden_size, kv_out]
|
||||
and v_shape == [hidden_size, kv_out]
|
||||
)
|
||||
|
||||
if model_format == "torch":
|
||||
if not torch_layout:
|
||||
raise ValueError(
|
||||
f"model_format=torch requires torch layout, got q={q_shape}, k={k_shape}, v={v_shape}."
|
||||
)
|
||||
q, k, v = q.T, k.T, v.T
|
||||
elif model_format == "paddle":
|
||||
if not paddle_layout:
|
||||
raise ValueError(
|
||||
f"model_format=paddle requires paddle layout, got q={q_shape}, k={k_shape}, v={v_shape}."
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model_format: {model_format}. Expect 'torch' or 'paddle'.")
|
||||
|
||||
# 转置后校验
|
||||
if q.shape[0] != hidden_size or k.shape[0] != hidden_size or v.shape[0] != hidden_size:
|
||||
raise ValueError(
|
||||
f"QKV shape mismatch after normalization: q={list(q.shape)}, k={list(k.shape)}, v={list(v.shape)}."
|
||||
)
|
||||
|
||||
fused = paddle.concat(
|
||||
[
|
||||
q.reshape([hidden_size, num_kv_heads, num_kv_groups, head_dim]),
|
||||
k.reshape([hidden_size, num_kv_heads, 1, head_dim]),
|
||||
v.reshape([hidden_size, num_kv_heads, 1, head_dim]),
|
||||
],
|
||||
axis=2,
|
||||
).reshape([hidden_size, -1])
|
||||
|
||||
return fused.T if model_format == "torch" else fused
|
||||
|
||||
# === 辅助函数 ===
|
||||
def load_param(name: str, tensor: paddle.Tensor, shard_id=None, no_transpose: bool = False):
|
||||
# === Helper functions ===
|
||||
def load_param(name: str, tensor: paddle.Tensor, shard_id=None):
|
||||
param = params_dict[name]
|
||||
if no_transpose and hasattr(param, "weight_need_transpose"):
|
||||
param.weight_need_transpose = False
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
|
||||
weight_loader(param, tensor, shard_id)
|
||||
if shard_id in {"q", "k", "v"} and bool(getattr(param, "_pf_qkv_pending", False)):
|
||||
return False
|
||||
process_fn(re.sub(r"\.(weight|bias)$", "", name), param)
|
||||
return True
|
||||
|
||||
# === 主循环 ===
|
||||
# === Main loading loop ===
|
||||
loaded_count = skipped_count = 0
|
||||
|
||||
for weight_name, weight in weights:
|
||||
# 1. QKV 融合处理
|
||||
if self._use_fused_qkv and (qkv_info := parse_qkv_name(weight_name)):
|
||||
layer_key, proj_type, qkv_param_name = qkv_info
|
||||
is_bias = ".bias" in weight_name
|
||||
buf = qkv_bias_buffer if is_bias else qkv_buffer
|
||||
buf.setdefault(layer_key, {})[proj_type] = weight
|
||||
|
||||
if len(buf[layer_key]) == 3:
|
||||
# 1. Handle fused QKV path in a unified split-shard style.
|
||||
if self._use_fused_qkv:
|
||||
if qkv_info := parse_qkv_shard_name(weight_name):
|
||||
layer_key, proj_type, qkv_param_name = qkv_info
|
||||
qkv_split_layers.add(layer_key)
|
||||
resolved = resolve_param_name(qkv_param_name)
|
||||
if resolved:
|
||||
fused = fuse_qkv(buf[layer_key]["q"], buf[layer_key]["k"], buf[layer_key]["v"], is_bias)
|
||||
load_param(resolved, fused, no_transpose=not is_bias)
|
||||
loaded_count += 3
|
||||
try:
|
||||
load_param(resolved, weight, shard_id=proj_type)
|
||||
loaded_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load qkv shard {weight_name} -> {resolved}: {e}")
|
||||
skipped_count += 1
|
||||
else:
|
||||
logger.warning(f"QKV {'bias ' if is_bias else ''}param {qkv_param_name} not found")
|
||||
skipped_count += 3
|
||||
del buf[layer_key]
|
||||
continue
|
||||
logger.warning(f"QKV shard mapping not found: {weight_name} -> {qkv_param_name}")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
if direct_qkv_info := parse_direct_qkv_name(weight_name):
|
||||
layer_key, is_bias = direct_qkv_info
|
||||
qkv_direct_pending[(layer_key, is_bias)] = (weight_name, weight)
|
||||
continue
|
||||
|
||||
# 2. Stacked params mapping
|
||||
for param_name, src_name, shard_id in stacked_params_mapping:
|
||||
@@ -810,7 +987,7 @@ class PaddleFormersModelBase(nn.Layer):
|
||||
logger.warning(f"Stacked mapping: {weight_name} -> NOT FOUND")
|
||||
break
|
||||
else:
|
||||
# 3. 直接加载
|
||||
# 3. Direct load.
|
||||
resolved = resolve_param_name(weight_name)
|
||||
if resolved:
|
||||
try:
|
||||
@@ -822,9 +999,40 @@ class PaddleFormersModelBase(nn.Layer):
|
||||
else:
|
||||
skipped_count += 1
|
||||
|
||||
# 4. Handle direct qkv_proj.* only when split q/k/v is absent for that layer.
|
||||
if self._use_fused_qkv and qkv_direct_pending:
|
||||
for (layer_key, is_bias), (weight_name, weight) in qkv_direct_pending.items():
|
||||
if layer_key in qkv_split_layers:
|
||||
logger.info(
|
||||
f"Skip direct qkv {'bias' if is_bias else 'weight'} for {layer_key}: "
|
||||
"split q/k/v shards are present."
|
||||
)
|
||||
continue
|
||||
|
||||
resolved = resolve_param_name(weight_name)
|
||||
if resolved:
|
||||
try:
|
||||
load_param(resolved, weight)
|
||||
loaded_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load direct fused qkv {weight_name} -> {resolved}: {e}")
|
||||
skipped_count += 1
|
||||
else:
|
||||
logger.warning(f"Direct fused qkv param not found: {weight_name}")
|
||||
skipped_count += 1
|
||||
|
||||
if self._use_fused_qkv:
|
||||
pending_qkv_params = [
|
||||
name for name, param in params_dict.items() if bool(getattr(param, "_pf_qkv_pending", False))
|
||||
]
|
||||
if pending_qkv_params:
|
||||
raise RuntimeError(
|
||||
"Incomplete QKV shard loading detected for parameters: " + ", ".join(sorted(pending_qkv_params))
|
||||
)
|
||||
|
||||
logger.info(f"Weight loading: {loaded_count} loaded, {skipped_count} skipped")
|
||||
|
||||
# === tie_word_embeddings 处理 ===
|
||||
# === tie_word_embeddings handling ===
|
||||
if hasattr(self, "lm_head") and getattr(self, "tie_word_embeddings", False):
|
||||
embed = self.model.get_input_embeddings()
|
||||
if hasattr(embed, "embeddings") and hasattr(embed.embeddings, "weight"):
|
||||
|
||||
@@ -17,7 +17,6 @@ Focused tests to increase coverage of base.py
|
||||
Tests actual code paths that were previously uncovered.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
@@ -960,6 +959,7 @@ class TestAttentionForwardEdgeCases:
|
||||
num_heads: int | None = None,
|
||||
num_kv_heads: int | None = None,
|
||||
expected_seq_len: int | None = None,
|
||||
tp_size: int = 1,
|
||||
):
|
||||
from fastdeploy.model_executor.models.paddleformers.base import (
|
||||
fastdeploy_append_attention_forward,
|
||||
@@ -974,6 +974,9 @@ class TestAttentionForwardEdgeCases:
|
||||
mock_attention = SimpleNamespace(
|
||||
forward=Mock(side_effect=fake_forward),
|
||||
)
|
||||
mock_attention.fd_config = SimpleNamespace(
|
||||
parallel_config=SimpleNamespace(tensor_parallel_size=tp_size),
|
||||
)
|
||||
if num_heads is not None:
|
||||
mock_attention.num_heads = num_heads
|
||||
if num_kv_heads is not None:
|
||||
@@ -1070,7 +1073,7 @@ class TestAttentionForwardEdgeCases:
|
||||
key = paddle.to_tensor((np.arange(20, dtype=np.float32) + 100).reshape([1, 2, 5, 2]))
|
||||
value = paddle.to_tensor((np.arange(20, dtype=np.float32) + 200).reshape([1, 2, 5, 2]))
|
||||
|
||||
qkv = self._run_attention(query, key, value, num_heads=8, num_kv_heads=4, expected_seq_len=5)
|
||||
qkv = self._run_attention(query, key, value, num_heads=8, num_kv_heads=4, expected_seq_len=5, tp_size=2)
|
||||
self._assert_qkv_concat_matches_known_layout(qkv, query, key, value)
|
||||
|
||||
def test_gqa_shd_layout_detection(self):
|
||||
@@ -1135,9 +1138,10 @@ class TestRecursiveReplaceAdvanced:
|
||||
"""Test recursive_replace advanced cases to cover more lines."""
|
||||
|
||||
def test_fused_qkv_replacement(self, mock_fd_config):
|
||||
"""Test that qkv_proj with fused QKV uses ColumnParallelLinear (lines 330-337)."""
|
||||
"""Test that qkv_proj with fused QKV uses PaddleFormersQKVParallelLinear."""
|
||||
from fastdeploy.model_executor.models.paddleformers.base import (
|
||||
PaddleFormersModelBase,
|
||||
PaddleFormersQKVParallelLinear,
|
||||
)
|
||||
|
||||
fd_config, _ = mock_fd_config
|
||||
@@ -1172,8 +1176,8 @@ class TestRecursiveReplaceAdvanced:
|
||||
|
||||
model.recursive_replace()
|
||||
|
||||
# qkv_proj should become ColumnParallelLinear
|
||||
assert isinstance(model.model.qkv_proj, ColumnParallelLinear)
|
||||
# qkv_proj should become PaddleFormersQKVParallelLinear
|
||||
assert isinstance(model.model.qkv_proj, PaddleFormersQKVParallelLinear)
|
||||
|
||||
def test_fused_ffn_replacement(self, mock_fd_config):
|
||||
"""Test that up_gate_proj with fused FFN uses MergedColumnParallelLinear (lines 340-347)."""
|
||||
@@ -1565,8 +1569,8 @@ class TestGetTPPlan:
|
||||
class TestFusionSettings:
|
||||
"""Test __init__ fusion settings to cover lines 201-202, 206-207, 214-216."""
|
||||
|
||||
def test_tp_greater_than_1_disables_fused_qkv(self, mock_fd_config_tp2):
|
||||
"""Test that TP>1 disables fused QKV (lines 201-202)."""
|
||||
def test_tp_greater_than_1_keeps_fused_qkv_for_qwen(self, mock_fd_config_tp2):
|
||||
"""Test that Qwen keeps fused QKV enabled under TP>1."""
|
||||
from fastdeploy.model_executor.models.paddleformers.base import (
|
||||
PaddleFormersModelBase,
|
||||
)
|
||||
@@ -1612,8 +1616,9 @@ class TestFusionSettings:
|
||||
|
||||
model = TestModel(fd_config)
|
||||
|
||||
# With TP=2, fused QKV should be disabled
|
||||
assert model._use_fused_qkv is False
|
||||
# With TP=2 and qwen model type, fused QKV stays enabled.
|
||||
assert model._use_fused_qkv is True
|
||||
assert mock_pf_config.fuse_attention_qkv is True
|
||||
|
||||
def test_qwen3_tp1_enables_fused_qkv_and_ffn(self, mock_fd_config_qwen3):
|
||||
"""Test that Qwen3 with TP=1 enables fused QKV and FFN (lines 206-207, 214-216)."""
|
||||
@@ -1962,7 +1967,7 @@ class TestLoadWeights:
|
||||
self.process_weights_patcher.stop()
|
||||
|
||||
def test_load_fused_qkv_weights(self, mock_fd_config):
|
||||
"""Test loading and fusing Q/K/V weights (lines 635-741)."""
|
||||
"""Test split q/k/v shards are routed to qkv_proj with shard ids."""
|
||||
from fastdeploy.model_executor.models.paddleformers.base import (
|
||||
PaddleFormersModelBase,
|
||||
)
|
||||
@@ -2027,15 +2032,17 @@ class TestLoadWeights:
|
||||
# Run load_weights
|
||||
model.load_weights(weights)
|
||||
|
||||
# Verification
|
||||
# Verification: split shards are forwarded via shard_id.
|
||||
assert qkv_param.weight_loader.called
|
||||
call_args = qkv_param.weight_loader.call_args
|
||||
assert call_args is not None
|
||||
fused_weight = call_args[0][1]
|
||||
assert sorted(fused_weight.shape) == [4096, 12288]
|
||||
calls = qkv_param.weight_loader.call_args_list
|
||||
assert len(calls) == 3
|
||||
assert [c.args[2] for c in calls] == ["q", "k", "v"]
|
||||
assert list(calls[0].args[1].shape) == [4096, 4096]
|
||||
assert list(calls[1].args[1].shape) == [4096, 4096]
|
||||
assert list(calls[2].args[1].shape) == [4096, 4096]
|
||||
|
||||
def test_load_fused_qkv_weights_torch_writeback_shape(self, mock_fd_config):
|
||||
"""Torch model_format should write fused qkv weight in storage layout [out, in]."""
|
||||
"""Torch model_format should route split q/k/v shards without in-test fusion."""
|
||||
from fastdeploy.model_executor.models.paddleformers.base import (
|
||||
PaddleFormersModelBase,
|
||||
)
|
||||
@@ -2086,11 +2093,15 @@ class TestLoadWeights:
|
||||
model.load_weights(weights)
|
||||
|
||||
assert qkv_param.weight_loader.called
|
||||
fused_weight_for_load = qkv_param.weight_loader.call_args[0][1]
|
||||
assert list(fused_weight_for_load.shape) == [6144, 4096]
|
||||
calls = qkv_param.weight_loader.call_args_list
|
||||
assert len(calls) == 3
|
||||
assert [c.args[2] for c in calls] == ["q", "k", "v"]
|
||||
assert list(calls[0].args[1].shape) == [4096, 4096]
|
||||
assert list(calls[1].args[1].shape) == [1024, 4096]
|
||||
assert list(calls[2].args[1].shape) == [1024, 4096]
|
||||
|
||||
def test_load_fused_qkv_weights_strict_torch_mismatched_source_raises(self, mock_fd_config):
|
||||
"""Strict torch policy should raise when source tensors are in paddle layout."""
|
||||
def test_load_fused_qkv_weights_torch_accepts_mismatched_source_shapes(self, mock_fd_config):
|
||||
"""Split q/k/v routing remains shape-agnostic at this unit-test layer."""
|
||||
from fastdeploy.model_executor.models.paddleformers.base import (
|
||||
PaddleFormersModelBase,
|
||||
)
|
||||
@@ -2141,16 +2152,16 @@ class TestLoadWeights:
|
||||
("model.layers.0.self_attn.v_proj.weight", v_weight),
|
||||
]
|
||||
|
||||
load_weights_src = inspect.getsource(PaddleFormersModelBase.load_weights)
|
||||
if "requires torch layout" in load_weights_src:
|
||||
with pytest.raises(ValueError, match="model_format=torch requires torch layout"):
|
||||
model.load_weights(weights)
|
||||
else:
|
||||
model.load_weights(weights)
|
||||
assert qkv_param.weight_loader.called
|
||||
model.load_weights(weights)
|
||||
calls = qkv_param.weight_loader.call_args_list
|
||||
assert len(calls) == 3
|
||||
assert [c.args[2] for c in calls] == ["q", "k", "v"]
|
||||
assert list(calls[0].args[1].shape) == [4096, 4096]
|
||||
assert list(calls[1].args[1].shape) == [4096, 1024]
|
||||
assert list(calls[2].args[1].shape) == [4096, 1024]
|
||||
|
||||
def test_load_fused_qkv_weights_unsupported_model_format_raises(self, mock_fd_config):
|
||||
"""Unsupported model_format should raise in fused QKV path."""
|
||||
def test_load_fused_qkv_weights_split_path_ignores_model_format(self, mock_fd_config):
|
||||
"""Split q/k/v routing should not depend on model_format value."""
|
||||
from fastdeploy.model_executor.models.paddleformers.base import (
|
||||
PaddleFormersModelBase,
|
||||
)
|
||||
@@ -2201,16 +2212,13 @@ class TestLoadWeights:
|
||||
("model.layers.0.self_attn.v_proj.weight", v_weight),
|
||||
]
|
||||
|
||||
load_weights_src = inspect.getsource(PaddleFormersModelBase.load_weights)
|
||||
if "Unsupported model_format" in load_weights_src:
|
||||
with pytest.raises(ValueError, match="Unsupported model_format"):
|
||||
model.load_weights(weights)
|
||||
else:
|
||||
model.load_weights(weights)
|
||||
assert qkv_param.weight_loader.called
|
||||
model.load_weights(weights)
|
||||
calls = qkv_param.weight_loader.call_args_list
|
||||
assert len(calls) == 3
|
||||
assert [c.args[2] for c in calls] == ["q", "k", "v"]
|
||||
|
||||
def test_load_fused_qkv_biases(self, mock_fd_config):
|
||||
"""QKV bias fusion should load q/k/v biases into qkv_proj.bias."""
|
||||
"""QKV bias shards should be routed to qkv_proj.bias with shard ids."""
|
||||
from fastdeploy.model_executor.models.paddleformers.base import (
|
||||
PaddleFormersModelBase,
|
||||
)
|
||||
@@ -2261,11 +2269,13 @@ class TestLoadWeights:
|
||||
]
|
||||
|
||||
model.load_weights(weights)
|
||||
if qkv_bias_param.weight_loader.called:
|
||||
fused_bias = qkv_bias_param.weight_loader.call_args[0][1]
|
||||
assert list(fused_bias.shape) == [6144]
|
||||
else:
|
||||
pytest.skip("Current load_weights implementation does not fuse qkv bias in this branch")
|
||||
assert qkv_bias_param.weight_loader.called
|
||||
calls = qkv_bias_param.weight_loader.call_args_list
|
||||
assert len(calls) == 3
|
||||
assert [c.args[2] for c in calls] == ["q", "k", "v"]
|
||||
assert list(calls[0].args[1].shape) == [4096]
|
||||
assert list(calls[1].args[1].shape) == [1024]
|
||||
assert list(calls[2].args[1].shape) == [1024]
|
||||
|
||||
def test_load_fused_ffn_weights(self, mock_fd_config):
|
||||
"""Test loading and fusing FFN weights (lines 619-624 + stacked mapping logic)."""
|
||||
@@ -2370,6 +2380,96 @@ class TestLoadWeights:
|
||||
# Verify set_value called on lm_head
|
||||
assert model.lm_head.linear.weight.set_value.called
|
||||
|
||||
def test_load_weights_qkv_direct_is_skipped_when_split_exists(self, mock_fd_config):
|
||||
"""When split q/k/v exists, direct qkv_proj.* should be skipped for that layer."""
|
||||
from fastdeploy.model_executor.models.paddleformers.base import (
|
||||
PaddleFormersModelBase,
|
||||
)
|
||||
|
||||
fd_config, _ = mock_fd_config
|
||||
|
||||
class TestModel(PaddleFormersModelBase):
|
||||
pass
|
||||
|
||||
def mock_layer_init(self, *args, **kwargs):
|
||||
self._sub_layers = {}
|
||||
self._parameters = {}
|
||||
self._buffers = {}
|
||||
self._loaddict_holder = {}
|
||||
|
||||
with (
|
||||
patch.object(nn.Layer, "__init__", mock_layer_init),
|
||||
patch.object(TestModel, "create_attention_instances", return_value={}),
|
||||
):
|
||||
model = TestModel(fd_config)
|
||||
model._use_fused_qkv = True
|
||||
model._use_fused_ffn = False
|
||||
|
||||
qkv_param = MagicMock(spec=paddle.Tensor)
|
||||
qkv_param.weight_loader = Mock()
|
||||
params_dict = {"model.layers.0.self_attn.qkv_proj.weight": qkv_param}
|
||||
model.named_parameters = Mock(return_value=params_dict.items())
|
||||
model.named_sublayers = Mock(return_value={}.items())
|
||||
|
||||
weights = [
|
||||
("model.layers.0.self_attn.q_proj.weight", paddle.randn([4096, 4096])),
|
||||
("model.layers.0.self_attn.k_proj.weight", paddle.randn([4096, 4096])),
|
||||
("model.layers.0.self_attn.v_proj.weight", paddle.randn([4096, 4096])),
|
||||
("model.layers.0.self_attn.qkv_proj.weight", paddle.randn([4096, 12288])),
|
||||
]
|
||||
model.load_weights(weights)
|
||||
|
||||
# Only split q/k/v shards should be loaded for this layer.
|
||||
assert qkv_param.weight_loader.call_count == 3
|
||||
assert [c.args[2] for c in qkv_param.weight_loader.call_args_list] == ["q", "k", "v"]
|
||||
|
||||
def test_load_weights_direct_qkv_not_found_and_tie_warning(self, mock_fd_config):
|
||||
"""Cover direct qkv not-found warning and tie_word_embeddings warning path."""
|
||||
from fastdeploy.model_executor.models.paddleformers.base import (
|
||||
PaddleFormersModelBase,
|
||||
)
|
||||
|
||||
fd_config, _ = mock_fd_config
|
||||
|
||||
class TestModel(PaddleFormersModelBase):
|
||||
pass
|
||||
|
||||
def mock_layer_init(self, *args, **kwargs):
|
||||
self._sub_layers = {}
|
||||
self._parameters = {}
|
||||
self._buffers = {}
|
||||
self._loaddict_holder = {}
|
||||
|
||||
with (
|
||||
patch.object(nn.Layer, "__init__", mock_layer_init),
|
||||
patch.object(TestModel, "create_attention_instances", return_value={}),
|
||||
patch("fastdeploy.model_executor.models.paddleformers.base.logger.warning") as mock_warning,
|
||||
):
|
||||
model = TestModel(fd_config)
|
||||
model._use_fused_qkv = True
|
||||
model._use_fused_ffn = False
|
||||
model.tie_word_embeddings = True
|
||||
model.lm_head = MagicMock()
|
||||
model.lm_head.linear.weight.set_value = Mock()
|
||||
|
||||
model.model = MagicMock()
|
||||
# Missing embeddings.weight to hit warning branch.
|
||||
model.model.get_input_embeddings.return_value = SimpleNamespace()
|
||||
|
||||
model.named_parameters = Mock(return_value=[].__iter__())
|
||||
model.named_sublayers = Mock(return_value=[].__iter__())
|
||||
|
||||
weights = [
|
||||
("model.layers.0.self_attn.qkv_proj.weight", paddle.randn([4096, 12288])),
|
||||
]
|
||||
|
||||
model.load_weights(weights)
|
||||
|
||||
warning_texts = [str(c.args[0]) for c in mock_warning.call_args_list if c.args]
|
||||
assert any("Direct fused qkv param not found" in msg for msg in warning_texts)
|
||||
assert any("tie_word_embeddings=True" in msg for msg in warning_texts)
|
||||
assert not model.lm_head.linear.weight.set_value.called
|
||||
|
||||
|
||||
class TestLinearNoWeight:
|
||||
"""Test Linear layer replacement when weight is None (lines 321-322)."""
|
||||
@@ -2427,5 +2527,111 @@ class TestLinearNoWeight:
|
||||
assert isinstance(model.model.q_proj, ColumnParallelLinear)
|
||||
|
||||
|
||||
class TestPaddleFormersQKVParallelLinearUnit:
|
||||
"""Unit tests for PaddleFormersQKVParallelLinear helper methods."""
|
||||
|
||||
@staticmethod
|
||||
def _build_layer(model_format: str = "paddle"):
|
||||
from fastdeploy.model_executor.models.paddleformers.base import (
|
||||
PaddleFormersQKVParallelLinear,
|
||||
)
|
||||
|
||||
layer = object.__new__(PaddleFormersQKVParallelLinear)
|
||||
layer._pending_local_shards = {}
|
||||
layer._model_format = model_format
|
||||
layer.tp_size = 1
|
||||
layer.local_rank = 0
|
||||
layer.num_heads = 4
|
||||
layer.kv_num_heads = 2
|
||||
layer.num_heads_per_rank = 4
|
||||
layer.kv_num_heads_per_rank = 2
|
||||
layer.num_kv_head_replicas = 1
|
||||
layer.head_dim = 2
|
||||
layer.fd_config = SimpleNamespace(load_config=SimpleNamespace(is_pre_sharded=False))
|
||||
return layer
|
||||
|
||||
def test_extract_local_shard_with_transpose_and_tp_slice(self):
|
||||
layer = self._build_layer()
|
||||
layer.tp_size = 2
|
||||
layer.local_rank = 1
|
||||
layer.num_heads_per_rank = 2
|
||||
layer.kv_num_heads_per_rank = 1
|
||||
layer.head_dim = 2
|
||||
|
||||
param = SimpleNamespace(output_dim=True, shape=[4, 8], weight_need_transpose=True)
|
||||
loaded = paddle.arange(32, dtype="float32").reshape([8, 4]) # [out, in], transpose -> [in, out]
|
||||
|
||||
q_local = layer._extract_local_shard(param, loaded, "q")
|
||||
assert list(q_local.shape) == [4, 4]
|
||||
|
||||
expected = loaded.transpose([1, 0])[:, 4:8]
|
||||
assert bool(paddle.allclose(q_local, expected))
|
||||
|
||||
def test_to_hidden_major_and_pack_paths(self):
|
||||
layer = self._build_layer()
|
||||
# q_out=8, kv_out=4 for current head setup.
|
||||
q = paddle.randn([8, 3], dtype="float32") # [out, hidden] -> should transpose
|
||||
k = paddle.randn([4, 3], dtype="float32")
|
||||
v = paddle.randn([4, 3], dtype="float32")
|
||||
|
||||
packed_out_major = layer._pack_pf_interleaved_local(q, k, v, output_dim=False)
|
||||
assert list(packed_out_major.shape) == [16, 3]
|
||||
|
||||
with pytest.raises(ValueError, match="Expected 2D"):
|
||||
layer._to_hidden_major(paddle.randn([2], dtype="float32"), 2, "q")
|
||||
with pytest.raises(ValueError, match="Cannot normalize"):
|
||||
layer._to_hidden_major(paddle.randn([3, 5], dtype="float32"), 4, "q")
|
||||
|
||||
def test_split_pf_fused_qkv_and_weight_loader_pending_finalize(self):
|
||||
layer = self._build_layer(model_format="paddle")
|
||||
|
||||
class DummyParam:
|
||||
def __init__(self, shape, output_dim=True):
|
||||
self.shape = shape
|
||||
self.output_dim = output_dim
|
||||
self.weight_need_transpose = False
|
||||
self.dtype = paddle.float32
|
||||
self._initialized = False
|
||||
self.saved = None
|
||||
|
||||
def _is_initialized(self):
|
||||
return self._initialized
|
||||
|
||||
def initialize(self):
|
||||
self._initialized = True
|
||||
|
||||
def set_value(self, value):
|
||||
self.saved = value
|
||||
|
||||
# split fused weight path
|
||||
fused_weight = paddle.randn([3, 16], dtype="float32")
|
||||
q, k, v = layer._split_pf_fused_qkv(fused_weight, is_bias=False)
|
||||
assert list(q.shape) == [3, 8]
|
||||
assert list(k.shape) == [3, 4]
|
||||
assert list(v.shape) == [3, 4]
|
||||
|
||||
fused_bias = paddle.randn([16], dtype="float32")
|
||||
qb, kb, vb = layer._split_pf_fused_qkv(fused_bias, is_bias=True)
|
||||
assert list(qb.shape) == [8]
|
||||
assert list(kb.shape) == [4]
|
||||
assert list(vb.shape) == [4]
|
||||
|
||||
# pending -> finalize path
|
||||
param = DummyParam(shape=[3, 16], output_dim=True)
|
||||
layer.weight_loader(param, q, "q")
|
||||
assert bool(getattr(param, "_pf_qkv_pending", False))
|
||||
layer.weight_loader(param, k, "k")
|
||||
assert bool(getattr(param, "_pf_qkv_pending", False))
|
||||
layer.weight_loader(param, v, "v")
|
||||
assert not bool(getattr(param, "_pf_qkv_pending", False))
|
||||
assert param.saved is not None
|
||||
assert list(param.saved.shape) == [3, 16]
|
||||
|
||||
# direct fused qkv in non-paddle format should be rejected.
|
||||
layer_torch = self._build_layer(model_format="torch")
|
||||
with pytest.raises(ValueError, match="only supported for model_format='paddle'"):
|
||||
layer_torch.weight_loader(param, fused_weight, None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
@@ -58,7 +58,7 @@ class MockPretrainedConfig:
|
||||
|
||||
|
||||
class MockLinearLayer(paddle.nn.Layer):
|
||||
"""Mock for ColumnParallelLinear/RowParallelLinear to avoid Fleet."""
|
||||
"""Mock for TP linear layers in fallback tests to avoid Fleet initialization."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
@@ -100,6 +100,9 @@ def mock_distributed_layers(monkeypatch):
|
||||
# Mock on base module (where the imports are used)
|
||||
monkeypatch.setattr("fastdeploy.model_executor.models.paddleformers.base.ColumnParallelLinear", MockLinearLayer)
|
||||
monkeypatch.setattr("fastdeploy.model_executor.models.paddleformers.base.RowParallelLinear", MockLinearLayer)
|
||||
monkeypatch.setattr(
|
||||
"fastdeploy.model_executor.models.paddleformers.base.PaddleFormersQKVParallelLinear", MockLinearLayer
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"fastdeploy.model_executor.models.paddleformers.base.MergedColumnParallelLinear", MockLinearLayer
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user