mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
7707be8384
* [Feature][KVCache] Support cache manager v1 architecture Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Update cache manager and related modules Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * chore: update cache_manager and related modules Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: add node to evictable set in complete_swap_to_device When a node transitions from SWAP_TO_DEVICE to DEVICE via complete_swap_to_device, it was not being added to the _evictable_device set. This caused nodes with ref_count=0 to become "orphaned" - not appearing in any evictable set despite having cache_status=DEVICE. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat: update cache manager v1 and related modules - Add new cache_manager.py with cache management functionality - Add radix_tree.py for prefix caching - Update block_pool.py and metadata.py - Update request.py and resource_manager_v1.py for scheduling - Update gpu_model_runner.py for GPU model execution Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat(cache): add cache controller v1 implementation - Add CacheController class for cache management - Update config.py with cache related configurations - Refactor gpu_model_runner.py for improved cache handling * feat(cache_manager): update cache manager v1 * fix(cache_manager): 修复 swap_cache H2D/D2H 方向的 block_ids 逻辑并清理 ForwardMeta ## Motivation 修复 swap_cache_optimized.cu 中 H2D 方向时 src/dst block_ids 使用错误的问题, 并清理 ForwardMeta 中已废弃的 cache_controller 字段。 ## Modifications - fix: swap_cache_optimized.cu 中根据 D2H 模板参数正确选取 src/dst block_ids, 修复 H2D 方向 src/dst 倒置 bug(同时修复 SwapCachePerLayerImpl 和 SwapCacheAllLayersBatchImpl) - refactor: cache_manager/v1/__init__.py 将 LayerSwapTimeoutError 导入从 cache_controller 改为 cache_utils(正确来源) - refactor: ForwardMeta 移除废弃的 cache_controller 字段 - refactor: gpu_model_runner.py 移除对应的 cache_controller 赋值语句 - test: 新增 tests/cache_manager/v1/test_swap_cache_ops.py 单元测试 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * feat(cache_manager): refactor cache manager v1 and optimize swap ops ## Motivation 对 cache manager v1 进行重构和优化,精简代码结构,提升可维护性。 ## Modifications - 重构 transfer_manager.py,大幅精简代码逻辑 - 优化 swap_cache_optimized.cu GPU 算子实现 - 调整 cache_manager.py、cache_controller.py 逻辑,修复 free_device_blocks 方法缺失问题 - 更新 block_pool.py、cache_utils.py、metadata.py、radix_tree.py - 精简 gpu_model_runner.py、forward_meta.py、attention.py 中相关调用 - 更新对应单元测试(test_cache_controller、test_swap_cache_ops、test_transfer_manager) - 调整 config.py 中相关配置项 * [KVCache][MTP] 支持 cache_manager_v1 下的 MTP KV Cache 初始化及多模态 hash ## Motivation 在 enable_cache_manager_v1 路径下,MTP(speculative decode)的 KV Cache 需要由 CacheController 统一管理,以复用 swap/transfer 能力,同时修复多模态场景下 block hash 未携带 multimodal extra_keys 的问题。 ## Modifications - `cache_controller.py` - 新增 `initialize_mtp_kv_cache`:通过 CacheController 初始化 MTP KV Cache, 并将其注册到 cache_kvs_map,使 transfer_manager 自动覆盖 MTP 层 - `initialize_host_cache` 中的 num_layers 改为包含 MTP 额外 cache 层数,保证 Host Cache 也为 MTP 分配足够空间 - `_free_gpu_cache` 改名为 `free_gpu_cache`(对外可调用) - `cache_utils.py` - 新增 `get_block_hash_extra_keys`:提取单个 block 内的多模态 hash 信息, 对齐 PrefixCacheManager 的 multimodal extra_keys 逻辑 - `get_request_block_hasher` 中在 hash_block_tokens 时携带 extra_keys, 修复多模态场景 prefix cache 命中率不准的问题 - `spec_decode/mtp.py` - `update_mtp_block_num` 新增 `skip_cache_init` 参数,避免 v1 cache manager 路径下重复初始化 MTP KV Cache - `gpu_model_runner.py` - `initialize_kv_cache(v1)` 路径:在主模型 cache 初始化后,调用 `cache_controller.initialize_mtp_kv_cache` 完成 MTP cache 创建 - `clear_cache` / `wakeup` / `reset` 等路径:respect `enable_cache_manager_v1` 标志,跳过重复的 proposer.initialize_kv_cache 调用 ## Usage or Command ```bash # 启动支持 MTP + cache_manager_v1 的推理服务(示例) bash run.sh ``` * fix(cache_manager): multi-GPU fix, mm hash boundary fix, and remove batch ops 1. Fix CuPy stream/event creation for multi-GPU: wrap all stream operations with cp.cuda.Device(device_id) context to ensure streams/events are bound to the correct device, preventing cross-device errors in multi-GPU setups. 2. Remove cudaSetDevice from SwapCacheAllLayers (handled by cupy context now). 3. Remove swap_cache_all_layers_batch op: simplified the implementation by removing the batch upload variant; all-layer transfers now use the standard swap_cache_all_layers with cupy device context. 4. Fix mm hash boundary comparison in get_block_hash_extra_keys: change strict less-than (<) to less-than-or-equal (<=) so that multimodal items ending exactly at block start are correctly excluded. 5. Extract config fields to KVCacheBase: model_config, cache_config, quant_config, parallel_config are now set in the base class __init__ to avoid duplication in CacheController and CacheManager subclasses. 6. Translate metadata.py docstrings from Chinese to English for broader contributor accessibility. 7. Add test_cache_utils.py: comprehensive unit tests for get_block_hash_extra_keys covering all boundary and overlap scenarios. 8. Expand test suite: test_request.py cache fields tests, test_radix_tree.py backup candidate tests, test_transfer_manager.py and test_cache_manager.py multi-GPU and concurrent operation tests. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * [BugFix][KVCache] fix List import and move write_policy normalization to CacheManager ## Motivation 修复两处问题: 1. `fastdeploy/engine/request.py` 中 `List` 未导入导致 pre-commit F821 报错 2. `write_policy` 归一化逻辑(`write_through` → `write_through_selective`)不应放在 `FDConfig`,移至 `CacheManager.__init__` 中,使其只影响 Cache Manager V1 的内部逻辑 ## Modifications - `fastdeploy/engine/request.py`: 在 `typing` 导入中补充 `List`,删除重复的 `CacheSwapMetadata` TYPE_CHECKING 导入,修复 F821/F811 - `fastdeploy/config.py`: 删除 `write_policy` 归一化逻辑 - `fastdeploy/cache_manager/v1/cache_manager.py`: 将归一化逻辑移入 `CacheManager.__init__` Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * [BugFix][KVCache] fix pre-commit code style issues ## Motivation 修复 CI pre-commit 代码风格检查失败问题。 ## Modifications - `fastdeploy/engine/common_engine.py`: black 格式化 - `fastdeploy/worker/worker_process.py`: black 格式化 + isort 修复 - `fastdeploy/cache_manager/v1/storage/__init__.py`: isort 修复 - `fastdeploy/worker/gpu_worker.py`: isort 修复 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * [Feature][KVCache] update cache_manager_v1 modules ## Motivation 更新 Cache Manager V1 相关模块,完善版权信息、改进模块结构与可维护性。 ## Modifications - `fastdeploy/cache_manager/v1/` 系列模块:补充版权 header,优化代码结构 - `fastdeploy/config.py`:配置项更新 - `fastdeploy/engine/sched/resource_manager_v1.py`:调度相关更新 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * [Feature][KVCache] add BatchRequest.from_tasks and refactor worker task parsing ## Motivation 将 worker_process 中重复的 task 解析逻辑收敛到 BatchRequest,减少代码冗余,提升可维护性。 ## Modifications - `fastdeploy/engine/request.py`:新增 `BatchRequest.from_tasks()` 类方法,统一将 task_queue 任务分类为推理请求和控制请求 - `fastdeploy/worker/worker_process.py`:使用 `BatchRequest.from_tasks()` 替代内联解析逻辑,并修复重复的 control_reqs 处理块 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * [Feature][KVCache] add NUMA affinity for host cache and skip swap cache tests ## Motivation 优化 Host cache 内存分配的 NUMA 亲和性,减少跨 NUMA 访问延迟; 同时跳过 swap cache ops 测试(当前环境不支持)。 ## Modifications - `fastdeploy/cache_manager/v1/cache_controller.py`: - 新增 `_get_numa_node_for_gpu()` 方法,通过 nvidia-smi 或 sysfs 获取 GPU 对应的 NUMA 节点 - 新增 `_bind_to_closest_numa_node()` 方法,绑定当前线程到 GPU 最近的 NUMA 节点 - 在 `initialize_host_cache()` 中调用 NUMA 绑定,优化 H2D 传输性能 - `tests/cache_manager/v1/test_swap_cache_ops.py`:跳过所有测试类(`TestSwapCacheAllLayersCorrectness`、`TestSwapCacheAllLayersPerformance`、`TestSwapCacheRandomBlockIndices`) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * [BugFix][KVCache] fix unittest failures for cache_manager_v1 三个单测因接口变更或 Mock 方式问题导致失败,需修复。 - tests/distributed/chunked_moe.py:`setup_model_runner` 使用 `__new__` 跳过 `__init__`,补加 `enable_cache_manager_v1 = False`,修复 `AttributeError` - tests/engine/test_resource_manager.py:`PrefixCacheManager` 为局部导入,`patch` 路径改为定义位置 `fastdeploy.cache_manager.prefix_cache_manager.PrefixCacheManager` - tests/v1/test_resource_manager_v1.py:`_trigger_preempt` 第四参数已由 `list` 改为 `BatchRequest`,更新测试传参和断言 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * [BugFix][KVCache] remove debug logging code ## Modifications - fastdeploy/engine/request.py:删除调试用 logger 及 prompt_hashes 中的 debug 日志 - fastdeploy/worker/worker_process.py:删除 __main__ 中的调试 import 和 print 语句 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * [BugFix][KVCache] fix cupy device id caching and pickle for _match_result ## Motivation 修复两个 bug: 1. `transfer_manager.py` 中每次调用 `cp.cuda.runtime.getDevice()` 存在隐患,应在初始化时缓存为实例变量,保证后续操作使用一致的设备 ID。 2. `request.py` 的 `__getstate__` 未跳过 `_match_result`,该字段包含 BlockNode 树的父子循环引用,pickle 时会触发 `RecursionError`;同时补充 `__setstate__` 确保 unpickle 后字段恢复为安全默认值。 ## Modifications - `transfer_manager.py`:初始化时调用 `cp.cuda.runtime.getDevice()` 并缓存到 `self._cupy_device_id`,后续 `with cp.cuda.Device(...)` 和日志均使用该缓存值。 - `request.py`: - `__getstate__` 中将 `_match_result` 加入跳过集合 `_SKIP_KEYS`,避免循环引用导致 pickle 失败。 - 新增 `__setstate__`,unpickle 后将 `_block_hasher` 和 `_match_result` 恢复为 `None`。 ## Usage or Command * fix(test): fix unit test errors for _trigger_preempt and wakeup with MTP - Use BatchRequest instead of list in test_trigger_preempt_records_tasks - Add missing enable_cache_manager_v1 attr in TestSleepWakeupBehavior._make_runner Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * [BugFix][KVCache] fix gpu_free_block_list returning wrong block IDs ## Motivation `gpu_free_block_list` 的兼容 property 中误用了 `list(range(N))`, 将 `available_blocks()` 的返回值当作整数传给 `range()`, 导致返回 `[0, 1, ..., N-1]` 的假列表,而非真实的空闲 block ID。 ## Modifications - `cache_manager/v1/cache_manager.py`:将 `list(range(self._device_pool.available_blocks()))` 改为 `list(self._device_pool.available_blocks())` Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * [BugFix][KVCache] 修复 gpu_free_block_list 返回 int 导致 TypeError ## Motivation gpu_free_block_list 属性中调用 BlockPool.available_blocks(), 该方法返回 int(空闲块数量),用 list() 包装 int 会触发 TypeError: 'int' object is not iterable。 ## Modifications 将 list(self._device_pool.available_blocks()) 改为 list(self._device_pool._free_blocks),直接返回空闲块索引列表。 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * [KVCache][CacheManager] 适配 V1 CacheManager 的 pause/sleep/free_cache 操作 ## Motivation V1 CacheManager 引入了新的 reset_cache() 接口,pause 和 sleep 操作需要适配, 同时 free_cache 需要支持可选的 clear_storage 参数。 ## Modifications - cache_controller.py: free_cache 新增 clear_storage 参数(默认 False), 仅当 clear_storage=True 时才调用 _clear_storage(),避免不必要的 storage 清空 - common_engine.py: pause 和 sleep 操作中,当 ENABLE_V1_KVCACHE_MANAGER 时 使用 cache_manager.reset_cache() 替代旧的 reset() 和 pause_transfer 逻辑 - gpu_model_runner.py: sleep 时仅在非 V1 cache manager 下执行 MTP cache 清除 ## Usage or Command # 启动服务(V1 CacheManager) python -m fastdeploy.entrypoints.openai.api_server \ --enable-v1-kvcache-manager \ ... * [BugFix][KVCache] fix missing enable_cache_manager_v1 in test mocks and remove unused select_blocks_for_backup - Remove unused `select_blocks_for_backup` method from radix_tree.py - Fix `match_prefix` default param `skip_storage=True` and log order in cache_manager.py - Sync test_gpu_model_runner.py with upstream/develop (add TestInsertTasksV1SplitwiseSuffix) - Add `enable_cache_manager_v1=False` to all mock runners to fix AttributeError in CI Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * [BugFix][KVCache] simplify _free_blocks in ResourceManagerV1 for non-v1 path Remove redundant prefix_caching branch in else path; always call recycle_gpu_blocks with full block_tables for non-cache-manager-v1 case. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * [KVCache][Optimization][BugFix] fix and optimize block_pool, cache_manager, transfer_manager, request ## Motivation 修复 cache_manager v1 中若干代码质量问题,提升性能并消除潜在的类型不一致 Bug。 ## Modifications 1. **block_pool.py**:`BlockPool.allocate` 将逐个 pop 循环替换为切片 + 批量 set.update,消除 Python 循环开销,O(n) → O(k)(C 层批量操作) 2. **cache_manager.py**:`match_prefix` 在 prefix caching 关闭时提前 return 前写入空 `MatchResult()`,避免调用方解引用 `_match_result=None` 崩溃 3. **transfer_manager.py**:`_build_device_layer_indices` 在 `_cache_kvs_map` 为空时也重置四个层索引列表,防止残留旧 tensor 被 swap 算子使用 4. **request.py**:`BatchRequest.append_swap_metadata` / `append_evict_metadata` 构造 `CacheSwapMetadata` 时将 `src_type`/`dst_type` 从字符串改为 `CacheLevel` 枚举,与字段类型声明一致;补充 `CacheLevel` 导入;`match_result` 属性返回类型标注修正为 `Optional[MatchResult]` 5. **resource_manager_v1.py**:`_allocate_gpu_blocks` 日志从 `INFO` 降级为 `DEBUG`,消除高频调度路径的日志噪音 6. **tests/engine/test_request.py**:同步更新 `src_type`/`dst_type` 断言为 `CacheLevel` 枚举值,补充 `CacheLevel` 导入 ## Usage or Command 单元测试: ```bash source .venv/py310/bin/activate cd baidu/FastDeploy python -m pytest tests/cache_manager/v1/test_cache_manager.py -v python -m pytest tests/cache_manager/v1/test_transfer_manager.py -v python -m pytest tests/engine/test_request.py -v ``` * [BugFix][KVCache] Fix BlockPool.allocate returns all blocks when num_blocks=0 ## Motivation 当 `allocate(num_blocks=0)` 被调用时,Python 负索引陷阱导致严重错误: `-0 == 0`,所以 `self._free_blocks[-0:]` 等价于 `self._free_blocks[0:]`, 会返回并清空整个空闲块列表,而非返回空列表。 ## Modifications 在 `BlockPool.allocate` 中增加对 `num_blocks == 0` 的提前判断,直接返回 `[]`, 避免触发 Python 负索引陷阱。 ## Usage or Command ```bash # 运行相关单元测试验证修复 python -m pytest tests/cache_manager/v1/test_cache_manager.py -vv -s ``` * [KVCache][Test] add unit tests for cache_manager v1 modules ## Motivation 补全 cache_manager/v1 各模块的单测覆盖,确保核心方法有完整的测试保障。 ## Modifications 新增/补充以下测试文件,全部 326 个用例通过: - tests/cache_manager/v1/test_block_pool.py(新建) 覆盖 BlockPool.get_metadata/set_metadata/resize、DeviceBlockPool/HostBlockPool - tests/cache_manager/v1/test_metadata.py(新建) 覆盖 BlockNode、RadixTreeStats、MatchResult、CacheSwapMetadata、AsyncTaskHandler - tests/cache_manager/v1/test_cache_utils.py(补充) 新增 hash_block_tokens、get_request_block_hasher、LayerDoneCounter 时间追踪及内部辅助方法 - tests/cache_manager/v1/test_radix_tree.py(补充) 新增 TestCompleteSwapToDevice 专项测试类(6 个用例) - tests/cache_manager/v1/test_cache_manager.py(补充) 新增 offload_to_host、load_from_host、pending backup 系列、prepare_prefetch_metadata - tests/cache_manager/v1/test_transfer_manager.py(补充) 新增 _swap_single_layer 校验路径、sync_input/output_stream、record_input_stream_event ## Usage or Command ```bash # 运行所有新增单测 source .venv/py310/bin/activate python -m pytest tests/cache_manager/v1/test_block_pool.py \ tests/cache_manager/v1/test_metadata.py \ tests/cache_manager/v1/test_cache_utils.py \ tests/cache_manager/v1/test_radix_tree.py \ tests/cache_manager/v1/test_cache_manager.py \ tests/cache_manager/v1/test_transfer_manager.py -v # 期望结果:326 passed ``` --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
1327 lines
63 KiB
Python
1327 lines
63 KiB
Python
"""
|
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
import os
|
|
import time
|
|
from typing import TYPE_CHECKING, List
|
|
|
|
import numpy as np
|
|
import paddle
|
|
from paddleformers.utils.log import logger
|
|
|
|
from fastdeploy import envs
|
|
from fastdeploy.engine.request import Request, RequestType
|
|
from fastdeploy.inter_communicator import IPCSignal
|
|
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
|
from fastdeploy.model_executor.layers.attention import get_attention_backend
|
|
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
|
AttentionBackend,
|
|
)
|
|
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
|
from fastdeploy.model_executor.layers.sample.sampler import MTPSampler
|
|
from fastdeploy.model_executor.model_loader import get_model_loader
|
|
from fastdeploy.model_executor.models import ModelForCasualLM
|
|
from fastdeploy.platforms import current_platform
|
|
|
|
if current_platform.is_xpu():
|
|
from fastdeploy.model_executor.ops.xpu import (
|
|
draft_model_postprocess,
|
|
draft_model_preprocess,
|
|
draft_model_update,
|
|
eagle_get_hidden_states,
|
|
eagle_get_self_hidden_states,
|
|
mtp_save_first_token,
|
|
mtp_step_paddle,
|
|
set_data_ipc,
|
|
share_external_data,
|
|
update_attn_mask_offsets,
|
|
)
|
|
|
|
# temporary solution
|
|
from fastdeploy.model_executor.xpu_pre_and_post_process import (
|
|
async_set_value,
|
|
xpu_pre_process,
|
|
xpu_process_output,
|
|
)
|
|
else:
|
|
from fastdeploy.model_executor.ops.gpu import (
|
|
draft_model_postprocess,
|
|
draft_model_preprocess,
|
|
draft_model_update,
|
|
eagle_get_hidden_states,
|
|
eagle_get_self_hidden_states,
|
|
eagle_gather_hidden_states,
|
|
hybrid_mtp_ngram,
|
|
mtp_step_paddle,
|
|
share_external_data,
|
|
speculate_get_logits,
|
|
speculate_save_output_topk,
|
|
update_attn_mask_offsets,
|
|
set_data_ipc,
|
|
unset_data_ipc,
|
|
)
|
|
from fastdeploy.model_executor.pre_and_post_process import async_set_value, pre_process
|
|
|
|
from fastdeploy.worker.input_batch import (
|
|
ProposerInputBatch,
|
|
recover_batch_index_for_output,
|
|
recover_batch_index_for_sampler_output,
|
|
reorder_split_prefill_and_decode_form_index_to_batch_id,
|
|
)
|
|
|
|
from .base import Proposer
|
|
|
|
if TYPE_CHECKING:
|
|
from fastdeploy.config import FDConfig
|
|
|
|
|
|
class MTPProposer(Proposer):
|
|
"""
|
|
Proposer for Multi-Token-Prediction(MTP)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
fd_config: "FDConfig",
|
|
main_model: ModelForCasualLM,
|
|
local_rank: int,
|
|
device_id: int, # physical device id
|
|
target_model_inputs, # main model share inputs
|
|
):
|
|
super().__init__(fd_config)
|
|
self.num_main_model_layers = self.model_config.num_hidden_layers
|
|
self.local_rank = local_rank
|
|
self.device_id = device_id
|
|
self.use_attn_mask_offset = self.enable_mm
|
|
|
|
self._update_mtp_config(main_model)
|
|
self._load_model()
|
|
self.target_model_inputs = target_model_inputs
|
|
self.mtp_strategy = self.speculative_config.mtp_strategy
|
|
self.hybrid_mode = self.mtp_strategy == "with_ngram" and self.max_draft_token_num > self.num_model_steps
|
|
self.enable_logprob = self.model_config.enable_logprob
|
|
self.enable_draft_logprob = self.speculative_config.enable_draft_logprob
|
|
self.cache_kvs_map = {}
|
|
|
|
# [mixed, prefill, decoder]
|
|
self.role = self.scheduler_config.splitwise_role
|
|
self.pd_disaggregation_mode = fd_config.parallel_config.pd_disaggregation_mode
|
|
|
|
if current_platform.is_xpu():
|
|
self._prepare_inputs = self._prepare_inputs_xpu
|
|
self._propose = self._propose_xpu
|
|
elif current_platform.is_cuda() or current_platform.is_maca():
|
|
self._prepare_inputs = self._prepare_inputs_cuda
|
|
self._propose = self._propose_cuda
|
|
else:
|
|
raise RuntimeError(
|
|
f"Unsupported platform for MTP: {current_platform}. " f"Supported platforms: CUDA, MACA, XPU"
|
|
)
|
|
|
|
self.sampler = MTPSampler(fd_config)
|
|
self.model_inputs = ProposerInputBatch(self.fd_config, self.target_model_inputs)
|
|
self.model_inputs.init_share_inputs()
|
|
|
|
# CUDA Graph
|
|
self.draft_model_use_cudagraph = self.graph_opt_config.draft_model_use_cudagraph
|
|
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
|
|
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes
|
|
|
|
self.attn_backends: list[AttentionBackend] = []
|
|
self._initialize_attn_backend()
|
|
|
|
# Forward meta store the global meta information of the forward
|
|
self.forward_meta = None
|
|
self.exist_prefill_flag = False
|
|
|
|
def _update_mtp_config(self, main_model):
|
|
"""
|
|
Update config for MTP from global config
|
|
"""
|
|
self.forward_meta: ForwardMeta = None
|
|
self.model_config.architectures[0] = self.model_config.architectures[0].replace("Moe", "MTP")
|
|
self.speculative_config.sharing_model = main_model
|
|
# TODO (wangyanpeng): The number of MTP layers should be read from model config
|
|
self.model_config.num_hidden_layers = 1
|
|
self.model_config.model = self.speculative_config.model
|
|
if "Ernie" in self.model_config.architectures[0]:
|
|
self.model_config.pretrained_config.prefix_name = "ernie.mtp_block"
|
|
self.model_config.prefix_layer_name = "mtp_block"
|
|
if self.speculative_config.quantization != "":
|
|
self.model_config.quantization = self.speculative_config.quantization
|
|
self.model_config.start_layer_index = self.num_main_model_layers
|
|
self.speculative_config.model_type = "mtp"
|
|
if not self.use_attn_mask_offset:
|
|
self.model_config.causal = True
|
|
|
|
def _load_model(self):
|
|
"""
|
|
Load MTP Layer
|
|
"""
|
|
model_loader = get_model_loader(load_config=self.fd_config.load_config)
|
|
self.model = model_loader.load_model(fd_config=self.fd_config)
|
|
|
|
def dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int):
|
|
"""Set dummy prefill inputs to model_inputs"""
|
|
max_dec_len = expected_decode_len + 1
|
|
|
|
input_length = min(
|
|
num_tokens // batch_size,
|
|
self.model_config.max_model_len - max_dec_len,
|
|
)
|
|
|
|
# TODO(wanglongzhi): Figure out the accurate buffer size of DeepEP.
|
|
if self.fd_config.parallel_config.enable_expert_parallel:
|
|
input_length = min(input_length, 32)
|
|
|
|
block_num = (
|
|
input_length + self.cache_config.block_size - 1
|
|
) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
|
|
|
|
for i in range(batch_size):
|
|
idx = i
|
|
self.model_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
|
|
self.model_inputs["eos_token_id"][:] = np.array([2], dtype="int64").reshape(-1, 1)
|
|
self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = input_length
|
|
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = input_length
|
|
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = 0
|
|
self.model_inputs["step_idx"][idx : idx + 1] = 0
|
|
self.model_inputs["max_dec_len"][idx : idx + 1] = max_dec_len
|
|
self.model_inputs["stop_flags"][idx : idx + 1] = False
|
|
|
|
self.model_inputs["encoder_block_lens"][idx : idx + 1] = block_num
|
|
self.model_inputs["block_tables"][idx : idx + 1, :block_num] = np.arange(
|
|
idx * block_num, (idx + 1) * block_num, 1
|
|
)
|
|
self.model_inputs.seq_lens_this_time = self.model_inputs["seq_lens_this_time_buffer"]
|
|
|
|
def initialize_kv_cache(self, main_model_num_blocks, profile: bool = False):
|
|
"""
|
|
Initialize kv cache
|
|
"""
|
|
self.num_gpu_blocks = int(main_model_num_blocks * self.speculative_config.num_gpu_block_expand_ratio)
|
|
self.cache_kvs = {}
|
|
|
|
# Get kv cache dtype
|
|
cache_type = self.model_config.dtype
|
|
kv_cache_quant_type = None
|
|
if (
|
|
self.quant_config
|
|
and hasattr(self.quant_config, "kv_cache_quant_type")
|
|
and self.quant_config.kv_cache_quant_type is not None
|
|
):
|
|
cache_type = self._get_cache_type()
|
|
kv_cache_quant_type = self.quant_config.kv_cache_quant_type
|
|
|
|
# Get kv cache shape
|
|
key_cache_shape, value_cache_shape = self.attn_backends[0].get_kv_cache_shape(
|
|
max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
|
|
)
|
|
if kv_cache_quant_type == "block_wise_fp8":
|
|
kv_cache_scale_shape = [key_cache_shape[0], key_cache_shape[1], key_cache_shape[2]]
|
|
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
|
|
|
cache_ready_signal_data = np.zeros(shape=[self.parallel_config.tensor_parallel_size], dtype=np.int32)
|
|
cache_ready_signal = IPCSignal(
|
|
name="cache_ready_signal",
|
|
array=cache_ready_signal_data,
|
|
dtype=np.int32,
|
|
suffix=self.parallel_config.local_engine_worker_queue_port,
|
|
create=False,
|
|
)
|
|
|
|
# Check if gpu runner needs to create kv cache
|
|
# 1. During profiling, it creates its own kv cache.
|
|
# 2. If no need to profile, create kv cache if cache managers do not exist.
|
|
create_cache_tensor = profile or not (
|
|
self.fd_config.cache_config.num_cpu_blocks > 0
|
|
or self.fd_config.cache_config.kvcache_storage_backend
|
|
or self.fd_config.scheduler_config.splitwise_role != "mixed"
|
|
)
|
|
|
|
if not create_cache_tensor:
|
|
logger.info(f"Waiting for cache managers to create kv cache.. {cache_ready_signal.value}")
|
|
while cache_ready_signal.value[local_rank] != 1:
|
|
time.sleep(1)
|
|
logger.info(f"OK! Stop waiting. {cache_ready_signal.value}")
|
|
|
|
logger.info(f"Initializing kv cache for all layers. {cache_ready_signal.value}")
|
|
|
|
if not create_cache_tensor:
|
|
cache_kvs_list = []
|
|
for i in range(
|
|
self.num_main_model_layers,
|
|
self.num_main_model_layers + self.model_config.num_hidden_layers,
|
|
):
|
|
logger.info(
|
|
f"..attaching kv cache for mtp layer {i}: key:{key_cache_shape}, value:{value_cache_shape}"
|
|
)
|
|
key_cache = paddle.empty(shape=[], dtype=cache_type)
|
|
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
|
|
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
|
|
key_cache = self._share_external_data(key_cache, key_cache_name, key_cache_shape)
|
|
self.cache_kvs_map[key_cache_name] = key_cache
|
|
cache_kvs_list.append(key_cache)
|
|
value_cache = paddle.empty(shape=[], dtype=cache_type)
|
|
value_cache = self._share_external_data(value_cache, val_cache_name, value_cache_shape)
|
|
self.cache_kvs_map[val_cache_name] = value_cache
|
|
cache_kvs_list.append(value_cache)
|
|
|
|
if kv_cache_quant_type == "block_wise_fp8":
|
|
scale_key_cache_name = f"key_cache_scales_{i}_rank{local_rank}.device{self.device_id}"
|
|
scale_val_cache_name = f"value_cache_scales_{i}_rank{local_rank}.device{self.device_id}"
|
|
key_scale_cache = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
|
|
key_scale_cache = self._share_external_data(
|
|
key_scale_cache, scale_key_cache_name, kv_cache_scale_shape
|
|
)
|
|
self.cache_kvs_map[scale_key_cache_name] = key_scale_cache
|
|
cache_kvs_list.append(key_scale_cache)
|
|
value_scale_cache = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
|
|
value_scale_cache = self._share_external_data(
|
|
value_scale_cache, scale_val_cache_name, kv_cache_scale_shape
|
|
)
|
|
self.cache_kvs_map[scale_val_cache_name] = value_scale_cache
|
|
cache_kvs_list.append(value_scale_cache)
|
|
|
|
self.model_inputs["caches"] = cache_kvs_list
|
|
else:
|
|
cache_kvs_list = []
|
|
for i in range(
|
|
self.num_main_model_layers,
|
|
self.num_main_model_layers + self.model_config.num_hidden_layers,
|
|
):
|
|
logger.info(f"..creating kv cache for mtp layer {i}: key:{key_cache_shape}, value:{value_cache_shape}")
|
|
key_cache = paddle.full(
|
|
shape=key_cache_shape,
|
|
fill_value=0,
|
|
dtype=cache_type,
|
|
)
|
|
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
|
|
set_data_ipc(key_cache, key_cache_name)
|
|
self.cache_kvs_map[key_cache_name] = key_cache
|
|
cache_kvs_list.append(key_cache)
|
|
|
|
val_cache = paddle.full(
|
|
shape=value_cache_shape,
|
|
fill_value=0,
|
|
dtype=cache_type,
|
|
)
|
|
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
|
|
set_data_ipc(val_cache, val_cache_name)
|
|
self.cache_kvs_map[val_cache_name] = val_cache
|
|
cache_kvs_list.append(val_cache)
|
|
|
|
if kv_cache_quant_type == "block_wise_fp8":
|
|
key_cache_scales = paddle.full(
|
|
shape=kv_cache_scale_shape,
|
|
fill_value=0,
|
|
dtype=paddle.get_default_dtype(),
|
|
)
|
|
key_cache_scales_name = f"key_cache_scales_{i}_rank{local_rank}.device{self.device_id}"
|
|
set_data_ipc(key_cache_scales, key_cache_scales_name)
|
|
self.cache_kvs_map[key_cache_scales_name] = key_cache_scales
|
|
cache_kvs_list.append(key_cache_scales)
|
|
|
|
val_cache_scales = paddle.full(
|
|
shape=kv_cache_scale_shape,
|
|
fill_value=0,
|
|
dtype=paddle.get_default_dtype(),
|
|
)
|
|
val_cache_scales_name = f"value_cache_scales_{i}_rank{local_rank}.device{self.device_id}"
|
|
set_data_ipc(val_cache_scales, val_cache_scales_name)
|
|
self.cache_kvs_map[val_cache_scales_name] = val_cache_scales
|
|
cache_kvs_list.append(val_cache_scales)
|
|
|
|
self.model_inputs["caches"] = cache_kvs_list
|
|
|
|
self._empty_cache()
|
|
|
|
def _initialize_attn_backend(
|
|
self,
|
|
) -> None:
|
|
"""
|
|
Initialize attention backends and forward metadata
|
|
"""
|
|
assert len(self.attn_backends) == 0
|
|
|
|
num_heads = self.model_config.num_attention_heads // self.parallel_config.tensor_parallel_size
|
|
self.model_config.kv_num_heads = max(
|
|
1,
|
|
int(self.model_config.num_key_value_heads) // self.parallel_config.tensor_parallel_size,
|
|
)
|
|
head_dim = self.model_config.head_dim
|
|
|
|
# Initialize AttentionBackend buffers
|
|
encoder_block_shape_q = 64
|
|
decoder_block_shape_q = 16
|
|
|
|
self.model_inputs["decoder_batch_ids"] = paddle.zeros_like(self.target_model_inputs["decoder_batch_ids"])
|
|
self.model_inputs["decoder_tile_ids_per_batch"] = paddle.zeros_like(
|
|
self.target_model_inputs["decoder_tile_ids_per_batch"]
|
|
)
|
|
if current_platform.is_xpu() or current_platform.is_maca():
|
|
self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like(
|
|
self.target_model_inputs["decoder_num_blocks_cpu"]
|
|
).cpu()
|
|
else:
|
|
self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like(
|
|
self.target_model_inputs["decoder_num_blocks_cpu"]
|
|
).pin_memory()
|
|
self.model_inputs["decoder_num_blocks_device"] = paddle.zeros_like(
|
|
self.target_model_inputs["decoder_num_blocks_device"]
|
|
)
|
|
self.model_inputs["decoder_chunk_size_device"] = paddle.zeros_like(
|
|
self.target_model_inputs["decoder_chunk_size_device"]
|
|
)
|
|
self.model_inputs["max_len_tensor_cpu"] = paddle.zeros_like(
|
|
self.target_model_inputs["max_len_tensor_cpu"]
|
|
).cpu()
|
|
|
|
self.model_inputs["encoder_batch_ids"] = paddle.zeros_like(self.target_model_inputs["encoder_batch_ids"])
|
|
self.model_inputs["encoder_tile_ids_per_batch"] = paddle.zeros_like(
|
|
self.target_model_inputs["encoder_tile_ids_per_batch"]
|
|
)
|
|
self.model_inputs["encoder_num_blocks_x_cpu"] = paddle.zeros_like(
|
|
self.target_model_inputs["encoder_num_blocks_x_cpu"]
|
|
).cpu()
|
|
self.model_inputs["kv_batch_ids"] = paddle.zeros_like(self.target_model_inputs["kv_batch_ids"])
|
|
self.model_inputs["kv_tile_ids_per_batch"] = paddle.zeros_like(
|
|
self.target_model_inputs["kv_tile_ids_per_batch"]
|
|
)
|
|
self.model_inputs["kv_num_blocks_x_cpu"] = paddle.zeros_like(
|
|
self.target_model_inputs["kv_num_blocks_x_cpu"]
|
|
).cpu()
|
|
|
|
# Get the attention backend
|
|
attn_cls = get_attention_backend()
|
|
attn_backend = attn_cls(
|
|
self.fd_config,
|
|
kv_num_heads=self.model_config.kv_num_heads,
|
|
num_heads=num_heads,
|
|
head_dim=head_dim,
|
|
encoder_block_shape_q=encoder_block_shape_q,
|
|
decoder_block_shape_q=decoder_block_shape_q,
|
|
)
|
|
if attn_backend is None:
|
|
raise NotImplementedError(
|
|
"Attention backend which you specified is not supported, please set FD_ATTENTION_BACKEND correctly."
|
|
)
|
|
self.attn_backends.append(attn_backend)
|
|
|
|
def clear_mtp_cache(self, profile=False):
|
|
"""
|
|
Clear allocated cacheKV
|
|
"""
|
|
create_cache_tensor = profile or not (
|
|
self.fd_config.cache_config.num_cpu_blocks > 0
|
|
or self.fd_config.cache_config.kvcache_storage_backend
|
|
or self.fd_config.scheduler_config.splitwise_role != "mixed"
|
|
)
|
|
if not create_cache_tensor:
|
|
for name, tensor in self.cache_kvs_map.items():
|
|
unset_data_ipc(tensor, name, True, False)
|
|
self.cache_kvs_map.clear()
|
|
del self.model_inputs["caches"]
|
|
if self.forward_meta is not None:
|
|
del self.forward_meta.caches
|
|
|
|
def update_mtp_block_num(self, num_gpu_blocks, skip_cache_init: bool = False) -> None:
|
|
"""
|
|
Update MTP block num by theoretical calculation
|
|
|
|
Args:
|
|
num_gpu_blocks: Main model GPU block count.
|
|
skip_cache_init: When True, skip internal initialize_kv_cache call.
|
|
Set this when the caller (e.g. gpu_model_runner with enable_cache_manager_v1)
|
|
has already re-created MTP cache via cache_controller.
|
|
"""
|
|
# Reset block table and kv cache with global block num
|
|
self.main_model_num_gpu_blocks = num_gpu_blocks
|
|
if not skip_cache_init:
|
|
self.initialize_kv_cache(main_model_num_blocks=self.main_model_num_gpu_blocks)
|
|
|
|
# Reset free list
|
|
free_list = list(
|
|
range(
|
|
self.num_gpu_blocks - 1,
|
|
int(self.main_model_num_gpu_blocks * self.cache_config.kv_cache_ratio) - 1,
|
|
-1,
|
|
)
|
|
)
|
|
self.free_list_len = len(free_list)
|
|
self.model_inputs.update(
|
|
{
|
|
"free_list": paddle.to_tensor(free_list, dtype="int32"),
|
|
"free_list_len": paddle.full([1], self.free_list_len, dtype="int32"),
|
|
}
|
|
)
|
|
|
|
def insert_tasks_v1(
|
|
self, req_dicts: List[Request], num_running_requests: int, target_model_index_to_batch_id: dict = {}
|
|
):
|
|
|
|
if "caches" not in self.model_inputs:
|
|
self.initialize_kv_cache()
|
|
req_len = len(req_dicts)
|
|
self.model_inputs["num_running_requests"] = num_running_requests
|
|
self.model_inputs["running_requests_ids"] = range(num_running_requests)
|
|
if target_model_index_to_batch_id:
|
|
self.model_inputs.index_to_batch_id = dict(target_model_index_to_batch_id)
|
|
for i in range(req_len):
|
|
request = req_dicts[i]
|
|
logger.debug(f"{i}th request-{request.request_id}: {request}")
|
|
idx = self.model_inputs.get_index_by_batch_id(request.idx)
|
|
if request.task_type.value == RequestType.PREFILL.value: # prefill task
|
|
prefill_start_index = request.prefill_start_index
|
|
prefill_end_index = request.prefill_end_index
|
|
length = prefill_end_index - prefill_start_index
|
|
|
|
input_ids = request.prompt_token_ids + request.output_token_ids
|
|
|
|
self.model_inputs["input_ids_len"][idx] = length - 1
|
|
async_set_value(self.model_inputs["pre_ids"][idx : idx + 1], -1)
|
|
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs["input_ids"][
|
|
idx : idx + 1, 1:length
|
|
]
|
|
# TODO: use token_all_ids replace with input_ids_cpu
|
|
if getattr(self, "hybrid_mode", False) and "input_ids_cpu" in self.model_inputs:
|
|
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = self.target_model_inputs[
|
|
"input_ids"
|
|
][idx : idx + 1, 1:length].cpu()
|
|
encoder_block_num = len(request.block_tables)
|
|
async_set_value(self.model_inputs["encoder_block_lens"][idx : idx + 1], encoder_block_num)
|
|
async_set_value(self.model_inputs["block_tables"][idx : idx + 1, :], -1)
|
|
async_set_value(
|
|
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num], request.block_tables
|
|
)
|
|
|
|
async_set_value(self.model_inputs["stop_flags"][idx : idx + 1], False)
|
|
async_set_value(self.model_inputs["batch_drop"][idx : idx + 1], False)
|
|
|
|
async_set_value(self.model_inputs["seq_lens_encoder"][idx : idx + 1], length)
|
|
self.exist_prefill_flag = True
|
|
async_set_value(self.model_inputs["seq_lens_decoder"][idx : idx + 1], prefill_start_index)
|
|
async_set_value(self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1], length)
|
|
async_set_value(
|
|
self.model_inputs["step_idx"][idx : idx + 1],
|
|
len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0,
|
|
)
|
|
if self.use_attn_mask_offset:
|
|
inputs = request.multimodal_inputs
|
|
self.model_inputs["attn_mask_offsets_full"][idx][0 : prefill_end_index - prefill_start_index] = (
|
|
paddle.to_tensor(
|
|
inputs["attention_mask_offset"][prefill_start_index:prefill_end_index], dtype="int32"
|
|
)
|
|
)
|
|
# GPU don't need it anymore
|
|
# NOTE: XPU backend needs decoder attention mask offset; GPU backend does not use it
|
|
if current_platform.is_xpu():
|
|
self.model_inputs["attn_mask_offsets_decoder"][idx : idx + 1] = (
|
|
inputs["attention_mask_offset"][prefill_end_index - 1] + 1
|
|
)
|
|
if (
|
|
self.fd_config.scheduler_config.splitwise_role == "decode"
|
|
): # In PD, we continue to decode after P generates first token
|
|
async_set_value(self.model_inputs["seq_lens_encoder"][idx : idx + 1], 0)
|
|
self.exist_prefill_flag = False
|
|
async_set_value(self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1], length + 1)
|
|
# NOTE(liuzichang):
|
|
# extra 1 : P-D split need rollback one step
|
|
|
|
async_set_value(self.model_inputs["recompute_token_num"][idx : idx + 1], 0)
|
|
async_set_value(self.model_inputs["mask_rollback"][idx : idx + 1], 1)
|
|
# has_prefill_task = True
|
|
elif request.task_type.value == RequestType.DECODE.value: # decode task
|
|
encoder_block_num = len(request.block_tables)
|
|
async_set_value(self.model_inputs["encoder_block_lens"][idx : idx + 1], encoder_block_num)
|
|
async_set_value(self.model_inputs["block_tables"][idx : idx + 1, :], -1)
|
|
if current_platform.is_cuda():
|
|
async_set_value(
|
|
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num], request.block_tables
|
|
)
|
|
else:
|
|
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
|
|
request.block_tables, dtype="int32"
|
|
)
|
|
else:
|
|
async_set_value(self.model_inputs["block_tables"][idx : idx + 1, :], -1)
|
|
async_set_value(self.model_inputs["stop_flags"][idx : idx + 1], True)
|
|
async_set_value(self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1], 0)
|
|
async_set_value(self.model_inputs["seq_lens_decoder"][idx : idx + 1], 0)
|
|
async_set_value(self.model_inputs["seq_lens_encoder"][idx : idx + 1], 0)
|
|
async_set_value(self.model_inputs["is_block_step"][idx : idx + 1], False)
|
|
continue
|
|
|
|
# TODO(liuzichang): Solve splitewise-p bug to restore
|
|
# self.model_inputs["seq_lens_this_time"] = self.model_inputs["seq_lens_this_time_buffer"][:num_running_requests]
|
|
self.model_inputs.seq_lens_this_time = self.model_inputs["seq_lens_this_time_buffer"]
|
|
|
|
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int):
|
|
"""
|
|
Process inputs for prefill tasks and insert it to model_inputs buffer
|
|
"""
|
|
# TODO:Init role in initialize process
|
|
if req_dicts[-1].disaggregate_info is not None:
|
|
if req_dicts[-1].disaggregate_info["role"] == "prefill":
|
|
self.role = "prefill"
|
|
os.environ["PREFILL_NODE_ONE_STEP_STOP"] = "1"
|
|
elif req_dicts[-1].disaggregate_info["role"] == "decode":
|
|
self.role = "decode"
|
|
else:
|
|
self.role = "mixed"
|
|
|
|
req_len = len(req_dicts)
|
|
for i in range(req_len):
|
|
request = req_dicts[i]
|
|
idx = request.idx
|
|
length = len(request.prompt_token_ids)
|
|
self.model_inputs.input_ids_len[idx] = length - 1
|
|
|
|
if req_dicts[i].disaggregate_info is not None and req_dicts[i].disaggregate_info["role"] == "decode":
|
|
length = len(request.prompt_token_ids)
|
|
if length > 1:
|
|
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs[
|
|
"input_ids"
|
|
][idx : idx + 1, 1:length]
|
|
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = np.array(
|
|
request.prompt_token_ids
|
|
)[1:]
|
|
self.model_inputs["pre_ids"][idx : idx + 1] = request.prompt_token_ids[-1]
|
|
prefill_token_num = self.max_draft_token_num + 1
|
|
self.model_inputs["draft_tokens"][idx : idx + 1, 0:1] = paddle.to_tensor(
|
|
request.draft_token_ids[1:2], dtype="int64"
|
|
)
|
|
|
|
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0
|
|
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = length
|
|
self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = prefill_token_num
|
|
|
|
self.model_inputs["stop_flags"][idx : idx + 1] = False
|
|
self.model_inputs["batch_drop"][idx : idx + 1] = False
|
|
self.model_inputs["step_idx"][idx : idx + 1] = 1
|
|
encoder_block_num = len(request.block_tables)
|
|
|
|
self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
|
|
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
|
|
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
|
|
request.block_tables, dtype="int32"
|
|
)
|
|
|
|
else:
|
|
length = len(request.prompt_token_ids)
|
|
|
|
if length > 1:
|
|
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs[
|
|
"input_ids"
|
|
][idx : idx + 1, 1:length]
|
|
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = np.array(
|
|
request.prompt_token_ids
|
|
)[1:]
|
|
self.model_inputs["pre_ids"][idx : idx + 1] = -1
|
|
self.model_inputs["step_idx"][idx : idx + 1] = 0
|
|
if self.cache_config.enable_chunked_prefill:
|
|
token_chunk_size = request.prefill_chunk_info[0]
|
|
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size
|
|
self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = token_chunk_size
|
|
else:
|
|
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = length
|
|
self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = length
|
|
|
|
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
|
|
self.model_inputs["stop_flags"][idx : idx + 1] = False
|
|
self.model_inputs["batch_drop"][idx : idx + 1] = False
|
|
|
|
encoder_block_num = len(request.get("block_tables"))
|
|
self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
|
|
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
|
|
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
|
|
request.get("block_tables"), dtype="int32"
|
|
)
|
|
self.model_inputs.seq_lens_this_time = self.model_inputs["seq_lens_this_time_buffer"]
|
|
|
|
def _initialize_forward_meta(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, substep: int = 0):
|
|
"""
|
|
Initialize forward meta and attention meta data
|
|
"""
|
|
# Initialize forward meta
|
|
self.forward_meta = ForwardMeta(
|
|
ids_remove_padding=self.model_inputs["ids_remove_padding"],
|
|
rotary_embs=self.model_inputs["rope_emb"],
|
|
attn_backend=self.attn_backends[0],
|
|
decoder_batch_ids=self.model_inputs["decoder_batch_ids"],
|
|
decoder_tile_ids_per_batch=self.model_inputs["decoder_tile_ids_per_batch"],
|
|
decoder_num_blocks_cpu=self.model_inputs["decoder_num_blocks_cpu"],
|
|
decoder_num_blocks_device=self.model_inputs["decoder_num_blocks_device"],
|
|
decoder_chunk_size_device=self.model_inputs["decoder_chunk_size_device"],
|
|
max_len_tensor_cpu=self.model_inputs["max_len_tensor_cpu"],
|
|
seq_lens_encoder=self.model_inputs["seq_lens_encoder"],
|
|
seq_lens_decoder=self.model_inputs["seq_lens_decoder"],
|
|
seq_lens_this_time=self.model_inputs["seq_lens_this_time"],
|
|
batch_id_per_token=self.model_inputs["batch_id_per_token"],
|
|
cu_seqlens_q=self.model_inputs["cu_seqlens_q"],
|
|
cu_seqlens_k=self.model_inputs["cu_seqlens_k"],
|
|
block_tables=self.model_inputs["block_tables"],
|
|
caches=self.model_inputs["caches"],
|
|
encoder_batch_ids=self.model_inputs["encoder_batch_ids"],
|
|
encoder_tile_ids_per_batch=self.model_inputs["encoder_tile_ids_per_batch"],
|
|
encoder_num_blocks_x_cpu=self.model_inputs["encoder_num_blocks_x_cpu"],
|
|
kv_batch_ids=self.model_inputs["kv_batch_ids"],
|
|
kv_tile_ids_per_batch=self.model_inputs["kv_tile_ids_per_batch"],
|
|
kv_num_blocks_x_cpu=self.model_inputs["kv_num_blocks_x_cpu"],
|
|
attn_mask_offsets=self.model_inputs["attn_mask_offsets"] if self.use_attn_mask_offset else None,
|
|
)
|
|
|
|
# Initialize attention meta data
|
|
for attn_backend in self.attn_backends:
|
|
attn_backend.init_attention_metadata(self.forward_meta)
|
|
|
|
# Notes(liuzichang):
|
|
# 1. CUDA Graph capture sizes must be recorded in descending order (large → small).
|
|
# 2. In multi-step execution, only the first step should be captured.
|
|
self.forward_meta.step_use_cudagraph = (
|
|
step_use_cudagraph and self.draft_model_use_cudagraph and not (substep > 0 and is_dummy_run)
|
|
)
|
|
|
|
def _initialize_forward_meta_xpu(self):
|
|
|
|
self.forward_meta.decoder_batch_ids = (self.model_inputs["decoder_batch_ids"],)
|
|
self.forward_meta.decoder_tile_ids_per_batch = (self.model_inputs["decoder_tile_ids_per_batch"],)
|
|
self.forward_meta.decoder_num_blocks_cpu = (self.model_inputs["decoder_num_blocks_cpu"],)
|
|
self.forward_meta.decoder_num_blocks_device = (self.model_inputs["decoder_num_blocks_device"],)
|
|
self.forward_meta.decoder_chunk_size_device = (self.model_inputs["decoder_chunk_size_device"],)
|
|
self.forward_meta.max_len_tensor_cpu = (self.model_inputs["max_len_tensor_cpu"],)
|
|
|
|
self.forward_meta.encoder_batch_ids = (self.model_inputs["encoder_batch_ids"],)
|
|
self.forward_meta.encoder_tile_ids_per_batch = (self.model_inputs["encoder_tile_ids_per_batch"],)
|
|
self.forward_meta.encoder_num_blocks_x_cpu = (self.model_inputs["encoder_num_blocks_x_cpu"],)
|
|
self.forward_meta.kv_batch_ids = (self.model_inputs["kv_batch_ids"],)
|
|
self.forward_meta.kv_tile_ids_per_batch = (self.model_inputs["kv_tile_ids_per_batch"],)
|
|
self.forward_meta.kv_num_blocks_x_cpu = (self.model_inputs["kv_num_blocks_x_cpu"],)
|
|
self.forward_meta.attn_backend = self.attn_backends[0]
|
|
if self.pd_disaggregation_mode == "per_chunk" or self.pd_disaggregation_mode == "per_query":
|
|
self.forward_meta.kv_signal_sender = self.target_model_inputs["kv_signal_sender"]
|
|
|
|
self.forward_meta.is_draft = True
|
|
|
|
# Initialize attention meta data
|
|
for attn_backend in self.attn_backends:
|
|
attn_backend.init_attention_metadata(self.forward_meta)
|
|
|
|
def exist_prefill(self):
|
|
"""
|
|
check whether prefill stage exist
|
|
"""
|
|
return self.exist_prefill_flag
|
|
|
|
def _prepare_inputs_cuda(self, full_hidden_states):
|
|
"""
|
|
Prepare MTP inputs
|
|
|
|
MTP state (seq_lens_decoder, step_idx) is "shadow state":
|
|
- Initialized from target model state each round
|
|
- Used for MTP forward, but not committed until verify
|
|
- No rollback needed since it's always re-initialized
|
|
"""
|
|
|
|
draft_model_preprocess(
|
|
self.model_inputs["draft_tokens"],
|
|
self.model_inputs["input_ids"],
|
|
self.model_inputs["stop_flags"],
|
|
self.model_inputs["seq_lens_this_time"],
|
|
self.model_inputs["seq_lens_encoder"],
|
|
self.model_inputs["seq_lens_decoder"],
|
|
self.model_inputs["step_idx"],
|
|
self.model_inputs["not_need_stop_device"],
|
|
self.model_inputs["pre_ids"],
|
|
self.target_model_inputs["accept_tokens"],
|
|
self.target_model_inputs["accept_num"],
|
|
self.target_model_inputs["seq_lens_encoder"],
|
|
self.target_model_inputs["seq_lens_decoder"],
|
|
self.target_model_inputs["step_idx"],
|
|
self.target_model_inputs["stop_flags"],
|
|
self.model_inputs["max_dec_len"],
|
|
self.target_model_inputs["draft_tokens"],
|
|
self.num_model_steps,
|
|
self.role == "prefill", # is_splitwise_prefill
|
|
)
|
|
|
|
target_hidden_states, _ = eagle_get_hidden_states(
|
|
full_hidden_states,
|
|
self.model_inputs["seq_lens_this_time"],
|
|
self.model_inputs["seq_lens_encoder"],
|
|
self.model_inputs["seq_lens_decoder"],
|
|
self.model_inputs["stop_flags"],
|
|
self.target_model_inputs["accept_num"],
|
|
self.target_model_inputs["seq_lens_this_time"],
|
|
self.target_model_inputs["seq_lens_encoder"],
|
|
self.num_model_steps,
|
|
)
|
|
|
|
self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False)
|
|
|
|
def _prepare_inputs_xpu(self, full_hidden_states):
|
|
use_v1_cache_scheduler = bool(envs.ENABLE_V1_KVCACHE_SCHEDULER)
|
|
draft_model_preprocess(
|
|
self.model_inputs["draft_tokens"],
|
|
self.model_inputs["input_ids"],
|
|
self.model_inputs["stop_flags"],
|
|
self.model_inputs["seq_lens_this_time"],
|
|
self.model_inputs["seq_lens_encoder"],
|
|
self.model_inputs["seq_lens_decoder"],
|
|
self.model_inputs["step_idx"],
|
|
self.model_inputs["not_need_stop"],
|
|
self.model_inputs["batch_drop"],
|
|
self.model_inputs["is_block_step"],
|
|
self.model_inputs["pre_ids"],
|
|
self.model_inputs["mask_rollback"],
|
|
self.model_inputs["recompute_token_num"],
|
|
self.target_model_inputs["accept_tokens"],
|
|
self.target_model_inputs["accept_num"],
|
|
self.target_model_inputs["seq_lens_this_time"],
|
|
self.target_model_inputs["seq_lens_encoder"],
|
|
self.target_model_inputs["seq_lens_decoder"],
|
|
self.target_model_inputs["step_idx"],
|
|
self.target_model_inputs["stop_flags"],
|
|
self.target_model_inputs["is_block_step"],
|
|
self.target_model_inputs["draft_tokens"],
|
|
self.num_model_steps,
|
|
True,
|
|
self.role == "prefill",
|
|
use_v1_cache_scheduler,
|
|
)
|
|
|
|
target_hidden_states = eagle_get_hidden_states(
|
|
full_hidden_states,
|
|
self.model_inputs["seq_lens_this_time"],
|
|
self.model_inputs["seq_lens_encoder"],
|
|
self.model_inputs["seq_lens_decoder"],
|
|
self.model_inputs["stop_flags"],
|
|
self.target_model_inputs["accept_num"],
|
|
self.target_model_inputs["seq_lens_this_time"],
|
|
self.target_model_inputs["seq_lens_encoder"],
|
|
self.num_model_steps,
|
|
)
|
|
self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False)
|
|
|
|
def _post_process(self, sampled_token_ids):
|
|
"""
|
|
PostProcess for generation
|
|
"""
|
|
draft_model_update(
|
|
sampled_token_ids,
|
|
self.model_inputs["draft_tokens"],
|
|
self.model_inputs["pre_ids"],
|
|
self.model_inputs["seq_lens_this_time"],
|
|
self.model_inputs["seq_lens_encoder"],
|
|
self.model_inputs["seq_lens_decoder"],
|
|
self.model_inputs["step_idx"],
|
|
# Note(ZKK):
|
|
# I strongly advise xpu student delete the fuck `output_cum_offsets` name in XPU backend
|
|
# like my pr https://github.com/PaddlePaddle/FastDeploy/pull/6358
|
|
self.model_inputs["cu_seqlens_q_output"],
|
|
self.model_inputs["stop_flags"],
|
|
(
|
|
self.model_inputs["not_need_stop_device"]
|
|
if current_platform.is_cuda()
|
|
else self.model_inputs["not_need_stop"]
|
|
),
|
|
self.model_inputs["max_dec_len"],
|
|
self.model_inputs["eos_token_id"],
|
|
self.model_inputs["base_model_draft_tokens"],
|
|
self.max_model_len,
|
|
self.model_inputs["substep"],
|
|
)
|
|
|
|
if self.role == "prefill" and self.parallel_config.tensor_parallel_rank == 0:
|
|
if current_platform.is_xpu():
|
|
# Note(wangyanpeng): mtp_save_first_token for GPU platforms has been moved to model_runner.
|
|
# Only XPU platform is retained here.
|
|
skip_save = bool(int(envs.ENABLE_V1_KVCACHE_SCHEDULER))
|
|
recover_model_output_map = recover_batch_index_for_output(
|
|
self.model_inputs,
|
|
self.model_inputs.index_to_batch_id,
|
|
self.model_inputs.enable_pd_reorder,
|
|
["base_model_draft_tokens", "seq_lens_decoder", "prompt_lens", "step_idx"],
|
|
)
|
|
mtp_save_first_token(
|
|
recover_model_output_map["base_model_draft_tokens"],
|
|
self.model_inputs["not_need_stop"],
|
|
recover_model_output_map["seq_lens_decoder"],
|
|
recover_model_output_map["prompt_lens"],
|
|
recover_model_output_map["step_idx"],
|
|
self.local_rank,
|
|
self.parallel_config.use_ep,
|
|
skip_save,
|
|
)
|
|
# Ensure only save first token once.
|
|
paddle.assign(
|
|
paddle.where(
|
|
self.model_inputs["stop_flags"],
|
|
paddle.zeros_like(self.model_inputs["step_idx"]),
|
|
self.model_inputs["step_idx"],
|
|
),
|
|
self.model_inputs["step_idx"],
|
|
)
|
|
|
|
def _propose_cuda(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, real_bsz: int = 0):
|
|
"""
|
|
Main process for MTP inference.
|
|
Args:
|
|
step_use_cudagraph: bool
|
|
Whether to use cuda graph. Use the target model flag to avoid hanging problems with EP.
|
|
"""
|
|
is_blocking = (
|
|
(not self.fd_config.scheduler_config.enable_overlap_schedule)
|
|
or is_dummy_run
|
|
or self.exist_prefill()
|
|
or real_bsz == 0
|
|
)
|
|
for substep in range(self.num_model_steps):
|
|
if is_blocking:
|
|
token_num_cpu = self.model_inputs["seq_lens_this_time"].numpy().sum().item()
|
|
else:
|
|
if substep == 0:
|
|
token_num_cpu = real_bsz * (self.max_draft_token_num + 1)
|
|
else:
|
|
token_num_cpu = real_bsz
|
|
if token_num_cpu > 0:
|
|
self.model_inputs["substep"] = substep
|
|
# Remove padding
|
|
(
|
|
ids_remove_padding,
|
|
batch_id_per_token,
|
|
cu_seqlens_q,
|
|
cu_seqlens_k,
|
|
cu_seqlens_q_output,
|
|
batch_id_per_token_output,
|
|
real_output_token_num,
|
|
) = pre_process(
|
|
token_num_cpu,
|
|
self.model_inputs["input_ids"],
|
|
self.model_inputs["seq_lens_this_time"],
|
|
True,
|
|
self.model_inputs["draft_tokens"],
|
|
self.model_inputs["seq_lens_encoder"],
|
|
self.model_inputs["seq_lens_decoder"],
|
|
)
|
|
|
|
if self.use_attn_mask_offset:
|
|
attn_mask_offsets = update_attn_mask_offsets(
|
|
ids_remove_padding,
|
|
getattr(
|
|
self.model_inputs, "seq_lens_this_time", self.model_inputs["seq_lens_this_time_buffer"]
|
|
),
|
|
self.model_inputs["seq_lens_encoder"],
|
|
self.model_inputs["seq_lens_decoder"],
|
|
cu_seqlens_q,
|
|
self.model_inputs["attn_mask_offsets_full"],
|
|
self.model_inputs["is_block_step"],
|
|
self.model_inputs["decode_states"],
|
|
)
|
|
self.model_inputs["attn_mask_offsets"].copy_(attn_mask_offsets, False)
|
|
|
|
# Initialize forward meta data
|
|
self.model_inputs["ids_remove_padding"].copy_(ids_remove_padding, False)
|
|
self.model_inputs["batch_id_per_token"][:] = -1
|
|
self.model_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False)
|
|
self.model_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False)
|
|
|
|
# For speculative decoding
|
|
self.model_inputs["cu_seqlens_q_output"].copy_(cu_seqlens_q_output, False)
|
|
self.model_inputs["batch_id_per_token_output"].copy_(batch_id_per_token_output, False)
|
|
|
|
# Initialize forward meta data
|
|
self._initialize_forward_meta(
|
|
step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run, substep=substep
|
|
)
|
|
self.forward_meta.batch_id_per_token.copy_(batch_id_per_token, False)
|
|
self.forward_meta.real_bsz = real_bsz
|
|
|
|
# Padding inputs for cuda graph
|
|
self.padding_cudagraph_inputs()
|
|
|
|
# Get sampling metadata
|
|
self.sampling_metadata = SamplingMetadata(
|
|
temperature=self.model_inputs["temperature"],
|
|
top_p=self.model_inputs["top_p"],
|
|
top_k=self.model_inputs["top_k"],
|
|
seed=self.model_inputs["infer_seed"],
|
|
step_idx=self.model_inputs["step_idx"],
|
|
token_ids_all=self.model_inputs["token_ids_all"],
|
|
pre_token_ids=self.model_inputs["pre_ids"],
|
|
prompt_lens=self.model_inputs["prompt_lens"],
|
|
fake_prompt_lens=self.model_inputs["fake_prompt_lens"],
|
|
frequency_penalties=self.model_inputs["frequency_score"],
|
|
presence_penalties=self.model_inputs["presence_score"],
|
|
repetition_penalties=self.model_inputs["penalty_score"],
|
|
min_dec_lens=self.model_inputs["min_dec_len"],
|
|
bad_words_token_ids=self.model_inputs["bad_tokens"],
|
|
bad_words_token_len=self.model_inputs["bad_tokens_len"],
|
|
eos_token_ids=self.model_inputs["eos_token_id"],
|
|
max_num_logprobs=20 if self.enable_logprob else None,
|
|
temp_scaled_logprobs=self.model_inputs["temp_scaled_logprobs"],
|
|
top_p_normalized_logprobs=self.model_inputs["top_p_normalized_logprobs"],
|
|
share_inputs=self.model_inputs,
|
|
)
|
|
|
|
real_num = self.model_inputs["ids_remove_padding"].shape[0]
|
|
target_hidden_states = self.model_inputs["target_hidden_states"][:real_num]
|
|
model_output = self.model(
|
|
ids_remove_padding=self.model_inputs["ids_remove_padding"],
|
|
previous_hidden_states=target_hidden_states,
|
|
forward_meta=self.forward_meta,
|
|
)
|
|
if self.forward_meta.step_use_cudagraph:
|
|
model_output = model_output[: self.real_token_num]
|
|
|
|
hidden_states, _ = eagle_gather_hidden_states(
|
|
model_output,
|
|
self.model_inputs["cu_seqlens_q"],
|
|
self.model_inputs["seq_lens_this_time"],
|
|
self.model_inputs["seq_lens_decoder"],
|
|
self.model_inputs["seq_lens_encoder"],
|
|
self.model_inputs["batch_id_per_token_output"],
|
|
self.model_inputs["cu_seqlens_q_output"],
|
|
real_output_token_num,
|
|
)
|
|
|
|
# 4. Compute logits, Sample
|
|
logits = self.model.compute_logits(hidden_states, forward_meta=self.forward_meta)
|
|
if self.enable_logprob and self.enable_draft_logprob and substep == 0:
|
|
first_token_logits = self.model.compute_logits(
|
|
self.model_inputs["first_token_hidden_states"], forward_meta=self.forward_meta
|
|
)
|
|
|
|
speculate_get_logits(
|
|
self.model_inputs["draft_logits"],
|
|
self.model_inputs["next_token_num"],
|
|
self.model_inputs["batch_token_num"],
|
|
self.model_inputs["cu_next_token_offset"],
|
|
self.model_inputs["cu_batch_token_offset"],
|
|
logits,
|
|
first_token_logits,
|
|
self.model_inputs["seq_lens_this_time"],
|
|
self.model_inputs["seq_lens_encoder"],
|
|
)
|
|
|
|
sampled_token_ids, sampler_output = self.sampler(
|
|
logits,
|
|
self.sampling_metadata,
|
|
self.max_model_len,
|
|
self.model_inputs,
|
|
)
|
|
|
|
if (
|
|
not is_dummy_run
|
|
and self.parallel_config.tensor_parallel_rank == 0
|
|
and substep == 0
|
|
and sampler_output.logprobs_tensors is not None
|
|
):
|
|
real_bsz = self.model_inputs["seq_lens_this_time"].shape[0]
|
|
recover_batch_index_for_sampler_output(sampler_output, self.model_inputs.index_to_batch_id)
|
|
recover_model_output_map = recover_batch_index_for_output(
|
|
self.model_inputs,
|
|
self.model_inputs.index_to_batch_id,
|
|
self.model_inputs.enable_pd_reorder,
|
|
["batch_token_num", "cu_batch_token_offset", "seq_lens_decoder", "prompt_lens"],
|
|
)
|
|
speculate_save_output_topk(
|
|
sampler_output.sampled_token_ids,
|
|
sampler_output.logprobs_tensors.logprob_token_ids,
|
|
sampler_output.logprobs_tensors.logprobs,
|
|
sampler_output.logprobs_tensors.selected_token_ranks,
|
|
recover_model_output_map["batch_token_num"][:real_bsz],
|
|
recover_model_output_map["cu_batch_token_offset"][:real_bsz],
|
|
self.model_inputs["not_need_stop"],
|
|
recover_model_output_map["seq_lens_decoder"],
|
|
recover_model_output_map["prompt_lens"],
|
|
4, # mtype
|
|
self.local_rank,
|
|
self.parallel_config.use_ep,
|
|
)
|
|
|
|
if self.parallel_config.tensor_parallel_size > 1:
|
|
paddle.distributed.broadcast(
|
|
sampled_token_ids,
|
|
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
|
group=self.parallel_config.tp_group,
|
|
)
|
|
|
|
self._post_process(sampled_token_ids)
|
|
self.model_inputs["target_hidden_states"].copy_(hidden_states, False)
|
|
else:
|
|
if hasattr(self.model, "empty_input_forward") and not is_dummy_run:
|
|
self.model.empty_input_forward(forward_meta=self.forward_meta)
|
|
self.exist_prefill_flag = False
|
|
|
|
def _propose_xpu(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, real_bsz: int = 0):
|
|
"""
|
|
Main process for MTP inference.
|
|
Args:
|
|
step_use_cudagraph: bool
|
|
Whether to use cuda graph. Use the target model flag to avoid hanging problems with EP.
|
|
"""
|
|
for substep in range(self.num_model_steps):
|
|
if self.model_inputs["not_need_stop"]:
|
|
self.model_inputs["substep"] = substep
|
|
# Remove padding
|
|
self.forward_meta = xpu_pre_process(
|
|
self.model_inputs["input_ids"],
|
|
self.model_inputs["seq_lens_this_time"],
|
|
self.model_inputs,
|
|
True,
|
|
self.cache_config.block_size,
|
|
self.model_inputs["draft_tokens"],
|
|
self.model_inputs["seq_lens_encoder"],
|
|
self.model_inputs["seq_lens_decoder"],
|
|
num_speculative_tokens=self.speculative_config.num_speculative_tokens,
|
|
)
|
|
|
|
if self.enable_mm:
|
|
attn_mask_offsets = update_attn_mask_offsets(
|
|
self.model_inputs["ids_remove_padding"],
|
|
getattr(
|
|
self.model_inputs, "seq_lens_this_time", self.model_inputs["seq_lens_this_time_buffer"]
|
|
),
|
|
self.model_inputs["seq_lens_encoder"],
|
|
self.model_inputs["seq_lens_decoder"],
|
|
self.model_inputs["cu_seqlens_q"],
|
|
self.model_inputs["attn_mask_offsets_full"],
|
|
self.model_inputs["attn_mask_offsets_decoder"],
|
|
self.model_inputs["is_block_step"],
|
|
self.model_inputs["decode_states"],
|
|
self.model_inputs["mask_rollback"],
|
|
)
|
|
self.model_inputs["attn_mask_offsets"].copy_(attn_mask_offsets, False)
|
|
|
|
self._initialize_forward_meta_xpu()
|
|
# Get sampling metadata
|
|
self.sampling_metadata = SamplingMetadata(
|
|
temperature=self.model_inputs["temperature"],
|
|
top_p=self.model_inputs["top_p"],
|
|
top_k=self.model_inputs["top_k"],
|
|
seed=self.model_inputs["infer_seed"],
|
|
step_idx=self.model_inputs["step_idx"],
|
|
pre_token_ids=self.model_inputs["pre_ids"],
|
|
frequency_penalties=self.model_inputs["frequency_score"],
|
|
presence_penalties=self.model_inputs["presence_score"],
|
|
repetition_penalties=self.model_inputs["penalty_score"],
|
|
min_dec_lens=self.model_inputs["min_dec_len"],
|
|
bad_words_token_ids=self.model_inputs["bad_tokens"],
|
|
eos_token_ids=self.model_inputs["eos_token_id"],
|
|
max_num_logprobs=20 if self.enable_logprob else None,
|
|
temp_scaled_logprobs=self.model_inputs["temp_scaled_logprobs"],
|
|
top_p_normalized_logprobs=self.model_inputs["top_p_normalized_logprobs"],
|
|
share_inputs=self.model_inputs,
|
|
)
|
|
|
|
if self.num_model_steps > 1:
|
|
self.model_inputs.last_seq_lens_this_time = paddle.clone(self.model_inputs["seq_lens_this_time"])
|
|
|
|
model_output = self.model(
|
|
ids_remove_padding=self.model_inputs["ids_remove_padding"],
|
|
previous_hidden_states=self.model_inputs["target_hidden_states"],
|
|
forward_meta=self.forward_meta,
|
|
)
|
|
hidden_states = xpu_process_output(model_output, self.forward_meta, self.model_inputs)
|
|
# 4. Compute logits, Sample
|
|
logits = self.model.compute_logits(hidden_states, forward_meta=self.forward_meta)
|
|
sampled_token_ids, sampler_output = self.sampler(
|
|
logits,
|
|
self.sampling_metadata,
|
|
self.max_model_len,
|
|
self.model_inputs,
|
|
)
|
|
|
|
if substep == 0 and sampler_output.logprobs_tensors is not None:
|
|
real_bsz = self.model_inputs["seq_lens_this_time"].shape[0]
|
|
recover_batch_index_for_sampler_output(sampler_output, self.model_inputs.index_to_batch_id)
|
|
recover_model_output_map = recover_batch_index_for_output(
|
|
self.model_inputs,
|
|
self.model_inputs.index_to_batch_id,
|
|
self.model_inputs.enable_pd_reorder,
|
|
["batch_token_num", "cu_batch_token_offset"],
|
|
)
|
|
speculate_save_output_topk(
|
|
sampler_output.sampled_token_ids,
|
|
sampler_output.logprobs_tensors.logprob_token_ids,
|
|
sampler_output.logprobs_tensors.logprobs,
|
|
sampler_output.logprobs_tensors.selected_token_ranks,
|
|
recover_model_output_map["batch_token_num"][:real_bsz],
|
|
recover_model_output_map["cu_batch_token_offset"][:real_bsz],
|
|
self.model_inputs["not_need_stop"],
|
|
4, # mtype
|
|
self.local_rank,
|
|
)
|
|
|
|
if self.parallel_config.tensor_parallel_size > 1:
|
|
paddle.distributed.broadcast(
|
|
sampled_token_ids,
|
|
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
|
group=self.parallel_config.tp_group,
|
|
)
|
|
|
|
self._post_process(sampled_token_ids)
|
|
if substep != self.num_model_steps - 1:
|
|
self._get_self_hidden_states_xpu(hidden_states)
|
|
else:
|
|
if hasattr(self.model, "empty_input_forward") and not is_dummy_run:
|
|
self.model.empty_input_forward(self.forward_meta)
|
|
|
|
def _get_self_hidden_states_xpu(self, hidden_states):
|
|
target_hidden_states = eagle_get_self_hidden_states(
|
|
hidden_states,
|
|
self.model_inputs.last_seq_lens_this_time,
|
|
self.model_inputs["seq_lens_this_time"],
|
|
self.model_inputs["step_idx"],
|
|
)
|
|
self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False)
|
|
|
|
def update_task_chunk_prefill(self, task):
|
|
"""
|
|
Update single task's chunk_prefill info
|
|
"""
|
|
idx = self.model_inputs.get_index_by_batch_id(task.idx)
|
|
start_idx = sum(task.prefill_chunk_info[: task.chunk_idx])
|
|
|
|
if task.chunk_idx == len(task.prefill_chunk_info):
|
|
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0
|
|
self.model_inputs["step_idx"][idx : idx + 1] = 1
|
|
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = start_idx + task.get("seq_lens_decoder", 0)
|
|
else:
|
|
token_chunk_size = task.prefill_chunk_info[task.chunk_idx]
|
|
|
|
if task.chunk_idx < len(task.prefill_chunk_info) - 1:
|
|
self.model_inputs["input_ids"][idx, :token_chunk_size] = np.array(
|
|
task.prompt_token_ids[start_idx + 1 : start_idx + token_chunk_size + 1]
|
|
)
|
|
# Last prefill
|
|
else:
|
|
self.model_inputs["input_ids"][idx, : token_chunk_size - 1] = np.array(
|
|
task.prompt_token_ids[start_idx + 1 : start_idx + token_chunk_size]
|
|
)
|
|
|
|
self.model_inputs["seq_lens_this_time"][idx : idx + 1] = token_chunk_size
|
|
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size
|
|
self.model_inputs["step_idx"][idx : idx + 1] = 0
|
|
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = start_idx + task.get("seq_lens_decoder", 0)
|
|
|
|
def _update_status(self):
|
|
"""
|
|
Update main-model's forward info in next step.
|
|
Allocate/Free block of MPT.
|
|
"""
|
|
draft_model_postprocess(
|
|
self.target_model_inputs["draft_tokens"],
|
|
self.target_model_inputs["seq_lens_this_time"],
|
|
self.target_model_inputs["seq_lens_encoder"],
|
|
self.target_model_inputs["stop_flags"],
|
|
)
|
|
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
|
mtp_step_paddle(
|
|
self.target_model_inputs["stop_flags"],
|
|
self.model_inputs["stop_flags"],
|
|
self.model_inputs["batch_drop"],
|
|
self.model_inputs["seq_lens_this_time"],
|
|
self.model_inputs["seq_lens_encoder"],
|
|
self.model_inputs["seq_lens_decoder"],
|
|
self.model_inputs["block_tables"],
|
|
self.model_inputs["encoder_block_lens"],
|
|
self.model_inputs["used_list_len"],
|
|
self.model_inputs["free_list"],
|
|
self.model_inputs["free_list_len"],
|
|
self.cache_config.block_size,
|
|
self.max_draft_token_num,
|
|
)
|
|
|
|
def _extend_draft_token_with_ngram_match(self):
|
|
# TODO: replace with gpu tensor
|
|
hybrid_mtp_ngram(
|
|
self.model_inputs["input_ids_cpu"].cuda(),
|
|
self.model_inputs["input_ids_len"].cuda(),
|
|
self.model_inputs["pre_ids"],
|
|
self.model_inputs["step_idx"],
|
|
self.target_model_inputs["actual_draft_token_num"],
|
|
self.target_model_inputs["draft_tokens"],
|
|
self.target_model_inputs["seq_lens_this_time"],
|
|
self.model_inputs["seq_lens_decoder"],
|
|
self.model_inputs["max_dec_len"],
|
|
self.max_ngram_size,
|
|
self.min_ngram_size,
|
|
self.max_draft_token_num,
|
|
)
|
|
|
|
def _run_impl(
|
|
self,
|
|
full_hidden_states: paddle.Tensor,
|
|
step_use_cudagraph: bool = False,
|
|
is_dummy_run: bool = False,
|
|
real_bsz: int = 0,
|
|
):
|
|
"""Execute Draft Model"""
|
|
self._prepare_inputs(full_hidden_states)
|
|
self._propose(step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run, real_bsz=real_bsz)
|
|
self._update_status()
|
|
if self.hybrid_mode:
|
|
self._extend_draft_token_with_ngram_match()
|
|
|
|
def is_chunk_prefill_enabled(self):
|
|
""""""
|
|
return True
|
|
|
|
def padding_cudagraph_inputs(self) -> None:
|
|
"""
|
|
Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch.
|
|
In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch.
|
|
"""
|
|
# In init_attention_metadata, the decode buffer has already been cleared
|
|
|
|
# To adapt to CUDA Graph, keep the forward pass at the maximum batch size.
|
|
if self.forward_meta.step_use_cudagraph:
|
|
self.forward_meta.seq_lens_this_time = self.model_inputs["seq_lens_this_time_buffer"]
|
|
self.real_token_num = self.forward_meta.ids_remove_padding.shape[0]
|
|
return
|
|
|
|
def _empty_cache(self):
|
|
if current_platform.is_cuda():
|
|
paddle.device.cuda.empty_cache()
|
|
elif current_platform.is_xpu():
|
|
paddle.device.xpu.empty_cache()
|
|
else:
|
|
paddle.device.empty_cache()
|
|
|
|
def _get_cache_type(self):
|
|
cache_type = None
|
|
if current_platform.is_cuda():
|
|
cache_type = "uint8"
|
|
elif current_platform.is_xpu():
|
|
cache_type = "int8"
|
|
else:
|
|
raise NotImplementedError
|
|
return cache_type
|
|
|
|
def reorder_inputs(self, target_model_input_batch):
|
|
"""
|
|
Reorder inputs to split prefill and decode.
|
|
"""
|
|
reorder_split_prefill_and_decode_form_index_to_batch_id(self.model_inputs, target_model_input_batch)
|
|
|
|
def _share_external_data(self, cache, cache_name, cache_shape):
|
|
if current_platform.is_xpu():
|
|
return share_external_data(cache, cache_name, cache_shape, False)
|
|
else:
|
|
return share_external_data(cache, cache_name, cache_shape)
|