[Models] Add forward_meta to moe models' forward function (#5138)

* [Models] Add forward_meta to moe models' forward function

* fix missing param

* fix

* fix

* fix forward_meta

* fix test and remove chunked MoE releated in config

* fix test

* fix

* fix
This commit is contained in:
Longzhi Wang
2025-12-04 13:26:58 +08:00
committed by GitHub
parent f5bdb36e9b
commit 5cd17fd662
21 changed files with 131 additions and 87 deletions
+16 -11
View File
@@ -28,6 +28,13 @@ class MockStructuredOutputsConfig:
logits_processors = []
class MockForwardMeta:
def __init__(self):
# chunked MoE related.
self.moe_num_chunk = 1
self.max_moe_num_chunk = 1
class MockModelConfig:
max_model_len = 10
pad_token_id = 0
@@ -60,8 +67,6 @@ class MockFDConfig:
enable_expert_parallel = True
enable_chunked_moe = True
chunked_moe_size = 2
max_moe_num_chunk = 1
moe_num_chunk = 1
use_ep = True
use_sequence_parallel_moe = False
@@ -148,19 +153,19 @@ class TestChunkedMoE(unittest.TestCase):
def run_model_runner(self):
self.model_runner.initialize_forward_meta()
assert self.model_runner.fd_config.parallel_config.max_moe_num_chunk == 5, (
assert self.model_runner.forward_meta.max_moe_num_chunk == 5, (
f"chunk size is 2, max token_num is 10, max_moe_num_chunk should be 5, "
f"but got {self.model_runner.fd_config.parallel_config.max_moe_num_chunk}"
f"but got {self.model_runner.forward_meta.max_moe_num_chunk }"
)
if dist.get_rank() == 0:
assert self.model_runner.fd_config.parallel_config.moe_num_chunk == 5, (
f"chunk size is 2, token_num is 10, moe_num_chunk in rank 0 should be 5"
f"but got {self.model_runner.fd_config.parallel_config.moe_num_chunk}"
assert self.model_runner.forward_meta.moe_num_chunk == 5, (
f"chunk size is 2, token_num is 10, moe_num_chunk in rank 0 should be 5, "
f"but got {self.model_runner.forward_meta.moe_num_chunk}"
)
else:
assert self.model_runner.fd_config.parallel_config.moe_num_chunk == 1, (
f"chunk size is 2, token_num is 1, moe_num_chunk in rank 1 should be 1"
f", but got {self.model_runner.fd_config.parallel_config.moe_num_chunk}"
assert self.model_runner.forward_meta.moe_num_chunk == 1, (
f"chunk size is 2, token_num is 1, moe_num_chunk in rank 1 should be 1, "
f"but got {self.model_runner.forward_meta.moe_num_chunk}"
)
def run_fused_moe(self):
@@ -170,7 +175,7 @@ class TestChunkedMoE(unittest.TestCase):
else:
x = paddle.ones([1])
out = self.fused_moe.forward(x, gate)
out = self.fused_moe.forward(x, gate, MockForwardMeta())
assert out.shape == x.shape
def test_case(self):