[RL] Support SM100 FP8 quantization in RL (#6601)

* RL SM100 Fix

* update
This commit is contained in:
bukejiyu
2026-03-04 20:55:04 +08:00
committed by GitHub
parent 1256fd3806
commit 598cce8545
3 changed files with 34 additions and 12 deletions
@@ -1695,8 +1695,19 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
process_weight_transpose(layer, down_proj_weight_name)
process_weight_transpose(layer, up_gate_proj_scale_name)
process_weight_transpose(layer, down_proj_scale_name)
else:
return
if self.quant_config.deepgemm_scale_ue8m0:
up_gate_proj_scale = getattr(layer, self.added_scale_attrs[0])
new_up_gate_proj_scale = paddle.empty(
up_gate_proj_scale.shape[:1] + up_gate_proj_scale.shape[1:][::-1], dtype=up_gate_proj_scale.dtype
)
new_up_gate_proj_scale = new_up_gate_proj_scale.transpose([0, 2, 1])
getattr(layer, self.added_scale_attrs[0]).data = new_up_gate_proj_scale
down_proj_scale = getattr(layer, self.added_scale_attrs[1])
new_down_proj_scale = paddle.empty(
down_proj_scale.shape[:1] + down_proj_scale.shape[1:][::-1], dtype=down_proj_scale.dtype
)
new_down_proj_scale = new_down_proj_scale.transpose([0, 2, 1])
getattr(layer, self.added_scale_attrs[1]).data = new_down_proj_scale
def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""