mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 01:29:57 +08:00
fix paddleformers fallback (#6465)
This commit is contained in:
@@ -33,6 +33,7 @@ from fastdeploy.model_executor.graph_optimization.decorator import (
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastdeploy.config import FDConfig
|
||||
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
||||
from fastdeploy.model_executor.layers.linear import (
|
||||
@@ -108,37 +109,85 @@ def fastdeploy_append_attention_forward(
|
||||
if scaling is not None:
|
||||
self_attn.scale = float(scaling)
|
||||
|
||||
# query shape is either [1, H, S, D] or [S, H, D]
|
||||
seq_len = query.shape[-2] if query.ndim == 4 else query.shape[0]
|
||||
# 统一获取 heads 信息
|
||||
num_heads = (
|
||||
getattr(module, "num_heads", None)
|
||||
or getattr(config, "num_attention_heads", None)
|
||||
or getattr(self_attn, "num_heads", None)
|
||||
)
|
||||
num_kv_heads = (
|
||||
getattr(module, "num_key_value_heads", None)
|
||||
or getattr(config, "num_key_value_heads", None)
|
||||
or getattr(self_attn, "num_key_value_heads", None)
|
||||
or getattr(self_attn, "kv_num_heads", None)
|
||||
or num_heads
|
||||
)
|
||||
num_heads = int(num_heads) if num_heads is not None else None
|
||||
num_kv_heads = int(num_kv_heads) if num_kv_heads is not None else None
|
||||
|
||||
def flatten_to_sd(t: paddle.Tensor, name: str) -> paddle.Tensor:
|
||||
"""[B, H, S, D] -> [S, H*D] for FD attention"""
|
||||
# 仅支持 3D(HSD/SHD) 或 4D(BHSD/BSHD, 且 B=1) 输入
|
||||
def squeeze_to_3d(t: paddle.Tensor, name: str) -> paddle.Tensor:
|
||||
if t.ndim == 4:
|
||||
if int(t.shape[0]) != 1:
|
||||
raise ValueError(f"{name} batch size {int(t.shape[0])} not supported")
|
||||
return t.squeeze(0)
|
||||
if t.ndim == 3:
|
||||
return t.reshape([t.shape[0], -1])
|
||||
if t.ndim != 4:
|
||||
raise ValueError(f"{name} has unexpected dims {t.ndim}, expect 3 or 4")
|
||||
return t
|
||||
raise ValueError(f"{name} has unexpected dims {t.ndim}, expect 3 or 4")
|
||||
|
||||
batch, dim1, dim2, dim3 = t.shape
|
||||
if batch != 1:
|
||||
raise ValueError(f"{name} batch size {batch} not supported")
|
||||
q = squeeze_to_3d(query, "query")
|
||||
k = squeeze_to_3d(key, "key")
|
||||
v = squeeze_to_3d(value, "value")
|
||||
|
||||
squeezed = t.squeeze(0) # [dim1, dim2, dim3]
|
||||
def heads_match(actual_heads: int, expected_heads: int | None) -> bool:
|
||||
if expected_heads is None:
|
||||
return False
|
||||
if actual_heads == expected_heads:
|
||||
return True
|
||||
return actual_heads > 0 and expected_heads % actual_heads == 0
|
||||
|
||||
if dim2 == seq_len:
|
||||
# [H, S, D] -> transpose to [S, H, D] -> reshape [S, H*D]
|
||||
return squeezed.transpose([1, 0, 2]).reshape([seq_len, -1])
|
||||
elif dim1 == seq_len:
|
||||
# [S, H, D] -> reshape [S, H*D]
|
||||
return squeezed.reshape([seq_len, -1])
|
||||
else:
|
||||
# Fallback: assume [H, S, D] format
|
||||
return squeezed.transpose([1, 0, 2]).reshape([seq_len, -1])
|
||||
# 使用 Q/K 共同判断布局;歧义时默认 hsd(兼容 Paddle 常见路径)
|
||||
is_hsd = (
|
||||
heads_match(int(q.shape[0]), num_heads)
|
||||
and heads_match(int(k.shape[0]), num_kv_heads)
|
||||
and heads_match(int(v.shape[0]), num_kv_heads)
|
||||
)
|
||||
is_shd = (
|
||||
heads_match(int(q.shape[1]), num_heads)
|
||||
and heads_match(int(k.shape[1]), num_kv_heads)
|
||||
and heads_match(int(v.shape[1]), num_kv_heads)
|
||||
)
|
||||
|
||||
if is_hsd:
|
||||
q_flat = q.transpose([1, 0, 2]).reshape([int(q.shape[1]), -1])
|
||||
k_flat = k.transpose([1, 0, 2]).reshape([int(k.shape[1]), -1])
|
||||
v_flat = v.transpose([1, 0, 2]).reshape([int(v.shape[1]), -1])
|
||||
elif is_shd:
|
||||
q_flat = q.reshape([int(q.shape[0]), -1])
|
||||
k_flat = k.reshape([int(k.shape[0]), -1])
|
||||
v_flat = v.reshape([int(v.shape[0]), -1])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid attention layout: q={list(q.shape)}, k={list(k.shape)}, v={list(v.shape)}, "
|
||||
f"heads={num_heads}/{num_kv_heads}"
|
||||
)
|
||||
|
||||
# Q/K/V flatten 后序列长度必须一致
|
||||
q_seq, k_seq, v_seq = int(q_flat.shape[0]), int(k_flat.shape[0]), int(v_flat.shape[0])
|
||||
if not (q_seq == k_seq == v_seq):
|
||||
raise ValueError(
|
||||
f"Sequence length mismatch after flattening: Q={q_seq}, K={k_seq}, V={v_seq}, "
|
||||
f"raw query={list(query.shape)}, key={list(key.shape)}, value={list(value.shape)}."
|
||||
)
|
||||
|
||||
# 若 forward_meta 带了 ids_remove_padding,则强校验 Q 序列长度
|
||||
ids_remove_padding = getattr(forward_meta, "ids_remove_padding", None)
|
||||
if ids_remove_padding is not None:
|
||||
expected_seq = int(ids_remove_padding.shape[0])
|
||||
if q_seq != expected_seq:
|
||||
raise ValueError(f"Seq len mismatch: got {q_seq}, expect {expected_seq}")
|
||||
|
||||
q_flat = flatten_to_sd(query, "query")
|
||||
k_flat = flatten_to_sd(key, "key")
|
||||
v_flat = flatten_to_sd(value, "value")
|
||||
qkv = paddle.concat([q_flat, k_flat, v_flat], axis=-1)
|
||||
|
||||
output = self_attn.forward(qkv=qkv, forward_meta=forward_meta)
|
||||
|
||||
return output, None
|
||||
@@ -599,11 +648,7 @@ class PaddleFormersModelBase(nn.Layer):
|
||||
|
||||
@paddle.no_grad()
|
||||
def load_weights(self, weights: Iterable[tuple[str, paddle.Tensor]]):
|
||||
"""Load weights from checkpoint into model parameters.
|
||||
|
||||
Using FD native pattern: iterate weights and use param.weight_loader()
|
||||
for each FD layer (handles shape conversion automatically).
|
||||
"""
|
||||
"""Load weights from checkpoint into model parameters."""
|
||||
from fastdeploy.model_executor.utils import (
|
||||
default_weight_loader,
|
||||
process_weights_after_loading,
|
||||
@@ -613,194 +658,176 @@ class PaddleFormersModelBase(nn.Layer):
|
||||
process_fn = process_weights_after_loading(sublayers_dict, self.fd_config)
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
# Weight name mapping: HF name -> FD param name + shard_id
|
||||
# === 前缀别名处理 ===
|
||||
model_type = str(getattr(self.paddleformers_config, "model_type", "") or "").lower()
|
||||
ckpt_prefix_aliases = {model_type, model_type.replace("-", "_"), model_type.replace("_", "")} - {""}
|
||||
ckpt_alias_markers = (".layers.", ".embed_tokens.", ".lm_head.", ".norm.", ".final_layernorm.", ".rotary_emb.")
|
||||
|
||||
def resolve_param_name(weight_name: str) -> str | None:
|
||||
# 动态收集前缀别名
|
||||
if "." in weight_name:
|
||||
prefix = weight_name.split(".", 1)[0]
|
||||
if prefix not in {"model", "lm_head"} and any(m in weight_name for m in ckpt_alias_markers):
|
||||
ckpt_prefix_aliases.add(prefix)
|
||||
|
||||
# 生成候选名称
|
||||
candidates = [weight_name]
|
||||
candidates.append(weight_name[6:] if weight_name.startswith("model.") else "model." + weight_name)
|
||||
if "." in weight_name:
|
||||
prefix, rest = weight_name.split(".", 1)
|
||||
if prefix in ckpt_prefix_aliases:
|
||||
candidates.extend([rest, "model." + rest])
|
||||
|
||||
return next((c for c in candidates if c in params_dict), None)
|
||||
|
||||
# === 权重映射配置 ===
|
||||
stacked_params_mapping = [
|
||||
# Embeddings and lm_head (same as native)
|
||||
("embed_tokens.embeddings", "embed_tokens", None),
|
||||
("lm_head.linear", "lm_head", None),
|
||||
]
|
||||
|
||||
# Add gate+up fusion mapping if enabled
|
||||
if self._use_fused_ffn:
|
||||
stacked_params_mapping.extend(
|
||||
[
|
||||
("up_gate_proj", "gate_proj", "gate"),
|
||||
("up_gate_proj", "up_proj", "up"),
|
||||
]
|
||||
stacked_params_mapping += [("up_gate_proj", "gate_proj", "gate"), ("up_gate_proj", "up_proj", "up")]
|
||||
|
||||
# === QKV 融合相关 ===
|
||||
mc = self.fd_config.model_config
|
||||
model_format = str(getattr(mc, "model_format", "") or "").lower()
|
||||
qkv_buffer, qkv_bias_buffer = {}, {}
|
||||
|
||||
def parse_qkv_name(name: str) -> tuple[str, str, str] | None:
|
||||
for proj, ptype in [("q_proj", "q"), ("k_proj", "k"), ("v_proj", "v")]:
|
||||
if proj in name:
|
||||
layer_key = name.replace(f".{proj}.weight", "").replace(f".{proj}.bias", "")
|
||||
return layer_key, ptype, name.replace(proj, "qkv_proj")
|
||||
return None
|
||||
|
||||
def fuse_qkv(q, k, v, is_bias: bool) -> paddle.Tensor:
|
||||
num_heads, num_kv_heads, head_dim = mc.num_attention_heads, mc.num_key_value_heads, mc.head_dim
|
||||
hidden_size, num_kv_groups = mc.hidden_size, num_heads // num_kv_heads
|
||||
q_out, kv_out = num_heads * head_dim, num_kv_heads * head_dim
|
||||
|
||||
if is_bias:
|
||||
# 校验 bias 维度和形状
|
||||
if q.ndim != 1 or k.ndim != 1 or v.ndim != 1:
|
||||
raise ValueError(f"Unexpected qkv bias dims: q={q.shape}, k={k.shape}, v={v.shape}; expected 1D.")
|
||||
if q.shape[0] != q_out or k.shape[0] != kv_out or v.shape[0] != kv_out:
|
||||
raise ValueError(
|
||||
f"Unexpected qkv bias shapes: q={q.shape}, k={k.shape}, v={v.shape}; "
|
||||
f"expected q=[{q_out}], k/v=[{kv_out}]."
|
||||
)
|
||||
return paddle.concat(
|
||||
[
|
||||
q.reshape([num_kv_heads, num_kv_groups, head_dim]),
|
||||
k.reshape([num_kv_heads, 1, head_dim]),
|
||||
v.reshape([num_kv_heads, 1, head_dim]),
|
||||
],
|
||||
axis=1,
|
||||
).reshape([-1])
|
||||
|
||||
# 校验 weight 形状和 model_format
|
||||
q_shape, k_shape, v_shape = [int(x) for x in q.shape], [int(x) for x in k.shape], [int(x) for x in v.shape]
|
||||
torch_layout = (
|
||||
q_shape == [q_out, hidden_size]
|
||||
and k_shape == [kv_out, hidden_size]
|
||||
and v_shape == [kv_out, hidden_size]
|
||||
)
|
||||
paddle_layout = (
|
||||
q_shape == [hidden_size, q_out]
|
||||
and k_shape == [hidden_size, kv_out]
|
||||
and v_shape == [hidden_size, kv_out]
|
||||
)
|
||||
|
||||
loaded_count = 0
|
||||
skipped_count = 0
|
||||
if model_format == "torch":
|
||||
if not torch_layout:
|
||||
raise ValueError(
|
||||
f"model_format=torch requires torch layout, got q={q_shape}, k={k_shape}, v={v_shape}."
|
||||
)
|
||||
q, k, v = q.T, k.T, v.T
|
||||
elif model_format == "paddle":
|
||||
if not paddle_layout:
|
||||
raise ValueError(
|
||||
f"model_format=paddle requires paddle layout, got q={q_shape}, k={k_shape}, v={v_shape}."
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model_format: {model_format}. Expect 'torch' or 'paddle'.")
|
||||
|
||||
# QKV weight fusion buffer for fused QKV mode
|
||||
# Collect q/k/v weights per layer and fuse them into PaddleFormers' per-KV-head interleaved format
|
||||
qkv_buffer = {} # layer_key -> {"q": weight, "k": weight, "v": weight}
|
||||
|
||||
def parse_qkv_weight_name(weight_name):
|
||||
"""Parse q/k/v_proj weight name to extract layer key and proj type."""
|
||||
for proj, proj_type in [("q_proj", "q"), ("k_proj", "k"), ("v_proj", "v")]:
|
||||
if proj in weight_name:
|
||||
# Extract layer key (e.g., "model.layers.0.self_attn")
|
||||
layer_key = weight_name.replace(f".{proj}.weight", "")
|
||||
layer_key = layer_key.replace(f".{proj}.bias", "")
|
||||
qkv_param_name = weight_name.replace(proj, "qkv_proj")
|
||||
return layer_key, proj_type, qkv_param_name
|
||||
return None, None, None
|
||||
|
||||
def fuse_qkv_weights_for_paddleformers(q_weight, k_weight, v_weight):
|
||||
"""Fuse q/k/v weights to PaddleFormers' per-KV-head interleaved format.
|
||||
|
||||
PaddleFormers format: [Q_group0|K0|V0 | Q_group1|K1|V1 | ...]
|
||||
where Q_group has (num_heads // num_kv_heads) heads.
|
||||
|
||||
Note: Checkpoint weights may be stored as [out, in] (transposed) or [in, out].
|
||||
We detect this by checking dimensions against config.
|
||||
|
||||
Args:
|
||||
q_weight: [hidden_size, num_heads * head_dim] or transposed
|
||||
k_weight: [hidden_size, num_kv_heads * head_dim] or transposed
|
||||
v_weight: [hidden_size, num_kv_heads * head_dim] or transposed
|
||||
|
||||
Returns:
|
||||
fused_weight: [hidden_size, num_kv_heads * (num_kv_groups + 2) * head_dim]
|
||||
"""
|
||||
mc = self.fd_config.model_config
|
||||
hidden_size = mc.hidden_size
|
||||
num_heads = mc.num_attention_heads
|
||||
num_kv_heads = mc.num_key_value_heads
|
||||
head_dim = mc.head_dim
|
||||
num_kv_groups = num_heads // num_kv_heads
|
||||
|
||||
q_expected_out = num_heads * head_dim
|
||||
|
||||
# Detect and handle transposed weights (safetensors often stores [out, in])
|
||||
if q_weight.shape[0] == q_expected_out and q_weight.shape[1] == hidden_size:
|
||||
q_weight = q_weight.T
|
||||
k_weight = k_weight.T
|
||||
v_weight = v_weight.T
|
||||
elif q_weight.shape[0] != hidden_size or q_weight.shape[1] != q_expected_out:
|
||||
# 转置后校验
|
||||
if q.shape[0] != hidden_size or k.shape[0] != hidden_size or v.shape[0] != hidden_size:
|
||||
raise ValueError(
|
||||
f"Unexpected q_weight shape {q_weight.shape}, expected [{hidden_size}, {q_expected_out}] or [{q_expected_out}, {hidden_size}]"
|
||||
f"QKV shape mismatch after normalization: q={list(q.shape)}, k={list(k.shape)}, v={list(v.shape)}."
|
||||
)
|
||||
|
||||
# Reshape for GQA interleaving: Q [hidden, num_kv_heads, num_kv_groups, head_dim]
|
||||
q_reshaped = q_weight.reshape([hidden_size, num_kv_heads, num_kv_groups, head_dim])
|
||||
k_reshaped = k_weight.reshape([hidden_size, num_kv_heads, 1, head_dim])
|
||||
v_reshaped = v_weight.reshape([hidden_size, num_kv_heads, 1, head_dim])
|
||||
fused = paddle.concat(
|
||||
[
|
||||
q.reshape([hidden_size, num_kv_heads, num_kv_groups, head_dim]),
|
||||
k.reshape([hidden_size, num_kv_heads, 1, head_dim]),
|
||||
v.reshape([hidden_size, num_kv_heads, 1, head_dim]),
|
||||
],
|
||||
axis=2,
|
||||
).reshape([hidden_size, -1])
|
||||
|
||||
# Interleave to PaddleFormers format: [Q_group|K|V] per KV head
|
||||
fused = paddle.concat([q_reshaped, k_reshaped, v_reshaped], axis=2)
|
||||
fused = fused.reshape([hidden_size, -1])
|
||||
return fused
|
||||
return fused.T if model_format == "torch" else fused
|
||||
|
||||
for loaded_weight_name, loaded_weight in weights:
|
||||
# Handle QKV weight loading: collect q/k/v and fuse when all 3 are ready
|
||||
# Only when fused QKV is enabled for this model
|
||||
if self._use_fused_qkv:
|
||||
layer_key, proj_type, qkv_param_name = parse_qkv_weight_name(loaded_weight_name)
|
||||
if layer_key is not None and ".weight" in loaded_weight_name:
|
||||
# Collect this weight
|
||||
if layer_key not in qkv_buffer:
|
||||
qkv_buffer[layer_key] = {}
|
||||
qkv_buffer[layer_key][proj_type] = loaded_weight
|
||||
|
||||
# Check if all 3 (q, k, v) are collected
|
||||
if len(qkv_buffer[layer_key]) == 3:
|
||||
# Fuse and load
|
||||
fused_weight = fuse_qkv_weights_for_paddleformers(
|
||||
qkv_buffer[layer_key]["q"], qkv_buffer[layer_key]["k"], qkv_buffer[layer_key]["v"]
|
||||
)
|
||||
|
||||
# Find qkv_proj param
|
||||
if qkv_param_name not in params_dict:
|
||||
if "model." + qkv_param_name in params_dict:
|
||||
qkv_param_name = "model." + qkv_param_name
|
||||
elif qkv_param_name.startswith("model.") and qkv_param_name[6:] in params_dict:
|
||||
qkv_param_name = qkv_param_name[6:]
|
||||
|
||||
if qkv_param_name in params_dict:
|
||||
param = params_dict[qkv_param_name]
|
||||
|
||||
# Check if param is in torch format [out, in] vs paddle format [in, out]
|
||||
# Fused weight is [in=hidden, out=qkv_out], transpose if param is [out, in]
|
||||
if param.shape[0] != fused_weight.shape[0]:
|
||||
fused_weight = fused_weight.T
|
||||
|
||||
# Disable weight_need_transpose since we've already handled transpose
|
||||
if hasattr(param, "weight_need_transpose"):
|
||||
param.weight_need_transpose = False
|
||||
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
|
||||
weight_loader(param, fused_weight, None) # No shard_id, full fused weight
|
||||
|
||||
# Post-process the loaded weight (for torch format transpose, quantization, etc.)
|
||||
model_sublayer_name = re.sub(r"\.(weight|bias)$", "", qkv_param_name)
|
||||
process_fn(model_sublayer_name, param)
|
||||
|
||||
loaded_count += 3 # Count all 3
|
||||
else:
|
||||
logger.warning(f" QKV param {qkv_param_name} not found in params_dict")
|
||||
skipped_count += 3
|
||||
|
||||
# Clear buffer for this layer
|
||||
del qkv_buffer[layer_key]
|
||||
continue
|
||||
|
||||
# Try stacked params mapping first
|
||||
matched = False
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in loaded_weight_name:
|
||||
continue
|
||||
model_param_name = loaded_weight_name.replace(weight_name, param_name)
|
||||
if model_param_name not in params_dict:
|
||||
logger.warning(
|
||||
f" Stacked mapping: {loaded_weight_name} -> {model_param_name} NOT FOUND in params_dict!"
|
||||
)
|
||||
continue
|
||||
|
||||
param = params_dict[model_param_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
loaded_count += 1
|
||||
matched = True
|
||||
break
|
||||
|
||||
if matched:
|
||||
continue
|
||||
|
||||
# Direct mapping with "model." prefix normalization
|
||||
model_param_name = loaded_weight_name
|
||||
if model_param_name not in params_dict:
|
||||
model_param_name = "model." + loaded_weight_name
|
||||
if model_param_name not in params_dict and loaded_weight_name.startswith("model."):
|
||||
model_param_name = loaded_weight_name[6:]
|
||||
|
||||
if model_param_name not in params_dict:
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
param = params_dict[model_param_name]
|
||||
# === 辅助函数 ===
|
||||
def load_param(name: str, tensor: paddle.Tensor, shard_id=None, no_transpose: bool = False):
|
||||
param = params_dict[name]
|
||||
if no_transpose and hasattr(param, "weight_need_transpose"):
|
||||
param.weight_need_transpose = False
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
|
||||
weight_loader(param, tensor, shard_id)
|
||||
process_fn(re.sub(r"\.(weight|bias)$", "", name), param)
|
||||
|
||||
try:
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load {model_param_name}: {e}")
|
||||
skipped_count += 1
|
||||
# === 主循环 ===
|
||||
loaded_count = skipped_count = 0
|
||||
|
||||
# Post-process (for quantization etc)
|
||||
model_sublayer_name = re.sub(r"\.(weight|bias)$", "", model_param_name)
|
||||
process_fn(model_sublayer_name, param)
|
||||
for weight_name, weight in weights:
|
||||
# 1. QKV 融合处理
|
||||
if self._use_fused_qkv and (qkv_info := parse_qkv_name(weight_name)):
|
||||
layer_key, proj_type, qkv_param_name = qkv_info
|
||||
is_bias = ".bias" in weight_name
|
||||
buf = qkv_bias_buffer if is_bias else qkv_buffer
|
||||
buf.setdefault(layer_key, {})[proj_type] = weight
|
||||
|
||||
logger.info(f"Weight loading completed: {loaded_count} loaded, {skipped_count} skipped")
|
||||
if len(buf[layer_key]) == 3:
|
||||
resolved = resolve_param_name(qkv_param_name)
|
||||
if resolved:
|
||||
fused = fuse_qkv(buf[layer_key]["q"], buf[layer_key]["k"], buf[layer_key]["v"], is_bias)
|
||||
load_param(resolved, fused, no_transpose=not is_bias)
|
||||
loaded_count += 3
|
||||
else:
|
||||
logger.warning(f"QKV {'bias ' if is_bias else ''}param {qkv_param_name} not found")
|
||||
skipped_count += 3
|
||||
del buf[layer_key]
|
||||
continue
|
||||
|
||||
if hasattr(self, "lm_head"):
|
||||
if hasattr(self, "tie_word_embeddings") and self.tie_word_embeddings:
|
||||
embed_weight = self.model.get_input_embeddings()
|
||||
if hasattr(embed_weight, "embeddings") and hasattr(embed_weight.embeddings, "weight"):
|
||||
embed_tensor = embed_weight.embeddings.weight
|
||||
lm_head_weight = embed_tensor.T
|
||||
self.lm_head.linear.weight.set_value(lm_head_weight)
|
||||
# 2. Stacked params mapping
|
||||
for param_name, src_name, shard_id in stacked_params_mapping:
|
||||
if src_name in weight_name:
|
||||
resolved = resolve_param_name(weight_name.replace(src_name, param_name))
|
||||
if resolved:
|
||||
load_param(resolved, weight, shard_id)
|
||||
loaded_count += 1
|
||||
else:
|
||||
logger.warning(f"Stacked mapping: {weight_name} -> NOT FOUND")
|
||||
break
|
||||
else:
|
||||
# 3. 直接加载
|
||||
resolved = resolve_param_name(weight_name)
|
||||
if resolved:
|
||||
try:
|
||||
load_param(resolved, weight)
|
||||
loaded_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load {resolved}: {e}")
|
||||
skipped_count += 1
|
||||
else:
|
||||
logger.warning("tie_word_embeddings=True but embed_tokens.embeddings.weight not found!")
|
||||
skipped_count += 1
|
||||
|
||||
logger.info(f"Weight loading: {loaded_count} loaded, {skipped_count} skipped")
|
||||
|
||||
# === tie_word_embeddings 处理 ===
|
||||
if hasattr(self, "lm_head") and getattr(self, "tie_word_embeddings", False):
|
||||
embed = self.model.get_input_embeddings()
|
||||
if hasattr(embed, "embeddings") and hasattr(embed.embeddings, "weight"):
|
||||
self.lm_head.linear.weight.set_value(embed.embeddings.weight.T)
|
||||
else:
|
||||
logger.warning("tie_word_embeddings=True but embed_tokens.embeddings.weight not found!")
|
||||
|
||||
Reference in New Issue
Block a user