mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
router support divided roolout (#6150)
This commit is contained in:
@@ -73,6 +73,11 @@ void SwapCacheImpLayout(
|
||||
copy_kind,
|
||||
stream);
|
||||
|
||||
PADDLE_ENFORCE_EQ(status,
|
||||
cudaSuccess,
|
||||
phi::errors::External("cudaMemcpyAsync failed: %s",
|
||||
cudaGetErrorString(status)));
|
||||
|
||||
#ifdef SWAP_DEBUG
|
||||
cudaStreamSynchronize(stream);
|
||||
std::cout << "mode:" << mode << ", layer_idx:" << layer_idx
|
||||
@@ -81,7 +86,11 @@ void SwapCacheImpLayout(
|
||||
#endif
|
||||
}
|
||||
}
|
||||
cudaStreamSynchronize(stream);
|
||||
cudaError_t sync_status = cudaStreamSynchronize(stream);
|
||||
PADDLE_ENFORCE_EQ(sync_status,
|
||||
cudaSuccess,
|
||||
phi::errors::External("cudaStreamSynchronize failed: %s",
|
||||
cudaGetErrorString(sync_status)));
|
||||
}
|
||||
|
||||
void SwapCacheLayout(
|
||||
|
||||
+132
-1
@@ -5,7 +5,11 @@ This module references the router implementation of slglang and vllm.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from itertools import chain
|
||||
from uuid import uuid4
|
||||
@@ -166,7 +170,10 @@ class Router:
|
||||
mixed_server = await self.select_mixed()
|
||||
|
||||
if request_data.get("stream", False):
|
||||
return await self._generate_stream(request_data, [mixed_server.url()], endpoint=endpoint_name)
|
||||
if request_data.get("divided_stream", int(os.environ.get("DIVIDED_STREAM", "0")) == 1):
|
||||
return await self._divided_generate_stream(request_data, [mixed_server.url()], endpoint=endpoint_name)
|
||||
else:
|
||||
return await self._generate_stream(request_data, [mixed_server.url()], endpoint=endpoint_name)
|
||||
else:
|
||||
return await self._generate(request_data, [mixed_server.url()], endpoint=endpoint_name)
|
||||
|
||||
@@ -241,6 +248,130 @@ class Router:
|
||||
|
||||
return StreamingResponse(stream_results(), media_type="text/event-stream")
|
||||
|
||||
async def _divided_generate_stream(
|
||||
self,
|
||||
modified_request,
|
||||
urls,
|
||||
return_result_url_index=-1,
|
||||
endpoint="v1/chat/completions",
|
||||
):
|
||||
"""
|
||||
NOTE: Used for debugging, not used in production
|
||||
"""
|
||||
|
||||
async def stream_results():
|
||||
total_max_tokens = modified_request.get("max_tokens", 0)
|
||||
step_max_tokens = modified_request.get("step_max_tokens", 10)
|
||||
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
||||
|
||||
round_idx = -1
|
||||
generated_tokens = 0
|
||||
input_ids = []
|
||||
output_ids = []
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
while generated_tokens < total_max_tokens:
|
||||
round_idx += 1
|
||||
remain_tokens = total_max_tokens - generated_tokens
|
||||
cur_max_tokens = min(step_max_tokens, remain_tokens)
|
||||
is_last_round = remain_tokens <= step_max_tokens
|
||||
|
||||
cur_request = copy.deepcopy(modified_request)
|
||||
cur_request["max_tokens"] = cur_max_tokens
|
||||
cur_request["return_token_ids"] = True
|
||||
cur_request["max_streaming_response_tokens"] = 1
|
||||
if round_idx == 0:
|
||||
cur_request["disable_chat_template"] = False
|
||||
else:
|
||||
cur_request["messages"] = []
|
||||
cur_request["prompt_token_ids"] = input_ids + output_ids
|
||||
cur_request["disable_chat_template"] = True
|
||||
|
||||
logger.debug(f"_divided_generate_stream, cur_request={cur_request}")
|
||||
|
||||
resp = await session.post(
|
||||
f"{urls[return_result_url_index]}/{endpoint}",
|
||||
json=cur_request,
|
||||
)
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise RuntimeError(f"Request failed: {resp.status}, body={text}")
|
||||
|
||||
buffer = b""
|
||||
chunk_idx = -1
|
||||
is_real_finished = False
|
||||
async for raw_chunk in resp.content.iter_chunked(64 * 1024):
|
||||
try:
|
||||
buffer += raw_chunk
|
||||
|
||||
while b"\n\n" in buffer:
|
||||
event_bytes, buffer = buffer.split(b"\n\n", 1)
|
||||
event_str = event_bytes.decode("utf-8")
|
||||
|
||||
for chunk in event_str.splitlines():
|
||||
logger.debug(f"receive response chunk: {chunk}")
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
chunk_idx += 1
|
||||
if round_idx > 0 and chunk_idx == 0:
|
||||
continue
|
||||
|
||||
assert chunk.startswith("data: "), f"Invalid response chunk: {chunk}"
|
||||
if chunk.startswith("data: [DONE]"):
|
||||
if is_real_finished:
|
||||
yield chunk + "\n\n"
|
||||
else:
|
||||
payload = json.loads(chunk[5:])
|
||||
choices = payload.get("choices", [])
|
||||
if not choices:
|
||||
continue
|
||||
delta = payload["choices"][0]["delta"]
|
||||
finish_reason = payload["choices"][0].get("finish_reason")
|
||||
|
||||
if not input_ids and len(delta["prompt_token_ids"]) > 0:
|
||||
input_ids = delta["prompt_token_ids"]
|
||||
|
||||
if finish_reason == "stop" or (is_last_round and finish_reason == "length"):
|
||||
is_real_finished = True
|
||||
|
||||
token_ids = delta.get("completion_token_ids")
|
||||
if (
|
||||
token_ids
|
||||
and isinstance(token_ids, list)
|
||||
and (finish_reason is None or is_real_finished)
|
||||
):
|
||||
output_ids.extend(token_ids)
|
||||
generated_tokens += len(token_ids)
|
||||
|
||||
if finish_reason is None or is_real_finished:
|
||||
yield chunk + "\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error decoding response chunk: {raw_chunk}, round_idx: {round_idx}, "
|
||||
f"chunk_idx: {chunk_idx}, error: {e}, traceback:{traceback.format_exc()}"
|
||||
)
|
||||
pass
|
||||
|
||||
if not is_real_finished:
|
||||
expected_tokens = (step_max_tokens - 1) * (round_idx + 1)
|
||||
if generated_tokens != expected_tokens:
|
||||
err_msg = (
|
||||
f"Generated tokens mismatch: generated_tokens is {generated_tokens}, "
|
||||
f"expected is {expected_tokens}"
|
||||
)
|
||||
logger.error(err_msg)
|
||||
raise RuntimeError(err_msg)
|
||||
|
||||
if is_real_finished:
|
||||
break
|
||||
|
||||
return StreamingResponse(
|
||||
stream_results(),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
async def monitor_instance_health(self, interval_secs: float = 5.0):
|
||||
"""
|
||||
Continuously check the health of prefill, decode, and mixed instances and remove unhealthy ones.
|
||||
|
||||
Reference in New Issue
Block a user