mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 17:11:21 +08:00
[New][RL] Support Rollout Routing Replay (#5405)
* [RL] Support Rollout Routing Replay
* add routing indices cache
* fix config bug and moe forward bug
* R3 Support GLM
* support eb4.5
* fix merge bug
* Apply suggestion from @Copilot
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Apply suggestion from @Copilot
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Apply suggestion from @Copilot
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Apply suggestion from @Copilot
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* add routing replay ci
* support glm topk
* support orther top_k
* fix ci bug
* pre-commit
* only support chatcmpl
* Revert "Revert "[RL] Support Rollout Routing Replay (#5321)" (#5402)"
This reverts commit c45e064f3d.
* Fix XPU and NPU bug
---------
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
@@ -14,7 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from functools import partial
|
||||
from typing import Callable, Optional
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
@@ -26,6 +27,9 @@ from fastdeploy.distributed.communication import (
|
||||
tensor_model_parallel_all_reduce_custom,
|
||||
)
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
from fastdeploy.model_executor.layers.moe.routing_indices_cache import (
|
||||
save_routing_to_buffer,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.model_executor.utils import h2d_copy, slice_fn
|
||||
from fastdeploy.platforms import current_platform
|
||||
@@ -226,7 +230,7 @@ class FusedMoE(nn.Layer):
|
||||
self.is_rearrange = False
|
||||
if self.ep_size > 1:
|
||||
self.quant_method.init_ep(self)
|
||||
|
||||
self.enable_routing_replay = fd_config.routing_replay_config.enable_routing_replay
|
||||
# Merge normal and RL build model
|
||||
if gate_correction_bias is not None:
|
||||
self.gate_correction_bias = gate_correction_bias
|
||||
@@ -600,7 +604,7 @@ class FusedMoE(nn.Layer):
|
||||
else:
|
||||
self.quant_method.process_loaded_weights(self, state_dict)
|
||||
|
||||
def forward_split_allgather(self, x: paddle.Tensor, gate: nn.Layer):
|
||||
def forward_split_allgather(self, x: paddle.Tensor, gate: nn.Layer, topk_ids_hookfunc: Callable = None):
|
||||
"""
|
||||
Forward split allgather function.
|
||||
"""
|
||||
@@ -615,14 +619,14 @@ class FusedMoE(nn.Layer):
|
||||
if end_offset > token_num:
|
||||
end_offset = token_num
|
||||
part_x[: (end_offset - start_offset), :] = x[start_offset:end_offset, :]
|
||||
out = self.quant_method.apply(self, part_x, gate)
|
||||
out = self.quant_method.apply(self, part_x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
multi_outs = paddle.zeros([token_num_per_rank * self.attn_tp_size, x.shape[1]], dtype=x.dtype)
|
||||
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
|
||||
out = multi_outs[:token_num, :]
|
||||
|
||||
return out
|
||||
|
||||
def forward(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta):
|
||||
def forward(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta = None):
|
||||
"""
|
||||
Defines the forward computation of the moe layer.
|
||||
|
||||
@@ -633,6 +637,21 @@ class FusedMoE(nn.Layer):
|
||||
Tensor: Output tensor.s
|
||||
|
||||
"""
|
||||
topk_ids_hookfunc = None
|
||||
if self.enable_routing_replay:
|
||||
if forward_meta is not None: # forward_meta is None when execute empty_input_forward
|
||||
topk_ids_hookfunc = partial(
|
||||
save_routing_to_buffer,
|
||||
routing_replay_table=forward_meta.routing_replay_table,
|
||||
batch_id_per_token=forward_meta.batch_id_per_token,
|
||||
seq_lens_decoder=forward_meta.seq_lens_decoder,
|
||||
cu_seqlens_q=forward_meta.cu_seqlens_q,
|
||||
layer_idx=self.layer_idx,
|
||||
tp_size=self.fd_config.parallel_config.tensor_parallel_size,
|
||||
ep_size=self.fd_config.parallel_config.expert_parallel_size,
|
||||
tp_group=self.fd_config.parallel_config.tp_group,
|
||||
)
|
||||
|
||||
token_num = x.shape[0]
|
||||
if (
|
||||
self.ep_size > 1
|
||||
@@ -640,11 +659,16 @@ class FusedMoE(nn.Layer):
|
||||
and (not self.fd_config.parallel_config.use_sequence_parallel_moe)
|
||||
and token_num >= self.attn_tp_size
|
||||
):
|
||||
out = self.forward_split_allgather(x, gate)
|
||||
out = self.forward_split_allgather(x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
elif self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.enable_chunked_moe:
|
||||
out = self.forward_chunked_moe(x, gate, forward_meta)
|
||||
out = self.forward_chunked_moe(
|
||||
x,
|
||||
gate,
|
||||
forward_meta,
|
||||
topk_ids_hookfunc=topk_ids_hookfunc,
|
||||
)
|
||||
else:
|
||||
out = self.forward_normal(x, gate)
|
||||
out = self.forward_normal(x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
if current_platform.is_intel_hpu():
|
||||
@@ -653,7 +677,9 @@ class FusedMoE(nn.Layer):
|
||||
out = tensor_model_parallel_all_reduce(out, self.tp_group)
|
||||
return out
|
||||
|
||||
def forward_chunked_moe(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta):
|
||||
def forward_chunked_moe(
|
||||
self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta, topk_ids_hookfunc: Callable = None
|
||||
):
|
||||
"""
|
||||
Split input to multi chunk to reduce the memory usage of moe.
|
||||
|
||||
@@ -677,21 +703,25 @@ class FusedMoE(nn.Layer):
|
||||
|
||||
for i in range(forward_meta.max_moe_num_chunk):
|
||||
if i < forward_meta.moe_num_chunk:
|
||||
out_split_list[i] = self.quant_method.apply(self, x_split_list[i], gate)
|
||||
out_split_list[i] = self.quant_method.apply(
|
||||
self, x_split_list[i], gate, topk_ids_hookfunc=topk_ids_hookfunc
|
||||
)
|
||||
else:
|
||||
# just need to use real data to infer max_moe_num_chunk times.
|
||||
self.quant_method.apply(self, fake_x, gate)
|
||||
self.quant_method.apply(self, fake_x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
|
||||
out = paddle.concat(out_split_list, axis=0)
|
||||
else:
|
||||
# when only one chunk, just need to use real data to infer once.
|
||||
out = self.quant_method.apply(self, x, gate)
|
||||
out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
for i in range(forward_meta.max_moe_num_chunk - 1):
|
||||
self.quant_method.apply(self, fake_x, gate)
|
||||
self.quant_method.apply(self, fake_x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
|
||||
return out
|
||||
|
||||
def forward_normal(self, x: paddle.Tensor, gate: nn.Layer):
|
||||
def forward_normal(
|
||||
self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta, topk_ids_hookfunc: Callable = None
|
||||
):
|
||||
"""
|
||||
Normal mode of forward.
|
||||
|
||||
@@ -702,5 +732,5 @@ class FusedMoE(nn.Layer):
|
||||
Tensor: Output tensor.s
|
||||
|
||||
"""
|
||||
out = self.quant_method.apply(self, x, gate)
|
||||
out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user