mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 01:29:57 +08:00
[Optimization] Optimization for gather_logprob by 10GB (#5817)
* opt logprobs gather_logprob,reduce device memory usage by 10GB when token_num=8k
This commit is contained in:
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import paddle
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def count_greater_kernel(
|
||||
x_ptr, # [num_tokens, n_elements]
|
||||
y_ptr, # [num_tokens, 1]
|
||||
out_ptr, # [num_tokens, 1]
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
b = tl.program_id(0)
|
||||
sum_val = 0.0
|
||||
y = tl.load(y_ptr + b * 1 + 0)
|
||||
for col_start_idx in range(0, tl.cdiv(n_elements, BLOCK_SIZE)):
|
||||
col_ids = col_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
col_mask = col_ids < n_elements
|
||||
x = tl.load(x_ptr + b * n_elements + col_ids, mask=col_mask, other=-float("inf"))
|
||||
compare_mask = x >= y
|
||||
cmp_mask = tl.where(compare_mask & col_mask, 1, 0)
|
||||
sum_val += tl.sum(cmp_mask, axis=0)
|
||||
tl.store(out_ptr + b, sum_val.to(tl.int64))
|
||||
|
||||
|
||||
def batched_count_greater_than(x: paddle.Tensor, y: paddle.Tensor) -> paddle.Tensor:
|
||||
"""
|
||||
Triton implementation: (x >= y).sum(-1)
|
||||
|
||||
Args:
|
||||
x (paddle.Tensor): 2D tensor,shape [num_tokens, n_elements],float32。
|
||||
y (paddle.Tensor): 2D tensor,shape [num_tokens, 1],float32。
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: 1D tensor,shape [num_tokens].
|
||||
"""
|
||||
assert x.dim() == 2, f"x must be 2D, got {x.dim()}D"
|
||||
assert y.dim() == 2 and y.shape[1] == 1, f"y must be 2D with shape [num_tokens, 1], got {y.shape}"
|
||||
assert x.shape[0] == y.shape[0], f"batch size mismatch: x has {x.shape[0]}, y has {y.shape[0]}"
|
||||
assert x.dtype == y.dtype, f"dtype mismatch: x is {x.dtype}, y is {y.dtype}"
|
||||
|
||||
num_tokens, n_elements = x.shape
|
||||
dtype = paddle.int64
|
||||
|
||||
out = paddle.empty([num_tokens], dtype=dtype, device=x.place)
|
||||
|
||||
config = {"BLOCK_SIZE": 4096, "num_warps": 16}
|
||||
grid = (num_tokens,)
|
||||
|
||||
count_greater_kernel[grid](
|
||||
x_ptr=x,
|
||||
y_ptr=y,
|
||||
out_ptr=out,
|
||||
n_elements=n_elements,
|
||||
BLOCK_SIZE=config["BLOCK_SIZE"],
|
||||
num_warps=config["num_warps"],
|
||||
)
|
||||
|
||||
return out
|
||||
Reference in New Issue
Block a user