router support divided roolout (#6150)

This commit is contained in:
jc
2026-01-22 10:39:39 +08:00
committed by GitHub
parent 9c4db0ac3f
commit 309c7d9764
2 changed files with 142 additions and 2 deletions
+10 -1
View File
@@ -73,6 +73,11 @@ void SwapCacheImpLayout(
copy_kind,
stream);
PADDLE_ENFORCE_EQ(status,
cudaSuccess,
phi::errors::External("cudaMemcpyAsync failed: %s",
cudaGetErrorString(status)));
#ifdef SWAP_DEBUG
cudaStreamSynchronize(stream);
std::cout << "mode:" << mode << ", layer_idx:" << layer_idx
@@ -81,7 +86,11 @@ void SwapCacheImpLayout(
#endif
}
}
cudaStreamSynchronize(stream);
cudaError_t sync_status = cudaStreamSynchronize(stream);
PADDLE_ENFORCE_EQ(sync_status,
cudaSuccess,
phi::errors::External("cudaStreamSynchronize failed: %s",
cudaGetErrorString(sync_status)));
}
void SwapCacheLayout(
+132 -1
View File
@@ -5,7 +5,11 @@ This module references the router implementation of slglang and vllm.
"""
import asyncio
import copy
import json
import os
import random
import traceback
from dataclasses import dataclass
from itertools import chain
from uuid import uuid4
@@ -166,7 +170,10 @@ class Router:
mixed_server = await self.select_mixed()
if request_data.get("stream", False):
return await self._generate_stream(request_data, [mixed_server.url()], endpoint=endpoint_name)
if request_data.get("divided_stream", int(os.environ.get("DIVIDED_STREAM", "0")) == 1):
return await self._divided_generate_stream(request_data, [mixed_server.url()], endpoint=endpoint_name)
else:
return await self._generate_stream(request_data, [mixed_server.url()], endpoint=endpoint_name)
else:
return await self._generate(request_data, [mixed_server.url()], endpoint=endpoint_name)
@@ -241,6 +248,130 @@ class Router:
return StreamingResponse(stream_results(), media_type="text/event-stream")
async def _divided_generate_stream(
self,
modified_request,
urls,
return_result_url_index=-1,
endpoint="v1/chat/completions",
):
"""
NOTE: Used for debugging, not used in production
"""
async def stream_results():
total_max_tokens = modified_request.get("max_tokens", 0)
step_max_tokens = modified_request.get("step_max_tokens", 10)
timeout = aiohttp.ClientTimeout(total=self.timeout)
round_idx = -1
generated_tokens = 0
input_ids = []
output_ids = []
async with aiohttp.ClientSession(timeout=timeout) as session:
while generated_tokens < total_max_tokens:
round_idx += 1
remain_tokens = total_max_tokens - generated_tokens
cur_max_tokens = min(step_max_tokens, remain_tokens)
is_last_round = remain_tokens <= step_max_tokens
cur_request = copy.deepcopy(modified_request)
cur_request["max_tokens"] = cur_max_tokens
cur_request["return_token_ids"] = True
cur_request["max_streaming_response_tokens"] = 1
if round_idx == 0:
cur_request["disable_chat_template"] = False
else:
cur_request["messages"] = []
cur_request["prompt_token_ids"] = input_ids + output_ids
cur_request["disable_chat_template"] = True
logger.debug(f"_divided_generate_stream, cur_request={cur_request}")
resp = await session.post(
f"{urls[return_result_url_index]}/{endpoint}",
json=cur_request,
)
if resp.status != 200:
text = await resp.text()
raise RuntimeError(f"Request failed: {resp.status}, body={text}")
buffer = b""
chunk_idx = -1
is_real_finished = False
async for raw_chunk in resp.content.iter_chunked(64 * 1024):
try:
buffer += raw_chunk
while b"\n\n" in buffer:
event_bytes, buffer = buffer.split(b"\n\n", 1)
event_str = event_bytes.decode("utf-8")
for chunk in event_str.splitlines():
logger.debug(f"receive response chunk: {chunk}")
if not chunk:
continue
chunk_idx += 1
if round_idx > 0 and chunk_idx == 0:
continue
assert chunk.startswith("data: "), f"Invalid response chunk: {chunk}"
if chunk.startswith("data: [DONE]"):
if is_real_finished:
yield chunk + "\n\n"
else:
payload = json.loads(chunk[5:])
choices = payload.get("choices", [])
if not choices:
continue
delta = payload["choices"][0]["delta"]
finish_reason = payload["choices"][0].get("finish_reason")
if not input_ids and len(delta["prompt_token_ids"]) > 0:
input_ids = delta["prompt_token_ids"]
if finish_reason == "stop" or (is_last_round and finish_reason == "length"):
is_real_finished = True
token_ids = delta.get("completion_token_ids")
if (
token_ids
and isinstance(token_ids, list)
and (finish_reason is None or is_real_finished)
):
output_ids.extend(token_ids)
generated_tokens += len(token_ids)
if finish_reason is None or is_real_finished:
yield chunk + "\n\n"
except Exception as e:
logger.error(
f"Error decoding response chunk: {raw_chunk}, round_idx: {round_idx}, "
f"chunk_idx: {chunk_idx}, error: {e}, traceback:{traceback.format_exc()}"
)
pass
if not is_real_finished:
expected_tokens = (step_max_tokens - 1) * (round_idx + 1)
if generated_tokens != expected_tokens:
err_msg = (
f"Generated tokens mismatch: generated_tokens is {generated_tokens}, "
f"expected is {expected_tokens}"
)
logger.error(err_msg)
raise RuntimeError(err_msg)
if is_real_finished:
break
return StreamingResponse(
stream_results(),
media_type="text/event-stream",
)
async def monitor_instance_health(self, interval_secs: float = 5.0):
"""
Continuously check the health of prefill, decode, and mixed instances and remove unhealthy ones.