[FDConfig] Support CLI args for quantization params and add cudagraph validation (#7281)

* refactor quant cli param
This commit is contained in:
GoldPancake
2026-04-10 14:13:42 +08:00
committed by GitHub
parent 7614175e13
commit c1fb3112f8
5 changed files with 116 additions and 45 deletions
@@ -54,17 +54,56 @@ def _compute_hadamard_block_size(moe_intermediate_size: int, tp_size: int) -> in
return block_size
def _is_full_quantization_config(quantization_dict):
"""
Determine whether the parsed quantization dict is a simple method name or a full quantization_config.
Simple method name: {"quantization": "wint4"} (only one key "quantization")
Full config: {"quantization": "mix_quant", "dense_quant_type": "wint8", ...} (multiple keys)
Or torch format: {"quant_method": "fp8", "weight_block_size": [128, 128]} (has "quant_method" key)
"""
if "quant_method" in quantization_dict:
return True
if len(quantization_dict) > 1:
return True
return False
def parse_quant_config(args, model_config, is_ernie, is_v1_loader):
if args.quantization is not None and isinstance(args.quantization, str):
args.quantization = parse_quantization(args.quantization)
# Determine whether CLI --quantization is a simple method name or a full JSON quantization_config
cli_quantization = args.quantization
cli_is_full_config = (
cli_quantization is not None
and isinstance(cli_quantization, dict)
and _is_full_quantization_config(cli_quantization)
)
model_quantization_config = model_config.quantization_config
quantization_config = model_quantization_config
# If CLI provides a full quantization_config JSON, handle priority with config.json
if cli_is_full_config:
if model_quantization_config is not None:
if model_quantization_config != cli_quantization:
logger.warning(
"The quantization_config from --quantization argument "
"differs from the one in model's config.json. "
"Using config.json's quantization_config as it has higher priority. "
f"config.json: {model_quantization_config}, "
f"--quantization: {cli_quantization}"
)
else:
# config.json has no quantization_config, use CLI's full config
quantization_config = cli_quantization
# 1.model_config.is_quantized
# TODO(bukejiyu) model_config.is_quantized is v0 only need to be removed in future
if model_config.model_format == "torch":
quantization_config = model_config.quantization_config
if quantization_config is not None:
model_config.is_quantized = True
else:
quantization_config = model_config.quantization_config
if not model_config.is_quantized:
if quantization_config is not None:
if "is_quantized" in quantization_config:
@@ -84,12 +123,11 @@ def parse_quant_config(args, model_config, is_ernie, is_v1_loader):
quant_config_name = None
if quantization_config is not None:
if model_quantization_config is not None:
quant_config_name = _get_offline_quant_config_name(
quantization_config, model_config.model_format == "torch", is_v1_loader
model_quantization_config, model_config.model_format == "torch", is_v1_loader
)
elif args.quantization is not None:
elif cli_quantization is not None and not cli_is_full_config:
quantization_config = {}
try:
quantization_config.update(args.quantization)
@@ -117,8 +155,11 @@ def parse_quant_config(args, model_config, is_ernie, is_v1_loader):
quantization_config["hadamard_block_size"] = 512
quantization_config["quantization"] = "mix_quant"
quant_config_name = "mix_quant"
elif cli_quantization is not None and cli_is_full_config:
quant_config_name = quantization_config["quantization"]
else:
quant_config_name = None
if quant_config_name is None:
quant_config = None
else:
@@ -128,6 +169,7 @@ def parse_quant_config(args, model_config, is_ernie, is_v1_loader):
quantization_config["is_quantized"] = True
quant_cls = get_quantization_config(quant_config_name)
quant_config = quant_cls.from_config(quantization_config)
return quant_config