Files
FastDeploy/fastdeploy/config.py
T
Bingoo 6b891da02b [Optimization] enable trtllm_all_reduce fusion kernel in glm model (#6660)
* enable trtllm_all_reduce fusion kernel in glm model

* fix conflict

* format update

* fix a bug

* modify test

* modify test

* support empty tensor and modify test

* fix test_linear config issues

* modify test name

* add edge test case

* modify format

* fix conflict

* modify default max token num in trtllm_allreduce_fusion

* add max token num branch for trtllm_allreduce_fusion

* fix format

* fix rmsnorm config issue

* modify 2025 to 2026

* using compat grard

* Lazily import flashinfer.comm and fix test config issue

* fix test issues

* add flashinfer cache dir clean machine

* fix some issues
2026-04-16 14:10:19 +08:00

2470 lines
103 KiB
Python

"""
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import json
import os
from dataclasses import field
from enum import Enum
from typing import Any, Dict, Literal, Optional, Union
import paddle
import paddle.distributed as dist
import yaml
from packaging.version import parse as parse_version
from paddleformers.transformers.configuration_utils import PretrainedConfig
from typing_extensions import assert_never
import fastdeploy
from fastdeploy import envs
from fastdeploy.model_executor.layers.quantization.quant_base import QuantConfigBase
from fastdeploy.platforms import current_platform
from fastdeploy.scheduler import SchedulerConfig
from fastdeploy.spec_decode import SpecMethod
from fastdeploy.transformer_utils.config import get_pooling_config
from fastdeploy.utils import (
ceil_div,
check_unified_ckpt,
get_host_ip,
get_logger,
parse_ports,
)
logger = get_logger("config", "config.log")
TaskOption = Literal["auto", "generate", "embedding", "embed"]
RunnerType = Literal["generate", "pooling"]
RunnerOption = Literal["auto", "generate", "pooling"]
ConvertOption = Literal["auto", "none", "embed"]
ConvertType = Literal["none", "embed"]
_ResolvedTask = Literal["generate", "encode", "embed"]
# Model implementation backend options
ModelImpl = Literal["auto", "fastdeploy", "paddleformers"]
_RUNNER_CONVERTS: dict[RunnerType, list[ConvertType]] = {
"generate": [],
"pooling": ["embed", "reward"],
}
PREEMPTED_TOKEN_ID = -9
# Some model suffixes are based on auto classes from Transformers:
# https://huggingface.co/docs/transformers/en/model_doc/auto
# NOTE: Items higher on this list priority over lower ones
_SUFFIX_TO_DEFAULTS: list[tuple[str, tuple[RunnerType, ConvertType]]] = [
("ForCausalLM", ("generate", "none")),
("ForConditionalGeneration", ("generate", "none")),
("ChatModel", ("generate", "none")),
("LMHeadModel", ("generate", "none")),
("ForTextEncoding", ("pooling", "embed")),
("EmbeddingModel", ("pooling", "embed")),
("ForSequenceClassification", ("pooling", "classify")),
("ForAudioClassification", ("pooling", "classify")),
("ForImageClassification", ("pooling", "classify")),
("ForVideoClassification", ("pooling", "classify")),
("ClassificationModel", ("pooling", "classify")),
("ForRewardModeling", ("pooling", "reward")),
("RewardModel", ("pooling", "reward")),
# Let other `*Model`s take priority
("Model", ("pooling", "embed")),
]
def iter_architecture_defaults():
yield from _SUFFIX_TO_DEFAULTS
def try_match_architecture_defaults(
architecture: str,
*,
runner_type: Optional[RunnerType] = None,
convert_type: Optional[ConvertType] = None,
):
for suffix, (default_runner_type, default_convert_type) in iter_architecture_defaults():
if (
(runner_type is None or runner_type == default_runner_type)
and (convert_type is None or convert_type == default_convert_type)
and architecture.endswith(suffix)
):
return suffix, (default_runner_type, default_convert_type)
return None
class MoEPhase:
"""
The generation phase of the moe.
"""
def __init__(self, phase="prefill"):
self._phase = phase
@property
def phase(self):
return self._phase
@phase.setter
def phase(self, value):
if value not in ["prefill", "decode"]:
raise ValueError(f"The moe_phase is invalid, only support prefill and decode, but got {value}")
else:
self._phase = value
class ErnieArchitectures:
"""Helper class for ERNIE architecture check."""
ARCHITECTURES = {
"Ernie4_5ForCausalLM", # 0.3B-PT
"Ernie4_5_ForCausalLM",
"Ernie4_5_MoeForCausalLM",
"Ernie4_5_VLMoeForConditionalGeneration",
"Ernie4_5_VLMoeForProcessRewardModel",
}
ERNIE5_MODELS = {
"Ernie5ForCausalLM",
"Ernie5MoeForCausalLM",
"Ernie5MoEForRewardModel",
}
@classmethod
def register_ernie_model_arch(cls, model_class):
if model_class.name().startswith("Ernie") and model_class.name() not in cls.ARCHITECTURES:
cls.ARCHITECTURES.add(model_class.name())
@classmethod
def contains_ernie_arch(cls, architectures):
"""Check if any ERNIE architecture is present in the given architectures."""
return any(arch in architectures for arch in cls.ARCHITECTURES)
@classmethod
def is_ernie_arch(cls, architecture):
"""Check if the given architecture is an ERNIE architecture."""
return architecture in cls.ARCHITECTURES
@classmethod
def is_ernie5_arch(cls, architectures):
"""Check if the given architecture is an ERNIE5 architecture."""
return any(arch in architectures for arch in cls.ERNIE5_MODELS)
PRETRAINED_INIT_CONFIGURATION = {
"top_p": 1.0,
"temperature": 1.0,
"rope_theta": 10000.0,
"penalty_score": 1.0,
"frequency_score": 0.0,
"presence_score": 0.0,
"min_length": 1,
"num_key_value_heads": -1,
"start_layer_index": 0,
"moe_num_shared_experts": 0,
"moe_layer_start_index": 0,
"num_max_dispatch_tokens_per_rank": 128,
"moe_use_aux_free": False,
"vocab_size": -1,
"hidden_dropout_prob": 0.0,
"initializer_range": 0.02,
"max_position_embeddings": 512,
"quantization_config": None,
"tie_word_embeddings": False,
"rms_norm_eps": 1e-5,
"moe_num_experts": None,
"moe_layer_end_index": None,
}
class ModelConfig:
"""
The configuration class to store the configuration of a `LLM`.
"""
def __init__(
self,
args,
):
self.model = ""
self.is_quantized = False
self.is_moe_quantized = False
self.max_model_len = 0
self.dtype = "bfloat16"
self.enable_logprob = False
self.max_logprobs = 20
self.logprobs_mode = "raw_logprobs"
self.redundant_experts_num = 0
self.seed = 0
self.quantization = None
self.pad_token_id: int = -1
self.eos_tokens_lens: int = 2
self.lm_head_fp32: bool = False
self.moe_gate_fp32: bool = False
self.model_format = "auto"
self.runner = "auto"
self.convert = "auto"
self.pooler_config: Optional["PoolerConfig"] = field(init=False)
self.override_pooler_config: Optional[Union[dict, "PoolerConfig"]] = None
self.revision = None
self.prefix_layer_name = "layers"
self.kv_cache_quant_scale_path = ""
self.enable_entropy = False
self.model_impl: ModelImpl = "auto"
self.version: str = "init" # will override by the version.yaml in model dir
self.partial_rotary_factor: float = 1.0
self.num_nextn_predict_layers = 0
self.mm_max_tokens_per_item = None
for key, value in args.items():
if hasattr(self, key) and value != "None":
setattr(self, key, value)
assert self.model != ""
pretrained_config, _ = PretrainedConfig.get_config_dict(self.model)
self.pretrained_config = PretrainedConfig.from_dict(pretrained_config)
# Some exported configs (e.g. Qwen3-VL) embed the text model's configuration under a `text_config` key.
if "text_config" in pretrained_config and isinstance(pretrained_config["text_config"], dict):
text_fg = pretrained_config.pop("text_config")
for key, value in text_fg.items():
if not hasattr(self, key):
setattr(self, key, value)
# set attribute from pretrained_config
for key, value in pretrained_config.items():
setattr(self, key, value)
# we need set default value when not exist
for key, value in PRETRAINED_INIT_CONFIGURATION.items():
if not hasattr(self, key):
setattr(self, key, value)
if not hasattr(self, "head_dim"):
self.head_dim = self.hidden_size // self.num_attention_heads
if hasattr(self, "vision_config"):
self.vision_config = PretrainedConfig.from_dict(self.vision_config)
# Align external multimodal rope_3d configuration
if hasattr(self, "mrope_section"):
if (
hasattr(self, "rope_scaling")
and isinstance(self.rope_scaling, dict)
and "mrope_section" not in self.rope_scaling
):
self.rope_scaling["mrope_section"] = self.mrope_section
elif not hasattr(self, "rope_scaling"):
setattr(self, "rope_scaling", {"mrope_section": self.mrope_section})
if (
hasattr(self, "rope_scaling")
and isinstance(self.rope_scaling, dict)
and "mrope_section" in self.rope_scaling
):
setattr(self, "rope_3d", True)
setattr(self, "freq_allocation", self.rope_scaling["mrope_section"][0])
self.ori_vocab_size = args.get("ori_vocab_size", self.vocab_size)
self.think_start_id = args.get("think_start_id", -1)
self.think_end_id = args.get("think_end_id", -1)
self.im_patch_id = args.get("image_patch_id", -1)
self.line_break_id = args.get("line_break_id", -1)
self.think_truncate_prompt_ids = args.get("think_truncate_prompt_ids", [-1])
num_max_logprobs = args.get("max_logprobs", None)
if num_max_logprobs is not None and num_max_logprobs < -1:
raise ValueError(" The possible values for max_logprobs can't be less than -1 ")
if self.ori_vocab_size is not None and num_max_logprobs is not None:
if num_max_logprobs > self.ori_vocab_size:
raise ValueError(
f" The possible values for max_logprobs can't be greater than the vocabulary size {self.ori_vocab_size}"
)
self._post_init()
def _post_init(self):
self.is_unified_ckpt = check_unified_ckpt(self.model)
self.runner_type = self._get_runner_type(self.architectures, self.runner)
self.convert_type = self._get_convert_type(self.architectures, self.runner_type, self.convert)
registry = self.registry
is_generative_model = registry.is_text_generation_model(self.architectures, self)
is_pooling_model = registry.is_pooling_model(self.architectures, self)
is_multimodal_model = registry.is_multimodal_model(self.architectures, self)
self.is_reasoning_model = registry.is_reasoning_model(self.architectures, self)
self.enable_mm = is_multimodal_model
self.kv_cache_quant_scale_path = os.path.join(self.model, "kv_cache_scale.json")
if self.runner_type == "pooling":
os.environ["FD_USE_GET_SAVE_OUTPUT_V1"] = "1"
if self.runner_type == "generate" and not is_generative_model:
if is_multimodal_model:
pass
elif self.model_impl in ("auto", "paddleformers"):
# Skip check for auto/paddleformers - may fallback to paddleformers which supports any model
pass
else:
generate_converts = _RUNNER_CONVERTS["generate"]
if self.convert_type not in generate_converts:
raise ValueError("This model does not support '--runner generate.")
if self.runner_type == "pooling" and not is_pooling_model:
pooling_converts = _RUNNER_CONVERTS["pooling"]
if self.convert_type not in pooling_converts:
convert_option = "<" + "|".join(pooling_converts) + ">"
raise ValueError(
"This model does not support `--runner pooling`. "
f"You can pass `--convert {convert_option} to adapt "
"it into a pooling model."
)
self.supported_tasks = self._get_supported_tasks(self.architectures, self.runner_type, self.convert_type)
model_info, arch = registry.inspect_model_cls(self.architectures, self)
self._model_info = model_info
self._architecture = arch
self.architectures = [arch]
self.pooler_config = self._init_pooler_config()
self.override_name_from_config()
self.read_from_env()
self.read_model_config()
@property
def registry(self):
from fastdeploy.model_executor.models.model_base import ModelRegistry
return ModelRegistry()
def override_name_from_config(self):
"""
Override attribute names from the exported model's configuration.
"""
if not self.is_unified_ckpt and hasattr(self, "infer_model_mp_num"):
self.tensor_parallel_size = self.infer_model_mp_num
del self.infer_model_mp_num
if hasattr(self, "num_hidden_layers") and self.runner != "pooling":
if hasattr(self, "remove_tail_layer"):
if self.remove_tail_layer is True:
self.num_hidden_layers -= 1
elif isinstance(self.remove_tail_layer, int):
self.num_hidden_layers -= self.remove_tail_layer
if not hasattr(self, "mla_use_absorb"):
self.mla_use_absorb = False
if hasattr(self, "num_experts") and getattr(self, "moe_num_experts") is None:
self.moe_num_experts = self.num_experts
if hasattr(self, "n_routed_experts") and getattr(self, "moe_num_experts") is None:
self.moe_num_experts = self.n_routed_experts
if hasattr(self, "n_shared_experts") and getattr(self, "moe_num_shared_experts") is None:
# Because the ERNIE 4.5 config.json contains two sets of keys, adaptation is required.
self.moe_num_shared_experts = self.n_shared_experts
def read_from_env(self):
"""
Read configuration information from environment variables and update the object's attributes.
If an attribute is not present or is an empty string in the environment variables, use the default value.
"""
self.max_stop_seqs_num = envs.FD_MAX_STOP_SEQS_NUM
self.stop_seqs_max_len = envs.FD_STOP_SEQS_MAX_LEN
def reset_config_value(key, value):
if not hasattr(self, key.lower()):
if os.getenv(key, None):
value = eval(os.getenv(key))
logger.info(f"Get parameter `{key}` = {value} from environment.")
else:
logger.info(f"Parameter `{key}` will use default value {value}.")
setattr(self, key.lower(), value)
reset_config_value("COMPRESSION_RATIO", 1.0)
reset_config_value("ROPE_THETA", 10000)
def read_model_config(self):
config_path = os.path.join(self.model, "config.json")
if os.path.exists(config_path):
with open(config_path, "r", encoding="utf-8") as f:
raw_cfg = json.load(f)
if "text_config" in raw_cfg and isinstance(raw_cfg["text_config"], dict):
text_cfg = raw_cfg.pop("text_config")
for k, v in text_cfg.items():
if k not in raw_cfg:
raw_cfg[k] = v
self.model_config = raw_cfg
if "torch_dtype" in self.model_config and "dtype" in self.model_config:
raise ValueError(
"Only one of 'torch_dtype' or 'dtype' should be present in config.json. "
"Found both, which indicates an ambiguous model format. "
"Please ensure your config.json contains only one dtype field."
)
elif "torch_dtype" in self.model_config:
self.model_format = "torch"
logger.info("The model format is Hugging Face Torch")
elif "dtype" in self.model_config:
# https://github.com/huggingface/transformers/releases/tag/v4.56.0 Transformers 4.56.0 version deprecated torch_dtype
if "transformers_version" in self.model_config and parse_version(
self.model_config["transformers_version"]
) > parse_version("4.56.0"):
self.model_format = "torch"
logger.info("The model format is Hugging Face Torch")
else:
self.model_format = "paddle"
logger.info("The model format is Paddle")
elif (
"quantization_config" in self.model_config
and "quant_method" in self.model_config["quantization_config"]
and "mxfp4" == self.model_config["quantization_config"]["quant_method"]
):
self.model_format = "torch"
logger.info("The model format is Hugging Face")
else:
raise ValueError(
"Unknown model format. Please ensure your config.json contains "
"either 'torch_dtype' (for Hugging Face models) or 'dtype' (for Paddle models) field. "
f"Config file path: {config_path}"
)
def read_model_version(self):
"""
Read the version information from a YAML file located at 'version.yaml' within the model directory.
If the file exists, it extracts the 'version' field using yaml.safe_load.
Raises an assertion error if the file is not found at the specified path.
"""
version_path = os.path.join(self.model, "version.yaml")
assert os.path.exists(version_path), f"version.yaml not exist at {version_path}"
with open(version_path, "r", encoding="utf-8") as f:
self.version = yaml.safe_load(f)["version"]
def _get_default_runner_type(
self,
architectures: list[str],
) -> RunnerType:
registry = self.registry
if get_pooling_config(self.model, self.revision):
return "pooling"
for arch in architectures:
if arch in registry.get_supported_archs():
if registry.is_pooling_model(architectures, self):
return "pooling"
if registry.is_text_generation_model(architectures, self):
return "generate"
match = try_match_architecture_defaults(arch)
if match:
_, (runner_type, _) = match
return runner_type
return "generate"
def _get_default_convert_type(
self,
architectures: list[str],
runner_type: RunnerType,
) -> ConvertType:
registry = self.registry
for arch in architectures:
if arch in registry.get_supported_archs():
if runner_type == "generate" and registry.is_text_generation_model(architectures, self):
return "none"
if runner_type == "pooling" and registry.is_pooling_model(architectures, self):
return "none"
match = try_match_architecture_defaults(arch, runner_type=runner_type)
if match:
_, (_, convert_type) = match
return convert_type
# This is to handle Sentence Transformers models that use *ForCausalLM
# and also multi-modal pooling models which are not defined as
# Sentence Transformers models
if runner_type == "pooling":
return "embed"
return "none"
def _get_runner_type(
self,
architectures: list[str],
runner: RunnerOption,
) -> RunnerType:
if runner != "auto":
return runner
runner_type = self._get_default_runner_type(architectures)
if runner_type != "generate":
logger.info(
"Resolved `--runner auto` to `--runner %s`. " "Pass the value explicitly to silence this message.",
runner_type,
)
return runner_type
def _get_convert_type(
self,
architectures: list[str],
runner_type: RunnerType,
convert: ConvertOption,
) -> ConvertType:
if convert != "auto":
return convert
convert_type = self._get_default_convert_type(architectures, runner_type)
if convert_type != "none":
logger.info(
"Resolved `--convert auto` to `--convert %s`. " "Pass the value explicitly to silence this message.",
convert_type,
)
return convert_type
def _get_supported_generation_tasks(
self,
architectures: list[str],
convert_type: ConvertType,
) -> list[_ResolvedTask]:
registry = self.registry
supported_tasks = list[_ResolvedTask]()
if registry.is_text_generation_model(architectures, self) or convert_type in _RUNNER_CONVERTS["generate"]:
supported_tasks.append("generate")
# TODO:Temporarily does not support transcription.
return supported_tasks
def _get_default_pooling_task(
self,
architectures: list[str],
) -> Literal["embed"]:
# Temporarily does not support classification and reward.
for arch in architectures:
match = try_match_architecture_defaults(arch, runner_type="pooling")
if match:
_, (_, convert_type) = match
assert convert_type != "none"
return convert_type
return "embed"
def _get_supported_pooling_tasks(
self,
architectures: list[str],
convert_type: ConvertType,
) -> list[_ResolvedTask]:
registry = self.registry
supported_tasks = list[_ResolvedTask]()
if registry.is_pooling_model(architectures, self) or convert_type in _RUNNER_CONVERTS["pooling"]:
supported_tasks.append("encode")
extra_task = self._get_default_pooling_task(architectures) if convert_type == "none" else convert_type
supported_tasks.append(extra_task)
return supported_tasks
def _get_supported_tasks(
self,
architectures: list[str],
runner_type: RunnerType,
convert_type: ConvertType,
) -> list[_ResolvedTask]:
if runner_type == "generate":
return self._get_supported_generation_tasks(architectures, convert_type)
if runner_type == "pooling":
return self._get_supported_pooling_tasks(architectures, convert_type)
assert_never(runner_type)
def _init_pooler_config(self) -> Optional["PoolerConfig"]:
if self.runner_type == "pooling":
if isinstance(self.override_pooler_config, dict):
self.override_pooler_config = PoolerConfig(**self.override_pooler_config)
pooler_config = self.override_pooler_config or PoolerConfig()
base_config = get_pooling_config(self.model, self.revision)
if base_config is not None:
for k, v in base_config.items():
if getattr(pooler_config, k) is None:
setattr(pooler_config, k, v)
default_pooling_type = self._model_info.default_pooling_type
if pooler_config.pooling_type is None:
pooler_config.pooling_type = default_pooling_type
return pooler_config
return None
def _get_download_model(self, model_name, model_type="default"):
# TODO: Provide dynamic graph for self-downloading and save to the specified download directory.
pass
def print(self):
"""
Print all configuration information.
"""
logger.info("Model Configuration Information :")
for k, v in self.__dict__.items():
logger.info("{:<20}:{:<6}{}".format(k, "", v))
logger.info("=============================================================")
class ParallelConfig:
"""Configuration for the distributed execution."""
def __init__(
self,
args,
):
self.sequence_parallel = False # Whether to enable sequence parallelism.
self.use_ep = False # Whether to enable Expert Parallelism
self.msg_queue_id = 1 # message queue id
self.tensor_parallel_rank = 0 # TP rank ID
self.tensor_parallel_size = 1 # TP degree
self.expert_parallel_rank = 0 # EP rank ID
self.expert_parallel_size = 1 # EP degree
self.data_parallel_rank = 0 # DP rank ID
self.data_parallel_size = 1 # DP degree
self.enable_expert_parallel = False
self.enable_chunked_moe = False
self.chunked_moe_size = 256
self.local_data_parallel_id = 0
# Engine worker queue port
self.engine_worker_queue_port: Union[int, str, list] = None
self.local_engine_worker_queue_port: Optional[int] = None
# cuda visible devices
self.device_ids: str = "0"
# First token id
self.first_token_id: int = 1
# Process ID of engine
self.engine_pid: Optional[int] = None
# Do profile or not
self.do_profile: bool = False
# Use internode_ll_two_stage or not
self.use_internode_ll_two_stage: bool = False
# disable sequence parallel moe
self.disable_sequence_parallel_moe: bool = False
# shutdown comm group if worker idle
self.shutdown_comm_group_if_worker_idle: bool = None
# ep_prefill_use_worst_num_tokens
self.ep_prefill_use_worst_num_tokens: bool = False
self.pod_ip: str = None
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).
self.disable_custom_all_reduce: bool = False
self.enable_flashinfer_allreduce_fusion: bool = False
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
self.engine_worker_queue_port = parse_ports(self.engine_worker_queue_port)
# currently, the expert parallel size is equal data parallel size
if self.enable_expert_parallel:
self.expert_parallel_size = self.data_parallel_size * self.tensor_parallel_size
else:
self.expert_parallel_size = 1
self.use_ep = self.expert_parallel_size > 1
if self.shutdown_comm_group_if_worker_idle is None:
self.shutdown_comm_group_if_worker_idle = not self.use_ep
if self.shutdown_comm_group_if_worker_idle and envs.FD_ENABLE_V1_UPDATE_WEIGHTS:
raise RuntimeError("shutdown_comm_group_if_worker_idle cannot be True when FD_ENABLE_V1_UPDATE_WEIGHTS=1")
# pd_disaggregation
use_pd_disaggregation: int = int(os.getenv("FLAGS_use_pd_disaggregation", 0))
use_pd_disaggregation_per_chunk: int = int(os.getenv("FLAGS_use_pd_disaggregation_per_chunk", 0))
if use_pd_disaggregation_per_chunk:
self.pd_disaggregation_mode = "per_chunk"
elif use_pd_disaggregation:
self.pd_disaggregation_mode = "per_query"
else:
self.pd_disaggregation_mode = "None"
# Prefill node one step stop (PD disaggregation specific)
# When enabled, prefill node stops after one decoding step
self.prefill_one_step_stop: bool = os.getenv("PREFILL_NODE_ONE_STEP_STOP", "0") == "1"
# disable_sequence_parallel_moe: qkv_linear + attn + out_linear + allreduce
# use_sequence_parallel_moe: allgather + qkv_linear + attn + all2all + out_linear
self.use_sequence_parallel_moe = (
(not self.disable_sequence_parallel_moe)
and self.expert_parallel_size > 1
and self.tensor_parallel_size > 1
)
logger.info(f"use_sequence_parallel_moe: {self.use_sequence_parallel_moe}")
def set_communicate_group(self):
# different tp group id
# prevent different tp_groups using the same group_id
tp_gid_offset = envs.FD_TP_GROUP_GID_OFFSET
dist.collective._set_custom_gid(self.data_parallel_rank + tp_gid_offset)
self.tp_group = dist.new_group(
range(
self.data_parallel_rank * self.tensor_parallel_size,
(self.data_parallel_rank + 1) * self.tensor_parallel_size,
)
)
dist.collective._set_custom_gid(None)
# same ep group id
if self.enable_expert_parallel:
dist.collective._set_custom_gid(self.data_parallel_size + tp_gid_offset)
self.ep_group = dist.new_group(range(self.expert_parallel_size))
dist.collective._set_custom_gid(None)
logger.info(
f"data_parallel_size: {self.data_parallel_size}, tensor_parallel_size: {self.tensor_parallel_size}, expert_parallel_size: {self.expert_parallel_size}, data_parallel_rank: {self.data_parallel_rank}, tensor_parallel_rank: {self.tensor_parallel_rank}, expert_parallel_rank: {self.expert_parallel_rank}, tp_group: {self.tp_group}."
)
def print(self):
"""
print all config
"""
logger.info("Parallel Configuration Information :")
for k, v in self.__dict__.items():
logger.info("{:<20}:{:<6}{}".format(k, "", v))
logger.info("=============================================================")
class SpeculativeConfig:
"""
Configuration for speculative decoding.
"""
# Class-level default values for all config options
_DEFAULTS = {
"method": None,
"mtp_strategy": "default",
"num_speculative_tokens": 1,
"num_model_steps": 1,
"max_candidate_len": 5,
"verify_window": 2,
"max_ngram_size": 5,
"min_ngram_size": 2,
# Suffix Decoding
"suffix_decoding_max_tree_depth": 64,
"suffix_decoding_max_cached_requests": -1,
"suffix_decoding_max_spec_factor": 1.0,
"suffix_decoding_min_token_prob": 0.1,
"model": None,
"quantization": None,
"num_gpu_block_expand_ratio": 1.0,
"model_type": "main",
"sharing_model": None,
"benchmark_mode": False,
"enf_gen_phase_tag": False,
"enable_draft_logprob": False,
"verify_strategy": "topp",
"accept_policy": "normal",
}
# Environment variable to config mapping for backward compatibility
# Format: env_var: (config_key, value_when_set)
_ENV_OVERRIDES = {
"SPECULATE_VERIFY_USE_TOPK": ("verify_strategy", "greedy"),
"SPECULATE_VERIFY_USE_TARGET_SAMPLING": ("verify_strategy", "target_match"),
}
def __init__(
self,
args,
):
# Valid value lists (not defaults, but valid options)
self.method_list = ["ngram", "mtp", "naive", "suffix"]
self.mtp_strategy_list = ["default", "with_ngram"]
# Initialize from defaults
self._init_from_defaults()
# Apply user-provided arguments (highest priority)
self._apply_user_args(args)
# Read model config (overrides defaults but not user args)
self.read_model_config()
self._apply_model_config()
# Apply environment variable overrides (backward compatibility)
self._apply_env_overrides(args)
# Initialize computed fields
self.num_extra_cache_layer = 0
# Convert and validate all parameters
self._convert_and_validate()
def _init_from_defaults(self):
"""Initialize all config options from class defaults."""
for key, value in self._DEFAULTS.items():
setattr(self, key, value)
def _apply_user_args(self, args: Dict[str, Any]):
"""Apply user-provided arguments."""
if args is None:
return
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
def _apply_model_config(self):
"""Apply configuration from model config file."""
if not self.enabled_speculative_decoding():
return
if self.model is None:
return
# Model config can override certain defaults
# Currently no automatic overrides, but can be extended here
pass
def _apply_env_overrides(self, user_args: Dict[str, Any]):
"""
Apply environment variable overrides for backward compatibility.
Only applies if user hasn't explicitly set the corresponding config.
"""
for env_var, (config_key, env_value) in self._ENV_OVERRIDES.items():
if os.environ.get(env_var, "0") == "1":
# Only apply if user didn't explicitly set this config
if user_args is None or config_key not in user_args:
setattr(self, config_key, env_value)
def _convert_and_validate(self):
"""
Convert string configs to enums and validate all parameters.
"""
# Convert method from string to SpecMethod enum
if self.method is not None:
from fastdeploy.spec_decode import SpecMethod
self.method = SpecMethod.from_string(self.method)
# Set method-specific computed values
if self.method == SpecMethod.MTP:
self.num_extra_cache_layer = 1
# Run validation (includes dependency validation)
self.check_legality_parameters()
def read_model_config(self):
"""
Read configuration from file.
"""
self.model_config = {}
if not self.enabled_speculative_decoding():
return
self.is_unified_ckpt = check_unified_ckpt(self.model)
if self.model is None:
return
self.config_path = os.path.join(self.model, "config.json")
if os.path.exists(self.config_path):
self.model_config = json.load(open(self.config_path, "r", encoding="utf-8"))
def enabled_speculative_decoding(self):
"""
Check if speculative decoding is enabled.
"""
if self.method is None:
return False
return True
def to_json_string(self):
"""
Convert speculative_config to json string.
"""
return json.dumps({key: value for key, value in self.__dict__.items() if value is not None})
def print(self):
"""
print all config
"""
logger.info("Speculative Decoding Configuration Information :")
for k, v in self.__dict__.items():
logger.info("{:<20}:{:<6}{}".format(k, "", v))
logger.info("=============================================================")
def check_legality_parameters(
self,
) -> None:
"""Check the legality of parameters passed in from the command line"""
if self.method is not None:
from fastdeploy.spec_decode import SpecMethod
assert self.method in [
m.value for m in SpecMethod
], f"speculative method only support {[m.value for m in SpecMethod]} now, but get {self.method}."
if self.method != SpecMethod.NAIVE:
assert (
self.num_speculative_tokens >= 1 and self.num_speculative_tokens <= 5
), f"num_speculative_tokens only support in range[1, 5], but get {self.num_speculative_tokens}."
assert (
self.num_model_steps >= 1 and self.num_model_steps <= 5
), f"num_model_steps only support in range[1, 5], but get {self.num_model_steps}."
if self.method == SpecMethod.MTP:
if self.num_speculative_tokens < self.num_model_steps:
logger.warning(
f"Get num_model_steps > num_speculative_tokens. Reset num_speculative_tokens to {self.num_model_steps}"
)
self.num_speculative_tokens = self.num_model_steps
assert (
self.mtp_strategy in self.mtp_strategy_list
), f"mtp_strategy_list only support {self.mtp_strategy_list}, but get {self.mtp_strategy}"
# Validate verify strategy and accept policy
# Support case-insensitive input for better user experience
from fastdeploy.spec_decode import VerifyStrategy
if not isinstance(self.verify_strategy, VerifyStrategy):
# Handle both string and int inputs
if isinstance(self.verify_strategy, int):
# If it's already an int (enum value), convert directly
self.verify_strategy = VerifyStrategy(self.verify_strategy)
else:
# Assume it's a string
self.verify_strategy = VerifyStrategy.from_string(self.verify_strategy)
# Support case-insensitive accept_policy
valid_accept_policies = ["normal", "accept_all", "reject_all"]
accept_policy_lower = self.accept_policy.lower()
assert (
accept_policy_lower in valid_accept_policies
), f"accept_policy must be one of {valid_accept_policies} (case-insensitive), but got '{self.accept_policy}'."
self.accept_policy = accept_policy_lower
# Validate parameter dependencies after basic validation
self._validate_dependencies()
def _validate_dependencies(self) -> None:
"""
Validate parameter dependencies across different speculative methods.
Called by check_legality_parameters after basic validation.
"""
if not self.enabled_speculative_decoding():
return
from fastdeploy.spec_decode import SpecMethod
# Define parameter constraints for each speculative method
# Each constraint is a tuple: (dependent_param, operator, expected_relation)
constraints = {
SpecMethod.MTP: [
{
"check": lambda: self.num_speculative_tokens >= self.num_model_steps,
"message": f"MTP requires num_speculative_tokens >= num_model_steps, "
f"but got {self.num_speculative_tokens} < {self.num_model_steps}",
"auto_fix": lambda: setattr(self, "num_speculative_tokens", self.num_model_steps),
}
],
SpecMethod.NGRAM: [
{
"check": lambda: self.max_ngram_size >= self.min_ngram_size,
"message": f"NGRAM requires max_ngram_size >= min_ngram_size, "
f"but got {self.max_ngram_size} < {self.min_ngram_size}",
"auto_fix": None, # Cannot auto-fix, user must adjust
}
],
SpecMethod.NAIVE: [
{
"check": lambda: self.num_speculative_tokens == 0,
"message": f"NAIVE mode requires num_speculative_tokens == 0, "
f"but got {self.num_speculative_tokens}. Resetting to 0.",
"auto_fix": lambda: setattr(self, "num_speculative_tokens", 0),
}
],
}
if self.method in constraints:
method_constraints = constraints[self.method]
for constraint in method_constraints:
if not constraint["check"]():
if constraint["auto_fix"] is not None:
logger.warning(constraint["message"] + " Applying auto-fix.")
constraint["auto_fix"]()
else:
raise ValueError(constraint["message"])
def __str__(self) -> str:
return self.to_json_string()
class DeviceConfig:
"""
Configuration for device settings.
"""
def __init__(
self,
args,
):
self.device_type = "cuda"
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
class GraphOptimizationConfig:
"""
Configuration for compute graph level optimization.
"""
def __init__(
self,
args,
):
"""The Top-level graph optimization contral corresponds to different backends.
- 0: dyncmic graph
- 1: static graph
- 2: static graph + cinn compilation backend
"""
self.graph_opt_level: int = 0
# CUDA Graph Config
""" Whether to use cudagraph.
- False: cudagraph is not used.
- True: cudagraph is used.
It requires that all input buffers have fixed addresses, and all
splitting ops write their outputs to input buffers.
- With dyncmic graph backend: ...
- With static graph backend: WIP
"""
self.sot_warmup_sizes: list[int] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 32, 64, 128]
""" Number of warmup runs for SOT warmup. """
self.use_cudagraph: bool = False if paddle.is_compiled_with_xpu() else True
"""Sizes to capture cudagraph.
- None (default): capture sizes are inferred from llm config.
- list[int]: capture sizes are specified as given."""
self.cudagraph_capture_sizes: Optional[list[int]] = None
self.flag_cudagraph_capture_sizes_initlized = False
self.cudagraph_capture_sizes_prefill: list[int] = [1, 2, 4, 8]
""" Number of warmup runs for cudagraph. """
self.cudagraph_num_of_warmups: int = 2
"""Whether to copy input tensors for cudagraph.
If the caller can guarantee that the same input buffers
are always used, it can set this to False. Otherwise, it should
set this to True."""
self.cudagraph_copy_inputs: bool = False
""" In static graph, this is an operation list that does not need to be captured by the CUDA graph.
CudaGraphBackend will split these operations from the static graph.
Example usage:
cudagraph_splitting_ops = ["paddle.unified_attention"]
Note: If want to use subgraph capture functionality in a dynamic graph,
can manually split the model into multiple layers and apply the @support_graph_optimization decorator
only to the layer where CUDA graph functionality is required.
"""
self.cudagraph_splitting_ops: list[str] = []
""" Whether to use a full cuda graph for the entire forward pass rather than
splitting certain operations such as attention into subgraphs.
Thus this flag cannot be used together with splitting_ops."""
self.cudagraph_only_prefill: bool = False
"""When cudagraph_only_prefill is False, only capture decode-only.
When cudagraph_only_prefill is True, only capture prefill-only.
Now don't support capture both decode-only and prefill-only"""
self.full_cuda_graph: bool = True
""" Maximum CUDA Graph capture size """
self.max_capture_size: int = None
""" Record maps mapped from real shape to captured size to reduce runtime overhead """
self.real_shape_to_captured_size: dict[int, int] = None
""" Record maps mapped from real batch size to captured size"""
self.real_bsz_to_captured_size: dict[int, int] = {}
""" Whether to use shared memory pool for multi capture_size """
self.use_unique_memory_pool: bool = True
""" Whether to use cudagraph for draft model."""
self.draft_model_use_cudagraph: bool = False
""" Maximum CUDA Graph capture size for static graph mode.
Recommend 512 for small models (e.g., ERNIE45T 0.3B) and 128 for massive models (e.g., 300B).
"""
self.max_capture_shape_prefill: int = 512
# CINN Config ...
if args is not None:
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
self.check_legality_parameters()
def init_with_cudagrpah_size(
self,
max_capture_size: int = 0,
max_capture_shape_prefill: int = 0,
num_speculative_tokens: int = 0,
) -> None:
"""
Initialize cuda graph capture sizes and
pre-compute the mapping from batch size to padded graph size
"""
# Regular capture sizes
if num_speculative_tokens != 0:
max_capture_size = max_capture_size * (num_speculative_tokens + 1)
if not self.flag_cudagraph_capture_sizes_initlized and num_speculative_tokens != 0:
self.cudagraph_capture_sizes = [
size * (num_speculative_tokens + 1)
for size in self.cudagraph_capture_sizes
if (size * (num_speculative_tokens + 1)) <= max_capture_size
]
else:
self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size <= max_capture_size]
self.cudagraph_capture_sizes_prefill = [
size for size in self.cudagraph_capture_sizes_prefill if size <= max_capture_shape_prefill
]
dedup_sizes = list(set(self.cudagraph_capture_sizes))
if len(dedup_sizes) < len(self.cudagraph_capture_sizes):
logger.info(
("cudagraph sizes specified by model runner" " %s is overridden by config %s"),
self.cudagraph_capture_sizes,
dedup_sizes,
)
self.cudagraph_capture_sizes = dedup_sizes
# Sort to make sure cudagraph capture sizes are in descending order
self.cudagraph_capture_sizes.sort(reverse=True)
self.cudagraph_capture_sizes_prefill.sort(reverse=True)
self.max_capture_size = self.cudagraph_capture_sizes[0] if self.cudagraph_capture_sizes else 0
self.max_capture_size_prefill = (
self.cudagraph_capture_sizes_prefill[0] if self.cudagraph_capture_sizes_prefill else 0
)
# Pre-compute the mapping from shape to padded graph size
self.real_shape_to_captured_size = {}
for end, start in zip(self.cudagraph_capture_sizes, self.cudagraph_capture_sizes[1:] + [0]):
for bs in range(start, end):
if bs == start:
self.real_shape_to_captured_size[bs] = start
else:
self.real_shape_to_captured_size[bs] = end
self.real_shape_to_captured_size[self.max_capture_size] = self.max_capture_size
self.real_shape_to_captured_size_prefill = {}
for end, start in zip(self.cudagraph_capture_sizes_prefill, self.cudagraph_capture_sizes_prefill[1:] + [0]):
for bs in range(start, end):
if bs == start:
self.real_shape_to_captured_size_prefill[bs] = start
else:
self.real_shape_to_captured_size_prefill[bs] = end
self.real_shape_to_captured_size_prefill[self.max_capture_size_prefill] = self.max_capture_size_prefill
if num_speculative_tokens != 0:
real_bsz_to_captured_size = {}
for capture_size in self.cudagraph_capture_sizes:
dummy_batch_size = int(capture_size / (num_speculative_tokens + 1))
real_bsz_to_captured_size[dummy_batch_size] = capture_size
def expand_bsz_map(real_bsz_to_captured_size):
sorted_items = sorted(real_bsz_to_captured_size.items())
result = {}
prev_bsz = 0
for curr_bsz, cap in sorted_items:
for bsz in range(prev_bsz + 1, curr_bsz + 1):
result[bsz] = cap
prev_bsz = curr_bsz
return result
self.real_bsz_to_captured_size = expand_bsz_map(real_bsz_to_captured_size)
self.flag_cudagraph_capture_sizes_initlized = True
def _set_cudagraph_sizes(
self,
max_capture_size: int = 0,
max_capture_shape_prefill: int = 0,
):
"""
Calculate a series of candidate capture sizes,
and then extract a portion of them as the capture list for the CUDA graph based on user input.
"""
# Shape [1, 2, 4, 8, 16, ... 120, 128]
draft_capture_sizes = [i for i in [1, 2, 4]] + [8 * i for i in range(1, 17)]
# Shape [128, 144, ... 240, 256]
draft_capture_sizes += [16 * i for i in range(9, 17)]
# Shape [256, 288, ... 992, 1024]
draft_capture_sizes += [32 * i for i in range(9, 33)]
draft_capture_sizes_prefill = draft_capture_sizes.copy()
draft_capture_sizes.append(max_capture_size)
self.cudagraph_capture_sizes = sorted(draft_capture_sizes)
draft_capture_sizes_prefill.append(max_capture_shape_prefill)
self.cudagraph_capture_sizes_prefill = sorted(draft_capture_sizes_prefill)
def filter_capture_size(self, tp_size: int = 1):
"""When TSP is used, capture size must be divisible by tp size."""
self.cudagraph_capture_sizes = [
draft_size for draft_size in self.cudagraph_capture_sizes if (draft_size % tp_size == 0)
]
self.cudagraph_capture_sizes_prefill = [
draft_size for draft_size in self.cudagraph_capture_sizes_prefill if (draft_size % tp_size == 0)
]
def to_json_string(self):
"""
Convert speculative_config to json string.
"""
return json.dumps({key: value for key, value in self.__dict__.items()})
def __str__(self) -> str:
return self.to_json_string()
def check_legality_parameters(
self,
) -> None:
"""Check the legality of parameters passed in from the command line"""
if self.graph_opt_level is not None:
assert self.graph_opt_level in [
0,
1,
2,
], "In graph optimization config, graph_opt_level can only take the values of 0, 1 and 2."
if self.use_cudagraph is not None:
assert (
type(self.use_cudagraph) is bool
), "In graph optimization config, type of use_cudagraph must is bool."
if self.cudagraph_capture_sizes is not None:
assert (
type(self.cudagraph_capture_sizes) is list
), "In graph optimization config, type of cudagraph_capture_sizes must is list."
assert (
len(self.cudagraph_capture_sizes) > 0
), "In graph optimization config, When opening the CUDA graph, it is forbidden to set the capture sizes to an empty list."
class PlasAttentionConfig:
def __init__(
self,
args,
):
self.plas_encoder_top_k_left: int = None
self.plas_encoder_top_k_right: int = None
"The sparse topk of encoder attention is located at [plas_encoder_top_k_left, plas_encoder top_k_right]"
self.plas_decoder_top_k_left: int = None
self.plas_decoder_top_k_right: int = None
"The sparse topk of decoder attention is located at [plas_decoder_top_k_left, plas_decoder top_k_right]"
self.plas_use_encoder_seq_limit: int = None
"When the number of encdoer token is less than plas_use_encoder_seq_limit, it is not sparse"
self.plas_use_decoder_seq_limit: int = None
"When the number of decdoer token is less than plas_use_decoder_seq_limit, it is not sparse"
self.plas_block_size: int = 128
self.mlp_weight_name: str = "plas_attention_mlp_weight.safetensors"
self.plas_max_seq_length: int = 128 * 1024
if args is not None:
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
if self.plas_use_encoder_seq_limit is None and self.plas_encoder_top_k_left is not None:
self.plas_use_encoder_seq_limit = self.plas_encoder_top_k_left * self.plas_block_size
if self.plas_use_decoder_seq_limit is None and self.plas_decoder_top_k_left is not None:
self.plas_use_decoder_seq_limit = self.plas_decoder_top_k_left * self.plas_block_size
self.check_legality_parameters()
def check_legality_parameters(
self,
) -> None:
if self.plas_encoder_top_k_left is not None:
assert self.plas_encoder_top_k_left > 0, "plas_encoder_top_k_left must large than 0"
if self.plas_encoder_top_k_right is not None:
assert self.plas_encoder_top_k_right > 0, "plas_encoder_top_k_right must large than 0"
assert (
self.plas_encoder_top_k_right >= self.plas_encoder_top_k_left
), "plas_encoder_top_k_right must large than plas_encoder_top_k_left"
if self.plas_decoder_top_k_left is not None:
assert self.plas_decoder_top_k_left > 0, "plas_decoder_top_k_left must large than 0"
if self.plas_decoder_top_k_right is not None:
assert self.plas_decoder_top_k_right > 0, "plas_decoder_top_k_right must large than 0"
assert (
self.plas_decoder_top_k_right >= self.plas_decoder_top_k_left
), "plas_decoder_top_k_right must large than plas_decoder_top_k_left"
if self.plas_use_encoder_seq_limit is not None and self.plas_encoder_top_k_left is not None:
assert self.plas_use_encoder_seq_limit >= self.plas_encoder_top_k_left * self.plas_block_size
if self.plas_use_decoder_seq_limit is not None and self.plas_decoder_top_k_left is not None:
assert self.plas_use_decoder_seq_limit >= self.plas_decoder_top_k_left * self.plas_block_size
def to_json_string(self):
"""
Convert plas_attention_config to json string.
"""
return json.dumps({key: value for key, value in self.__dict__.items() if value is not None})
def __str__(self) -> str:
return json.dumps({key: value for key, value in self.__dict__.items()})
class EarlyStopConfig:
def __init__(
self,
args,
):
"""
Early Stop Configuration class.
Attributes:
window_size: size of the window
threshold: trigger early stop when the ratio of probs exceeds the threshold
"""
"""enable to use early stop"""
self.enable_early_stop: bool = False
"""strategy for early stop, the strategy lists are ['repetition']"""
self.strategy: str = "repetition"
""" the maximum length of verify window for early stop """
self.window_size: int = 3000
""" the probs threshold for early stop """
self.threshold: float = 0.99
if args is not None:
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
self.check_legality_parameters()
def to_json_string(self):
"""
Convert early_stop_config to json string.
"""
return json.dumps({key: value for key, value in self.__dict__.items()})
def __str__(self) -> str:
return self.to_json_string()
def check_legality_parameters(
self,
) -> None:
"""Check the legality of parameters passed in from the command line"""
if self.enable_early_stop is not None:
assert isinstance(
self.enable_early_stop, bool
), "In early stop config, type of enable_early_stop must is bool."
if self.window_size is not None:
assert isinstance(self.window_size, int), "In early stop config, type of window_size must be int."
assert self.window_size > 0, "window_size must large than 0"
if self.threshold is not None:
assert isinstance(self.threshold, float), "In early stop config, type of threshold must be float."
assert self.threshold >= 0 and self.threshold <= 1, "threshold must between 0 and 1"
def update_enable_early_stop(self, argument: bool):
"""
Unified user specifies the enable_early_stop parameter through two methods,
'--enable-early-stop' and '--early-stop-config'
"""
if self.enable_early_stop is None:
# User only set '--enable-early-stop'
self.enable_early_stop = argument
else:
# User both set '--enable-early-stop' and '--early-stop-config'
if self.enable_early_stop is False and argument is True:
raise ValueError(
"Invalid parameter: Cannot set ---enable-early-stop and --early-stop-config '{\"enable_early_stop\":false}' simultaneously."
)
argument = self.enable_early_stop
class DeployModality(str, Enum):
"""Modality mode for the serving engine deployment.
Determines which input modalities the serving engine should handle:
- TEXT: Text-only deployment. The engine only processes text inputs,
skipping multimodal preprocessing (e.g., vision encoder, audio
encoder). This reduces GPU memory usage and startup time when
multimodal capabilities are not needed.
- MIXED: Multimodal deployment (default). The engine handles mixed-modality
inputs including text, images, audio, and video. All modality-specific
encoders and preprocessing pipelines are initialized at startup.
Usage:
--deploy-modality text # text-only, lower resource footprint
--deploy-modality mixed # full multimodal support (default)
"""
TEXT = "text"
MIXED = "mixed"
@classmethod
def from_str(cls, value: str) -> "DeployModality":
"""Parse a string into a DeployModality enum, with validation."""
value = value.strip().lower()
try:
return cls(value)
except ValueError:
valid = ", ".join(f"'{m.value}'" for m in cls)
raise ValueError(f"Invalid deploy_modality '{value}'. Must be one of: {valid}")
class LoadChoices(str, Enum):
"""LoadChoices"""
DEFAULT = "default"
DEFAULT_V1 = "default_v1"
DUMMY = "dummy"
class LoadConfig:
"""
Configuration for dynamic weight loading strategies
Attributes:
dynamic_load_weight: Whether to enable dynamic weight loading
load_strategy: Specifies the weight loading method when enabled:
- 'ipc': Real-time IPC streaming with automatic resharding
- 'ipc_snapshot': Load from disk snapshot of IPC weights
- 'meta': Only model meta messages
- None: No dynamic loading
"""
def __init__(
self,
args,
):
self.load_choices: Union[str, LoadChoices] = LoadChoices.DEFAULT.value
self.is_pre_sharded: bool = False
self.dynamic_load_weight: bool = False
self.load_strategy: Optional[Literal["ipc", "ipc_snapshot", "meta", "normal", "rsync"]] = "normal"
self.rsync_config: Optional[Dict[str, Any]] = None
self.model_loader_extra_config: Optional[Dict[str, Any]] = None
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
def __str__(self) -> str:
return json.dumps({key: value for key, value in self.__dict__.items()})
class PoolerConfig:
"""Controls the behavior of output pooling in pooling models."""
pooling_type: Optional[str] = None
"""
The pooling method of the pooling model.
"""
# for embeddings models
normalize: Optional[bool] = None
"""
Whether to normalize the embeddings outputs. Defaults to True.
"""
dimensions: Optional[int] = None
"""
Reduce the dimensions of embeddings if model
support matryoshka representation. Defaults to None.
"""
enable_chunked_processing: Optional[bool] = None
"""
Whether to enable chunked processing for long inputs that exceed the model's
maximum position embeddings. When enabled, long inputs will be split into
chunks, processed separately, and then aggregated using weighted averaging.
This allows embedding models to handle arbitrarily long text without CUDA
errors. Defaults to False.
"""
max_embed_len: Optional[int] = None
"""
Maximum input length allowed for embedding generation. When set, allows
inputs longer than max_embed_len to be accepted for embedding models.
When an input exceeds max_embed_len, it will be handled according to
the original max_model_len validation logic.
Defaults to None (i.e. set to max_model_len).
"""
class EPLBConfig:
"""
Configuration for EPLB manager.
"""
def __init__(
self,
args,
):
if args is None:
args = {}
# enable eplb
self.enable_eplb: bool = False
# redundant experts num
self.redundant_experts_num: int = 0
# expert ip shm size
self.redundant_expert_ip_shm_size: int = 1024
# expert meta dir
self.redundant_expert_meta_dir: str = "/tmp/redundant_expert_meta"
# expert api user and password
self.redundant_expert_api_user: str = ""
self.redundant_expert_api_password: str = ""
# expert eplb strategy
self.redundant_expert_eplb_strategy: str = ""
# expert dump workload interval
self.redundant_expert_dump_workload_interval: int = 10
# expert async load model shmem size gb
self.redundant_expert_async_load_model_shmem_size_gb: int = 0
# expert enable schedule cordon
self.redundant_expert_enable_schedule_cordon: bool = True
# model use safetensors
self.model_use_safetensors: bool = True
# model use offline quant
self.model_use_offline_quant: bool = True
# moe quant type
self.moe_quant_type: str = "w4a8"
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
def to_json_string(self):
"""
Convert eplb_config to json string.
"""
return json.dumps({key: value for key, value in self.__dict__.items() if value is not None})
def print(self):
"""
Print all configuration information.
"""
logger.info("EPLB Configuration Information :")
for k, v in self.__dict__.items():
logger.info("{:<20}:{:<6}{}".format(k, "", v))
logger.info("=============================================================")
class CacheConfig:
"""
Configuration for the KV cache.
Attributes:
block_size (int): Size of a cache block in number of tokens.
gpu_memory_utilization (float): Fraction of GPU memory to use for model execution.
cache_dtype (str): Data type for kv cache storage. Default is 'bfloat16'.
num_gpu_blocks_override (Optional[int]): Number of GPU blocks to use.
Overrides profiled num_gpu_blocks if provided.
kv_cache_ratio (float): Ratio for calculating the maximum block number.
enc_dec_block_num (int): Number of encoder-decoder blocks.
prealloc_dec_block_slot_num_threshold (int): Number of token slot threadshold to allocate next blocks for decoding.
enable_prefix_caching (bool): Flag to enable prefix caching.
enable_output_caching (bool): Flag to enable kv cache output tokens, only works in V1 scheduler.
"""
def __init__(self, args):
"""
Initialize the CacheConfig class.
Args:
block_size (int): Size of a cache block in number of tokens.
gpu_memory_utilization (float): Fraction of GPU memory to use.
cache_dtype (str): Data type for cache storage. Default is 'bfloat16'.
num_gpu_blocks_override (Optional[int]): Override for number of GPU blocks.
num_cpu_blocks (Optional[int]): Number of CPU blocks.
kv_cache_ratio (float): Ratio for max block calculation.
enc_dec_block_num (int): Number of encoder-decoder blocks.
prealloc_dec_block_slot_num_threshold (int): Number of token slot threshold to allocate next blocks for decoding, used when ENABLE_V1_KVCACHE_SCHEDULER=1.
enable_prefix_caching (bool): Enable prefix caching.
max_encoder_cache(int): Maximum number of tokens in the encoder cache.
max_processor_cache(int): Maximum number of bytes in the processor cache.
"""
self.block_size = 64
self.gpu_memory_utilization = 0.9
self.num_gpu_blocks_override = None
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.kv_cache_ratio = 1.0
else:
self.kv_cache_ratio = 0.75
self.enc_dec_block_num = envs.FD_ENC_DEC_BLOCK_NUM
self.prealloc_dec_block_slot_num_threshold = 12
self.cache_dtype = "bfloat16"
self.model_cfg = None
self.enable_chunked_prefill = False
self.rdma_comm_ports = None
self.local_rdma_comm_ports = None
self.cache_transfer_protocol = None
self.pd_comm_port = None
self.local_pd_comm_port = None
self.enable_prefix_caching = False
self.enable_ssd_cache = False
self.cache_queue_port = None
self.local_cache_queue_port = None
self.swap_space = None
self.max_encoder_cache = None
self.max_processor_cache = None
self.enable_output_caching = False
self.disable_chunked_mm_input = False
self.kvcache_storage_backend = None
self.write_policy = None
self.num_cpu_blocks = None
self.use_mla_cache = envs.FD_ATTENTION_BACKEND == "MLA_ATTN"
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
self.cache_queue_port = parse_ports(self.cache_queue_port)
self.rdma_comm_ports = parse_ports(self.rdma_comm_ports)
self.pd_comm_port = parse_ports(self.pd_comm_port)
if self.model_cfg is not None:
if self.model_cfg.quantization is not None and isinstance(self.model_cfg.quantization, dict):
self.cache_dtype = self.model_cfg.quantization.get("kv_cache_quant_type", self.cache_dtype)
if self.model_cfg.quantization_config is not None:
self.cache_dtype = self.model_cfg.quantization_config.get("kv_cache_quant_type", self.cache_dtype)
if any(t in self.cache_dtype.lower() for t in ["int4", "int8", "float4", "float8"]):
self.cache_dtype = "uint8"
self.head_num = getattr(self.model_cfg, "num_key_value_heads", None) or getattr(
self.model_cfg, "num_attention_heads", None
)
self.head_dim = getattr(self.model_cfg, "head_dim")
self.byte_size = self.get_cache_bytes(self.cache_dtype)
self.kv_factor = 1 if self.use_mla_cache else 2
self.bytes_per_token_per_layer = int(self.head_num * self.head_dim * self.byte_size * self.kv_factor)
self.bytes_per_block = int(
self.bytes_per_token_per_layer * self.block_size * self.model_cfg.num_hidden_layers
)
if self.num_cpu_blocks is None:
if self.swap_space is None:
self.num_cpu_blocks = 0
else:
self.num_cpu_blocks = int(self.swap_space * 1024**3 / self.bytes_per_block)
self._verify_args()
@staticmethod
def get_cache_bytes(cache_dtype: str):
if any(t in cache_dtype.lower() for t in ["float32", "fp32"]):
return 4
elif any(t in cache_dtype.lower() for t in ["float16", "bf16", "fp16"]):
return 2
elif any(t in cache_dtype.lower() for t in ["uint8", "int8", "float8", "fp8"]):
return 1
elif any(t in cache_dtype.lower() for t in ["int4", "float4"]):
return 0.5
else:
raise ValueError(f"Unsupported cache dtype: {cache_dtype}")
def metrics_info(self):
"""Convert cache_config to dict(key: str, value: str) for prometheus metrics info."""
return {key: str(value) for key, value in self.__dict__.items()}
def _verify_args(self):
if self.gpu_memory_utilization > 1.0:
raise ValueError("GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.")
if self.kv_cache_ratio > 1.0:
raise ValueError("KV cache ratio must be less than 1.0. Got " f"{self.kv_cache_ratio}.")
def postprocess(self, num_total_tokens, number_of_tasks):
"""
calculate block num
"""
self.dec_token_num = self.enc_dec_block_num * self.block_size
if self.num_gpu_blocks_override is not None:
self.total_block_num = self.num_gpu_blocks_override
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.prefill_kvcache_block_num = self.total_block_num
else:
self.prefill_kvcache_block_num = int(self.total_block_num * self.kv_cache_ratio)
assert self.prefill_kvcache_block_num >= self.max_block_num_per_seq + self.enc_dec_block_num, (
f"prefill_kvcache_block_num: {self.prefill_kvcache_block_num} should be larger "
f"than or equal to {self.max_block_num_per_seq + self.enc_dec_block_num}, please reduce "
"the max_model_len or increase num_gpu_blocks_override"
)
else:
length = num_total_tokens // number_of_tasks
block_num = (length + self.block_size - 1 + self.dec_token_num) // self.block_size
self.total_block_num = block_num * number_of_tasks
self.prefill_kvcache_block_num = self.total_block_num
logger.info(f"Doing profile, the total_block_num:{self.total_block_num}")
def reset(self, num_gpu_blocks):
"""
reset gpu block number
"""
self.total_block_num = num_gpu_blocks
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.prefill_kvcache_block_num = self.total_block_num
else:
self.prefill_kvcache_block_num = int(self.total_block_num * self.kv_cache_ratio)
logger.info(
f"Reset block num, the total_block_num:{self.total_block_num},"
f" prefill_kvcache_block_num:{self.prefill_kvcache_block_num}"
)
assert self.prefill_kvcache_block_num >= self.max_block_num_per_seq + self.enc_dec_block_num, (
f"current device block num: {self.prefill_kvcache_block_num} "
f"should be larger than or equal to {self.max_block_num_per_seq + self.enc_dec_block_num}, please reduce "
"the max_model_len or replace the machine with larger GPU cards"
)
def print(self):
"""
print all config
"""
logger.info("Cache Configuration Information :")
for k, v in self.__dict__.items():
logger.info("{:<20}:{:<6}{}".format(k, "", v))
logger.info("=============================================================")
class RouterConfig:
"""
Configuration for router
Attributes:
router: the url of router, such as http://127.0.0.1:8000
api_server_host: the host ip of model server
api_server_port: the http port of model server
"""
def __init__(self, args: dict):
self.router = args["router"]
if self.router is not None and not self.router.startswith(("http://", "https://")):
self.router = f"http://{self.router}"
self.api_server_host = get_host_ip()
self.api_server_port = args["port"]
if args["metrics_port"] is not None:
self.metrics_port = args["metrics_port"]
else:
self.metrics_port = self.api_server_port
class CommitConfig:
"""
Configuration for tracking version information from version.txt
Attributes:
fastdeploy_commit: Full FastDeploy git commit hash
paddle_version: PaddlePaddle version string
paddle_commit: PaddlePaddle git commit hash
cuda_version: CUDA version string
compiler_version: CXX compiler version string
"""
def __init__(
self,
):
self.fastdeploy_commit: str = ""
self.paddle_version: str = ""
self.paddle_commit: str = ""
self.cuda_version: str = ""
self.compiler_version: str = ""
self._load_from_version_file()
def _load_from_version_file(self, file_path: str = None):
"""Internal method to load version info from file"""
if file_path is None:
file_path = os.path.join(fastdeploy.__path__[0], "version.txt")
try:
with open(file_path, "r") as f:
for line in f:
line = line.strip()
if line.startswith("fastdeploy GIT COMMIT ID:"):
self.fastdeploy_commit = line.split(":")[1].strip()
elif line.startswith("Paddle version:"):
self.paddle_version = line.split(":")[1].strip()
elif line.startswith("Paddle GIT COMMIT ID:"):
self.paddle_commit = line.split(":")[1].strip()
elif line.startswith("CUDA version:"):
self.cuda_version = line.split(":")[1].strip()
elif line.startswith("CXX compiler version:"):
self.compiler_version = line.split(":")[1].strip()
except FileNotFoundError:
logger.info(f"Warning: Version file not found at {file_path}")
except Exception as e:
logger.info(f"Warning: Could not read version file - {e!s}")
def print(self):
"""
print all config
"""
logger.info("Fasedeploy Commit Information :")
for k, v in self.__dict__.items():
logger.info("{:<20}:{:<6}{}".format(k, "", v))
logger.info("=============================================================")
class StructuredOutputsConfig:
"""
Configuration for structured outputs
"""
def __init__(
self,
args,
) -> None:
self.reasoning_parser: Optional[str] = None
self.guided_decoding_backend: Optional[str] = None
# disable any whitespace for guided decoding
self.disable_any_whitespace: bool = True
self.logits_processors: Optional[list[str]] = None
for key, value in args.items():
if hasattr(self, key) and value != "None":
setattr(self, key, value)
def __str__(self) -> str:
return json.dumps({key: value for key, value in self.__dict__.items()})
class RoutingReplayConfig:
"""Configuration for Routing Replay used in RL training"""
def __init__(self, args) -> None:
self.enable_routing_replay: bool = False
# Routing store type: local/rdma
self.routing_store_type: str = "local"
# Local routing store
self.local_store_dir: str = "./routing_replay_output"
# RDMA routing store
self.rdma_store_server: str = ""
# Only save last turn
self.only_last_turn: bool = False
# Fused routing of all layers
self.use_fused_put: bool = False
if args is not None:
for key, value in args.items():
if hasattr(self, key) and value != "None":
setattr(self, key, value)
def to_json_string(self):
"""
Convert routing replay config to json string.
"""
return json.dumps({key: value for key, value in self.__dict__.items()})
class FDConfig:
"""
The configuration class which contains all fastdeploy-related configuration. This
simplifies passing around the distinct configurations in the codebase.
"""
def __init__(
self,
model_config: ModelConfig = None,
cache_config: CacheConfig = None,
parallel_config: ParallelConfig = None,
load_config: LoadConfig = None,
commit_config: CommitConfig = CommitConfig(),
scheduler_config: SchedulerConfig = None,
device_config: DeviceConfig = None,
quant_config: QuantConfigBase = None,
graph_opt_config: GraphOptimizationConfig = None,
plas_attention_config: PlasAttentionConfig = None,
speculative_config: SpeculativeConfig = None,
eplb_config: EPLBConfig = None,
structured_outputs_config: StructuredOutputsConfig = None,
router_config: RouterConfig = None,
tokenizer: str = None,
ips: str = None,
use_warmup: bool = False,
limit_mm_per_prompt: Optional[Dict[str, Any]] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
max_num_partial_prefills: int = 1,
max_long_partial_prefills: int = 1,
long_prefill_token_threshold: int = 0,
early_stop_config: Optional[Dict[str, Any]] = None,
tool_parser: str = None,
test_mode=False,
routing_replay_config: Optional[RoutingReplayConfig] = None,
deploy_modality: DeployModality = DeployModality.MIXED,
):
self.model_config: ModelConfig = model_config # type: ignore
self.cache_config: CacheConfig = cache_config # type: ignore
self.scheduler_config: SchedulerConfig = scheduler_config # type: ignore
self.parallel_config = parallel_config # type: ignore
self.speculative_config: SpeculativeConfig = speculative_config
self.eplb_config: Optional[EPLBConfig] = eplb_config
self.device_config: DeviceConfig = device_config # type: ignore
self.load_config: LoadConfig = load_config
self.quant_config: Optional[QuantConfigBase] = quant_config
self.graph_opt_config: Optional[GraphOptimizationConfig] = graph_opt_config
self.early_stop_config: Optional[EarlyStopConfig] = early_stop_config
self.plas_attention_config: Optional[PlasAttentionConfig] = plas_attention_config
self.structured_outputs_config: StructuredOutputsConfig = structured_outputs_config
self.router_config: RouterConfig = router_config
self.routing_replay_config = routing_replay_config
self.deploy_modality: DeployModality = deploy_modality
# Initialize cuda graph capture list
max_capture_shape = self.scheduler_config.max_num_seqs
if self.graph_opt_config.cudagraph_only_prefill:
max_capture_shape = 512
else:
max_capture_shape = min(512, max_capture_shape)
max_capture_shape_prefill = graph_opt_config.max_capture_shape_prefill
if self.graph_opt_config.cudagraph_capture_sizes is None:
self.graph_opt_config._set_cudagraph_sizes(
max_capture_size=max_capture_shape,
max_capture_shape_prefill=max_capture_shape_prefill,
)
self.graph_opt_config.init_with_cudagrpah_size(
max_capture_size=max_capture_shape,
max_capture_shape_prefill=max_capture_shape_prefill,
num_speculative_tokens=(
self.speculative_config.num_speculative_tokens
if (
self.speculative_config is not None
and self.speculative_config.method
in [
SpecMethod.MTP,
SpecMethod.SUFFIX,
]
)
else 0
),
)
self.tokenizer = tokenizer
self.ips = ips
self.tool_parser = tool_parser
if self.ips is None:
self.master_ip = "0.0.0.0"
elif isinstance(self.ips, str):
self.ips = self.ips.split(",")
self.host_ip = get_host_ip()
if self.ips is None:
self.nnode = 1
self.node_rank = 0
else:
self.nnode = len(self.ips)
for idx, ip in enumerate(self.ips):
if ip == self.host_ip:
self.node_rank = idx
self.limit_mm_per_prompt = limit_mm_per_prompt
self.mm_processor_kwargs = mm_processor_kwargs
self.use_warmup = use_warmup
self.max_num_partial_prefills = max_num_partial_prefills
self.max_long_partial_prefills = max_long_partial_prefills
self.long_prefill_token_threshold = long_prefill_token_threshold
if envs.FD_FOR_TORCH_MODEL_FORMAT:
self.model_config.model_format = "torch"
# TODO
if not envs.FD_ENABLE_MAX_PREFILL:
self.max_prefill_batch = int(os.getenv("MAX_PREFILL_NUM", "3"))
if (
int(envs.ENABLE_V1_KVCACHE_SCHEDULER) == 0
and self.model_config is not None
and self.model_config.enable_mm
and self.deploy_modality != DeployModality.TEXT
):
self.max_prefill_batch = 1 # TODO:当前V0多模prefill阶段只支持并行度为1,待优化
else:
self.max_prefill_batch = self.scheduler_config.max_num_seqs
num_ranks = self.parallel_config.tensor_parallel_size * self.parallel_config.data_parallel_size
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
if num_ranks > self.max_chips_per_node and self.load_config and self.load_config.load_strategy != "meta":
self.worker_num_per_node = self.max_chips_per_node
nnode = ceil_div(num_ranks, self.worker_num_per_node)
assert nnode == self.nnode, f"nnode: {nnode}, but got {self.nnode}"
else:
self.worker_num_per_node = num_ranks
self.parallel_config.device_ids = ",".join([str(i) for i in range(self.worker_num_per_node)])
self.parallel_config.device_ids = os.getenv("CUDA_VISIBLE_DEVICES", self.parallel_config.device_ids)
if current_platform.is_xpu():
self.parallel_config.device_ids = os.getenv("XPU_VISIBLE_DEVICES", self.parallel_config.device_ids)
if current_platform.is_intel_hpu():
self.parallel_config.device_ids = os.getenv("HPU_VISIBLE_DEVICES", self.parallel_config.device_ids)
if (
self.load_config
and self.load_config.dynamic_load_weight
and self.router_config
and self.router_config.router
):
# For RL scenario, version.yaml is required for models
# Temporarily enforce use router to be enabled.
self.model_config.read_model_version()
self.read_from_config()
self.postprocess()
self.init_pd_info()
if test_mode:
return
self.check()
# self.print() # NOTE: it's better to explicitly call .print() when FDConfig is initialized
@property
def enable_mm_runtime(self) -> bool:
return (
self.model_config is not None
and self.model_config.enable_mm
and self.deploy_modality != DeployModality.TEXT
)
@property
def enable_rope_3d_runtime(self) -> bool:
return self.enable_mm_runtime and (
getattr(self.model_config, "rope_3d", False) or getattr(self.model_config, "use_3d_rope", False)
)
def _disable_sequence_parallel_moe_if_needed(self, mode_name):
if self.parallel_config.use_sequence_parallel_moe and self.graph_opt_config.use_cudagraph:
self.parallel_config.use_sequence_parallel_moe = False
logger.warning(
f"Sequence parallel MoE does not support {mode_name} mode with cudagraph. "
"Setting use_sequence_parallel_moe to False."
)
def postprocess(self):
"""
calculate some parameters
"""
# Unified field model config
if self.model_config.architectures[0] == "Glm4MoeForCausalLM":
# The first moe layer id of GLM4.5 model
self.model_config.moe_layer_start_index = self.model_config.first_k_dense_replace
if self.parallel_config.tensor_parallel_size <= self.worker_num_per_node or self.node_rank == 0:
self.is_master = True
self.master_ip = "0.0.0.0"
else:
self.is_master = False
self.master_ip = self.ips[0]
self.paddle_commit_id = paddle.version.commit
if self.scheduler_config.max_num_batched_tokens is None:
if int(envs.ENABLE_V1_KVCACHE_SCHEDULER):
if int(envs.FD_DISABLE_CHUNKED_PREFILL):
self.scheduler_config.max_num_batched_tokens = self.model_config.max_model_len
else:
self.scheduler_config.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM
else:
if self.cache_config.enable_chunked_prefill:
self.scheduler_config.max_num_batched_tokens = 2048
else:
self.scheduler_config.max_num_batched_tokens = self.model_config.max_model_len
if self.long_prefill_token_threshold == 0:
self.long_prefill_token_threshold = int(self.model_config.max_model_len * 0.04)
if (
self.model_config is not None
and self.model_config.enable_mm
and self.deploy_modality == DeployModality.TEXT
):
if getattr(self.model_config, "rope_3d", False) or getattr(self.model_config, "use_3d_rope", False):
logger.info(
"Deploy modality is text; forcing the multimodal-capable model onto the 2D RoPE runtime path."
)
setattr(self.model_config, "rope_3d", False)
setattr(self.model_config, "use_3d_rope", False)
self.cache_config.max_block_num_per_seq = int(self.model_config.max_model_len // self.cache_config.block_size)
self.cache_config.postprocess(self.get_max_chunk_tokens(), self.scheduler_config.max_num_seqs)
if self.model_config is not None and self.enable_mm_runtime and not envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.cache_config.enable_prefix_caching = False
if (
self.structured_outputs_config is not None
and self.structured_outputs_config.guided_decoding_backend != "off"
):
if current_platform.is_xpu() or self.speculative_config.method is not None:
logger.warning("Speculative Decoding and XPU currently do not support Guided decoding, set off.")
self.structured_outputs_config.guided_decoding_backend = "off"
elif self.structured_outputs_config.guided_decoding_backend in ["auto", "xgrammar"]:
self.structured_outputs_config.guided_decoding_backend = "xgrammar"
elif self.structured_outputs_config.guided_decoding_backend == "guidance":
try:
import llguidance.torch
llguidance.torch
except ImportError:
raise ImportError(
"The 'llguidance' package is required for using guidance as the guided decoding backend. "
"Please install it via the appropriate method."
)
else:
raise NotImplementedError(
f"Guided decoding backend '{self.structured_outputs_config.guided_decoding_backend}' is not implemented. [auto, xgrammar, guidance, off]"
)
if self.enable_mm_runtime:
if self.cache_config.max_encoder_cache is None or self.cache_config.max_encoder_cache < 0:
self.cache_config.max_encoder_cache = self.scheduler_config.max_num_batched_tokens
elif self.cache_config.max_encoder_cache != 0:
if self.cache_config.max_encoder_cache < self.scheduler_config.max_num_batched_tokens:
logger.warning(
f"max_encoder_cache{self.cache_config.max_encoder_cache} is less than "
f"max_num_batched_tokens{self.scheduler_config.max_num_batched_tokens}, "
f"set to max_num_batched_tokens."
)
self.cache_config.max_encoder_cache = self.scheduler_config.max_num_batched_tokens
# TODO: mm encoder_cache close for now
self.cache_config.max_encoder_cache = 0
else:
self.cache_config.max_encoder_cache = 0
# Adjustment GraphOptConfig
if self.scheduler_config is not None and self.scheduler_config.splitwise_role == "prefill":
self.graph_opt_config.use_cudagraph = self.graph_opt_config.cudagraph_only_prefill
if self.load_config is not None and self.load_config.dynamic_load_weight is True:
self.graph_opt_config.graph_opt_level = 0
logger.info(
"Static Graph does not support to be started together with RL Training, and automatically switch to dynamic graph!"
)
if (
not current_platform.is_cuda()
and not current_platform.is_maca()
and not current_platform.is_xpu()
and not current_platform.is_iluvatar()
):
self.graph_opt_config.use_cudagraph = False
logger.info(
"Current Platform can not support CUDAGraph, CUDAGraph currently only support on GPU/XPU/Metax GPU !"
)
# adjust speculative config
if self.speculative_config is not None and self.speculative_config.method == SpecMethod.MTP:
if self.scheduler_config.splitwise_role == "prefill":
self.speculative_config.num_speculative_tokens = 1
self.speculative_config.num_model_steps = 1
# Auto-compute num_max_dispatch_tokens_per_rank from max_num_seqs and num_speculative_tokens
if self.speculative_config is not None and self.speculative_config.method is not None:
num_spec_tokens = self.speculative_config.num_speculative_tokens
auto_dispatch_tokens = self.scheduler_config.max_num_seqs * (num_spec_tokens + 1)
else:
auto_dispatch_tokens = self.scheduler_config.max_num_seqs
if (
getattr(self.model_config, "num_max_dispatch_tokens_per_rank", None)
and self.model_config.num_max_dispatch_tokens_per_rank != auto_dispatch_tokens
):
logger.info(
f"Auto-setting num_max_dispatch_tokens_per_rank from "
f"{self.model_config.num_max_dispatch_tokens_per_rank} to {auto_dispatch_tokens} "
f"(max_num_seqs={self.scheduler_config.max_num_seqs}"
f"{f', num_speculative_tokens={num_spec_tokens}' if self.speculative_config is not None and self.speculative_config.method is not None else ''})."
)
self.model_config.num_max_dispatch_tokens_per_rank = auto_dispatch_tokens
if self.scheduler_config.splitwise_role == "mixed":
self._disable_sequence_parallel_moe_if_needed("Mixed")
self.model_config.moe_phase = MoEPhase(phase="prefill")
elif self.scheduler_config.splitwise_role == "prefill":
self.model_config.moe_phase = MoEPhase(phase="prefill")
elif self.scheduler_config.splitwise_role == "decode":
self.model_config.moe_phase = MoEPhase(phase="decode")
else:
raise NotImplementedError
if self.parallel_config.use_sequence_parallel_moe and self.graph_opt_config.use_cudagraph:
if self.scheduler_config.max_num_seqs < self.parallel_config.tensor_parallel_size:
self.parallel_config.use_sequence_parallel_moe = False
logger.info(
"Warning: sequence parallel moe do not support max_num_seqs < tensor_parallel_size when cudagraph enabled. We set use_sequence_parallel_moe to False."
)
else:
# It will hang when real batch_size < tp_size
self.graph_opt_config.filter_capture_size(tp_size=self.parallel_config.tensor_parallel_size)
if ErnieArchitectures.is_ernie5_arch(self.model_config.architectures):
# ernie5 model not support chunked_mm_input
self.cache_config.disable_chunked_mm_input = True
self.postprocess_devices_and_ports()
def postprocess_devices_and_ports(self):
try:
# get devices and ports for current dp
self.local_device_ids = self.parallel_config.device_ids.split(",")[
self.parallel_config.local_data_parallel_id
* self.parallel_config.tensor_parallel_size : (self.parallel_config.local_data_parallel_id + 1)
* self.parallel_config.tensor_parallel_size
]
self.parallel_config.local_engine_worker_queue_port = self.parallel_config.engine_worker_queue_port[
self.parallel_config.local_data_parallel_id
]
self.cache_config.local_cache_queue_port = (
self.cache_config.cache_queue_port[self.parallel_config.local_data_parallel_id]
if self.cache_config.cache_queue_port
else None
)
self.cache_config.local_pd_comm_port = (
self.cache_config.pd_comm_port[self.parallel_config.local_data_parallel_id]
if self.cache_config.pd_comm_port
else None
)
self.cache_config.local_rdma_comm_ports = (
self.cache_config.rdma_comm_ports[
self.parallel_config.local_data_parallel_id
* self.parallel_config.tensor_parallel_size : (self.parallel_config.local_data_parallel_id + 1)
* self.parallel_config.tensor_parallel_size
]
if self.cache_config.rdma_comm_ports
else None
)
except Exception as e:
logger.error(f"Failed to extract local devices or ports. Servers may not be able to start properly. {e}")
def check(self):
"""
check the legality of config
"""
assert self.scheduler_config.max_num_seqs <= 512, (
"The parameter `max_num_seqs` is not allowed to exceed 512, "
f"but now it's {self.scheduler_config.max_num_seqs}."
)
assert self.nnode >= 1, f"nnode: {self.nnode} should no less than 1"
assert (
self.model_config.max_model_len >= 16
), f"max_model_len: {self.model_config.max_model_len} should be larger than 16"
assert (
self.scheduler_config.max_num_seqs >= 1
), f"max_num_seqs: {self.scheduler_config.max_num_seqs} should be larger than 1"
assert self.scheduler_config.max_num_batched_tokens >= self.scheduler_config.max_num_seqs, (
f"max_num_batched_tokens: {self.scheduler_config.max_num_batched_tokens} "
f"should be larger than or equal to max_num_seqs: {self.scheduler_config.max_num_seqs}"
)
assert (
self.scheduler_config.max_num_batched_tokens
<= self.model_config.max_model_len * self.scheduler_config.max_num_seqs
), (
f"max_num_batched_tokens: {self.scheduler_config.max_num_batched_tokens} should be less "
f"than or equal to max_num_seqs: {self.scheduler_config.max_num_seqs} * max_model_len: {self.model_config.max_model_len}"
)
assert (
self.max_num_partial_prefills >= 1
), f"max_num_partial_prefills: {self.max_num_partial_prefills} should be larger than or equal to 1"
assert (
self.max_long_partial_prefills >= 1
), f"max_long_partial_prefills: {self.max_long_partial_prefills} should be larger than or equal to 1"
assert self.max_long_partial_prefills <= self.max_num_partial_prefills, (
f"max_long_partial_prefills: {self.max_long_partial_prefills} should "
f"be less than or equal to max_num_partial_prefills: {self.max_num_partial_prefills}"
)
assert self.scheduler_config.splitwise_role in ["mixed", "prefill", "decode"]
if not self.cache_config.enable_chunked_prefill:
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
assert self.scheduler_config.max_num_batched_tokens >= self.model_config.max_model_len, (
f"max_num_batched_tokens: {self.scheduler_config.max_num_batched_tokens} "
f"should be larger than or equal to max_model_len: {self.model_config.max_model_len}"
)
else:
assert self.scheduler_config.max_num_batched_tokens >= self.cache_config.block_size, (
f"max_num_batched_tokens: {self.scheduler_config.max_num_batched_tokens} "
f"should be larger than or equal to block_size: {self.cache_config.block_size}"
)
if self.max_num_partial_prefills > 1:
assert (
self.cache_config.enable_chunked_prefill is True
), "Chunked prefill must be enabled to set max_num_partial_prefills > 1"
assert self.long_prefill_token_threshold < self.model_config.max_model_len, (
f"long_prefill_token_threshold: {self.long_prefill_token_threshold} should be less than"
f" max_model_len: {self.model_config.max_model_len}"
)
if (
self.structured_outputs_config is not None
and self.structured_outputs_config.guided_decoding_backend is not None
):
assert self.structured_outputs_config.guided_decoding_backend in [
"xgrammar",
"XGrammar",
"auto",
"off",
"guidance",
], f"Only support [auto, xgrammar, guidance, off] guided decoding backend, but got {self.structured_outputs_config.guided_decoding_backend}."
if self.structured_outputs_config.guided_decoding_backend != "off":
# TODO: speculative decoding support guided_decoding
assert (
self.speculative_config.method is None
), "speculative decoding currently do not support guided_decoding"
# TODO: xpu support guided_decoding
assert not current_platform.is_xpu(), "XPU currently do not support guided_decoding"
try:
import xgrammar # noqa
except Exception as e:
raise Exception(
f"import XGrammar failed, please install XGrammar use `pip install xgrammar==0.1.19`. \n\t {e}"
)
if self.scheduler_config is not None:
self.scheduler_config.check()
# Check graph optimization config
if self.graph_opt_config.graph_opt_level > 0:
if self.load_config is not None:
assert (
self.load_config.dynamic_load_weight is False
), "Static graph cannot be used in RL scene temporarily"
if int(envs.ENABLE_V1_KVCACHE_SCHEDULER) == 1:
assert (
int(envs.FD_DISABLED_RECOVER) == 0
), "FD_DISABLED_RECOVER is not supported while ENABLE_V1_KVCACHE_SCHEDULER is turned on."
if self.eplb_config is not None and self.eplb_config.enable_eplb:
try:
import cuda # noqa
except ImportError:
raise ImportError(
"cuda-python not installed. Install the version matching your CUDA toolkit:\n"
" CUDA 12.x → pip install cuda-python==12.*\n"
)
def print(self):
"""
print all config
"""
logger.info("=================== Configuration Information ===============")
for k, v in self.__dict__.items():
if k == "generation_config" and v is not None:
for gck, gcv in v.to_dict().items():
logger.info("{:<20}:{:<6}{}".format(gck, "", gcv))
elif (
k == "cache_config"
or k == "model_config"
or k == "scheduler_config"
or k == "parallel_config"
or k == "commit_config"
):
if v is not None:
v.print()
else:
logger.info("{:<20}:{:<6}{}".format(k, "", v))
logger.info("=============================================================")
def init_pd_info(self):
"""
initialize info for pd deployment
"""
# There are two methods for splitwise deployment:
# 1. v0 splitwise_scheduler or dp_scheduler
# 2. v1 local_scheduler + router (optional)
self.splitwise_version = None
if self.scheduler_config.name in ("splitwise", "dp"):
self.splitwise_version = "v0"
elif self.scheduler_config.name == "local":
self.splitwise_version = "v1"
# the information for registering this server to router or splitwise_scheduler
port = self.router_config.api_server_port if self.router_config else None
metrics_port = self.router_config.metrics_port if self.router_config else None
transfer_protocol = (
self.cache_config.cache_transfer_protocol.split(",") if self.cache_config.cache_transfer_protocol else []
)
self.register_info = {
"role": self.scheduler_config.splitwise_role,
"host_ip": self.host_ip,
"port": port,
"metrics_port": metrics_port,
"connector_port": self.cache_config.local_pd_comm_port,
"rdma_ports": self.cache_config.local_rdma_comm_ports,
"engine_worker_queue_port": self.parallel_config.local_engine_worker_queue_port,
"device_ids": self.local_device_ids,
"transfer_protocol": transfer_protocol,
"tp_size": self.parallel_config.tensor_parallel_size,
"is_paused": False,
"version": self.model_config.version,
"connected_decodes": [],
}
logger.info(f"register_info: {self.register_info}")
def read_from_config(self):
"""
reset model config from json file
"""
def reset_value(cls, value_name, key):
if hasattr(cls, key):
value = getattr(cls, key)
setattr(cls, value_name, value)
logger.info(f"Reset parameter {value_name} = {value} from configuration.")
reset_value(self.cache_config, "block_size", "infer_model_block_size")
reset_value(
self.model_config,
"return_full_hidden_states",
"return_full_hidden_states",
)
reset_value(self.cache_config, "cache_dtype", "infer_model_dtype")
def get_max_chunk_tokens(self, mm_max_tokens_per_item=None):
"""
get max chunk tokens
The maximum tokens size of a single inference in a multimodal model is influenced by the logic of chunking
"""
if mm_max_tokens_per_item is None:
mm_max_tokens_per_item = self.model_config.mm_max_tokens_per_item
if self.scheduler_config.splitwise_role == "decode":
if paddle.is_compiled_with_xpu():
num_tokens = self.scheduler_config.max_num_batched_tokens
else:
num_tokens = self.scheduler_config.max_num_seqs
else:
num_tokens = self.scheduler_config.max_num_batched_tokens
if self.enable_mm_runtime and mm_max_tokens_per_item is not None:
max_mm_tokens = max(
mm_max_tokens_per_item.get("image", 0),
mm_max_tokens_per_item.get("video", 0),
mm_max_tokens_per_item.get("audio", 0),
)
num_tokens = min(num_tokens + max_mm_tokens, self.model_config.max_model_len)
return num_tokens
def _check_master(self):
return self.is_master
def _str_to_list(self, attr_name, default_type):
if hasattr(self, attr_name):
val = getattr(self, attr_name)
if val is None:
return
if type(val) is str:
setattr(self, attr_name, [default_type(i) for i in val.split(",")])
else:
setattr(self, attr_name, [default_type(i) for i in val])
def __str__(self) -> str:
return json.dumps(self.__dict__, indent=4)