mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Speculative Decoding] Unify Spec and non-spec branch (#6685)
* optimize spec-inference architecture * delete debug log * optimize spec_method usage && fix unit_test * add claude unit-test skill * fix some ugly bug * enhance robustness and bounds check * unify method & spec_method to method to avoid bug * activate CI * fix unit test * Unify logprobs computation for naive and speculative decoding, fix CUDA kernel * fix logprob bug && optimize verify kernel * fix exist_decode() judge
This commit is contained in:
@@ -14,13 +14,17 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.ops.gpu import ngram_match
|
||||
|
||||
from .base import Proposer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastdeploy.config import FDConfig
|
||||
|
||||
|
||||
class NgramProposer(Proposer):
|
||||
"""
|
||||
@@ -29,7 +33,7 @@ class NgramProposer(Proposer):
|
||||
Matching corresponding tokens in input and output as draft tokens.
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
def __init__(self, fd_config: "FDConfig"):
|
||||
super().__init__(fd_config)
|
||||
self.max_ngram_size = self.speculative_config.max_ngram_size
|
||||
self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu()
|
||||
|
||||
Reference in New Issue
Block a user