[V1 Loader] support weight_only (#3413)
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled

* support wint4/wint8

* delete smoe case

* update ci

* print log
This commit is contained in:
bukejiyu
2025-08-23 13:13:41 +08:00
committed by GitHub
parent 93e1b63200
commit 77514e3e1e
24 changed files with 1055 additions and 524 deletions
+81 -117
View File
@@ -23,7 +23,7 @@ from paddle import nn
from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
from fastdeploy.model_executor.models.utils import (
from fastdeploy.model_executor.utils import (
default_weight_loader,
set_weight_attrs,
slice_fn,
@@ -39,6 +39,7 @@ class UnquantizedLinearMethod(QuantMethodBase):
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
"""
extra_weight_attrs is a dictionary that may include parameters like:
- split_axis: axis along which to split the tensor in a distributed environment
- output_dim: determines whether the split is applied along the output dimension (rows) or input dimension (columns)
- weight_loader: a callable or method responsible for loading the weight data
"""
@@ -48,12 +49,16 @@ class UnquantizedLinearMethod(QuantMethodBase):
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
split_axis = extra_weight_attrs.get("split_axis")
if hasattr(layer, "nranks") and layer.nranks > 0:
_set_var_distributed(layer.weight, split_axis=split_axis)
set_weight_attrs(
layer.weight,
{"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config))},
{
**extra_weight_attrs,
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
},
)
if hasattr(layer, "nranks") and layer.nranks > 1:
set_weight_attrs(layer.weight, {"output_dim": extra_weight_attrs.get("output_dim")})
def process_loaded_weights(self, layer, weights) -> None:
# mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation
@@ -340,7 +345,6 @@ class ColumnParallelLinear(LinearBase):
),
)
if self.nranks > 0:
_set_var_distributed(self.weight, split_axis=1)
if self.with_bias:
# col parallel
_set_var_distributed(self.bias, split_axis=1)
@@ -399,28 +403,27 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
)
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
output_dim = getattr(param, "output_dim", None)
shard_dim = -1 if output_dim else 0
output_size = param.shape[shard_dim]
if loaded_shard_id is None:
# Loaded weight is already fused on disk.
if self.nranks != 1:
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("gate", 0, self.output_size * self.nranks // 2),
("up", self.output_size * self.nranks // 2, self.output_size * self.nranks // 2),
]
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = loaded_weight[..., shard_offset : shard_offset + shard_size]
self.weight_loader(param, loaded_weight_shard, shard_id)
else:
loaded_weight = get_tensor(loaded_weight)
param.copy_(loaded_weight, False)
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("gate", 0, output_size * self.nranks // 2),
("up", output_size * self.nranks // 2, output_size * self.nranks // 2),
]
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = slice_fn(
loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size
)
self.weight_loader(param, loaded_weight_shard, shard_id)
else:
# 1.fused gate_up in disk
# 2.split gate up
# split gate up
assert loaded_shard_id in ["gate", "up"]
output_dim = getattr(param, "output_dim", None)
# Tensor parallelism splits the weight along the output_dim
if output_dim is not None:
dim = -1
if self.nranks != 1:
dim = -1 if output_dim else 0
if isinstance(loaded_weight, np.ndarray):
size = loaded_weight.shape[dim]
else:
@@ -428,15 +431,20 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
block_size = size // self.nranks
shard_offset = self.local_rank * block_size
shard_size = (self.local_rank + 1) * block_size
loaded_weight = loaded_weight[..., shard_offset:shard_size]
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size)
loaded_weight = get_tensor(loaded_weight)
if not param._is_initialized():
param.initialize()
param_shard_size = output_size // 2
if loaded_shard_id == "gate":
param = param[:, : self.output_size // 2]
elif loaded_shard_id == "up":
param = param[:, self.output_size // 2 :]
param_shard_offset = 0
else:
# loaded_shard_id == "up"
param_shard_offset = param_shard_size
if hasattr(param, "tensor_track"):
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
param = slice_fn(param, output_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
@@ -513,30 +521,25 @@ class QKVParallelLinear(ColumnParallelLinear):
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
output_dim = getattr(param, "output_dim", None)
head_dim = param.shape[output_dim] // (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank)
if loaded_shard_id is None:
# Loaded weight is already fused on disk
if self.nranks != 1:
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("q", 0, self.num_heads * self.head_dim),
("k", self.num_heads * self.head_dim, self.kv_num_heads * self.head_dim),
("v", (self.num_heads + self.kv_num_heads) * self.head_dim, self.kv_num_heads * self.head_dim),
]
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = loaded_weight_shard = slice_fn(
loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size
)
self.weight_loader(param, loaded_weight_shard, shard_id)
else:
loaded_weight = get_tensor(loaded_weight)
split_loaded_weight = loaded_weight
param.copy_(split_loaded_weight, False)
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("q", 0, self.num_heads * head_dim),
("k", self.num_heads * head_dim, self.kv_num_heads * head_dim),
("v", (self.num_heads + self.kv_num_heads) * head_dim, self.kv_num_heads * head_dim),
]
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = slice_fn(
loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size
)
self.weight_loader(param, loaded_weight_shard, shard_id)
else:
# 1.fused qkv in disk
# 2.split q k v
# split q k v
assert loaded_shard_id in ["q", "k", "v"]
# Tensor parallelism splits the weight along the output_dim
if output_dim is not None:
if self.nranks != 1:
dim = -1 if output_dim else 0
if isinstance(loaded_weight, np.ndarray):
size = loaded_weight.shape[dim]
@@ -545,20 +548,25 @@ class QKVParallelLinear(ColumnParallelLinear):
block_size = size // self.nranks
shard_offset = self.local_rank * block_size
shard_size = (self.local_rank + 1) * block_size
loaded_weight = loaded_weight[..., shard_offset:shard_size]
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size)
loaded_weight = get_tensor(loaded_weight)
if not param._is_initialized():
param.initialize()
if loaded_shard_id == "q":
param_shard_offset = 0
param_shard_size = self.num_heads_per_rank * self.head_dim
param_shard_size = self.num_heads_per_rank * head_dim
elif loaded_shard_id == "k":
param_shard_offset = self.num_heads_per_rank * self.head_dim
param_shard_size = self.kv_num_heads_per_rank * self.head_dim
param_shard_offset = self.num_heads_per_rank * head_dim
param_shard_size = self.kv_num_heads_per_rank * head_dim
else:
# loaded_shard_id == "v"
param_shard_offset = (self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim
param_shard_size = self.kv_num_heads_per_rank * self.head_dim
param_shard_offset = (self.num_heads_per_rank + self.kv_num_heads_per_rank) * head_dim
param_shard_size = self.kv_num_heads_per_rank * head_dim
if hasattr(param, "tensor_track"):
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
param = slice_fn(param, output_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
@@ -706,7 +714,6 @@ class RowParallelLinear(LinearBase):
),
)
if self.nranks > 0:
_set_var_distributed(self.weight, split_axis=0)
if self.with_bias:
# col parallel
_set_var_distributed(self.bias, split_axis=0)
@@ -732,7 +739,7 @@ class RowParallelLinear(LinearBase):
return out
class KVBatchLinear(LinearBase):
class KVBatchLinear(nn.Layer):
"""
KVBatchLinear Layer for handling combined KV projections with bmm.
"""
@@ -740,13 +747,12 @@ class KVBatchLinear(LinearBase):
def __init__(
self,
fd_config: FDConfig,
kv_b_proj: nn.Layer,
prefix: str = "",
kv_lora_rank: int = None,
num_attention_heads: int = None,
qk_nope_head_dim: int = None,
v_head_dim: int = None,
with_bias: bool = False,
skip_quant: bool = False,
):
"""
Initializes a KV batch linear layer that internally splits into K and V projections.
@@ -761,6 +767,7 @@ class KVBatchLinear(LinearBase):
with_bias (bool): Whether to include bias or not. Defaults to False.
skip_quant (bool): Whether to skip quantization. Defaults to False.
"""
super().__init__()
self.nranks = fd_config.parallel_config.tensor_parallel_size
self.kv_lora_rank = kv_lora_rank
self.num_attention_heads = num_attention_heads
@@ -770,69 +777,27 @@ class KVBatchLinear(LinearBase):
self.num_heads_per_partition = divide(num_attention_heads, self.nranks)
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
# Initialize parent with combined dimensions
super().__init__(
fd_config=fd_config,
prefix=prefix,
input_size=None, # Will be determined from weight shape
output_size=None, # Will be determined from weight shape
with_bias=with_bias,
add_bias=False,
skip_quant=skip_quant,
)
self.weight_dtype = self._dtype
self.kv_b_proj = kv_b_proj
self.weight_dtype = self._helper.get_default_dtype()
# Override weight keys to use the combined kv_b_proj
self.weight_key = f"{prefix}.weight" # e.g., "kv_b_proj.weight"
self.k_weight_key = f"{prefix.replace('kv_b_proj', 'k_b_proj')}.weight"
self.v_weight_key = f"{prefix.replace('kv_b_proj', 'v_b_proj')}.weight"
self.k_b_proj_weight = self.create_parameter(
shape=[self.num_heads_per_partition, self.qk_nope_head_dim, self.kv_lora_rank],
dtype=self.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
def process_weights_after_loading(self):
self.v_b_proj_weight = self.create_parameter(
shape=[self.num_heads_per_partition, self.kv_lora_rank, self.v_head_dim],
dtype=self.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
w = self.kv_b_proj.weight.reshape(
[
self.kv_lora_rank,
self.num_heads_per_partition,
-1,
]
).transpose(perm=[1, 2, 0])
self.kv_b_proj = None
set_weight_attrs(
self.k_b_proj_weight,
{"weight_loader": self.weight_loader},
)
if w.dtype != self.weight_dtype:
w = w.cast(self.weight_dtype)
if self.nranks > 0:
_set_var_distributed(self.k_b_proj_weight, split_axis=1)
set_weight_attrs(self.k_b_proj_weight, {"output_dim": True})
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
output_dim = getattr(param, "output_dim", None)
# Tensor parallelism splits the weight along the output_dim
if output_dim is not None:
dim = -1
size = loaded_weight.get_shape()[dim]
block_size = size // self.nranks
shard_offset = self.local_rank * block_size
shard_size = (self.local_rank + 1) * block_size
loaded_weight = loaded_weight[..., shard_offset:shard_size]
w = (
get_tensor(loaded_weight)
.reshape(
[
self.kv_lora_rank,
self.num_heads_per_partition,
-1,
]
)
.transpose(perm=[1, 2, 0])
)
if param.dtype != w.dtype:
w = w.cast(param.dtype)
# Split into K and V weights
# wk_b: [num_heads, qk_nope_head_dim, kv_lora_rank]
wk_b = w[:, : self.qk_nope_head_dim, :]
@@ -840,9 +805,8 @@ class KVBatchLinear(LinearBase):
raise ValueError("self.v_head_dim should not be None")
# wv_b: [num_heads, kv_lora_rank, v_head_dim]
wv_b = w[:, -self.v_head_dim :, :].transpose(perm=[0, 2, 1])
self.k_b_proj_weight.set_value(wk_b)
self.v_b_proj_weight.set_value(wv_b)
self.k_b_proj_weight = wk_b
self.v_b_proj_weight = wv_b
def load_state_dict(self, state_dict: dict):
"""
@@ -916,7 +880,7 @@ class KVBatchLinear(LinearBase):
out = paddle.bmm(x, self.v_b_proj_weight)
return out
def forward_cuda(self, x: paddle.Tensor, proj_type: str = "k") -> paddle.Tensor:
def forward(self, x: paddle.Tensor, proj_type: str = "k") -> paddle.Tensor:
"""
Forward function that can handle both K and V projections