mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 01:29:57 +08:00
[FDConfig] Support CLI args for quantization params and add cudagraph validation (#7281)
* refactor quant cli param
This commit is contained in:
@@ -54,17 +54,56 @@ def _compute_hadamard_block_size(moe_intermediate_size: int, tp_size: int) -> in
|
||||
return block_size
|
||||
|
||||
|
||||
def _is_full_quantization_config(quantization_dict):
|
||||
"""
|
||||
Determine whether the parsed quantization dict is a simple method name or a full quantization_config.
|
||||
Simple method name: {"quantization": "wint4"} (only one key "quantization")
|
||||
Full config: {"quantization": "mix_quant", "dense_quant_type": "wint8", ...} (multiple keys)
|
||||
Or torch format: {"quant_method": "fp8", "weight_block_size": [128, 128]} (has "quant_method" key)
|
||||
"""
|
||||
if "quant_method" in quantization_dict:
|
||||
return True
|
||||
if len(quantization_dict) > 1:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def parse_quant_config(args, model_config, is_ernie, is_v1_loader):
|
||||
if args.quantization is not None and isinstance(args.quantization, str):
|
||||
args.quantization = parse_quantization(args.quantization)
|
||||
|
||||
# Determine whether CLI --quantization is a simple method name or a full JSON quantization_config
|
||||
cli_quantization = args.quantization
|
||||
cli_is_full_config = (
|
||||
cli_quantization is not None
|
||||
and isinstance(cli_quantization, dict)
|
||||
and _is_full_quantization_config(cli_quantization)
|
||||
)
|
||||
|
||||
model_quantization_config = model_config.quantization_config
|
||||
quantization_config = model_quantization_config
|
||||
|
||||
# If CLI provides a full quantization_config JSON, handle priority with config.json
|
||||
if cli_is_full_config:
|
||||
if model_quantization_config is not None:
|
||||
if model_quantization_config != cli_quantization:
|
||||
logger.warning(
|
||||
"The quantization_config from --quantization argument "
|
||||
"differs from the one in model's config.json. "
|
||||
"Using config.json's quantization_config as it has higher priority. "
|
||||
f"config.json: {model_quantization_config}, "
|
||||
f"--quantization: {cli_quantization}"
|
||||
)
|
||||
else:
|
||||
# config.json has no quantization_config, use CLI's full config
|
||||
quantization_config = cli_quantization
|
||||
|
||||
# 1.model_config.is_quantized
|
||||
# TODO(bukejiyu) model_config.is_quantized is v0 only need to be removed in future
|
||||
if model_config.model_format == "torch":
|
||||
quantization_config = model_config.quantization_config
|
||||
if quantization_config is not None:
|
||||
model_config.is_quantized = True
|
||||
else:
|
||||
quantization_config = model_config.quantization_config
|
||||
if not model_config.is_quantized:
|
||||
if quantization_config is not None:
|
||||
if "is_quantized" in quantization_config:
|
||||
@@ -84,12 +123,11 @@ def parse_quant_config(args, model_config, is_ernie, is_v1_loader):
|
||||
|
||||
quant_config_name = None
|
||||
|
||||
if quantization_config is not None:
|
||||
if model_quantization_config is not None:
|
||||
quant_config_name = _get_offline_quant_config_name(
|
||||
quantization_config, model_config.model_format == "torch", is_v1_loader
|
||||
model_quantization_config, model_config.model_format == "torch", is_v1_loader
|
||||
)
|
||||
|
||||
elif args.quantization is not None:
|
||||
elif cli_quantization is not None and not cli_is_full_config:
|
||||
quantization_config = {}
|
||||
try:
|
||||
quantization_config.update(args.quantization)
|
||||
@@ -117,8 +155,11 @@ def parse_quant_config(args, model_config, is_ernie, is_v1_loader):
|
||||
quantization_config["hadamard_block_size"] = 512
|
||||
quantization_config["quantization"] = "mix_quant"
|
||||
quant_config_name = "mix_quant"
|
||||
elif cli_quantization is not None and cli_is_full_config:
|
||||
quant_config_name = quantization_config["quantization"]
|
||||
else:
|
||||
quant_config_name = None
|
||||
|
||||
if quant_config_name is None:
|
||||
quant_config = None
|
||||
else:
|
||||
@@ -128,6 +169,7 @@ def parse_quant_config(args, model_config, is_ernie, is_v1_loader):
|
||||
quantization_config["is_quantized"] = True
|
||||
quant_cls = get_quantization_config(quant_config_name)
|
||||
quant_config = quant_cls.from_config(quantization_config)
|
||||
|
||||
return quant_config
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user