Co-authored-by: gongweibao <gognweibao@baidu.com>
This commit is contained in:
gongweibao
2026-03-04 21:55:31 +08:00
committed by GitHub
parent 5c8f5184d9
commit ddb06ff83f
306 changed files with 40627 additions and 34418 deletions
@@ -17,46 +17,45 @@
#include "sample_kernels/sampling.cuh"
std::vector<paddle::Tensor> MinPSamplingFromProbs(const paddle::Tensor &probs,
const paddle::Tensor &min_p) {
std::vector<int64_t> probs_shape = probs.shape();
unsigned int batch_size = probs_shape[0];
unsigned int vocab_size = probs_shape[1];
auto cu_stream = probs.stream();
const paddle::Tensor &min_p) {
std::vector<int64_t> probs_shape = probs.shape();
unsigned int batch_size = probs_shape[0];
unsigned int vocab_size = probs_shape[1];
auto cu_stream = probs.stream();
auto renorm_probs =
auto renorm_probs =
GetEmptyTensor(probs.dims(), paddle::DataType::FLOAT32, probs.place());
cudaError_t status;
cudaError_t status;
status = sampling::MinPSamplingFromProb<float, int>(
const_cast<float *>(probs.data<float>()),
const_cast<float *>(min_p.data<float>()),
renorm_probs.data<float>(),
batch_size,
vocab_size,
true, // deterministic
cu_stream);
status = sampling::MinPSamplingFromProb<float, int>(
const_cast<float *>(probs.data<float>()),
const_cast<float *>(min_p.data<float>()),
renorm_probs.data<float>(),
batch_size,
vocab_size,
true, // deterministic
cu_stream);
PD_CHECK(status == cudaSuccess, "SamplingFromProbs failed with error code " +
std::string(cudaGetErrorString(status)));
PD_CHECK(status == cudaSuccess,
"SamplingFromProbs failed with error code " +
std::string(cudaGetErrorString(status)));
return {renorm_probs};
}
std::vector<std::vector<int64_t>>
MinPSamplingFromProbsInferShape(const std::vector<int64_t> &probs_shape,
const paddle::optional<std::vector<int64_t>> &min_p_shape) {
std::vector<std::vector<int64_t>> MinPSamplingFromProbsInferShape(
const std::vector<int64_t> &probs_shape,
const paddle::optional<std::vector<int64_t>> &min_p_shape) {
return {probs_shape};
}
std::vector<paddle::DataType>
MinPSamplingFromProbsInferDtype(const paddle::DataType &probs_dtype,
const paddle::optional<paddle::DataType> &min_p_dtype) {
std::vector<paddle::DataType> MinPSamplingFromProbsInferDtype(
const paddle::DataType &probs_dtype,
const paddle::optional<paddle::DataType> &min_p_dtype) {
return {probs_dtype};
}
PD_BUILD_STATIC_OP(min_p_sampling)
.Inputs({"probs", "min_p"})
.Outputs({"renorm_probs"})
@@ -16,10 +16,11 @@
#include "paddle/phi/backends/context_pool.h"
#include "sample_kernels/sampling.cuh"
std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor &probs,
const paddle::Tensor &top_p,
const paddle::optional<paddle::Tensor> &top_k,
int64_t seed) {
std::vector<paddle::Tensor> TopPSamplingReject(
const paddle::Tensor &probs,
const paddle::Tensor &top_p,
const paddle::optional<paddle::Tensor> &top_k,
int64_t seed) {
std::vector<int64_t> probs_shape = probs.shape();
unsigned int batch_size = probs_shape[0];
unsigned int vocab_size = probs_shape[1];
@@ -30,9 +31,11 @@ std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor &probs,
// need_batch_random
if (seed == -1) {
#ifdef PADDLE_WITH_COREX
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(probs.place()));
auto dev_ctx = static_cast<const phi::CustomContext *>(
paddle::experimental::DeviceContextPool::Instance().Get(probs.place()));
#else
phi::GPUContext* dev_ctx = static_cast<phi::GPUContext*>(phi::DeviceContextPool::Instance().Get(probs.place()));
phi::GPUContext *dev_ctx = static_cast<phi::GPUContext *>(
phi::DeviceContextPool::Instance().Get(probs.place()));
#endif
auto gen_cuda = dev_ctx->GetGenerator();
auto seed_offset = gen_cuda->IncrementOffset(32 * batch_size);
@@ -47,35 +50,48 @@ std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor &probs,
if (top_k) {
status = sampling::TopKTopPSamplingFromProb<float, int64_t>(
const_cast<float *>(probs.data<float>()), samples.data<int64_t>(),
batch_size, top_p.data<float>(), top_k.get().data<int64_t>(), vocab_size,
true, philox_seed, philox_offset, cu_stream);
}
else {
const_cast<float *>(probs.data<float>()),
samples.data<int64_t>(),
batch_size,
top_p.data<float>(),
top_k.get().data<int64_t>(),
vocab_size,
true,
philox_seed,
philox_offset,
cu_stream);
} else {
status = sampling::TopPSamplingFromProb<float, int64_t>(
const_cast<float *>(probs.data<float>()), samples.data<int64_t>(),
batch_size, top_p.data<float>(), vocab_size,
true, philox_seed, philox_offset, cu_stream);
const_cast<float *>(probs.data<float>()),
samples.data<int64_t>(),
batch_size,
top_p.data<float>(),
vocab_size,
true,
philox_seed,
philox_offset,
cu_stream);
}
PD_CHECK(status == cudaSuccess, "SamplingFromProbs failed with error code " +
std::string(cudaGetErrorString(status)));
PD_CHECK(status == cudaSuccess,
"SamplingFromProbs failed with error code " +
std::string(cudaGetErrorString(status)));
return {samples};
}
std::vector<std::vector<int64_t>>
TopPSamplingRejectInferShape(const std::vector<int64_t> &probs_shape,
const std::vector<int64_t> &top_p_shape,
const paddle::optional<std::vector<int64_t>> &top_k_shape) {
std::vector<std::vector<int64_t>> TopPSamplingRejectInferShape(
const std::vector<int64_t> &probs_shape,
const std::vector<int64_t> &top_p_shape,
const paddle::optional<std::vector<int64_t>> &top_k_shape) {
int64_t bs = probs_shape[0];
return {{bs, 1}};
}
std::vector<paddle::DataType>
TopPSamplingRejectInferDtype(const paddle::DataType &probs_dtype,
const paddle::DataType &top_p_dtype,
const paddle::optional<paddle::DataType> &top_k_dtype) {
std::vector<paddle::DataType> TopPSamplingRejectInferDtype(
const paddle::DataType &probs_dtype,
const paddle::DataType &top_p_dtype,
const paddle::optional<paddle::DataType> &top_k_dtype) {
return {paddle::DataType::INT64};
}
@@ -28,28 +28,29 @@ std::vector<paddle::Tensor> TopKRenorm(const paddle::Tensor &probs,
cudaError_t status;
status = sampling::TopKRenormProb<float>(
const_cast<float *>(probs.data<float>()),
renorm_probs.data<float>(),
const_cast<int64_t *>(top_k.data<int64_t>()),
batch_size, vocab_size, cu_stream);
const_cast<float *>(probs.data<float>()),
renorm_probs.data<float>(),
const_cast<int64_t *>(top_k.data<int64_t>()),
batch_size,
vocab_size,
cu_stream);
PD_CHECK(status == cudaSuccess, "TopKRenormProb failed with error code " +
std::string(cudaGetErrorString(status)));
PD_CHECK(status == cudaSuccess,
"TopKRenormProb failed with error code " +
std::string(cudaGetErrorString(status)));
return {renorm_probs};
}
std::vector<std::vector<int64_t>>
TopKRenormInferShape(const std::vector<int64_t> &probs_shape,
const std::vector<int64_t> &top_k_shape) {
std::vector<std::vector<int64_t>> TopKRenormInferShape(
const std::vector<int64_t> &probs_shape,
const std::vector<int64_t> &top_k_shape) {
return {probs_shape};
}
std::vector<paddle::DataType>
TopKRenormInferDtype(const paddle::DataType &probs_dtype,
const paddle::DataType &top_k_shape) {
std::vector<paddle::DataType> TopKRenormInferDtype(
const paddle::DataType &probs_dtype, const paddle::DataType &top_k_shape) {
return {probs_dtype};
}