[TBO] Apply tbo to gpu_model_runner (#7165)

* apply tbo in gpu_model_runner

* fix
This commit is contained in:
RichardWooSJTU
2026-04-08 16:55:17 +08:00
committed by GitHub
parent 4cd574cf90
commit 771d42c90b
2 changed files with 7 additions and 3 deletions
+7 -1
View File
@@ -56,6 +56,7 @@ from fastdeploy.platforms import current_platform
from fastdeploy.spec_decode import SpecMethod
from fastdeploy.utils import print_gpu_memory_use
from fastdeploy.worker.input_batch import InputBatch, reorder_split_prefill_and_decode
from fastdeploy.worker.tbo import GLOBAL_ATTN_BUFFERS
if current_platform.is_iluvatar():
from fastdeploy.model_executor.ops.iluvatar import (
@@ -1530,7 +1531,7 @@ class GPUModelRunner(ModelRunnerBase):
if envs.FD_DETERMINISTIC_MODE:
decoder_block_shape_q = envs.FD_DETERMINISTIC_SPLIT_KV_SIZE
res_buffer = allocate_launch_related_buffer(
buffer_kwargs = dict(
max_batch_size=self.scheduler_config.max_num_seqs,
max_model_len=self.model_config.max_model_len,
encoder_block_shape_q=encoder_block_shape_q,
@@ -1540,8 +1541,13 @@ class GPUModelRunner(ModelRunnerBase):
kv_num_heads=self.model_config.kv_num_heads,
block_size=self.fd_config.cache_config.block_size,
)
res_buffer = allocate_launch_related_buffer(**buffer_kwargs)
self.share_inputs.update(res_buffer)
if int(os.getenv("USE_TBO", "0")) == 1:
for j in range(2):
GLOBAL_ATTN_BUFFERS[j] = allocate_launch_related_buffer(**buffer_kwargs)
# Get the attention backend
attn_cls = get_attention_backend()
attn_backend = attn_cls(
-2
View File
@@ -114,8 +114,6 @@ def split_batch_decoder_layers(forward_meta: ForwardMeta, fd_config):
end_bs += 1
if len(forward_meta.rotary_embs.shape) == 6:
max_bs = forward_meta.rotary_embs.shape[0]
assert max_bs == forward_meta.block_tables.shape[0]
assert forward_meta.rotary_embs.shape[1:3] == [2, 1]
assert forward_meta.rotary_embs.shape[4] == 1
res[i].rotary_embs = forward_meta.rotary_embs[start_bs:end_bs]