[Speculative Decoding] Add draft_logprobs Support for Speculative Decode MTP (#4467)

* feat: add draft_logprobs for Speculative Decode MTP

* feat: add draft_logprobs for Speculative Decode MTP

* feat: add draft_logprobs for Speculative Decode MTP

* fix: postprocess for speculative decode

* test: test_speculative_decoding_use_logprobs

* fix: test_completion_echo

* fix test_max_streaming_tokens

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
SunLei
2025-10-21 14:57:50 +08:00
committed by GitHub
parent 775edcc09a
commit ee915220bd
7 changed files with 422 additions and 48 deletions
@@ -94,43 +94,43 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
response_data = [
{
"request_id": "test_request_id_0",
"outputs": {"token_ids": [1], "text": "a", "top_logprobs": None},
"outputs": {"token_ids": [1], "text": "a", "top_logprobs": None, "draft_top_logprobs": None},
"metrics": {"first_token_time": 0.1, "inference_start_time": 0.1},
"finished": False,
},
{
"request_id": "test_request_id_0",
"outputs": {"token_ids": [2], "text": "b", "top_logprobs": None},
"outputs": {"token_ids": [2], "text": "b", "top_logprobs": None, "draft_top_logprobs": None},
"metrics": {"arrival_time": 0.2, "first_token_time": None},
"finished": False,
},
{
"request_id": "test_request_id_0",
"outputs": {"token_ids": [3], "text": "c", "top_logprobs": None},
"outputs": {"token_ids": [3], "text": "c", "top_logprobs": None, "draft_top_logprobs": None},
"metrics": {"arrival_time": 0.3, "first_token_time": None},
"finished": False,
},
{
"request_id": "test_request_id_0",
"outputs": {"token_ids": [4], "text": "d", "top_logprobs": None},
"outputs": {"token_ids": [4], "text": "d", "top_logprobs": None, "draft_top_logprobs": None},
"metrics": {"arrival_time": 0.4, "first_token_time": None},
"finished": False,
},
{
"request_id": "test_request_id_0",
"outputs": {"token_ids": [5], "text": "e", "top_logprobs": None},
"outputs": {"token_ids": [5], "text": "e", "top_logprobs": None, "draft_top_logprobs": None},
"metrics": {"arrival_time": 0.5, "first_token_time": None},
"finished": False,
},
{
"request_id": "test_request_id_0",
"outputs": {"token_ids": [6], "text": "f", "top_logprobs": None},
"outputs": {"token_ids": [6], "text": "f", "top_logprobs": None, "draft_top_logprobs": None},
"metrics": {"arrival_time": 0.6, "first_token_time": None},
"finished": False,
},
{
"request_id": "test_request_id_0",
"outputs": {"token_ids": [7], "text": "g", "top_logprobs": None},
"outputs": {"token_ids": [7], "text": "g", "top_logprobs": None, "draft_top_logprobs": None},
"metrics": {"arrival_time": 0.7, "first_token_time": None, "request_start_time": 0.1},
"finished": True,
},
@@ -190,9 +190,9 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
chunk_dict = json.loads(json_part)
parsed_chunks.append(chunk_dict)
except json.JSONDecodeError as e:
self.fail(f"Cannot parser {i+1} chunk, JSON: {e}\n origin string: {repr(chunk_str)}")
self.fail(f"Cannot parser {i + 1} chunk, JSON: {e}\n origin string: {repr(chunk_str)}")
else:
self.fail(f"{i+1} chunk is unexcepted 'data: JSON\\n\\n': {repr(chunk_str)}")
self.fail(f"{i + 1} chunk is unexcepted 'data: JSON\\n\\n': {repr(chunk_str)}")
for chunk_dict in parsed_chunks:
choices_list = chunk_dict["choices"]
if choices_list[-1].get("finish_reason") is not None:
@@ -209,13 +209,13 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
[
{
"request_id": "test-request-id_0",
"outputs": {"token_ids": [1], "text": "a", "top_logprobs": None},
"outputs": {"token_ids": [1], "text": "a", "top_logprobs": None, "draft_top_logprobs": None},
"metrics": {"first_token_time": 0.1, "inference_start_time": 0.1},
"finished": False,
},
{
"request_id": "test-request-id_0",
"outputs": {"token_ids": [2], "text": "b", "top_logprobs": None},
"outputs": {"token_ids": [2], "text": "b", "top_logprobs": None, "draft_top_logprobs": None},
"metrics": {"arrival_time": 0.2, "first_token_time": None},
"finished": False,
},
@@ -223,7 +223,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
[
{
"request_id": "test-request-id_0",
"outputs": {"token_ids": [7], "text": "g", "top_logprobs": None},
"outputs": {"token_ids": [7], "text": "g", "top_logprobs": None, "draft_top_logprobs": None},
"metrics": {"arrival_time": 0.7, "first_token_time": None, "request_start_time": 0.1},
"finished": True,
}
@@ -269,11 +269,12 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
chunk_dict = json.loads(json_part)
parsed_chunks.append(chunk_dict)
except json.JSONDecodeError as e:
self.fail(f"Cannot parser {i+1} chunk, JSON: {e}\n origin string: {repr(chunk_str)}")
self.fail(f"Cannot parser {i + 1} chunk, JSON: {e}\n origin string: {repr(chunk_str)}")
else:
self.fail(f"{i+1} chunk is unexcepted 'data: JSON\\n\\n': {repr(chunk_str)}")
self.fail(f"{i + 1} chunk is unexcepted 'data: JSON\\n\\n': {repr(chunk_str)}")
self.assertEqual(len(parsed_chunks), 1)
for chunk_dict in parsed_chunks:
print(f"======>{chunk_dict}")
choices_list = chunk_dict["choices"]
self.assertEqual(len(choices_list), 3, f"Chunk {chunk_dict} should has three choices")
self.assertEqual(