diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 74e3f89c9d..d5058d6b3c 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -740,6 +740,8 @@ class SpeculativeConfig: self.num_extra_cache_layer = 0 + self.enable_draft_logprob: bool = False + for key, value in args.items(): if hasattr(self, key): setattr(self, key, value) diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 0a76c0157f..96edebdda5 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -660,7 +660,11 @@ class SpeculativeSampler(nn.Layer): top_p_logprob = None top_p_token_mask = None - if top_p_normalized_logprobs is not None and share_inputs is not None: + if ( + top_p_normalized_logprobs is not None + and share_inputs is not None + and sampling_metadata.top_p_normalized_logprobs_flag + ): real_token_top_p = ( sampling_metadata.top_p[:real_bsz].squeeze(1).repeat_interleave(batch_token_num).unsqueeze(1) ) @@ -837,9 +841,9 @@ class SpeculativeSampler(nn.Layer): logprobs_tensors = None token_ids = share_inputs["accept_tokens"] if num_logprobs is not None: - token_ids = paddle.concat( - [share_inputs["accept_tokens"][i, : share_inputs["accept_num"][i]] for i in range(real_bsz)] - ) + idx = paddle.arange(share_inputs["accept_tokens"].shape[1], dtype="int32") + mask = idx < share_inputs["accept_num"].unsqueeze(1) + token_ids = paddle.masked_select(share_inputs["accept_tokens"], mask) logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids) sampler_output = SamplerOutput( @@ -950,6 +954,7 @@ class MTPSampler(nn.Layer): else: raise NotImplementedError self.logprobs_mode = fd_config.model_config.logprobs_mode + self.enable_draft_logprob = fd_config.speculative_config.enable_draft_logprob def pre_process(self, skip_idx_list: List[int] = []): """pre process before running""" @@ -1002,7 +1007,11 @@ class MTPSampler(nn.Layer): top_p_logprob = None top_p_token_mask = None - if top_p_normalized_logprobs is not None and share_inputs is not None: + if ( + top_p_normalized_logprobs is not None + and share_inputs is not None + and sampling_metadata.top_p_normalized_logprobs_flag + ): real_token_top_p = ( sampling_metadata.top_p[:real_bsz] .squeeze(1) @@ -1079,7 +1088,7 @@ class MTPSampler(nn.Layer): """ """ num_logprobs = sampling_metadata.max_num_logprobs real_bsz = share_inputs["seq_lens_this_time"].shape[0] - if num_logprobs is not None and share_inputs["substep"] == 0: + if self.enable_draft_logprob and num_logprobs is not None and share_inputs["substep"] == 0: real_token_num = share_inputs["batch_token_num"][:real_bsz].sum() if self.logprobs_mode == "raw_logprobs": raw_logprobs = self.compute_logprobs( @@ -1110,7 +1119,7 @@ class MTPSampler(nn.Layer): token_ids = None logprobs_tensors = None - if num_logprobs is not None and share_inputs["substep"] == 0: + if self.enable_draft_logprob and num_logprobs is not None and share_inputs["substep"] == 0: token_ids = paddle.empty(real_token_num, dtype="int64") speculate_insert_first_token( token_ids, diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 1e81b167f7..578fe583c1 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -77,6 +77,7 @@ class TokenProcessor: self.speculative_decoding = self.cfg.speculative_config.method is not None self.use_logprobs = self.cfg.model_config.enable_logprob + self.enable_draft_logprob = self.cfg.speculative_config.enable_draft_logprob if self.speculative_decoding: if self.use_logprobs: @@ -446,7 +447,7 @@ class TokenProcessor: batch_result (list): batch results """ try: - if self.cfg.speculative_config.method and self.use_logprobs: + if self.cfg.speculative_config.method and self.use_logprobs and self.enable_draft_logprob: if mtype == 3: # target finished_batch_result, unfinished_batch_result = [], [] for r in batch_result: diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index c8252ed896..94dd928ef3 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -94,6 +94,7 @@ class MTPProposer(Proposer): self.mtp_strategy = self.speculative_config.mtp_strategy self.hybrid_mode = self.mtp_strategy == "with_ngram" and self.max_draft_token_num > self.num_model_steps self.enable_logprob = self.model_config.enable_logprob + self.enable_draft_logprob = self.speculative_config.enable_draft_logprob # [mixed, prefill, decoder] self.role = self.scheduler_config.splitwise_role @@ -943,7 +944,7 @@ class MTPProposer(Proposer): # 4. Compute logits, Sample logits = self.model.compute_logits(hidden_states) - if self.enable_logprob and substep == 0: + if self.enable_logprob and self.enable_draft_logprob and substep == 0: first_token_logits = self.model.compute_logits(self.model_inputs["first_token_hidden_states"]) speculate_get_logits( diff --git a/tests/output/test_process_batch_output.py b/tests/output/test_process_batch_output.py index 77d28aba04..fe0e05e3b0 100644 --- a/tests/output/test_process_batch_output.py +++ b/tests/output/test_process_batch_output.py @@ -122,6 +122,7 @@ class TestTokenProcessorProcessBatchOutput(unittest.TestCase): cfg.speculative_config.method = "mtp" if speculative_decoding else None cfg.speculative_config.num_speculative_tokens = 1 cfg.model_config.enable_logprob = use_logprobs + cfg.speculative_config.enable_draft_logprob = True processor = TokenProcessor.__new__(TokenProcessor) processor.cfg = cfg @@ -139,6 +140,7 @@ class TestTokenProcessorProcessBatchOutput(unittest.TestCase): processor.number_of_output_tokens = 0 processor.prefill_result_status = {} processor.use_logprobs = use_logprobs + processor.enable_draft_logprob = cfg.speculative_config.enable_draft_logprob processor.num_draft_tokens = 0 processor.num_accepted_tokens = 0 processor.num_emitted_tokens = 0 diff --git a/tests/output/test_token_processor.py b/tests/output/test_token_processor.py index 466c5fc20d..8d620cf7b1 100644 --- a/tests/output/test_token_processor.py +++ b/tests/output/test_token_processor.py @@ -53,6 +53,7 @@ class _DummyCfg: self.speculative_config = types.SimpleNamespace( method=speculative_method, num_speculative_tokens=2, + enable_draft_logprob=True, ) self.model_config = types.SimpleNamespace(enable_logprob=enable_logprob) self.scheduler_config = types.SimpleNamespace(name="default", splitwise_role="decode")