Co-authored-by: gongweibao <gognweibao@baidu.com>
This commit is contained in:
gongweibao
2026-03-04 21:55:31 +08:00
committed by GitHub
parent 5c8f5184d9
commit ddb06ff83f
306 changed files with 40627 additions and 34418 deletions
@@ -17,26 +17,26 @@
template <typename T>
std::optional<T> ConvertToStdOptional(const paddle::optional<T>& paddle_opt) {
return paddle_opt ? std::optional<T>(paddle_opt.get()) : std::nullopt;
return paddle_opt ? std::optional<T>(paddle_opt.get()) : std::nullopt;
}
std::vector<std::string> supported_schedules(
paddle::DataType a_type, int64_t b_type_id,
paddle::DataType a_type,
int64_t b_type_id,
std::optional<paddle::DataType> maybe_group_scales_type,
std::optional<paddle::DataType> maybe_group_zeros_type,
std::optional<paddle::DataType> maybe_channel_scales_type,
std::optional<paddle::DataType> maybe_token_scales_type,
std::optional<paddle::DataType> maybe_out_type) {
machete::ScalarType const b_type = machete::ScalarType::from_id(b_type_id);
auto schedules = machete::supported_schedules_dispatch({
.a_type = a_type,
.b_type = b_type,
.maybe_group_scales_type = maybe_group_scales_type,
.maybe_group_zeros_type = maybe_group_zeros_type,
.maybe_channel_scales_type = maybe_channel_scales_type,
.maybe_token_scales_type = maybe_token_scales_type,
.maybe_out_type = maybe_out_type
});
auto schedules = machete::supported_schedules_dispatch(
{.a_type = a_type,
.b_type = b_type,
.maybe_group_scales_type = maybe_group_scales_type,
.maybe_group_zeros_type = maybe_group_zeros_type,
.maybe_channel_scales_type = maybe_channel_scales_type,
.maybe_token_scales_type = maybe_token_scales_type,
.maybe_out_type = maybe_out_type});
return schedules;
}
@@ -56,17 +56,20 @@ std::vector<std::string> MacheteSupportedSchedules(
} else {
PADDLE_ENFORCE(false, "a_type_str not supported!");
}
std::optional<paddle::DataType> maybe_group_scales_type = std::optional<paddle::DataType>(a_type);
std::optional<paddle::DataType> maybe_out_type = std::optional<paddle::DataType>(a_type);
std::optional<paddle::DataType> maybe_group_scales_type =
std::optional<paddle::DataType>(a_type);
std::optional<paddle::DataType> maybe_out_type =
std::optional<paddle::DataType>(a_type);
std::optional<paddle::DataType> maybe_group_zeros_type = std::nullopt;
std::optional<paddle::DataType> maybe_channel_scales_type = std::nullopt;
std::optional<paddle::DataType> maybe_token_scales_type = std::nullopt;
auto schedules = supported_schedules(a_type, b_type_id,
maybe_group_scales_type,
maybe_group_zeros_type,
maybe_channel_scales_type,
maybe_token_scales_type,
maybe_out_type);
auto schedules = supported_schedules(a_type,
b_type_id,
maybe_group_scales_type,
maybe_group_zeros_type,
maybe_channel_scales_type,
maybe_token_scales_type,
maybe_out_type);
return schedules;
}