mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 08:21:53 +08:00
[RL] Support SM100 FP8 quantization in RL (#6601)
* RL SM100 Fix * update
This commit is contained in:
@@ -1695,8 +1695,19 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
process_weight_transpose(layer, down_proj_weight_name)
|
||||
process_weight_transpose(layer, up_gate_proj_scale_name)
|
||||
process_weight_transpose(layer, down_proj_scale_name)
|
||||
else:
|
||||
return
|
||||
if self.quant_config.deepgemm_scale_ue8m0:
|
||||
up_gate_proj_scale = getattr(layer, self.added_scale_attrs[0])
|
||||
new_up_gate_proj_scale = paddle.empty(
|
||||
up_gate_proj_scale.shape[:1] + up_gate_proj_scale.shape[1:][::-1], dtype=up_gate_proj_scale.dtype
|
||||
)
|
||||
new_up_gate_proj_scale = new_up_gate_proj_scale.transpose([0, 2, 1])
|
||||
getattr(layer, self.added_scale_attrs[0]).data = new_up_gate_proj_scale
|
||||
down_proj_scale = getattr(layer, self.added_scale_attrs[1])
|
||||
new_down_proj_scale = paddle.empty(
|
||||
down_proj_scale.shape[:1] + down_proj_scale.shape[1:][::-1], dtype=down_proj_scale.dtype
|
||||
)
|
||||
new_down_proj_scale = new_down_proj_scale.transpose([0, 2, 1])
|
||||
getattr(layer, self.added_scale_attrs[1]).data = new_down_proj_scale
|
||||
|
||||
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user