Co-authored-by: gongweibao <gognweibao@baidu.com>
This commit is contained in:
gongweibao
2026-03-04 21:55:31 +08:00
committed by GitHub
parent 5c8f5184d9
commit ddb06ff83f
306 changed files with 40627 additions and 34418 deletions
@@ -1,4 +1,5 @@
// adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/torch_utils.hpp
// adapted from:
// https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/torch_utils.hpp
#pragma once
#include "helper.h"
@@ -16,7 +17,9 @@ namespace cute {
namespace detail {
template <class T, class F, class G, int... I>
CUTE_HOST_DEVICE constexpr auto tapply_with_idx(T&& t, F&& f, G&& g,
CUTE_HOST_DEVICE constexpr auto tapply_with_idx(T&& t,
F&& f,
G&& g,
seq<I...>) {
return g(f(cute::get<I>(static_cast<T&&>(t)), I)...);
}
@@ -32,7 +35,9 @@ template <class T, class F>
CUTE_HOST_DEVICE constexpr auto transform_with_idx(T const& t, F&& f) {
if constexpr (cute::is_tuple<T>::value) {
return detail::tapply_with_idx(
t, f, [](auto const&... a) { return cute::make_tuple(a...); },
t,
f,
[](auto const&... a) { return cute::make_tuple(a...); },
tuple_seq<T>{});
} else {
return f(t);
@@ -53,8 +58,8 @@ CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) {
// shape of the passed in tensor and the strides are of type `Stride` and
// contain the strides of the passed in tensor, checking that any static strides
// in `Stride{}` match the strides of the passed in tensor.
// If `tensor.shape().size() < rank(Stride{})`, the shape is padded with 1s and the extra
// strides are set to be 0 or 1.
// If `tensor.shape().size() < rank(Stride{})`, the shape is padded with 1s and
// the extra strides are set to be 0 or 1.
template <typename Stride>
static inline auto make_cute_layout(paddle::Tensor const& tensor,
std::string_view name = "tensor") {
@@ -65,8 +70,16 @@ static inline auto make_cute_layout(paddle::Tensor const& tensor,
if (idx < tensor.shape().size()) {
if constexpr (cute::is_static_v<StrideEle>) {
PD_CHECK(StrideEle::value == tensor.strides()[idx], "Expected ",
name, ".strides()[", idx, "] to be ", StrideEle::value, ", but got ", tensor.strides()[idx], ". ");
PD_CHECK(StrideEle::value == tensor.strides()[idx],
"Expected ",
name,
".strides()[",
idx,
"] to be ",
StrideEle::value,
", but got ",
tensor.strides()[idx],
". ");
return StrideEle{};
} else {
if (tensor.shape()[idx] == 1) {