Revert "[RL] Support Rollout Routing Replay (#5321)" (#5402)

This reverts commit 96d2d4877b.
This commit is contained in:
Jiang-Jia-Jun
2025-12-05 20:19:39 +08:00
committed by GitHub
parent 94c57e4175
commit c45e064f3d
24 changed files with 24 additions and 592 deletions
+15 -45
View File
@@ -14,8 +14,7 @@
# limitations under the License.
"""
from functools import partial
from typing import Callable, Optional
from typing import Optional
import paddle
from paddle import nn
@@ -27,9 +26,6 @@ from fastdeploy.distributed.communication import (
tensor_model_parallel_all_reduce_custom,
)
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.layers.moe.routing_indices_cache import (
save_routing_to_buffer,
)
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.utils import h2d_copy, slice_fn
from fastdeploy.platforms import current_platform
@@ -230,7 +226,7 @@ class FusedMoE(nn.Layer):
self.is_rearrange = False
if self.ep_size > 1:
self.quant_method.init_ep(self)
self.enable_routing_replay = fd_config.routing_replay_config.enable_routing_replay
# Merge normal and RL build model
if gate_correction_bias is not None:
self.gate_correction_bias = gate_correction_bias
@@ -604,7 +600,7 @@ class FusedMoE(nn.Layer):
else:
self.quant_method.process_loaded_weights(self, state_dict)
def forward_split_allgather(self, x: paddle.Tensor, gate: nn.Layer, topk_ids_hookfunc: Callable = None):
def forward_split_allgather(self, x: paddle.Tensor, gate: nn.Layer):
"""
Forward split allgather function.
"""
@@ -619,14 +615,14 @@ class FusedMoE(nn.Layer):
if end_offset > token_num:
end_offset = token_num
part_x[: (end_offset - start_offset), :] = x[start_offset:end_offset, :]
out = self.quant_method.apply(self, part_x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
out = self.quant_method.apply(self, part_x, gate)
multi_outs = paddle.zeros([token_num_per_rank * self.attn_tp_size, x.shape[1]], dtype=x.dtype)
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
out = multi_outs[:token_num, :]
return out
def forward(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta = None):
def forward(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta):
"""
Defines the forward computation of the moe layer.
@@ -637,21 +633,6 @@ class FusedMoE(nn.Layer):
Tensor: Output tensor.s
"""
topk_ids_hookfunc = None
if self.enable_routing_replay:
if forward_meta is not None: # forward_meta is None when execute empty_input_forward
topk_ids_hookfunc = partial(
save_routing_to_buffer,
routing_replay_table=forward_meta.routing_replay_table,
batch_id_per_token=forward_meta.batch_id_per_token,
seq_lens_decoder=forward_meta.seq_lens_decoder,
cu_seqlens_q=forward_meta.cu_seqlens_q,
layer_idx=self.layer_idx,
tp_size=self.fd_config.parallel_config.tensor_parallel_size,
ep_size=self.fd_config.parallel_config.expert_parallel_size,
tp_group=self.fd_config.parallel_config.tp_group,
)
token_num = x.shape[0]
if (
self.ep_size > 1
@@ -659,16 +640,11 @@ class FusedMoE(nn.Layer):
and (not self.fd_config.parallel_config.use_sequence_parallel_moe)
and token_num >= self.attn_tp_size
):
out = self.forward_split_allgather(x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
out = self.forward_split_allgather(x, gate)
elif self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.enable_chunked_moe:
out = self.forward_chunked_moe(
x,
gate,
forward_meta,
topk_ids_hookfunc=topk_ids_hookfunc,
)
out = self.forward_chunked_moe(x, gate, forward_meta)
else:
out = self.forward_normal(x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc)
out = self.forward_normal(x, gate)
if self.reduce_results and self.tp_size > 1:
if current_platform.is_intel_hpu():
@@ -677,9 +653,7 @@ class FusedMoE(nn.Layer):
out = tensor_model_parallel_all_reduce(out, self.tp_group)
return out
def forward_chunked_moe(
self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta, topk_ids_hookfunc: Callable = None
):
def forward_chunked_moe(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta):
"""
Split input to multi chunk to reduce the memory usage of moe.
@@ -703,25 +677,21 @@ class FusedMoE(nn.Layer):
for i in range(forward_meta.max_moe_num_chunk):
if i < forward_meta.moe_num_chunk:
out_split_list[i] = self.quant_method.apply(
self, x_split_list[i], gate, topk_ids_hookfunc=topk_ids_hookfunc
)
out_split_list[i] = self.quant_method.apply(self, x_split_list[i], gate)
else:
# just need to use real data to infer max_moe_num_chunk times.
self.quant_method.apply(self, fake_x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
self.quant_method.apply(self, fake_x, gate)
out = paddle.concat(out_split_list, axis=0)
else:
# when only one chunk, just need to use real data to infer once.
out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
out = self.quant_method.apply(self, x, gate)
for i in range(forward_meta.max_moe_num_chunk - 1):
self.quant_method.apply(self, fake_x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
self.quant_method.apply(self, fake_x, gate)
return out
def forward_normal(
self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta, topk_ids_hookfunc: Callable = None
):
def forward_normal(self, x: paddle.Tensor, gate: nn.Layer):
"""
Normal mode of forward.
@@ -732,5 +702,5 @@ class FusedMoE(nn.Layer):
Tensor: Output tensor.s
"""
out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
out = self.quant_method.apply(self, x, gate)
return out