Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
co63oc
2025-08-28 14:42:24 +08:00
committed by GitHub
parent c294fc8139
commit d4fc893fe3
3 changed files with 11 additions and 11 deletions
+4 -4
View File
@@ -109,11 +109,11 @@ class TestWFp8Afp8SparseGemm(unittest.TestCase):
TokenPadding = 0
tokens = [tokens_per_group] * BATCH
tokens_perfix_sum = np.cumsum(tokens)
tokens_perfix_sum = np.insert(tokens_perfix_sum, 0, 0)
tokens_prefix_sum = np.cumsum(tokens)
tokens_prefix_sum = np.insert(tokens_prefix_sum, 0, 0)
tokens = paddle.to_tensor(tokens, dtype="int32")
tokens_perfix_sum = paddle.to_tensor(tokens_perfix_sum, dtype="int32")
tokens_prefix_sum = paddle.to_tensor(tokens_prefix_sum, dtype="int32")
all_tokens = int(tokens.sum())
@@ -148,7 +148,7 @@ class TestWFp8Afp8SparseGemm(unittest.TestCase):
input_fp8,
convert_sparse_idx,
pack_weight.reshape([BATCH, N, K // 2]),
tokens_perfix_sum if TokenPadding == 0 else tokens,
tokens_prefix_sum if TokenPadding == 0 else tokens,
1 / weight_scale,
out_pd,
int(TokenPadding),