[Optimize] Support and robust for tpN for PD (#4595)

* [Optimize] Support and robust for tpN for PD

* fix

* fix

* support dpM tpN for cache messager

* fix

* fix token counter

* fix bug for merge develop

* fix bug

* robust cache messager for v0
This commit is contained in:
chenjian
2025-11-03 15:38:31 +08:00
committed by GitHub
parent 7b35488779
commit 25498efcf3
9 changed files with 452 additions and 197 deletions
+14 -32
View File
@@ -143,38 +143,7 @@ class DPLocalScheduler(LocalScheduler):
requests: List[Request] = []
with self.requests_not_empty:
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
while True:
batch_ids = self.requests_not_empty.wait_for(
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch],
0.005,
)
if batch_ids:
for request_id in batch_ids:
request = self.requests[request_id]
required_input_blocks = self.calc_required_blocks(
request.prompt_tokens_ids_len, block_size
)
current_prefill_tokens += request.prompt_tokens_ids_len
required_total_blocks += required_input_blocks + reserved_output_blocks
if required_total_blocks > available_blocks:
break
requests.append(request.raw)
self.ids_read_cursor += 1
start_batch_time = time.time()
if current_prefill_tokens > max_num_batched_tokens:
break
if len(requests) >= batch:
break
if (
(current_prefill_tokens > max_num_batched_tokens)
or (len(requests) >= batch)
or (time.time() - start_batch_time > envs.FD_EP_BATCHED_TOKEN_TIMEOUT)
):
break
else:
required_total_blocks = 0
while True:
batch_ids = self.requests_not_empty.wait_for(
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch],
0.005,
@@ -183,11 +152,24 @@ class DPLocalScheduler(LocalScheduler):
for request_id in batch_ids:
request = self.requests[request_id]
required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size)
current_prefill_tokens += request.prompt_tokens_ids_len
required_total_blocks += required_input_blocks + reserved_output_blocks
if required_total_blocks > available_blocks:
break
requests.append(request.raw)
self.ids_read_cursor += 1
start_batch_time = time.time()
if current_prefill_tokens > max_num_batched_tokens:
break
if len(requests) >= batch:
break
if (
(current_prefill_tokens > max_num_batched_tokens)
or (len(requests) >= batch)
or (time.time() - start_batch_time > envs.FD_EP_BATCHED_TOKEN_TIMEOUT)
):
break
if batch_ids:
if len(batch_ids) > 0 and len(requests) == 0: