refactor pt loading (#4532)
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled

This commit is contained in:
bukejiyu
2025-11-11 21:30:39 +08:00
committed by GitHub
parent 4c911ecb74
commit b09ebb2813
35 changed files with 1094 additions and 797 deletions
+54 -26
View File
@@ -25,6 +25,8 @@ from fastdeploy.distributed.communication import tensor_model_parallel_all_reduc
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
from fastdeploy.model_executor.utils import (
default_weight_loader,
h2d_copy,
process_weight_transpose,
set_weight_attrs,
slice_fn,
)
@@ -43,8 +45,13 @@ class UnquantizedLinearMethod(QuantMethodBase):
- output_dim: determines whether the split is applied along the output dimension (rows) or input dimension (columns)
- weight_loader: a callable or method responsible for loading the weight data
"""
self.model_format = extra_weight_attrs.get("model_format")
self.weight_shape = (
layer.weight_shape[::-1] if extra_weight_attrs.get("model_format") == "torch" else layer.weight_shape
)
layer.weight = layer.create_parameter(
shape=layer.weight_shape,
shape=self.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
@@ -52,15 +59,22 @@ class UnquantizedLinearMethod(QuantMethodBase):
split_axis = extra_weight_attrs.get("split_axis")
if hasattr(layer, "nranks") and layer.nranks > 0:
_set_var_distributed(layer.weight, split_axis=split_axis)
if self.model_format == "torch" and "output_dim" in extra_weight_attrs:
extra_weight_attrs["output_dim"] = not extra_weight_attrs["output_dim"]
set_weight_attrs(
layer.weight,
{
**extra_weight_attrs,
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
"weight_need_transpose": extra_weight_attrs.get("model_format") == "torch",
},
)
def process_weights_after_loading(self, layer):
if self.model_format == "torch":
process_weight_transpose(layer, "weight")
def process_loaded_weights(self, layer, weights) -> None:
# mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation
if layer.weight.dtype != weights.dtype:
@@ -165,7 +179,7 @@ class LinearBase(nn.Layer):
if self.with_bias:
self.bias = self.create_parameter(
shape=[self.output_size],
dtype=self._dtype,
dtype=self.weight_dtype,
is_bias=True,
)
setattr(
@@ -262,6 +276,7 @@ class ReplicatedLinear(LinearBase):
skip_quant: bool = False,
weight_dtype: str = "",
weight_key: str = "",
model_format: Optional[str] = None,
):
"""
Initializes a replicated linear layer.
@@ -296,7 +311,7 @@ class ReplicatedLinear(LinearBase):
weight_loader=(
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
),
model_format=fd_config.model_config.model_format,
model_format=fd_config.model_config.model_format if model_format is None else model_format,
)
@@ -344,10 +359,8 @@ class MergedReplicatedLinear(ReplicatedLinear):
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
weight_need_transpose = getattr(param, "weight_need_transpose", False)
loaded_weight = get_tensor(loaded_weight)
if weight_need_transpose:
loaded_weight = loaded_weight.transpose([1, 0])
loaded_weight = get_tensor(loaded_weight).transpose([1, 0])
assert loaded_shard_id in ["q_a", "kv_a"]
if not param._is_initialized():
@@ -373,7 +386,9 @@ class MergedReplicatedLinear(ReplicatedLinear):
loaded_weight = loaded_weight.view(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
param.copy_(loaded_weight, False)
# (bukejiyu) After this fix, the early H2D copy for non-GPU devices is no longer needed and can be safely removed.
loaded_weight = get_tensor(loaded_weight)
h2d_copy(param, loaded_weight)
class ColumnParallelLinear(LinearBase):
@@ -393,7 +408,7 @@ class ColumnParallelLinear(LinearBase):
with_bias: bool = False,
add_bias: bool = False,
skip_quant: bool = False,
weight_dtype="",
weight_dtype: str = "",
):
"""
Initializes a linear layer and provides additional parameters required for inference and quantization.
@@ -493,6 +508,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
)
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
# for xpu and other backend
weight_need_transpose = getattr(param, "weight_need_transpose", False)
output_dim = getattr(param, "output_dim", None)
assert output_dim is not None
@@ -522,7 +538,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
# Tensor parallelism splits the weight along the output_dim
if self.nranks != 1:
if self.nranks > 1 and output_dim is not None:
dim = -1 if output_dim else 0
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
size = loaded_weight.shape[dim]
@@ -532,7 +548,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset = self.local_rank * block_size
shard_size = (self.local_rank + 1) * block_size
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size)
loaded_weight = get_tensor(loaded_weight)
if not param._is_initialized():
param.initialize()
param_shard_size = output_size // 2
@@ -553,7 +568,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
loaded_weight = loaded_weight.view(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
param.copy_(loaded_weight, False)
h2d_copy(param, loaded_weight)
def load_state_dict(self, state_dict: dict):
"""
@@ -589,7 +605,19 @@ class QKVParallelLinear(ColumnParallelLinear):
QKVParallelLinear Layer.
"""
def __init__(self, fd_config, prefix, with_bias=False, add_bias=True):
def __init__(
self,
fd_config,
prefix,
with_bias=False,
add_bias=True,
num_heads: Optional[int] = None,
kv_num_heads: Optional[int] = None,
hidden_size: Optional[int] = None,
head_dim: Optional[int] = None,
skip_quant: bool = False,
weight_dtype: str = "",
):
"""
Initialize the QKV Linear layer with given parameters.
@@ -599,11 +627,15 @@ class QKVParallelLinear(ColumnParallelLinear):
Can be arbitrarily named.
with_bias (bool): Whether to include bias or not. Defaults to False.
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to True.
num_heads (Optional[int]): Number of attention heads in the model.
kv_num_heads (Optional[int]): Number of key/value heads, used for multi-query or grouped-query attention.
hidden_size (Optional[int]): Total hidden layer dimension, typically the embedding size.
head_dim (Optional[int]): Size of each attention head, usually computed as hidden_size divided by num_heads.
"""
self.num_heads = fd_config.model_config.num_attention_heads
self.kv_num_heads = fd_config.model_config.num_key_value_heads
self.hidden_size = fd_config.model_config.hidden_size
self.head_dim = fd_config.model_config.head_dim
self.num_heads = fd_config.model_config.num_attention_heads if num_heads is None else num_heads
self.kv_num_heads = fd_config.model_config.num_key_value_heads if kv_num_heads is None else kv_num_heads
self.hidden_size = fd_config.model_config.hidden_size if hidden_size is None else hidden_size
self.head_dim = fd_config.model_config.head_dim if head_dim is None else head_dim
self.nranks = fd_config.parallel_config.tensor_parallel_size
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
self.num_heads_per_rank = divide(self.num_heads, self.nranks)
@@ -623,6 +655,8 @@ class QKVParallelLinear(ColumnParallelLinear):
output_size=output_size,
with_bias=with_bias,
add_bias=add_bias,
skip_quant=skip_quant,
weight_dtype=weight_dtype,
)
def _get_shard_size_mapping(self, loaded_shard_id: str, head_dim: int):
@@ -664,15 +698,13 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
# Tensor parallelism splits the weight along the output_dim
if self.nranks != 1:
if self.nranks > 1 and output_dim is not None:
block_size = self._get_shard_size_mapping(loaded_shard_id, head_dim)
shard_id = self.local_rank if loaded_shard_id == "q" else self.local_rank // self.num_kv_head_replicas
shard_offset = shard_id * block_size
shard_size = block_size
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size)
loaded_weight = get_tensor(loaded_weight)
if not param._is_initialized():
param.initialize()
@@ -700,7 +732,7 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_weight = loaded_weight.view(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
param.copy_(loaded_weight, False)
h2d_copy(param, loaded_weight)
def load_weight(self, state_dict: dict):
"""
@@ -798,7 +830,7 @@ class RowParallelLinear(LinearBase):
add_bias: bool = False,
reduce_results: bool = True,
skip_quant: bool = False,
weight_dtype="",
weight_dtype: str = "",
layer_id: int = -1,
):
"""
@@ -857,10 +889,6 @@ class RowParallelLinear(LinearBase):
),
model_format=fd_config.model_config.model_format,
)
if self.nranks > 0:
if self.with_bias:
# col parallel
_set_var_distributed(self.bias, split_axis=0)
self.reduce_results = reduce_results