mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
refactor pt loading (#4532)
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
This commit is contained in:
@@ -25,6 +25,8 @@ from fastdeploy.distributed.communication import tensor_model_parallel_all_reduc
|
||||
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
|
||||
from fastdeploy.model_executor.utils import (
|
||||
default_weight_loader,
|
||||
h2d_copy,
|
||||
process_weight_transpose,
|
||||
set_weight_attrs,
|
||||
slice_fn,
|
||||
)
|
||||
@@ -43,8 +45,13 @@ class UnquantizedLinearMethod(QuantMethodBase):
|
||||
- output_dim: determines whether the split is applied along the output dimension (rows) or input dimension (columns)
|
||||
- weight_loader: a callable or method responsible for loading the weight data
|
||||
"""
|
||||
self.model_format = extra_weight_attrs.get("model_format")
|
||||
self.weight_shape = (
|
||||
layer.weight_shape[::-1] if extra_weight_attrs.get("model_format") == "torch" else layer.weight_shape
|
||||
)
|
||||
|
||||
layer.weight = layer.create_parameter(
|
||||
shape=layer.weight_shape,
|
||||
shape=self.weight_shape,
|
||||
dtype=layer.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
@@ -52,15 +59,22 @@ class UnquantizedLinearMethod(QuantMethodBase):
|
||||
split_axis = extra_weight_attrs.get("split_axis")
|
||||
if hasattr(layer, "nranks") and layer.nranks > 0:
|
||||
_set_var_distributed(layer.weight, split_axis=split_axis)
|
||||
|
||||
if self.model_format == "torch" and "output_dim" in extra_weight_attrs:
|
||||
extra_weight_attrs["output_dim"] = not extra_weight_attrs["output_dim"]
|
||||
|
||||
set_weight_attrs(
|
||||
layer.weight,
|
||||
{
|
||||
**extra_weight_attrs,
|
||||
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
|
||||
"weight_need_transpose": extra_weight_attrs.get("model_format") == "torch",
|
||||
},
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.model_format == "torch":
|
||||
process_weight_transpose(layer, "weight")
|
||||
|
||||
def process_loaded_weights(self, layer, weights) -> None:
|
||||
# mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation
|
||||
if layer.weight.dtype != weights.dtype:
|
||||
@@ -165,7 +179,7 @@ class LinearBase(nn.Layer):
|
||||
if self.with_bias:
|
||||
self.bias = self.create_parameter(
|
||||
shape=[self.output_size],
|
||||
dtype=self._dtype,
|
||||
dtype=self.weight_dtype,
|
||||
is_bias=True,
|
||||
)
|
||||
setattr(
|
||||
@@ -262,6 +276,7 @@ class ReplicatedLinear(LinearBase):
|
||||
skip_quant: bool = False,
|
||||
weight_dtype: str = "",
|
||||
weight_key: str = "",
|
||||
model_format: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initializes a replicated linear layer.
|
||||
@@ -296,7 +311,7 @@ class ReplicatedLinear(LinearBase):
|
||||
weight_loader=(
|
||||
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
|
||||
),
|
||||
model_format=fd_config.model_config.model_format,
|
||||
model_format=fd_config.model_config.model_format if model_format is None else model_format,
|
||||
)
|
||||
|
||||
|
||||
@@ -344,10 +359,8 @@ class MergedReplicatedLinear(ReplicatedLinear):
|
||||
|
||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
|
||||
if weight_need_transpose:
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
loaded_weight = get_tensor(loaded_weight).transpose([1, 0])
|
||||
|
||||
assert loaded_shard_id in ["q_a", "kv_a"]
|
||||
if not param._is_initialized():
|
||||
@@ -373,7 +386,9 @@ class MergedReplicatedLinear(ReplicatedLinear):
|
||||
loaded_weight = loaded_weight.view(param.dtype)
|
||||
else:
|
||||
loaded_weight = loaded_weight.cast(param.dtype)
|
||||
param.copy_(loaded_weight, False)
|
||||
# (bukejiyu) After this fix, the early H2D copy for non-GPU devices is no longer needed and can be safely removed.
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
h2d_copy(param, loaded_weight)
|
||||
|
||||
|
||||
class ColumnParallelLinear(LinearBase):
|
||||
@@ -393,7 +408,7 @@ class ColumnParallelLinear(LinearBase):
|
||||
with_bias: bool = False,
|
||||
add_bias: bool = False,
|
||||
skip_quant: bool = False,
|
||||
weight_dtype="",
|
||||
weight_dtype: str = "",
|
||||
):
|
||||
"""
|
||||
Initializes a linear layer and provides additional parameters required for inference and quantization.
|
||||
@@ -493,6 +508,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
)
|
||||
|
||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||
# for xpu and other backend
|
||||
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
assert output_dim is not None
|
||||
@@ -522,7 +538,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
# Tensor parallelism splits the weight along the output_dim
|
||||
if self.nranks != 1:
|
||||
if self.nranks > 1 and output_dim is not None:
|
||||
dim = -1 if output_dim else 0
|
||||
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
|
||||
size = loaded_weight.shape[dim]
|
||||
@@ -532,7 +548,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
shard_offset = self.local_rank * block_size
|
||||
shard_size = (self.local_rank + 1) * block_size
|
||||
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size)
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
if not param._is_initialized():
|
||||
param.initialize()
|
||||
param_shard_size = output_size // 2
|
||||
@@ -553,7 +568,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
loaded_weight = loaded_weight.view(param.dtype)
|
||||
else:
|
||||
loaded_weight = loaded_weight.cast(param.dtype)
|
||||
param.copy_(loaded_weight, False)
|
||||
|
||||
h2d_copy(param, loaded_weight)
|
||||
|
||||
def load_state_dict(self, state_dict: dict):
|
||||
"""
|
||||
@@ -589,7 +605,19 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
QKVParallelLinear Layer.
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config, prefix, with_bias=False, add_bias=True):
|
||||
def __init__(
|
||||
self,
|
||||
fd_config,
|
||||
prefix,
|
||||
with_bias=False,
|
||||
add_bias=True,
|
||||
num_heads: Optional[int] = None,
|
||||
kv_num_heads: Optional[int] = None,
|
||||
hidden_size: Optional[int] = None,
|
||||
head_dim: Optional[int] = None,
|
||||
skip_quant: bool = False,
|
||||
weight_dtype: str = "",
|
||||
):
|
||||
"""
|
||||
Initialize the QKV Linear layer with given parameters.
|
||||
|
||||
@@ -599,11 +627,15 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
Can be arbitrarily named.
|
||||
with_bias (bool): Whether to include bias or not. Defaults to False.
|
||||
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to True.
|
||||
num_heads (Optional[int]): Number of attention heads in the model.
|
||||
kv_num_heads (Optional[int]): Number of key/value heads, used for multi-query or grouped-query attention.
|
||||
hidden_size (Optional[int]): Total hidden layer dimension, typically the embedding size.
|
||||
head_dim (Optional[int]): Size of each attention head, usually computed as hidden_size divided by num_heads.
|
||||
"""
|
||||
self.num_heads = fd_config.model_config.num_attention_heads
|
||||
self.kv_num_heads = fd_config.model_config.num_key_value_heads
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.head_dim = fd_config.model_config.head_dim
|
||||
self.num_heads = fd_config.model_config.num_attention_heads if num_heads is None else num_heads
|
||||
self.kv_num_heads = fd_config.model_config.num_key_value_heads if kv_num_heads is None else kv_num_heads
|
||||
self.hidden_size = fd_config.model_config.hidden_size if hidden_size is None else hidden_size
|
||||
self.head_dim = fd_config.model_config.head_dim if head_dim is None else head_dim
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||
self.num_heads_per_rank = divide(self.num_heads, self.nranks)
|
||||
@@ -623,6 +655,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
output_size=output_size,
|
||||
with_bias=with_bias,
|
||||
add_bias=add_bias,
|
||||
skip_quant=skip_quant,
|
||||
weight_dtype=weight_dtype,
|
||||
)
|
||||
|
||||
def _get_shard_size_mapping(self, loaded_shard_id: str, head_dim: int):
|
||||
@@ -664,15 +698,13 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
# Tensor parallelism splits the weight along the output_dim
|
||||
if self.nranks != 1:
|
||||
if self.nranks > 1 and output_dim is not None:
|
||||
block_size = self._get_shard_size_mapping(loaded_shard_id, head_dim)
|
||||
shard_id = self.local_rank if loaded_shard_id == "q" else self.local_rank // self.num_kv_head_replicas
|
||||
shard_offset = shard_id * block_size
|
||||
shard_size = block_size
|
||||
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size)
|
||||
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
|
||||
if not param._is_initialized():
|
||||
param.initialize()
|
||||
|
||||
@@ -700,7 +732,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
loaded_weight = loaded_weight.view(param.dtype)
|
||||
else:
|
||||
loaded_weight = loaded_weight.cast(param.dtype)
|
||||
param.copy_(loaded_weight, False)
|
||||
h2d_copy(param, loaded_weight)
|
||||
|
||||
def load_weight(self, state_dict: dict):
|
||||
"""
|
||||
@@ -798,7 +830,7 @@ class RowParallelLinear(LinearBase):
|
||||
add_bias: bool = False,
|
||||
reduce_results: bool = True,
|
||||
skip_quant: bool = False,
|
||||
weight_dtype="",
|
||||
weight_dtype: str = "",
|
||||
layer_id: int = -1,
|
||||
):
|
||||
"""
|
||||
@@ -857,10 +889,6 @@ class RowParallelLinear(LinearBase):
|
||||
),
|
||||
model_format=fd_config.model_config.model_format,
|
||||
)
|
||||
if self.nranks > 0:
|
||||
if self.with_bias:
|
||||
# col parallel
|
||||
_set_var_distributed(self.bias, split_axis=0)
|
||||
|
||||
self.reduce_results = reduce_results
|
||||
|
||||
|
||||
Reference in New Issue
Block a user