mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
00eb12f656
* [BugFix][Models] avoid custom all-reduce in PaddleFormers fallback TP path and tighten TP-aware layout matching * [BugFix][Models] unify PaddleFormers fused QKV TP loading and align fallback tests
1042 lines
45 KiB
Python
1042 lines
45 KiB
Python
"""
|
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
"""Generic PaddleFormers modeling backend base class."""
|
|
|
|
import re
|
|
from collections.abc import Iterable
|
|
from typing import TYPE_CHECKING
|
|
|
|
import paddle
|
|
from paddle import nn
|
|
from paddleformers.nn.attention.interface import ALL_ATTENTION_FUNCTIONS
|
|
from paddleformers.transformers import AutoModel, PretrainedModel
|
|
from paddleformers.utils.log import logger
|
|
|
|
from fastdeploy.model_executor.forward_meta import ForwardMeta # noqa: F401
|
|
from fastdeploy.model_executor.graph_optimization.decorator import (
|
|
support_graph_optimization,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from fastdeploy.config import FDConfig
|
|
|
|
from fastdeploy.model_executor.layers.attention.attention import Attention
|
|
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
|
from fastdeploy.model_executor.layers.linear import (
|
|
ColumnParallelLinear,
|
|
MergedColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
ReplicatedLinear,
|
|
RowParallelLinear,
|
|
)
|
|
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
|
from fastdeploy.model_executor.utils import WeightsMapper, slice_fn
|
|
|
|
|
|
class PaddleFormersRMSNormWrapper(nn.Layer):
|
|
"""
|
|
Wrapper for FD's RMSNorm to make it compatible with PaddleFormers.
|
|
|
|
FD's RMSNorm always returns (output, residual_out) tuple,
|
|
but PaddleFormers expects a single tensor.
|
|
This wrapper extracts only the normalized output.
|
|
"""
|
|
|
|
def __init__(self, fd_rmsnorm: RMSNorm):
|
|
super().__init__()
|
|
self._fd_rmsnorm = fd_rmsnorm
|
|
# Expose weight for weight loading and other access
|
|
self.weight = fd_rmsnorm.weight
|
|
|
|
def forward(self, x):
|
|
# FD RMSNorm returns (out, residual_out), we only need out
|
|
out, _ = self._fd_rmsnorm(x)
|
|
return out
|
|
|
|
|
|
class PaddleFormersQKVParallelLinear(QKVParallelLinear):
|
|
"""PF-specific QKV loader that packs local shards in PF interleaved order."""
|
|
|
|
def __init__(self, fd_config, prefix: str, with_bias: bool = False):
|
|
super().__init__(fd_config=fd_config, prefix=prefix, with_bias=with_bias)
|
|
self._pending_local_shards: dict[int, dict[str, paddle.Tensor]] = {}
|
|
self._model_format = str(getattr(fd_config.model_config, "model_format", "") or "").lower()
|
|
|
|
@staticmethod
|
|
def _to_tensor(t: paddle.Tensor | object) -> paddle.Tensor:
|
|
return t if isinstance(t, paddle.Tensor) else paddle.to_tensor(t)
|
|
|
|
def _extract_local_shard(self, param: paddle.Tensor, loaded_weight: paddle.Tensor, loaded_shard_id: str):
|
|
output_dim = getattr(param, "output_dim", None)
|
|
if output_dim is None:
|
|
raise ValueError("Missing output_dim for QKV parameter.")
|
|
|
|
dim = -1 if output_dim else 0
|
|
denom = self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank
|
|
head_dim = int(param.shape[dim]) // int(denom)
|
|
|
|
weight = self._to_tensor(loaded_weight)
|
|
if getattr(param, "weight_need_transpose", False):
|
|
if weight.ndim != 2:
|
|
raise ValueError(f"Expected 2D tensor for transpose, got shape={list(weight.shape)}")
|
|
weight = weight.transpose([1, 0])
|
|
|
|
if self.tp_size > 1 and output_dim is not None and not self.fd_config.load_config.is_pre_sharded:
|
|
block_size = self._get_shard_size_mapping(loaded_shard_id, head_dim)
|
|
shard_id = self.local_rank if loaded_shard_id == "q" else self.local_rank // self.num_kv_head_replicas
|
|
shard_offset = shard_id * block_size
|
|
weight = slice_fn(weight, output_dim, start=shard_offset, end=shard_offset + block_size)
|
|
|
|
return weight
|
|
|
|
@staticmethod
|
|
def _to_hidden_major(weight: paddle.Tensor, expected_out: int, name: str) -> paddle.Tensor:
|
|
if weight.ndim != 2:
|
|
raise ValueError(f"Expected 2D {name} shard, got shape={list(weight.shape)}")
|
|
|
|
s0, s1 = int(weight.shape[0]), int(weight.shape[1])
|
|
if s1 == expected_out:
|
|
return weight
|
|
if s0 == expected_out:
|
|
return weight.transpose([1, 0])
|
|
raise ValueError(
|
|
f"Cannot normalize {name} shard shape={list(weight.shape)} to hidden-major with expected_out={expected_out}."
|
|
)
|
|
|
|
def _pack_pf_interleaved_local(
|
|
self,
|
|
q_local: paddle.Tensor,
|
|
k_local: paddle.Tensor,
|
|
v_local: paddle.Tensor,
|
|
output_dim: bool,
|
|
):
|
|
kv_local = int(self.kv_num_heads_per_rank)
|
|
if kv_local <= 0:
|
|
raise ValueError("Invalid kv_num_heads_per_rank, must be > 0.")
|
|
if self.num_heads_per_rank % kv_local != 0:
|
|
raise ValueError(
|
|
f"num_heads_per_rank={self.num_heads_per_rank} is not divisible by kv_num_heads_per_rank={kv_local}"
|
|
)
|
|
q_groups_local = self.num_heads_per_rank // kv_local
|
|
|
|
if q_local.ndim == 1:
|
|
q = q_local.reshape([kv_local, q_groups_local, self.head_dim])
|
|
k = k_local.reshape([kv_local, 1, self.head_dim])
|
|
v = v_local.reshape([kv_local, 1, self.head_dim])
|
|
return paddle.concat([q, k, v], axis=1).reshape([-1])
|
|
|
|
q_out = kv_local * q_groups_local * self.head_dim
|
|
kv_out = kv_local * self.head_dim
|
|
|
|
q_hm = self._to_hidden_major(q_local, q_out, "q")
|
|
k_hm = self._to_hidden_major(k_local, kv_out, "k")
|
|
v_hm = self._to_hidden_major(v_local, kv_out, "v")
|
|
|
|
hidden_size = int(q_hm.shape[0])
|
|
if int(k_hm.shape[0]) != hidden_size or int(v_hm.shape[0]) != hidden_size:
|
|
raise ValueError(
|
|
"Q/K/V hidden dimension mismatch after normalization: "
|
|
f"q={list(q_hm.shape)}, k={list(k_hm.shape)}, v={list(v_hm.shape)}"
|
|
)
|
|
|
|
q = q_hm.reshape([hidden_size, kv_local, q_groups_local, self.head_dim])
|
|
k = k_hm.reshape([hidden_size, kv_local, 1, self.head_dim])
|
|
v = v_hm.reshape([hidden_size, kv_local, 1, self.head_dim])
|
|
packed_hidden_major = paddle.concat([q, k, v], axis=2).reshape([hidden_size, -1])
|
|
|
|
if output_dim:
|
|
return packed_hidden_major
|
|
return packed_hidden_major.transpose([1, 0])
|
|
|
|
def _split_pf_fused_qkv(self, loaded_weight: paddle.Tensor, is_bias: bool):
|
|
if self._model_format != "paddle":
|
|
raise ValueError(
|
|
"Direct qkv_proj loading is only supported for model_format='paddle'. "
|
|
"Use split q_proj/k_proj/v_proj weights for other formats."
|
|
)
|
|
|
|
weight = self._to_tensor(loaded_weight)
|
|
if is_bias:
|
|
if weight.ndim != 1:
|
|
raise ValueError(f"Unexpected fused qkv bias dims: {list(weight.shape)}, expected 1D.")
|
|
width = int(weight.shape[0])
|
|
else:
|
|
if weight.ndim != 2:
|
|
raise ValueError(f"Unexpected fused qkv weight dims: {list(weight.shape)}, expected 2D.")
|
|
width = int(weight.shape[1])
|
|
|
|
global_width = int((self.num_heads + 2 * self.kv_num_heads) * self.head_dim)
|
|
local_width = int((self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank) * self.head_dim)
|
|
|
|
if width == global_width:
|
|
num_heads, num_kv_heads = self.num_heads, self.kv_num_heads
|
|
elif width == local_width:
|
|
num_heads, num_kv_heads = self.num_heads_per_rank, self.kv_num_heads_per_rank
|
|
else:
|
|
raise ValueError(
|
|
f"Cannot validate fused qkv_proj width={width}. "
|
|
f"Expect global={global_width} or local={local_width} for PF interleaved layout."
|
|
)
|
|
|
|
if num_heads % num_kv_heads != 0:
|
|
raise ValueError(f"Invalid head config: num_heads={num_heads}, num_kv_heads={num_kv_heads}")
|
|
q_groups = num_heads // num_kv_heads
|
|
|
|
if is_bias:
|
|
fused = weight.reshape([num_kv_heads, q_groups + 2, self.head_dim])
|
|
q = fused[:, :q_groups, :].reshape([-1])
|
|
k = fused[:, q_groups : q_groups + 1, :].reshape([-1])
|
|
v = fused[:, q_groups + 1 :, :].reshape([-1])
|
|
return q, k, v
|
|
|
|
hidden_size = int(weight.shape[0])
|
|
fused = weight.reshape([hidden_size, num_kv_heads, q_groups + 2, self.head_dim])
|
|
q = fused[:, :, :q_groups, :].reshape([hidden_size, -1])
|
|
k = fused[:, :, q_groups : q_groups + 1, :].reshape([hidden_size, -1])
|
|
v = fused[:, :, q_groups + 1 :, :].reshape([hidden_size, -1])
|
|
return q, k, v
|
|
|
|
def weight_loader(self, param, loaded_weight, loaded_shard_id: str | None = None):
|
|
if loaded_shard_id is None:
|
|
is_bias = len(param.shape) == 1
|
|
q_shard, k_shard, v_shard = self._split_pf_fused_qkv(loaded_weight, is_bias=is_bias)
|
|
self.weight_loader(param, q_shard, "q")
|
|
self.weight_loader(param, k_shard, "k")
|
|
self.weight_loader(param, v_shard, "v")
|
|
return
|
|
|
|
if loaded_shard_id not in {"q", "k", "v"}:
|
|
super().weight_loader(param, loaded_weight, loaded_shard_id)
|
|
return
|
|
|
|
local_shard = self._extract_local_shard(param, loaded_weight, loaded_shard_id)
|
|
key = id(param)
|
|
pending = self._pending_local_shards.setdefault(key, {})
|
|
pending[loaded_shard_id] = local_shard
|
|
|
|
if len(pending) < 3:
|
|
setattr(param, "_pf_qkv_pending", True)
|
|
return
|
|
|
|
packed = self._pack_pf_interleaved_local(
|
|
pending["q"],
|
|
pending["k"],
|
|
pending["v"],
|
|
output_dim=bool(getattr(param, "output_dim", True)),
|
|
)
|
|
if not param._is_initialized():
|
|
param.initialize()
|
|
if packed.dtype != param.dtype:
|
|
packed = packed.cast(param.dtype)
|
|
if list(param.shape) != list(packed.shape):
|
|
raise ValueError(f"Packed qkv shape mismatch: packed={list(packed.shape)} param={list(param.shape)}")
|
|
|
|
param.set_value(packed)
|
|
del self._pending_local_shards[key]
|
|
setattr(param, "_pf_qkv_pending", False)
|
|
|
|
|
|
def getattr_iter(obj, names, default=None):
|
|
for name in names:
|
|
if hasattr(obj, name):
|
|
return getattr(obj, name)
|
|
return default
|
|
|
|
|
|
def maybe_prefix(prefix, name):
|
|
if prefix:
|
|
return f"{prefix}.{name}"
|
|
return name
|
|
|
|
|
|
def fastdeploy_append_attention_forward(
|
|
module: paddle.nn.Layer,
|
|
query: paddle.Tensor,
|
|
key: paddle.Tensor,
|
|
value: paddle.Tensor,
|
|
attention_mask: paddle.Tensor,
|
|
scaling: float | None = None,
|
|
**kwargs,
|
|
):
|
|
config = getattr(module, "config", None)
|
|
if config is None:
|
|
raise ValueError(f"Module {module} does not have 'config' attribute.")
|
|
|
|
attention_instances = getattr(config, "attention_instances", None)
|
|
forward_meta = getattr(config, "forward_meta", None)
|
|
|
|
if attention_instances is None:
|
|
raise ValueError("attention_instances not found in module.config")
|
|
if forward_meta is None:
|
|
raise ValueError("forward_meta not found in module.config")
|
|
|
|
layer_idx = getattr(module, "layer_idx", getattr(module, "layer_id", None))
|
|
if layer_idx is None:
|
|
raise ValueError("layer_idx not found on attention module")
|
|
|
|
self_attn = attention_instances[int(layer_idx)]
|
|
if scaling is not None:
|
|
self_attn.scale = float(scaling)
|
|
|
|
tp_size = 1
|
|
if hasattr(self_attn, "fd_config") and hasattr(self_attn.fd_config, "parallel_config"):
|
|
tp_size = int(getattr(self_attn.fd_config.parallel_config, "tensor_parallel_size", 1) or 1)
|
|
|
|
# Resolve head-related metadata.
|
|
num_heads = (
|
|
getattr(module, "num_heads", None)
|
|
or getattr(config, "num_attention_heads", None)
|
|
or getattr(self_attn, "num_heads", None)
|
|
)
|
|
num_kv_heads = (
|
|
getattr(module, "num_key_value_heads", None)
|
|
or getattr(config, "num_key_value_heads", None)
|
|
or getattr(self_attn, "num_key_value_heads", None)
|
|
or getattr(self_attn, "kv_num_heads", None)
|
|
or num_heads
|
|
)
|
|
num_heads = int(num_heads) if num_heads is not None else None
|
|
num_kv_heads = int(num_kv_heads) if num_kv_heads is not None else None
|
|
|
|
# Support only 3D (HSD/SHD) or 4D (BHSD/BSHD with B=1) inputs.
|
|
def squeeze_to_3d(t: paddle.Tensor, name: str) -> paddle.Tensor:
|
|
if t.ndim == 4:
|
|
if int(t.shape[0]) != 1:
|
|
raise ValueError(f"{name} batch size {int(t.shape[0])} not supported")
|
|
return t.squeeze(0)
|
|
if t.ndim == 3:
|
|
return t
|
|
raise ValueError(f"{name} has unexpected dims {t.ndim}, expect 3 or 4")
|
|
|
|
q = squeeze_to_3d(query, "query")
|
|
k = squeeze_to_3d(key, "key")
|
|
v = squeeze_to_3d(value, "value")
|
|
|
|
def heads_match(actual_heads: int, expected_heads: int | None) -> bool:
|
|
if expected_heads is None:
|
|
return False
|
|
if actual_heads == expected_heads:
|
|
return True
|
|
if tp_size > 1 and expected_heads % tp_size == 0:
|
|
expected_heads //= tp_size
|
|
return actual_heads == expected_heads
|
|
|
|
# Determine layout from Q/K/V head axes; keep default behavior on ambiguity.
|
|
is_hsd = (
|
|
heads_match(int(q.shape[0]), num_heads)
|
|
and heads_match(int(k.shape[0]), num_kv_heads)
|
|
and heads_match(int(v.shape[0]), num_kv_heads)
|
|
)
|
|
is_shd = (
|
|
heads_match(int(q.shape[1]), num_heads)
|
|
and heads_match(int(k.shape[1]), num_kv_heads)
|
|
and heads_match(int(v.shape[1]), num_kv_heads)
|
|
)
|
|
|
|
if is_hsd:
|
|
q_flat = q.transpose([1, 0, 2]).reshape([int(q.shape[1]), -1])
|
|
k_flat = k.transpose([1, 0, 2]).reshape([int(k.shape[1]), -1])
|
|
v_flat = v.transpose([1, 0, 2]).reshape([int(v.shape[1]), -1])
|
|
elif is_shd:
|
|
q_flat = q.reshape([int(q.shape[0]), -1])
|
|
k_flat = k.reshape([int(k.shape[0]), -1])
|
|
v_flat = v.reshape([int(v.shape[0]), -1])
|
|
else:
|
|
raise ValueError(
|
|
f"Invalid attention layout: q={list(q.shape)}, k={list(k.shape)}, v={list(v.shape)}, "
|
|
f"heads={num_heads}/{num_kv_heads}"
|
|
)
|
|
|
|
# Sequence lengths must match after flattening Q/K/V.
|
|
q_seq, k_seq, v_seq = int(q_flat.shape[0]), int(k_flat.shape[0]), int(v_flat.shape[0])
|
|
if not (q_seq == k_seq == v_seq):
|
|
raise ValueError(
|
|
f"Sequence length mismatch after flattening: Q={q_seq}, K={k_seq}, V={v_seq}, "
|
|
f"raw query={list(query.shape)}, key={list(key.shape)}, value={list(value.shape)}."
|
|
)
|
|
|
|
# If forward_meta provides ids_remove_padding, strictly validate Q sequence length.
|
|
ids_remove_padding = getattr(forward_meta, "ids_remove_padding", None)
|
|
if ids_remove_padding is not None:
|
|
expected_seq = int(ids_remove_padding.shape[0])
|
|
if q_seq != expected_seq:
|
|
raise ValueError(f"Seq len mismatch: got {q_seq}, expect {expected_seq}")
|
|
|
|
qkv = paddle.concat([q_flat, k_flat, v_flat], axis=-1)
|
|
output = self_attn.forward(qkv=qkv, forward_meta=forward_meta)
|
|
|
|
return output, None
|
|
|
|
|
|
ALL_ATTENTION_FUNCTIONS._global_mapping["fastdeploy_append"] = fastdeploy_append_attention_forward
|
|
|
|
|
|
@support_graph_optimization
|
|
class PaddleFormersModelBase(nn.Layer):
|
|
"""
|
|
A mixin-style base class to provide PaddleFormers backend logic on top of nn.Layer.
|
|
This class subclasses nn.Layer and provides common methods to
|
|
initialize and manage a PaddleFormers model.
|
|
"""
|
|
|
|
pf_to_fd_mapper = WeightsMapper(
|
|
orig_to_new_prefix={
|
|
"": "model.",
|
|
"model.model.": "model.",
|
|
"model.embed_tokens.weight": "model.embed_tokens.embeddings.weight",
|
|
"embed_tokens.weight": "model.embed_tokens.embeddings.weight",
|
|
"model.lm_head.weight": "lm_head.linear.weight",
|
|
"model.score.": "classifier.",
|
|
"model.classifier.": "classifier.",
|
|
}
|
|
)
|
|
|
|
def __init_subclass__(cls, *args, **kwargs):
|
|
"""Merge pf_to_fd_mapper in MRO from most specific to least specific."""
|
|
super().__init_subclass__(*args, **kwargs)
|
|
|
|
# Collect all mappings from base classes
|
|
merged_mappings = {}
|
|
for base in reversed(cls.__mro__): # Reverse to go from least to most specific
|
|
if base_pf_to_fd_mapper := getattr(base, "pf_to_fd_mapper", None):
|
|
if hasattr(base_pf_to_fd_mapper, "orig_to_new_prefix"):
|
|
merged_mappings.update(base_pf_to_fd_mapper.orig_to_new_prefix)
|
|
|
|
# Create new mapper with merged mappings
|
|
cls.pf_to_fd_mapper = WeightsMapper(orig_to_new_prefix=merged_mappings)
|
|
|
|
def __init__(self, fd_config: "FDConfig", **kwargs):
|
|
super().__init__(fd_config)
|
|
logger.info("Initializing PaddleFormers backend.")
|
|
self.fd_config = fd_config # FastDeploy's top-level FDConfig
|
|
self.model_config = fd_config.model_config # FastDeploy's ModelConfig
|
|
|
|
from paddleformers.transformers import AutoConfig
|
|
|
|
self.paddleformers_config = AutoConfig.from_pretrained(self.model_config.model)
|
|
|
|
# PaddleFormers fused optimize option
|
|
self.paddleformers_config.fuse_rms_norm = True
|
|
model_type = getattr(self.paddleformers_config, "model_type", "").lower()
|
|
supported_fused_qkv_models = ["qwen3", "qwen2"]
|
|
|
|
tp_size = fd_config.parallel_config.tensor_parallel_size
|
|
self._use_fused_qkv = model_type in supported_fused_qkv_models
|
|
if self._use_fused_qkv:
|
|
self.paddleformers_config.fuse_attention_qkv = True
|
|
logger.info(f"Enabled fuse_attention_qkv for model_type={model_type}, tp={tp_size}")
|
|
else:
|
|
logger.debug(f"QKV fusion not enabled for model_type={model_type}")
|
|
|
|
# PaddleFormers fused optimize option
|
|
self._use_fused_ffn = model_type in supported_fused_qkv_models
|
|
if self._use_fused_ffn:
|
|
self.paddleformers_config.fuse_attention_ffn = True
|
|
self.paddleformers_config.fuse_swiglu = True
|
|
logger.info(f"Enabled fuse_attention_ffn and fuse_swiglu for model_type={model_type}")
|
|
|
|
self.text_config = self.paddleformers_config # The specific text model config
|
|
# Sync important config values from text_config to model_config
|
|
# This ensures fallback models use their actual config values instead of FD defaults
|
|
self._sync_config_from_text_config()
|
|
# For convenience, keep direct access to some FD configs
|
|
self.quant_config = self.fd_config.quant_config
|
|
self.parallel_config = self.fd_config.parallel_config
|
|
self.tp_group = self.parallel_config.tp_group
|
|
self.tp_rank = self.parallel_config.tensor_parallel_rank
|
|
self.paddleformers_config._attn_implementation = "fastdeploy_append"
|
|
|
|
self.model: PretrainedModel = AutoModel.from_config(
|
|
self.paddleformers_config,
|
|
dtype=self.model_config.dtype,
|
|
)
|
|
self.model.eval()
|
|
|
|
# Linear and Norm replace for FD optimized versions and TP support
|
|
self.recursive_replace()
|
|
# Patch PF attention head counts to TP-local values for fused qkv reshape
|
|
self._localize_pf_attention_heads()
|
|
# Attention instances for FD Attention backend
|
|
self.attention_instances = self.create_attention_instances()
|
|
self.paddleformers_config.attention_instances = self.attention_instances
|
|
# Embedding replace for TP support
|
|
input_embeddings = self.model.get_input_embeddings()
|
|
self.embed_scale = getattr(input_embeddings, "embed_scale", None)
|
|
embedding_dim = getattr_iter(self.text_config, ("embedding_size", "hidden_size"))
|
|
if embedding_dim is None:
|
|
raise ValueError(
|
|
"Failed to determine embedding dimension from text_config: "
|
|
"neither 'embedding_size' nor 'hidden_size' is set. "
|
|
f"text_config type={type(self.text_config).__name__}."
|
|
)
|
|
self.model.set_input_embeddings(
|
|
VocabParallelEmbedding(
|
|
fd_config=self.fd_config,
|
|
num_embeddings=self.text_config.vocab_size,
|
|
embedding_dim=embedding_dim,
|
|
)
|
|
)
|
|
|
|
def _sync_config_from_text_config(self) -> None:
|
|
"""
|
|
Sync important config values from text_config (PaddleFormers/HF config)
|
|
to model_config. This ensures fallback models use their actual config
|
|
values instead of FD's defaults.
|
|
|
|
This is crucial for models with unique configs like:
|
|
- Gemma3: tie_word_embeddings=True, layer_types, sliding_window
|
|
- Mistral: sliding_window
|
|
- etc.
|
|
"""
|
|
mc = self.model_config
|
|
tc = self.text_config
|
|
|
|
sync_fields = [
|
|
"tie_word_embeddings",
|
|
"sliding_window",
|
|
"sliding_window_pattern",
|
|
"layer_types", # May be computed as property
|
|
"rope_theta",
|
|
"rope_scaling",
|
|
"head_dim",
|
|
"rms_norm_eps",
|
|
"rope_local_base_freq", # Gemma3 specific
|
|
"query_pre_attn_scalar", # Gemma3 specific
|
|
]
|
|
|
|
synced = []
|
|
for field in sync_fields:
|
|
text_value = getattr(tc, field, None)
|
|
if text_value is not None:
|
|
# Only sync if not already set or if FD default differs
|
|
current_value = getattr(mc, field, None) if hasattr(mc, field) else None
|
|
if current_value is None or current_value != text_value:
|
|
setattr(mc, field, text_value)
|
|
synced.append(f"{field}={text_value}")
|
|
|
|
def recursive_replace(self):
|
|
"""Recursively replace modules in the model as needed.
|
|
|
|
Replaces:
|
|
- nn.Linear with FD's tensor parallel linear classes (based on naming rules)
|
|
- *RMSNorm with FD's RMSNorm
|
|
"""
|
|
tp_plan = self._get_tp_plan()
|
|
|
|
def _get_linear_style(qual_name: str) -> str:
|
|
"""Determine linear style based on layer name patterns."""
|
|
for pattern, style in tp_plan.items():
|
|
if re.search(pattern, qual_name):
|
|
return style
|
|
return "replicate"
|
|
|
|
def _recursive_replace(module: nn.Layer, prefix: str):
|
|
for child_name, child_module in module.named_children():
|
|
qual_name = maybe_prefix(prefix, child_name)
|
|
new_module = child_module
|
|
|
|
if isinstance(child_module, nn.Linear):
|
|
style = _get_linear_style(qual_name)
|
|
|
|
# PaddlePaddle nn.Linear: weight shape is [in_features, out_features]
|
|
# PyTorch nn.Linear: has in_features/out_features attributes
|
|
if hasattr(child_module, "weight") and child_module.weight is not None:
|
|
weight_shape = child_module.weight.shape
|
|
in_features = weight_shape[0]
|
|
out_features = weight_shape[1]
|
|
else:
|
|
in_features = getattr(child_module, "in_features", None)
|
|
out_features = getattr(child_module, "out_features", None)
|
|
|
|
with_bias = hasattr(child_module, "bias") and child_module.bias is not None
|
|
|
|
if style == "colwise":
|
|
# qkv_proj uses PF-specific TP-aware loader to support
|
|
# unified split-QKV loading across TP1/TP>1.
|
|
if "qkv_proj" in qual_name and self._use_fused_qkv:
|
|
new_module = PaddleFormersQKVParallelLinear(
|
|
self.fd_config,
|
|
prefix=qual_name,
|
|
with_bias=with_bias,
|
|
)
|
|
# For up_gate_proj when fused FFN is enabled:
|
|
# Use MergedColumnParallelLinear which handles gate+up weight loading
|
|
elif "up_gate_proj" in qual_name and self._use_fused_ffn:
|
|
new_module = MergedColumnParallelLinear(
|
|
self.fd_config,
|
|
prefix=qual_name,
|
|
input_size=in_features,
|
|
output_size=out_features,
|
|
with_bias=with_bias,
|
|
)
|
|
else:
|
|
new_module = ColumnParallelLinear(
|
|
self.fd_config,
|
|
prefix=qual_name,
|
|
input_size=in_features,
|
|
output_size=out_features,
|
|
with_bias=with_bias,
|
|
)
|
|
elif style == "rowwise":
|
|
new_module = RowParallelLinear(
|
|
self.fd_config,
|
|
prefix=qual_name,
|
|
input_size=in_features,
|
|
output_size=out_features,
|
|
with_bias=with_bias,
|
|
)
|
|
else: # replicate
|
|
new_module = ReplicatedLinear(
|
|
self.fd_config,
|
|
prefix=qual_name,
|
|
input_size=in_features,
|
|
output_size=out_features,
|
|
with_bias=with_bias,
|
|
)
|
|
|
|
# RMSNorm replacement: use wrapper to adapt FD's tuple return to single tensor
|
|
elif child_module.__class__.__name__.endswith("RMSNorm"):
|
|
if hasattr(child_module, "weight") and child_module.weight is not None:
|
|
hidden_size = child_module.weight.shape[0]
|
|
else:
|
|
hidden_size = getattr(self.text_config, "hidden_size", None)
|
|
eps = getattr(child_module, "epsilon", getattr(child_module, "variance_epsilon", 1e-6))
|
|
fd_rmsnorm = RMSNorm(
|
|
self.fd_config,
|
|
hidden_size=hidden_size,
|
|
eps=eps,
|
|
prefix=qual_name,
|
|
begin_norm_axis=-1, # Normalize only last dim (hidden), not entire flattened tensor
|
|
)
|
|
# Wrap with PaddleFormersRMSNormWrapper for interface compatibility
|
|
new_module = PaddleFormersRMSNormWrapper(fd_rmsnorm)
|
|
else:
|
|
_recursive_replace(child_module, prefix=qual_name)
|
|
|
|
if new_module is not child_module:
|
|
setattr(module, child_name, new_module)
|
|
|
|
_recursive_replace(self.model, prefix="model")
|
|
|
|
def _localize_pf_attention_heads(self):
|
|
"""Patch PF attention modules' head counts to TP-local values.
|
|
|
|
PF Attention.__init__ reads global head counts from config and stores
|
|
them as instance attrs (num_heads, num_key_value_heads, etc.).
|
|
Since we cannot set config.tensor_model_parallel_size > 1 (it would
|
|
trigger PF's own TP linears, conflicting with recursive_replace),
|
|
we patch the instance attrs directly after model creation.
|
|
|
|
Only needed when fused qkv is enabled, because the PF forward path
|
|
reshapes qkv_proj output using these head counts.
|
|
"""
|
|
tp_size = self.fd_config.parallel_config.tensor_parallel_size
|
|
if tp_size <= 1 or not self._use_fused_qkv:
|
|
return
|
|
|
|
g_heads = int(self.text_config.num_attention_heads)
|
|
g_kv = int(getattr(self.text_config, "num_key_value_heads", g_heads))
|
|
local_heads = g_heads // tp_size
|
|
local_kv = max(1, g_kv // tp_size)
|
|
local_groups = local_heads // local_kv
|
|
|
|
patched = 0
|
|
for name, module in self.model.named_sublayers():
|
|
# PF attention modules store head counts as instance attrs used in forward reshape
|
|
if not hasattr(module, "num_key_value_groups"):
|
|
continue
|
|
module.num_heads = local_heads
|
|
module.num_key_value_heads = local_kv
|
|
module.num_key_value_groups = local_groups
|
|
patched += 1
|
|
|
|
if patched:
|
|
logger.info(
|
|
f"Localized {patched} PF attention modules: "
|
|
f"heads {g_heads}->{local_heads}, kv {g_kv}->{local_kv}, tp={tp_size}"
|
|
)
|
|
|
|
def _get_tp_plan(self) -> dict[str, str]:
|
|
"""Get TP plan for linear layer replacement.
|
|
|
|
Priority:
|
|
1. Try to get from PaddleFormers model's _get_tensor_parallel_mappings classmethod
|
|
2. Fall back to default naming-based rules
|
|
|
|
Returns:
|
|
Dict mapping regex patterns to style ("colwise", "rowwise", "replicate")
|
|
"""
|
|
# Try to get TP mappings from PaddleFormers model class
|
|
model_cls = type(self.model)
|
|
if hasattr(model_cls, "_get_tensor_parallel_mappings"):
|
|
try:
|
|
# Call the classmethod with config
|
|
mappings = model_cls._get_tensor_parallel_mappings(self.text_config, is_split=True)
|
|
if mappings:
|
|
# Convert PaddleFormers mappings to our format
|
|
# mappings is like: {"model.layers.0.self_attn.q_proj.weight": partial(fn, is_column=True)}
|
|
# Extract layer name patterns and determine colwise/rowwise
|
|
colwise_layers = set()
|
|
rowwise_layers = set()
|
|
|
|
for key, func in mappings.items():
|
|
# Extract the layer suffix (e.g., "self_attn.q_proj.weight" -> "q_proj")
|
|
parts = key.split(".")
|
|
if len(parts) >= 2:
|
|
# Find the layer name (second to last before .weight/.bias)
|
|
for i, part in enumerate(parts):
|
|
if part.endswith("_proj") or part in (
|
|
"up_proj",
|
|
"gate_proj",
|
|
"down_proj",
|
|
"o_proj",
|
|
"q_proj",
|
|
"k_proj",
|
|
"v_proj",
|
|
"qkv_proj",
|
|
):
|
|
# Check is_column from partial func
|
|
if hasattr(func, "keywords") and func.keywords.get("is_column", False):
|
|
colwise_layers.add(part)
|
|
else:
|
|
rowwise_layers.add(part)
|
|
|
|
if colwise_layers or rowwise_layers:
|
|
# Handle QKV fusion: adjust layer names based on fusion setting
|
|
if self._use_fused_qkv:
|
|
# Using fused QKV: add qkv_proj, remove separate q/k/v_proj
|
|
colwise_layers.add("qkv_proj")
|
|
colwise_layers.discard("q_proj")
|
|
colwise_layers.discard("k_proj")
|
|
colwise_layers.discard("v_proj")
|
|
else:
|
|
# Not using fused QKV: ensure separate projections
|
|
colwise_layers.discard("qkv_proj")
|
|
colwise_layers.update(["q_proj", "k_proj", "v_proj"])
|
|
|
|
# Handle Gate+Up fusion: adjust layer names based on fusion setting
|
|
if self._use_fused_ffn:
|
|
# Using fused FFN: add up_gate_proj, remove separate gate/up_proj
|
|
colwise_layers.add("up_gate_proj")
|
|
colwise_layers.discard("gate_proj")
|
|
colwise_layers.discard("up_proj")
|
|
else:
|
|
# Not using fused FFN: ensure separate projections
|
|
colwise_layers.discard("up_gate_proj")
|
|
colwise_layers.update(["gate_proj", "up_proj"])
|
|
|
|
converted_plan = {}
|
|
for layer in colwise_layers:
|
|
converted_plan[rf"\.{layer}$"] = "colwise"
|
|
for layer in rowwise_layers:
|
|
converted_plan[rf"\.{layer}$"] = "rowwise"
|
|
return converted_plan
|
|
except Exception as e:
|
|
logger.warning(f"Failed to get PaddleFormers TP mappings: {e}, using default")
|
|
|
|
# Default naming-based TP plan
|
|
return {
|
|
# Column Parallel (output dimension split)
|
|
r"\.qkv_proj$": "colwise", # Fused QKV projection
|
|
r"\.up_gate_proj$": "colwise", # Fused FFN projection
|
|
r"\.q_proj$": "colwise",
|
|
r"\.k_proj$": "colwise",
|
|
r"\.v_proj$": "colwise",
|
|
r"\.gate_proj$": "colwise",
|
|
r"\.up_proj$": "colwise",
|
|
# Row Parallel (input dimension split)
|
|
r"\.o_proj$": "rowwise",
|
|
r"\.down_proj$": "rowwise",
|
|
}
|
|
|
|
def create_attention_instances(self) -> dict[int, Attention]:
|
|
"""Create FastDeploy attention instances for all layers.
|
|
|
|
These instances replace PaddleFormers' attention and are passed to model.forward().
|
|
For centralized deployment, create instances for all layers.
|
|
"""
|
|
num_layers = self.text_config.num_hidden_layers
|
|
|
|
layer_types = getattr(self.text_config, "layer_types", None)
|
|
sliding_window = getattr(self.text_config, "sliding_window", None)
|
|
|
|
if layer_types is None:
|
|
sliding_window_pattern = getattr(self.text_config, "sliding_window_pattern", None)
|
|
if sliding_window_pattern is not None and sliding_window is not None:
|
|
layer_types = [
|
|
"sliding_attention" if bool((i + 1) % sliding_window_pattern) else "full_attention"
|
|
for i in range(num_layers)
|
|
]
|
|
|
|
if layer_types is not None:
|
|
if not hasattr(self.fd_config.model_config, "layer_types"):
|
|
self.fd_config.model_config.layer_types = layer_types
|
|
if not hasattr(self.fd_config.model_config, "sliding_window") and sliding_window is not None:
|
|
self.fd_config.model_config.sliding_window = sliding_window
|
|
|
|
attention_instances = {}
|
|
for i in range(num_layers):
|
|
attention_instances[i] = Attention(
|
|
fd_config=self.fd_config,
|
|
layer_id=i,
|
|
)
|
|
|
|
return attention_instances
|
|
|
|
def embed_input_ids(self, input_ids: paddle.Tensor) -> paddle.Tensor:
|
|
"""Embed input_ids using the model's embedding layer."""
|
|
embedding_layer = self.model.get_input_embeddings()
|
|
inputs_embeds = embedding_layer(input_ids)
|
|
|
|
if hasattr(self, "embed_scale") and self.embed_scale is not None:
|
|
inputs_embeds *= self.embed_scale
|
|
return inputs_embeds
|
|
|
|
@paddle.no_grad()
|
|
def forward(
|
|
self,
|
|
ids_remove_padding: paddle.Tensor,
|
|
forward_meta: ForwardMeta,
|
|
**kwargs,
|
|
):
|
|
"""Full transformer forward: input_ids -> hidden_states.
|
|
|
|
This method is the primary forward pass for the model, computing:
|
|
1. Position IDs based on seq_lens_decoder (absolute positions for RoPE)
|
|
2. Token embeddings via embed_input_ids
|
|
3. Transformer layers via self.model()
|
|
|
|
Returns:
|
|
hidden_states: [TotalTokens, HiddenDim]
|
|
"""
|
|
num_tokens = ids_remove_padding.shape[0]
|
|
|
|
batch_id_per_token = forward_meta.batch_id_per_token # [num_tokens]
|
|
seq_lens_decoder = forward_meta.seq_lens_decoder # [batch_size, 1]
|
|
|
|
if batch_id_per_token is not None and seq_lens_decoder is not None:
|
|
decoder_offsets = seq_lens_decoder.squeeze(-1) # [batch_size]
|
|
token_decoder_offsets = paddle.index_select(decoder_offsets, batch_id_per_token, axis=0) # [num_tokens]
|
|
|
|
cu_seqlens = forward_meta.cu_seqlens_q # [batch_size + 1]
|
|
if cu_seqlens is not None:
|
|
token_global_idx = paddle.arange(num_tokens, dtype="int64")
|
|
request_start_idx = paddle.index_select(cu_seqlens[:-1], batch_id_per_token, axis=0)
|
|
relative_positions = token_global_idx - request_start_idx.astype("int64")
|
|
else:
|
|
relative_positions = paddle.zeros([num_tokens], dtype="int64")
|
|
position_ids = token_decoder_offsets.astype("int64") + relative_positions
|
|
else:
|
|
position_ids = paddle.arange(num_tokens, dtype="int64")
|
|
if seq_lens_decoder is not None:
|
|
position_ids = position_ids + seq_lens_decoder[0, 0].astype("int64")
|
|
|
|
inputs_embeds = self.embed_input_ids(ids_remove_padding).unsqueeze(0)
|
|
|
|
if getattr(self.text_config, "uses_mrope", False):
|
|
position_ids = position_ids.unsqueeze(1)
|
|
else:
|
|
position_ids = position_ids.unsqueeze(0)
|
|
|
|
forward_meta.rope_already_applied = True
|
|
self.paddleformers_config.forward_meta = forward_meta
|
|
|
|
model_output = self.model(
|
|
input_ids=None,
|
|
inputs_embeds=inputs_embeds,
|
|
use_cache=False,
|
|
position_ids=position_ids,
|
|
return_dict=False,
|
|
**kwargs,
|
|
)
|
|
|
|
hidden_states = model_output[0][0, ...] # Remove batch dim
|
|
|
|
return hidden_states
|
|
|
|
@paddle.no_grad()
|
|
def load_weights(self, weights: Iterable[tuple[str, paddle.Tensor]]):
|
|
"""Load weights from checkpoint into model parameters."""
|
|
from fastdeploy.model_executor.utils import (
|
|
default_weight_loader,
|
|
process_weights_after_loading,
|
|
)
|
|
|
|
sublayers_dict = dict(self.named_sublayers())
|
|
process_fn = process_weights_after_loading(sublayers_dict, self.fd_config)
|
|
params_dict = dict(self.named_parameters())
|
|
|
|
# === Checkpoint prefix alias handling ===
|
|
model_type = str(getattr(self.paddleformers_config, "model_type", "") or "").lower()
|
|
ckpt_prefix_aliases = {model_type, model_type.replace("-", "_"), model_type.replace("_", "")} - {""}
|
|
ckpt_alias_markers = (".layers.", ".embed_tokens.", ".lm_head.", ".norm.", ".final_layernorm.", ".rotary_emb.")
|
|
|
|
def resolve_param_name(weight_name: str) -> str | None:
|
|
# Collect prefix aliases dynamically.
|
|
if "." in weight_name:
|
|
prefix = weight_name.split(".", 1)[0]
|
|
if prefix not in {"model", "lm_head"} and any(m in weight_name for m in ckpt_alias_markers):
|
|
ckpt_prefix_aliases.add(prefix)
|
|
|
|
# Generate candidate parameter names.
|
|
candidates = [weight_name]
|
|
candidates.append(weight_name[6:] if weight_name.startswith("model.") else "model." + weight_name)
|
|
if "." in weight_name:
|
|
prefix, rest = weight_name.split(".", 1)
|
|
if prefix in ckpt_prefix_aliases:
|
|
candidates.extend([rest, "model." + rest])
|
|
|
|
return next((c for c in candidates if c in params_dict), None)
|
|
|
|
# === Stacked parameter mapping config ===
|
|
stacked_params_mapping = [
|
|
("embed_tokens.embeddings", "embed_tokens", None),
|
|
("lm_head.linear", "lm_head", None),
|
|
]
|
|
if self._use_fused_ffn:
|
|
stacked_params_mapping += [("up_gate_proj", "gate_proj", "gate"), ("up_gate_proj", "up_proj", "up")]
|
|
|
|
# === QKV loading helpers ===
|
|
qkv_split_layers: set[str] = set()
|
|
qkv_direct_pending: dict[tuple[str, bool], tuple[str, paddle.Tensor]] = {}
|
|
|
|
def parse_qkv_shard_name(name: str) -> tuple[str, str, str] | None:
|
|
shard_suffixes = (
|
|
(".q_proj.weight", "q"),
|
|
(".k_proj.weight", "k"),
|
|
(".v_proj.weight", "v"),
|
|
(".q_proj.bias", "q"),
|
|
(".k_proj.bias", "k"),
|
|
(".v_proj.bias", "v"),
|
|
)
|
|
for suffix, shard_id in shard_suffixes:
|
|
if name.endswith(suffix):
|
|
layer_key = name.replace(suffix, "")
|
|
qkv_param_name = name.replace(".q_proj.", ".qkv_proj.")
|
|
qkv_param_name = qkv_param_name.replace(".k_proj.", ".qkv_proj.")
|
|
qkv_param_name = qkv_param_name.replace(".v_proj.", ".qkv_proj.")
|
|
return layer_key, shard_id, qkv_param_name
|
|
return None
|
|
|
|
def parse_direct_qkv_name(name: str) -> tuple[str, bool] | None:
|
|
if name.endswith(".qkv_proj.weight"):
|
|
return name.replace(".qkv_proj.weight", ""), False
|
|
if name.endswith(".qkv_proj.bias"):
|
|
return name.replace(".qkv_proj.bias", ""), True
|
|
return None
|
|
|
|
# === Helper functions ===
|
|
def load_param(name: str, tensor: paddle.Tensor, shard_id=None):
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
|
|
weight_loader(param, tensor, shard_id)
|
|
if shard_id in {"q", "k", "v"} and bool(getattr(param, "_pf_qkv_pending", False)):
|
|
return False
|
|
process_fn(re.sub(r"\.(weight|bias)$", "", name), param)
|
|
return True
|
|
|
|
# === Main loading loop ===
|
|
loaded_count = skipped_count = 0
|
|
|
|
for weight_name, weight in weights:
|
|
# 1. Handle fused QKV path in a unified split-shard style.
|
|
if self._use_fused_qkv:
|
|
if qkv_info := parse_qkv_shard_name(weight_name):
|
|
layer_key, proj_type, qkv_param_name = qkv_info
|
|
qkv_split_layers.add(layer_key)
|
|
resolved = resolve_param_name(qkv_param_name)
|
|
if resolved:
|
|
try:
|
|
load_param(resolved, weight, shard_id=proj_type)
|
|
loaded_count += 1
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load qkv shard {weight_name} -> {resolved}: {e}")
|
|
skipped_count += 1
|
|
else:
|
|
logger.warning(f"QKV shard mapping not found: {weight_name} -> {qkv_param_name}")
|
|
skipped_count += 1
|
|
continue
|
|
|
|
if direct_qkv_info := parse_direct_qkv_name(weight_name):
|
|
layer_key, is_bias = direct_qkv_info
|
|
qkv_direct_pending[(layer_key, is_bias)] = (weight_name, weight)
|
|
continue
|
|
|
|
# 2. Stacked params mapping
|
|
for param_name, src_name, shard_id in stacked_params_mapping:
|
|
if src_name in weight_name:
|
|
resolved = resolve_param_name(weight_name.replace(src_name, param_name))
|
|
if resolved:
|
|
load_param(resolved, weight, shard_id)
|
|
loaded_count += 1
|
|
else:
|
|
logger.warning(f"Stacked mapping: {weight_name} -> NOT FOUND")
|
|
break
|
|
else:
|
|
# 3. Direct load.
|
|
resolved = resolve_param_name(weight_name)
|
|
if resolved:
|
|
try:
|
|
load_param(resolved, weight)
|
|
loaded_count += 1
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load {resolved}: {e}")
|
|
skipped_count += 1
|
|
else:
|
|
skipped_count += 1
|
|
|
|
# 4. Handle direct qkv_proj.* only when split q/k/v is absent for that layer.
|
|
if self._use_fused_qkv and qkv_direct_pending:
|
|
for (layer_key, is_bias), (weight_name, weight) in qkv_direct_pending.items():
|
|
if layer_key in qkv_split_layers:
|
|
logger.info(
|
|
f"Skip direct qkv {'bias' if is_bias else 'weight'} for {layer_key}: "
|
|
"split q/k/v shards are present."
|
|
)
|
|
continue
|
|
|
|
resolved = resolve_param_name(weight_name)
|
|
if resolved:
|
|
try:
|
|
load_param(resolved, weight)
|
|
loaded_count += 1
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load direct fused qkv {weight_name} -> {resolved}: {e}")
|
|
skipped_count += 1
|
|
else:
|
|
logger.warning(f"Direct fused qkv param not found: {weight_name}")
|
|
skipped_count += 1
|
|
|
|
if self._use_fused_qkv:
|
|
pending_qkv_params = [
|
|
name for name, param in params_dict.items() if bool(getattr(param, "_pf_qkv_pending", False))
|
|
]
|
|
if pending_qkv_params:
|
|
raise RuntimeError(
|
|
"Incomplete QKV shard loading detected for parameters: " + ", ".join(sorted(pending_qkv_params))
|
|
)
|
|
|
|
logger.info(f"Weight loading: {loaded_count} loaded, {skipped_count} skipped")
|
|
|
|
# === tie_word_embeddings handling ===
|
|
if hasattr(self, "lm_head") and getattr(self, "tie_word_embeddings", False):
|
|
embed = self.model.get_input_embeddings()
|
|
if hasattr(embed, "embeddings") and hasattr(embed.embeddings, "weight"):
|
|
self.lm_head.linear.weight.set_value(embed.embeddings.weight.T)
|
|
else:
|
|
logger.warning("tie_word_embeddings=True but embed_tokens.embeddings.weight not found!")
|