mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 08:21:53 +08:00
【Hackathon 9th No.86】autogen MultiQueryDecoderAttention template_instantiation -part (#4383)
* split MultiQueryDecoderAttention template_instantiation * update comment * CI
This commit is contained in:
@@ -11,143 +11,232 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""generate multiquery_attention_c8_kernel template instantiation."""
|
||||
"""Universal template instantiation generator - fully based on configuration file template instantiation generation."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
TEMPLATE_DIR = Path("gpu_ops/append_attn/template_instantiation/autogen")
|
||||
TEMPLATE_DIR.mkdir(exist_ok=True)
|
||||
|
||||
DISPATCH_PARAMS = {
|
||||
"GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16],
|
||||
"HEAD_DIM": [128],
|
||||
"BLOCK_SIZE": [64],
|
||||
"CAUSAL": [0, 1],
|
||||
"BLOCK_SHAPE_Q": [16, 32, 64, 128],
|
||||
"ENABLE_PREFILL": [0, 1],
|
||||
"IsFP8": [0, 1],
|
||||
"IsDynamicC8": [0, 1],
|
||||
}
|
||||
|
||||
DATA_TYPE_COMBINATIONS = [
|
||||
("paddle::float16", "paddle::float16", "float16_float16"),
|
||||
("paddle::float16", "paddle::float8_e4m3fn", "float16_fp8"),
|
||||
("paddle::float16", "int8_t", "float16_int8"),
|
||||
("paddle::bfloat16", "paddle::bfloat16", "bfloat16_bfloat16"),
|
||||
("paddle::bfloat16", "paddle::float8_e4m3fn", "bfloat16_fp8"),
|
||||
("paddle::bfloat16", "int8_t", "bfloat16_int8"),
|
||||
]
|
||||
|
||||
MAX_INSTANCES_PER_FILE = 60
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
def get_num_warp_q(block_shape_q):
|
||||
if block_shape_q <= 32:
|
||||
return 1
|
||||
else:
|
||||
return 4
|
||||
@dataclass
|
||||
class TemplateConfig:
|
||||
"""Template configuration class."""
|
||||
|
||||
name: str # Function name
|
||||
function_name: str # Actual function name
|
||||
impl_file: str # Implementation file path
|
||||
template_params: List[str] # Template parameter list (in order)
|
||||
dispatch_params: Dict[str, List[Any]] # Dispatch parameters
|
||||
data_types: Optional[List[Tuple[str, str, str]]] = None # Data type combinations (input_type, output_type, suffix)
|
||||
max_instances_per_file: int = 60 # Maximum instances per file
|
||||
file_prefix: str = "" # File prefix
|
||||
function_signature: str = "" # Function signature template
|
||||
|
||||
|
||||
def generate_file_header():
|
||||
return """// Generated by autogen_template_instantiation.py - Do not edit.
|
||||
class UniversalTemplateInstantiator:
|
||||
"""Universal template instantiator - fully based on configuration file."""
|
||||
|
||||
def __init__(self, config_file: str):
|
||||
"""Initialize the instantiator."""
|
||||
self.config_file = config_file
|
||||
self.configs = self._load_configs()
|
||||
|
||||
def _load_configs(self) -> Dict[str, TemplateConfig]:
|
||||
"""Load configuration file."""
|
||||
with open(self.config_file, "r", encoding="utf-8") as f:
|
||||
config_data = json.load(f)
|
||||
|
||||
configs = {}
|
||||
for name, config_dict in config_data.items():
|
||||
config = TemplateConfig(**config_dict)
|
||||
self._validate_config(config)
|
||||
configs[name] = config
|
||||
return configs
|
||||
|
||||
def _validate_config(self, config: TemplateConfig):
|
||||
"""Validate configuration completeness."""
|
||||
has_t = "T" in config.template_params
|
||||
has_out_t = "OutT" in config.template_params
|
||||
|
||||
if (has_t or has_out_t) and not config.data_types:
|
||||
raise ValueError(
|
||||
f"Configuration '{config.name}' has T or OutT in template_params but no data_types configured"
|
||||
)
|
||||
|
||||
special_params = {"T", "OutT", "NUM_WARP_Q"}
|
||||
for param_name in config.template_params:
|
||||
if param_name not in special_params and param_name not in config.dispatch_params:
|
||||
raise ValueError(f"Template parameter '{param_name}' in '{config.name}' not found in dispatch_params")
|
||||
|
||||
if "NUM_WARP_Q" in config.template_params and "BLOCK_SHAPE_Q" not in config.dispatch_params:
|
||||
raise ValueError(
|
||||
f"Template parameter 'NUM_WARP_Q' in '{config.name}' requires 'BLOCK_SHAPE_Q' in dispatch_params"
|
||||
)
|
||||
|
||||
def _calculate_num_warp_q(self, block_shape_q: int) -> int:
|
||||
"""Calculate number of warps."""
|
||||
if block_shape_q <= 32:
|
||||
return 1
|
||||
else:
|
||||
return 4
|
||||
|
||||
def _build_template_args(self, config: TemplateConfig, t_in: str, t_out: str, params: Dict[str, Any]) -> str:
|
||||
"""Build template arguments."""
|
||||
template_args_parts = []
|
||||
|
||||
for param_name in config.template_params:
|
||||
if param_name == "T":
|
||||
if t_in:
|
||||
template_args_parts.append(t_in)
|
||||
else:
|
||||
raise ValueError("Template parameter 'T' requires input type, but data_types is empty or invalid")
|
||||
elif param_name == "OutT":
|
||||
if t_out:
|
||||
template_args_parts.append(t_out)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Template parameter 'OutT' requires output type, but data_types is empty or invalid"
|
||||
)
|
||||
elif param_name == "NUM_WARP_Q":
|
||||
if "BLOCK_SHAPE_Q" in params:
|
||||
num_warp_q = self._calculate_num_warp_q(params["BLOCK_SHAPE_Q"])
|
||||
template_args_parts.append(str(num_warp_q))
|
||||
else:
|
||||
raise ValueError("Template parameter 'NUM_WARP_Q' requires 'BLOCK_SHAPE_Q' in dispatch_params")
|
||||
elif param_name in params:
|
||||
template_args_parts.append(str(params[param_name]))
|
||||
else:
|
||||
raise ValueError(f"Template parameter '{param_name}' not found in dispatch_params")
|
||||
|
||||
return f"<{', '.join(template_args_parts)}>"
|
||||
|
||||
def _generate_function_signature(self, config: TemplateConfig, template_args: str) -> str:
|
||||
"""Generate function signature."""
|
||||
if config.function_signature:
|
||||
return config.function_signature.format(function_name=config.function_name, template_args=template_args)
|
||||
else:
|
||||
raise ValueError(f"Function signature not found for {config.name}")
|
||||
|
||||
def _generate_file_header(self, config: TemplateConfig) -> str:
|
||||
"""Generate file header."""
|
||||
return f"""// Generated by autogen_template_instantiation.py - Do not edit.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../../multiquery_attention_c8_impl.cuh"
|
||||
#include "../../{config.impl_file}"
|
||||
"""
|
||||
|
||||
def _generate_template_instantiation(
|
||||
self, config: TemplateConfig, t_in: str, t_out: str, params: Dict[str, Any]
|
||||
) -> str:
|
||||
"""Generate template instantiation."""
|
||||
template_args = self._build_template_args(config, t_in, t_out, params)
|
||||
return self._generate_function_signature(config, template_args)
|
||||
|
||||
def generate_template_instantiation(t_in, t_out, params):
|
||||
num_warp_q = get_num_warp_q(params["BLOCK_SHAPE_Q"])
|
||||
template_args = f"<{t_in}, {params['GROUP_SIZE']}, {params['HEAD_DIM']}, {params['BLOCK_SIZE']}, {params['CAUSAL']}, {params['BLOCK_SHAPE_Q']}, {num_warp_q}, {t_out}, {params['ENABLE_PREFILL']}, {params['IsFP8']}, {params['IsDynamicC8']}>"
|
||||
def generate_combinations_for_type(self, config: TemplateConfig, t_in: str, t_out: str) -> List[Dict[str, Any]]:
|
||||
"""Generate parameter combinations for specific type."""
|
||||
combinations = []
|
||||
|
||||
return f"""
|
||||
template void MultiQueryAppendC8Attention{template_args}(
|
||||
const AppendAttnMetaData &meta_data,
|
||||
const paddle::Tensor &qkv,
|
||||
const paddle::Tensor &cache_k,
|
||||
const paddle::Tensor &cache_v,
|
||||
const paddle::optional<paddle::Tensor> &attn_mask,
|
||||
const paddle::Tensor &cache_k_scale,
|
||||
const paddle::Tensor &cache_v_scale,
|
||||
const paddle::optional<paddle::Tensor> &shift_bias,
|
||||
const paddle::optional<paddle::Tensor> &smooth_weight,
|
||||
const paddle::Tensor &seq_lens_q,
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &batch_id_per_token,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
const paddle::Tensor &batch_ids,
|
||||
const paddle::Tensor &tile_ids_per_batch,
|
||||
const int num_blocks_x_cpu,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const int max_partition_size,
|
||||
const int encoder_max_partition_size,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool is_decoder,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *out);
|
||||
def _generate_recursive(
|
||||
params_dict: Dict[str, List[Any]], current_params: Dict[str, Any], param_names: List[str]
|
||||
):
|
||||
if not param_names:
|
||||
combinations.append(current_params.copy())
|
||||
return
|
||||
|
||||
"""
|
||||
param_name = param_names[0]
|
||||
for value in params_dict[param_name]:
|
||||
current_params[param_name] = value
|
||||
_generate_recursive(params_dict, current_params, param_names[1:])
|
||||
|
||||
_generate_recursive(config.dispatch_params, {}, list(config.dispatch_params.keys()))
|
||||
return combinations
|
||||
|
||||
def generate_combinations_for_type(t_in, t_out):
|
||||
combinations = []
|
||||
for group_size in DISPATCH_PARAMS["GROUP_SIZE"]:
|
||||
for head_dim in DISPATCH_PARAMS["HEAD_DIM"]:
|
||||
for block_size in DISPATCH_PARAMS["BLOCK_SIZE"]:
|
||||
for causal in DISPATCH_PARAMS["CAUSAL"]:
|
||||
for block_shape_q in DISPATCH_PARAMS["BLOCK_SHAPE_Q"]:
|
||||
for enable_prefill in DISPATCH_PARAMS["ENABLE_PREFILL"]:
|
||||
for is_fp8 in DISPATCH_PARAMS["IsFP8"]:
|
||||
for is_dynamic_c8 in DISPATCH_PARAMS["IsDynamicC8"]:
|
||||
params = {
|
||||
"GROUP_SIZE": group_size,
|
||||
"HEAD_DIM": head_dim,
|
||||
"BLOCK_SIZE": block_size,
|
||||
"CAUSAL": causal,
|
||||
"BLOCK_SHAPE_Q": block_shape_q,
|
||||
"ENABLE_PREFILL": enable_prefill,
|
||||
"IsFP8": is_fp8,
|
||||
"IsDynamicC8": is_dynamic_c8,
|
||||
}
|
||||
combinations.append(params)
|
||||
def split_combinations(self, combinations: List[Dict[str, Any]], max_per_file: int) -> List[List[Dict[str, Any]]]:
|
||||
"""Split combinations into multiple files."""
|
||||
chunks = []
|
||||
for i in range(0, len(combinations), max_per_file):
|
||||
chunk = combinations[i : i + max_per_file]
|
||||
chunks.append(chunk)
|
||||
return chunks
|
||||
|
||||
return combinations
|
||||
def generate_file_content(
|
||||
self,
|
||||
config: TemplateConfig,
|
||||
t_in: str,
|
||||
t_out: str,
|
||||
t_out_name: str,
|
||||
file_index: int,
|
||||
combinations: List[Dict[str, Any]],
|
||||
) -> str:
|
||||
"""Generate file content."""
|
||||
content = self._generate_file_header(config)
|
||||
|
||||
for params in combinations:
|
||||
content += self._generate_template_instantiation(config, t_in, t_out, params)
|
||||
|
||||
def split_combinations(combinations, max_per_file):
|
||||
chunks = []
|
||||
for i in range(0, len(combinations), max_per_file):
|
||||
chunk = combinations[i : i + max_per_file]
|
||||
chunks.append(chunk)
|
||||
return chunks
|
||||
return content
|
||||
|
||||
def generate_for_function_type(self, function_name: str, output_dir: str):
|
||||
"""Generate template instantiation files for specific function type."""
|
||||
if function_name not in self.configs:
|
||||
raise ValueError(f"Function type '{function_name}' not found in config")
|
||||
|
||||
def generate_file_content(t_in, t_out, t_out_name, file_index, combinations):
|
||||
content = generate_file_header()
|
||||
for params in combinations:
|
||||
content += generate_template_instantiation(t_in, t_out, params)
|
||||
config = self.configs[function_name]
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(exist_ok=True)
|
||||
|
||||
return content
|
||||
if not config.data_types:
|
||||
data_types = [("", "", "")]
|
||||
else:
|
||||
data_types = config.data_types
|
||||
|
||||
for t_in, t_out, t_out_name in data_types:
|
||||
combinations = self.generate_combinations_for_type(config, t_in, t_out)
|
||||
if combinations:
|
||||
chunks = self.split_combinations(combinations, config.max_instances_per_file)
|
||||
for i, chunk in enumerate(chunks):
|
||||
filename = f"{config.file_prefix}{t_out_name}_part_{i:02d}.cu"
|
||||
filepath = output_path / filename
|
||||
content = self.generate_file_content(config, t_in, t_out, t_out_name, i, chunk)
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
|
||||
def generate_all(self, output_dir: str):
|
||||
"""Generate all configured function types."""
|
||||
for function_name in self.configs.keys():
|
||||
print(f"Generating template instantiations for {function_name}...")
|
||||
self.generate_for_function_type(function_name, output_dir)
|
||||
print(f"Completed generating {function_name} template instantiations.")
|
||||
|
||||
|
||||
def main():
|
||||
for t_in, t_out, t_out_name in DATA_TYPE_COMBINATIONS:
|
||||
combinations = generate_combinations_for_type(t_in, t_out)
|
||||
if combinations:
|
||||
chunks = split_combinations(combinations, MAX_INSTANCES_PER_FILE)
|
||||
for i, chunk in enumerate(chunks):
|
||||
filename = f"multiquery_attention_c8_{t_out_name}_part_{i:02d}.cu"
|
||||
filepath = TEMPLATE_DIR / filename
|
||||
content = generate_file_content(t_in, t_out, t_out_name, i, chunk)
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
"""Main function."""
|
||||
parser = argparse.ArgumentParser(description="Universal template instantiation generator")
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
"-c",
|
||||
type=str,
|
||||
default="gpu_ops/append_attn/template_config.json",
|
||||
help="Configuration file path (JSON format)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
"-o",
|
||||
type=str,
|
||||
default="gpu_ops/append_attn/template_instantiation/autogen",
|
||||
help="Output directory",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
instantiator = UniversalTemplateInstantiator(args.config)
|
||||
instantiator.generate_all(args.output)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user