mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[TBO] Apply tbo to gpu_model_runner (#7165)
* apply tbo in gpu_model_runner * fix
This commit is contained in:
@@ -56,6 +56,7 @@ from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.spec_decode import SpecMethod
|
||||
from fastdeploy.utils import print_gpu_memory_use
|
||||
from fastdeploy.worker.input_batch import InputBatch, reorder_split_prefill_and_decode
|
||||
from fastdeploy.worker.tbo import GLOBAL_ATTN_BUFFERS
|
||||
|
||||
if current_platform.is_iluvatar():
|
||||
from fastdeploy.model_executor.ops.iluvatar import (
|
||||
@@ -1530,7 +1531,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
if envs.FD_DETERMINISTIC_MODE:
|
||||
decoder_block_shape_q = envs.FD_DETERMINISTIC_SPLIT_KV_SIZE
|
||||
|
||||
res_buffer = allocate_launch_related_buffer(
|
||||
buffer_kwargs = dict(
|
||||
max_batch_size=self.scheduler_config.max_num_seqs,
|
||||
max_model_len=self.model_config.max_model_len,
|
||||
encoder_block_shape_q=encoder_block_shape_q,
|
||||
@@ -1540,8 +1541,13 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
kv_num_heads=self.model_config.kv_num_heads,
|
||||
block_size=self.fd_config.cache_config.block_size,
|
||||
)
|
||||
res_buffer = allocate_launch_related_buffer(**buffer_kwargs)
|
||||
self.share_inputs.update(res_buffer)
|
||||
|
||||
if int(os.getenv("USE_TBO", "0")) == 1:
|
||||
for j in range(2):
|
||||
GLOBAL_ATTN_BUFFERS[j] = allocate_launch_related_buffer(**buffer_kwargs)
|
||||
|
||||
# Get the attention backend
|
||||
attn_cls = get_attention_backend()
|
||||
attn_backend = attn_cls(
|
||||
|
||||
@@ -114,8 +114,6 @@ def split_batch_decoder_layers(forward_meta: ForwardMeta, fd_config):
|
||||
end_bs += 1
|
||||
|
||||
if len(forward_meta.rotary_embs.shape) == 6:
|
||||
max_bs = forward_meta.rotary_embs.shape[0]
|
||||
assert max_bs == forward_meta.block_tables.shape[0]
|
||||
assert forward_meta.rotary_embs.shape[1:3] == [2, 1]
|
||||
assert forward_meta.rotary_embs.shape[4] == 1
|
||||
res[i].rotary_embs = forward_meta.rotary_embs[start_bs:end_bs]
|
||||
|
||||
Reference in New Issue
Block a user