mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 17:49:42 +08:00
@@ -1,4 +1,5 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/torch_utils.hpp
|
||||
// adapted from:
|
||||
// https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/torch_utils.hpp
|
||||
#pragma once
|
||||
|
||||
#include "helper.h"
|
||||
@@ -16,7 +17,9 @@ namespace cute {
|
||||
namespace detail {
|
||||
|
||||
template <class T, class F, class G, int... I>
|
||||
CUTE_HOST_DEVICE constexpr auto tapply_with_idx(T&& t, F&& f, G&& g,
|
||||
CUTE_HOST_DEVICE constexpr auto tapply_with_idx(T&& t,
|
||||
F&& f,
|
||||
G&& g,
|
||||
seq<I...>) {
|
||||
return g(f(cute::get<I>(static_cast<T&&>(t)), I)...);
|
||||
}
|
||||
@@ -32,7 +35,9 @@ template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr auto transform_with_idx(T const& t, F&& f) {
|
||||
if constexpr (cute::is_tuple<T>::value) {
|
||||
return detail::tapply_with_idx(
|
||||
t, f, [](auto const&... a) { return cute::make_tuple(a...); },
|
||||
t,
|
||||
f,
|
||||
[](auto const&... a) { return cute::make_tuple(a...); },
|
||||
tuple_seq<T>{});
|
||||
} else {
|
||||
return f(t);
|
||||
@@ -53,8 +58,8 @@ CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) {
|
||||
// shape of the passed in tensor and the strides are of type `Stride` and
|
||||
// contain the strides of the passed in tensor, checking that any static strides
|
||||
// in `Stride{}` match the strides of the passed in tensor.
|
||||
// If `tensor.shape().size() < rank(Stride{})`, the shape is padded with 1s and the extra
|
||||
// strides are set to be 0 or 1.
|
||||
// If `tensor.shape().size() < rank(Stride{})`, the shape is padded with 1s and
|
||||
// the extra strides are set to be 0 or 1.
|
||||
template <typename Stride>
|
||||
static inline auto make_cute_layout(paddle::Tensor const& tensor,
|
||||
std::string_view name = "tensor") {
|
||||
@@ -65,8 +70,16 @@ static inline auto make_cute_layout(paddle::Tensor const& tensor,
|
||||
|
||||
if (idx < tensor.shape().size()) {
|
||||
if constexpr (cute::is_static_v<StrideEle>) {
|
||||
PD_CHECK(StrideEle::value == tensor.strides()[idx], "Expected ",
|
||||
name, ".strides()[", idx, "] to be ", StrideEle::value, ", but got ", tensor.strides()[idx], ". ");
|
||||
PD_CHECK(StrideEle::value == tensor.strides()[idx],
|
||||
"Expected ",
|
||||
name,
|
||||
".strides()[",
|
||||
idx,
|
||||
"] to be ",
|
||||
StrideEle::value,
|
||||
", but got ",
|
||||
tensor.strides()[idx],
|
||||
". ");
|
||||
return StrideEle{};
|
||||
} else {
|
||||
if (tensor.shape()[idx] == 1) {
|
||||
|
||||
Reference in New Issue
Block a user