【Hackathon 9th No.86】autogen MultiQueryDecoderAttention template_instantiation -part (#4383)

* split MultiQueryDecoderAttention template_instantiation

* update comment

* CI
This commit is contained in:
Zhenghai Zhang
2025-10-16 17:08:19 +08:00
committed by GitHub
parent f72be7a2c8
commit 6adfbe07ad
27 changed files with 3975 additions and 3836 deletions
@@ -11,143 +11,232 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""generate multiquery_attention_c8_kernel template instantiation."""
"""Universal template instantiation generator - fully based on configuration file template instantiation generation."""
import argparse
import json
from dataclasses import dataclass
from pathlib import Path
TEMPLATE_DIR = Path("gpu_ops/append_attn/template_instantiation/autogen")
TEMPLATE_DIR.mkdir(exist_ok=True)
DISPATCH_PARAMS = {
"GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16],
"HEAD_DIM": [128],
"BLOCK_SIZE": [64],
"CAUSAL": [0, 1],
"BLOCK_SHAPE_Q": [16, 32, 64, 128],
"ENABLE_PREFILL": [0, 1],
"IsFP8": [0, 1],
"IsDynamicC8": [0, 1],
}
DATA_TYPE_COMBINATIONS = [
("paddle::float16", "paddle::float16", "float16_float16"),
("paddle::float16", "paddle::float8_e4m3fn", "float16_fp8"),
("paddle::float16", "int8_t", "float16_int8"),
("paddle::bfloat16", "paddle::bfloat16", "bfloat16_bfloat16"),
("paddle::bfloat16", "paddle::float8_e4m3fn", "bfloat16_fp8"),
("paddle::bfloat16", "int8_t", "bfloat16_int8"),
]
MAX_INSTANCES_PER_FILE = 60
from typing import Any, Dict, List, Optional, Tuple
def get_num_warp_q(block_shape_q):
if block_shape_q <= 32:
return 1
else:
return 4
@dataclass
class TemplateConfig:
"""Template configuration class."""
name: str # Function name
function_name: str # Actual function name
impl_file: str # Implementation file path
template_params: List[str] # Template parameter list (in order)
dispatch_params: Dict[str, List[Any]] # Dispatch parameters
data_types: Optional[List[Tuple[str, str, str]]] = None # Data type combinations (input_type, output_type, suffix)
max_instances_per_file: int = 60 # Maximum instances per file
file_prefix: str = "" # File prefix
function_signature: str = "" # Function signature template
def generate_file_header():
return """// Generated by autogen_template_instantiation.py - Do not edit.
class UniversalTemplateInstantiator:
"""Universal template instantiator - fully based on configuration file."""
def __init__(self, config_file: str):
"""Initialize the instantiator."""
self.config_file = config_file
self.configs = self._load_configs()
def _load_configs(self) -> Dict[str, TemplateConfig]:
"""Load configuration file."""
with open(self.config_file, "r", encoding="utf-8") as f:
config_data = json.load(f)
configs = {}
for name, config_dict in config_data.items():
config = TemplateConfig(**config_dict)
self._validate_config(config)
configs[name] = config
return configs
def _validate_config(self, config: TemplateConfig):
"""Validate configuration completeness."""
has_t = "T" in config.template_params
has_out_t = "OutT" in config.template_params
if (has_t or has_out_t) and not config.data_types:
raise ValueError(
f"Configuration '{config.name}' has T or OutT in template_params but no data_types configured"
)
special_params = {"T", "OutT", "NUM_WARP_Q"}
for param_name in config.template_params:
if param_name not in special_params and param_name not in config.dispatch_params:
raise ValueError(f"Template parameter '{param_name}' in '{config.name}' not found in dispatch_params")
if "NUM_WARP_Q" in config.template_params and "BLOCK_SHAPE_Q" not in config.dispatch_params:
raise ValueError(
f"Template parameter 'NUM_WARP_Q' in '{config.name}' requires 'BLOCK_SHAPE_Q' in dispatch_params"
)
def _calculate_num_warp_q(self, block_shape_q: int) -> int:
"""Calculate number of warps."""
if block_shape_q <= 32:
return 1
else:
return 4
def _build_template_args(self, config: TemplateConfig, t_in: str, t_out: str, params: Dict[str, Any]) -> str:
"""Build template arguments."""
template_args_parts = []
for param_name in config.template_params:
if param_name == "T":
if t_in:
template_args_parts.append(t_in)
else:
raise ValueError("Template parameter 'T' requires input type, but data_types is empty or invalid")
elif param_name == "OutT":
if t_out:
template_args_parts.append(t_out)
else:
raise ValueError(
"Template parameter 'OutT' requires output type, but data_types is empty or invalid"
)
elif param_name == "NUM_WARP_Q":
if "BLOCK_SHAPE_Q" in params:
num_warp_q = self._calculate_num_warp_q(params["BLOCK_SHAPE_Q"])
template_args_parts.append(str(num_warp_q))
else:
raise ValueError("Template parameter 'NUM_WARP_Q' requires 'BLOCK_SHAPE_Q' in dispatch_params")
elif param_name in params:
template_args_parts.append(str(params[param_name]))
else:
raise ValueError(f"Template parameter '{param_name}' not found in dispatch_params")
return f"<{', '.join(template_args_parts)}>"
def _generate_function_signature(self, config: TemplateConfig, template_args: str) -> str:
"""Generate function signature."""
if config.function_signature:
return config.function_signature.format(function_name=config.function_name, template_args=template_args)
else:
raise ValueError(f"Function signature not found for {config.name}")
def _generate_file_header(self, config: TemplateConfig) -> str:
"""Generate file header."""
return f"""// Generated by autogen_template_instantiation.py - Do not edit.
#pragma once
#include "../../multiquery_attention_c8_impl.cuh"
#include "../../{config.impl_file}"
"""
def _generate_template_instantiation(
self, config: TemplateConfig, t_in: str, t_out: str, params: Dict[str, Any]
) -> str:
"""Generate template instantiation."""
template_args = self._build_template_args(config, t_in, t_out, params)
return self._generate_function_signature(config, template_args)
def generate_template_instantiation(t_in, t_out, params):
num_warp_q = get_num_warp_q(params["BLOCK_SHAPE_Q"])
template_args = f"<{t_in}, {params['GROUP_SIZE']}, {params['HEAD_DIM']}, {params['BLOCK_SIZE']}, {params['CAUSAL']}, {params['BLOCK_SHAPE_Q']}, {num_warp_q}, {t_out}, {params['ENABLE_PREFILL']}, {params['IsFP8']}, {params['IsDynamicC8']}>"
def generate_combinations_for_type(self, config: TemplateConfig, t_in: str, t_out: str) -> List[Dict[str, Any]]:
"""Generate parameter combinations for specific type."""
combinations = []
return f"""
template void MultiQueryAppendC8Attention{template_args}(
const AppendAttnMetaData &meta_data,
const paddle::Tensor &qkv,
const paddle::Tensor &cache_k,
const paddle::Tensor &cache_v,
const paddle::optional<paddle::Tensor> &attn_mask,
const paddle::Tensor &cache_k_scale,
const paddle::Tensor &cache_v_scale,
const paddle::optional<paddle::Tensor> &shift_bias,
const paddle::optional<paddle::Tensor> &smooth_weight,
const paddle::Tensor &seq_lens_q,
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &batch_id_per_token,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
const paddle::Tensor &batch_ids,
const paddle::Tensor &tile_ids_per_batch,
const int num_blocks_x_cpu,
const int max_seq_len,
const int max_dec_len,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const int max_partition_size,
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool is_decoder,
cudaStream_t &stream,
paddle::Tensor *out);
def _generate_recursive(
params_dict: Dict[str, List[Any]], current_params: Dict[str, Any], param_names: List[str]
):
if not param_names:
combinations.append(current_params.copy())
return
"""
param_name = param_names[0]
for value in params_dict[param_name]:
current_params[param_name] = value
_generate_recursive(params_dict, current_params, param_names[1:])
_generate_recursive(config.dispatch_params, {}, list(config.dispatch_params.keys()))
return combinations
def generate_combinations_for_type(t_in, t_out):
combinations = []
for group_size in DISPATCH_PARAMS["GROUP_SIZE"]:
for head_dim in DISPATCH_PARAMS["HEAD_DIM"]:
for block_size in DISPATCH_PARAMS["BLOCK_SIZE"]:
for causal in DISPATCH_PARAMS["CAUSAL"]:
for block_shape_q in DISPATCH_PARAMS["BLOCK_SHAPE_Q"]:
for enable_prefill in DISPATCH_PARAMS["ENABLE_PREFILL"]:
for is_fp8 in DISPATCH_PARAMS["IsFP8"]:
for is_dynamic_c8 in DISPATCH_PARAMS["IsDynamicC8"]:
params = {
"GROUP_SIZE": group_size,
"HEAD_DIM": head_dim,
"BLOCK_SIZE": block_size,
"CAUSAL": causal,
"BLOCK_SHAPE_Q": block_shape_q,
"ENABLE_PREFILL": enable_prefill,
"IsFP8": is_fp8,
"IsDynamicC8": is_dynamic_c8,
}
combinations.append(params)
def split_combinations(self, combinations: List[Dict[str, Any]], max_per_file: int) -> List[List[Dict[str, Any]]]:
"""Split combinations into multiple files."""
chunks = []
for i in range(0, len(combinations), max_per_file):
chunk = combinations[i : i + max_per_file]
chunks.append(chunk)
return chunks
return combinations
def generate_file_content(
self,
config: TemplateConfig,
t_in: str,
t_out: str,
t_out_name: str,
file_index: int,
combinations: List[Dict[str, Any]],
) -> str:
"""Generate file content."""
content = self._generate_file_header(config)
for params in combinations:
content += self._generate_template_instantiation(config, t_in, t_out, params)
def split_combinations(combinations, max_per_file):
chunks = []
for i in range(0, len(combinations), max_per_file):
chunk = combinations[i : i + max_per_file]
chunks.append(chunk)
return chunks
return content
def generate_for_function_type(self, function_name: str, output_dir: str):
"""Generate template instantiation files for specific function type."""
if function_name not in self.configs:
raise ValueError(f"Function type '{function_name}' not found in config")
def generate_file_content(t_in, t_out, t_out_name, file_index, combinations):
content = generate_file_header()
for params in combinations:
content += generate_template_instantiation(t_in, t_out, params)
config = self.configs[function_name]
output_path = Path(output_dir)
output_path.mkdir(exist_ok=True)
return content
if not config.data_types:
data_types = [("", "", "")]
else:
data_types = config.data_types
for t_in, t_out, t_out_name in data_types:
combinations = self.generate_combinations_for_type(config, t_in, t_out)
if combinations:
chunks = self.split_combinations(combinations, config.max_instances_per_file)
for i, chunk in enumerate(chunks):
filename = f"{config.file_prefix}{t_out_name}_part_{i:02d}.cu"
filepath = output_path / filename
content = self.generate_file_content(config, t_in, t_out, t_out_name, i, chunk)
with open(filepath, "w", encoding="utf-8") as f:
f.write(content)
def generate_all(self, output_dir: str):
"""Generate all configured function types."""
for function_name in self.configs.keys():
print(f"Generating template instantiations for {function_name}...")
self.generate_for_function_type(function_name, output_dir)
print(f"Completed generating {function_name} template instantiations.")
def main():
for t_in, t_out, t_out_name in DATA_TYPE_COMBINATIONS:
combinations = generate_combinations_for_type(t_in, t_out)
if combinations:
chunks = split_combinations(combinations, MAX_INSTANCES_PER_FILE)
for i, chunk in enumerate(chunks):
filename = f"multiquery_attention_c8_{t_out_name}_part_{i:02d}.cu"
filepath = TEMPLATE_DIR / filename
content = generate_file_content(t_in, t_out, t_out_name, i, chunk)
with open(filepath, "w", encoding="utf-8") as f:
f.write(content)
"""Main function."""
parser = argparse.ArgumentParser(description="Universal template instantiation generator")
parser.add_argument(
"--config",
"-c",
type=str,
default="gpu_ops/append_attn/template_config.json",
help="Configuration file path (JSON format)",
)
parser.add_argument(
"--output",
"-o",
type=str,
default="gpu_ops/append_attn/template_instantiation/autogen",
help="Output directory",
)
args = parser.parse_args()
try:
instantiator = UniversalTemplateInstantiator(args.config)
instantiator.generate_all(args.output)
except Exception as e:
print(f"Error: {e}")
if __name__ == "__main__":