[Optimization] Optimization for gather_logprob by 10GB (#5817)

* opt logprobs gather_logprob,reduce device memory usage by 10GB when token_num=8k
This commit is contained in:
chen
2025-12-30 15:33:34 +08:00
committed by GitHub
parent 98519ee2e9
commit 0bcf924e10
3 changed files with 126 additions and 3 deletions
@@ -0,0 +1,76 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import paddle
import triton
import triton.language as tl
@triton.jit
def count_greater_kernel(
x_ptr, # [num_tokens, n_elements]
y_ptr, # [num_tokens, 1]
out_ptr, # [num_tokens, 1]
n_elements,
BLOCK_SIZE: tl.constexpr,
):
b = tl.program_id(0)
sum_val = 0.0
y = tl.load(y_ptr + b * 1 + 0)
for col_start_idx in range(0, tl.cdiv(n_elements, BLOCK_SIZE)):
col_ids = col_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
col_mask = col_ids < n_elements
x = tl.load(x_ptr + b * n_elements + col_ids, mask=col_mask, other=-float("inf"))
compare_mask = x >= y
cmp_mask = tl.where(compare_mask & col_mask, 1, 0)
sum_val += tl.sum(cmp_mask, axis=0)
tl.store(out_ptr + b, sum_val.to(tl.int64))
def batched_count_greater_than(x: paddle.Tensor, y: paddle.Tensor) -> paddle.Tensor:
"""
Triton implementation: (x >= y).sum(-1)
Args:
x (paddle.Tensor): 2D tensorshape [num_tokens, n_elements]float32。
y (paddle.Tensor): 2D tensorshape [num_tokens, 1]float32。
Returns:
paddle.Tensor: 1D tensorshape [num_tokens].
"""
assert x.dim() == 2, f"x must be 2D, got {x.dim()}D"
assert y.dim() == 2 and y.shape[1] == 1, f"y must be 2D with shape [num_tokens, 1], got {y.shape}"
assert x.shape[0] == y.shape[0], f"batch size mismatch: x has {x.shape[0]}, y has {y.shape[0]}"
assert x.dtype == y.dtype, f"dtype mismatch: x is {x.dtype}, y is {y.dtype}"
num_tokens, n_elements = x.shape
dtype = paddle.int64
out = paddle.empty([num_tokens], dtype=dtype, device=x.place)
config = {"BLOCK_SIZE": 4096, "num_warps": 16}
grid = (num_tokens,)
count_greater_kernel[grid](
x_ptr=x,
y_ptr=y,
out_ptr=out,
n_elements=n_elements,
BLOCK_SIZE=config["BLOCK_SIZE"],
num_warps=config["num_warps"],
)
return out
@@ -30,6 +30,7 @@ from fastdeploy.model_executor.guided_decoding import LogitsProcessorBase
from fastdeploy.model_executor.layers.sample.early_stopper import (
get_early_stopper_cls_from_stragegy,
)
from fastdeploy.model_executor.layers.sample.logprobs import batched_count_greater_than
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.ops import (
apply_penalty_multi_scores,
@@ -466,7 +467,7 @@ class Sampler(nn.Layer):
token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1)
# Compute the ranks of the actual token.
token_ranks = (logprobs >= token_logprobs).sum(-1)
token_ranks = batched_count_greater_than(logprobs, token_logprobs)
if num_logprobs >= 1:
# Find the topK values.
@@ -709,7 +710,7 @@ class SpeculativeSampler(nn.Layer):
token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1)
# Compute the ranks of the actual token.
token_ranks = (logprobs >= token_logprobs).sum(-1)
token_ranks = batched_count_greater_than(logprobs, token_logprobs)
if num_logprobs >= 1:
# Find the topK values.
@@ -1055,7 +1056,7 @@ class MTPSampler(nn.Layer):
token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1)
# Compute the ranks of the actual token.
token_ranks = (logprobs >= token_logprobs).sum(-1)
token_ranks = batched_count_greater_than(logprobs, token_logprobs)
if num_logprobs >= 1:
# Find the topK values.
@@ -0,0 +1,46 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from fastdeploy.model_executor.layers.sample.logprobs import batched_count_greater_than
class TestBatchedCountGreaterThan(unittest.TestCase):
def setUp(self) -> None:
pass
def naive_impl(self, x, y):
return (x >= y).sum(-1)
def test_batched_count_greater_than(self):
vocab_size_list = [151552, 566]
test_token_nums = [1, 32, 128, 1024, 8192]
for idx, num_tokens in enumerate(test_token_nums):
for vocab_size in vocab_size_list:
x = paddle.randn([num_tokens, vocab_size], dtype="float32")
y = paddle.randn([num_tokens, 1], dtype="float32")
x[0, 0] = -float("inf")
y[0, 0] = -float("inf")
out = self.naive_impl(x, y)
out_triton = batched_count_greater_than(x, y)
self.assertTrue(np.allclose(out.numpy(), out_triton.numpy()))
return out
if __name__ == "__main__":
unittest.main()