diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index e3af67a304..ba3ce6ceb5 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -2910,12 +2910,19 @@ class GPUModelRunner(ModelRunnerBase): # Clear CUDAGraph if self.use_cudagraph: self.model.clear_graph_opt_backend() + if ( + self.speculative_decoding + and self.spec_method == SpecMethod.MTP + and self.graph_opt_config.draft_model_use_cudagraph + ): + self.proposer.model.clear_graph_opt_backend() # Clear parameters and Send single self.dynamic_weight_manager.clear_parameters( pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle ) - if self.spec_method == SpecMethod.MTP: - self.proposer.model.clear_graph_opt_backend() + + # NOTE(wangyanpeng): MTP cache must be cleared before clearing the main KV cache + if self.speculative_decoding and self.spec_method == SpecMethod.MTP: self.proposer.clear_mtp_cache() self.clear_cache() paddle.device.cuda.empty_cache()