From e78e22ebd5aee2683a8d8aada06b7b208cd224c1 Mon Sep 17 00:00:00 2001 From: GoldPancake <56388518+Deleter-D@users.noreply.github.com> Date: Tue, 30 Dec 2025 12:44:29 +0800 Subject: [PATCH] [BugFix] Fix entropy bugs (#5818) * fix entropy bugs * fix ut * fix --- fastdeploy/model_executor/entropy_utils.py | 12 ++++++++++-- fastdeploy/worker/gpu_model_runner.py | 2 +- scripts/calculate_avg_entropy.py | 14 +++++++++++++- tests/model_executor/test_entropy_utils.py | 12 ++++++------ 4 files changed, 30 insertions(+), 10 deletions(-) diff --git a/fastdeploy/model_executor/entropy_utils.py b/fastdeploy/model_executor/entropy_utils.py index c9fc431b44..2794e5b722 100644 --- a/fastdeploy/model_executor/entropy_utils.py +++ b/fastdeploy/model_executor/entropy_utils.py @@ -46,7 +46,11 @@ def calculate_logits_entropy(logits, share_inputs, temperature): for i in range(real_bsz): for _ in range(real_seq_lens[i]): share_inputs["entropy_list"][i].append(entropy.pop(0)) - if share_inputs["stop_flags"][i] and len(share_inputs["entropy_list"][i]) != 0: + if ( + share_inputs["stop_flags"][i] + and share_inputs["seq_lens_decoder"][i] != 0 + and len(share_inputs["entropy_list"][i]) != 0 + ): data_processor_logger.info( f"req_id: {share_inputs['req_ids'][i]}, entropy: {sum(share_inputs['entropy_list'][i])/len(share_inputs['entropy_list'][i])}" ) @@ -92,7 +96,11 @@ def speculate_calculate_logits_entropy(logits, share_inputs, temperature): for i in range(real_bsz): for _ in range(share_inputs["accept_num"][i]): share_inputs["entropy_list"][i].append(entropy.pop(0)) - if share_inputs["stop_flags"][i] and len(share_inputs["entropy_list"][i]) != 0: + if ( + share_inputs["stop_flags"][i] + and share_inputs["seq_lens_decoder"][i] != 0 + and len(share_inputs["entropy_list"][i]) != 0 + ): data_processor_logger.info( f"req_id: {share_inputs['req_ids'][i]}, entropy: {sum(share_inputs['entropy_list'][i])/len(share_inputs['entropy_list'][i])}" ) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 954da85c68..41eac4b6b4 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -645,7 +645,6 @@ class GPUModelRunner(ModelRunnerBase): request = req_dicts[i] # assert isinstance(request, Request) idx = request.idx - self.share_inputs["req_ids"][idx] = str(request.request_id) if hasattr(request, "pooling_params") and request.pooling_params is not None: batch_pooling_params.append(request.pooling_params) @@ -653,6 +652,7 @@ class GPUModelRunner(ModelRunnerBase): logits_info = None prefill_tokens = [] if request.task_type.value == RequestType.PREFILL.value: # prefill task + self.share_inputs["req_ids"][idx] = str(request.request_id) # guided decoding if ( request.guided_json is not None diff --git a/scripts/calculate_avg_entropy.py b/scripts/calculate_avg_entropy.py index f24c976cd5..ab7d4d989e 100644 --- a/scripts/calculate_avg_entropy.py +++ b/scripts/calculate_avg_entropy.py @@ -1,4 +1,5 @@ import argparse +import glob import os import re from typing import List, Optional @@ -40,8 +41,19 @@ def main(): parser.add_argument("--log-dir", type=str, required=True) parser.add_argument("--drop-ratio", "-d", type=float, default=0.1) parser.add_argument("--verbose", "-v", action="store_true") + parser.add_argument("--start-id", "-s", type=int) + parser.add_argument("--end-id", "-e", type=int) args = parser.parse_args() - entropy_values = extract_entropy_values(os.path.join(args.log_dir, "data_processor.log")) + log_files = glob.glob(os.path.join(args.log_dir, "data_processor.log.*")) + if not log_files: + print(f"No log files found in {args.log_dir}") + return + + entropy_values = [] + for log_file in log_files: + entropy_values.extend(extract_entropy_values(log_file)) + if args.start_id and args.end_id: + entropy_values = entropy_values[args.start_id : args.end_id] average_entropy, filtered_vals = calculate_average(entropy_values, args.drop_ratio) print(f"{len(entropy_values)} entropy values were found") diff --git a/tests/model_executor/test_entropy_utils.py b/tests/model_executor/test_entropy_utils.py index 1135a77f5a..18fd8b1a8b 100644 --- a/tests/model_executor/test_entropy_utils.py +++ b/tests/model_executor/test_entropy_utils.py @@ -28,6 +28,7 @@ class TestCalculateLogitsEntropy(unittest.TestCase): share_inputs = { "seq_lens_this_time": paddle.to_tensor([[1], [0], [15]], dtype="int32"), "seq_lens_encoder": paddle.to_tensor([[0], [0], [15]], dtype="int32"), + "seq_lens_decoder": paddle.to_tensor([[30], [0], [15]], dtype="int32"), "entropy_list": [[], [], []], "stop_flags": paddle.to_tensor([[False], [True], [False]], dtype="bool"), "req_ids": ["req_1", "req_2", "req_3"], @@ -55,6 +56,7 @@ class TestCalculateLogitsEntropy(unittest.TestCase): share_inputs = { "seq_lens_this_time": paddle.to_tensor([[1], [0], [15]], dtype="int32"), "seq_lens_encoder": paddle.to_tensor([[0], [0], [15]], dtype="int32"), + "seq_lens_decoder": paddle.to_tensor([[30], [0], [15]], dtype="int32"), "entropy_list": [[], [], []], "stop_flags": paddle.to_tensor([[False], [True], [False]], dtype="bool"), "req_ids": ["req_1", "req_2", "req_3"], @@ -82,6 +84,7 @@ class TestCalculateLogitsEntropy(unittest.TestCase): share_inputs = { "seq_lens_this_time": paddle.to_tensor([[1], [0], [15]], dtype="int32"), "seq_lens_encoder": paddle.to_tensor([[0], [0], [15]], dtype="int32"), + "seq_lens_decoder": paddle.to_tensor([[30], [0], [15]], dtype="int32"), "entropy_list": [[], [], []], "stop_flags": paddle.to_tensor([[True], [True], [False]], dtype="bool"), "req_ids": ["req_1", "req_2", "req_3"], @@ -111,6 +114,7 @@ class TestSpeculateCalculateLogitsEntropy(unittest.TestCase): share_inputs = { "seq_lens_this_time": paddle.to_tensor([[2], [2], [0], [15]], dtype="int32"), "seq_lens_encoder": paddle.to_tensor([[0], [0], [0], [15]], dtype="int32"), + "seq_lens_decoder": paddle.to_tensor([[30], [30], [0], [15]], dtype="int32"), "entropy_list": [[], [], [], []], "stop_flags": paddle.to_tensor([[False], [False], [True], [False]], dtype="bool"), "req_ids": ["req_1", "req_2", "req_3", "req_4"], @@ -130,8 +134,6 @@ class TestSpeculateCalculateLogitsEntropy(unittest.TestCase): speculate_calculate_logits_entropy(logits, share_inputs, temperature) - print(share_inputs["entropy_list"]) - self.assertEqual(len(share_inputs["entropy_list"][0]), 2) self.assertEqual(len(share_inputs["entropy_list"][1]), 1) self.assertEqual(len(share_inputs["entropy_list"][2]), 0) @@ -145,6 +147,7 @@ class TestSpeculateCalculateLogitsEntropy(unittest.TestCase): share_inputs = { "seq_lens_this_time": paddle.to_tensor([[2], [2], [0], [15]], dtype="int32"), "seq_lens_encoder": paddle.to_tensor([[0], [0], [0], [15]], dtype="int32"), + "seq_lens_decoder": paddle.to_tensor([[30], [30], [0], [15]], dtype="int32"), "entropy_list": [[], [], [], []], "stop_flags": paddle.to_tensor([[False], [False], [True], [False]], dtype="bool"), "req_ids": ["req_1", "req_2", "req_3", "req_4"], @@ -164,8 +167,6 @@ class TestSpeculateCalculateLogitsEntropy(unittest.TestCase): speculate_calculate_logits_entropy(logits, share_inputs, temperature) - print(share_inputs["entropy_list"]) - self.assertEqual(len(share_inputs["entropy_list"][0]), 2) self.assertEqual(len(share_inputs["entropy_list"][1]), 1) self.assertEqual(len(share_inputs["entropy_list"][2]), 0) @@ -179,6 +180,7 @@ class TestSpeculateCalculateLogitsEntropy(unittest.TestCase): share_inputs = { "seq_lens_this_time": paddle.to_tensor([[2], [2], [0], [15]], dtype="int32"), "seq_lens_encoder": paddle.to_tensor([[0], [0], [0], [15]], dtype="int32"), + "seq_lens_decoder": paddle.to_tensor([[30], [30], [0], [15]], dtype="int32"), "entropy_list": [[], [], [], []], "stop_flags": paddle.to_tensor([[True], [False], [True], [False]], dtype="bool"), "req_ids": ["req_1", "req_2", "req_3", "req_4"], @@ -198,8 +200,6 @@ class TestSpeculateCalculateLogitsEntropy(unittest.TestCase): speculate_calculate_logits_entropy(logits, share_inputs, temperature) - print(share_inputs["entropy_list"]) - self.assertEqual(len(share_inputs["entropy_list"][0]), 0) self.assertEqual(len(share_inputs["entropy_list"][1]), 1) self.assertEqual(len(share_inputs["entropy_list"][2]), 0)