add qwen-2.5-7B-PRM/ernie-rm (#4319)

This commit is contained in:
bukejiyu
2025-10-20 15:31:03 +08:00
committed by GitHub
parent 47595a2480
commit de2eaf4f81
10 changed files with 352 additions and 24 deletions
+4 -6
View File
@@ -393,6 +393,7 @@ class ColumnParallelLinear(LinearBase):
with_bias: bool = False,
add_bias: bool = False,
skip_quant: bool = False,
weight_dtype="",
):
"""
Initializes a linear layer and provides additional parameters required for inference and quantization.
@@ -421,6 +422,7 @@ class ColumnParallelLinear(LinearBase):
with_bias=with_bias,
add_bias=add_bias,
skip_quant=skip_quant,
weight_dtype=weight_dtype,
)
assert self.quant_method is not None
@@ -796,6 +798,7 @@ class RowParallelLinear(LinearBase):
add_bias: bool = False,
reduce_results: bool = True,
skip_quant: bool = False,
weight_dtype="",
):
"""
Initialize a linear layer with additional parameters for inference and quantization.
@@ -830,6 +833,7 @@ class RowParallelLinear(LinearBase):
with_bias=with_bias,
add_bias=add_bias,
skip_quant=skip_quant,
weight_dtype=weight_dtype,
)
if add_bias:
assert with_bias, "with_bias must be True when add_bias is True."
@@ -847,12 +851,6 @@ class RowParallelLinear(LinearBase):
if self.with_bias:
# col parallel
_set_var_distributed(self.bias, split_axis=0)
set_weight_attrs(
self.bias,
{
"output_dim": False,
},
)
self.reduce_results = reduce_results