This commit is contained in:
co63oc
2025-09-08 15:22:41 +08:00
committed by GitHub
parent 2033450391
commit aadd6a94d8
7 changed files with 26 additions and 26 deletions
+2 -2
View File
@@ -14,7 +14,7 @@
#include "paddle/extension.h"
void set_value_by_flag_and_id(const bool *stop_flags,
void set_value_by_flags_and_idx(const bool *stop_flags,
int64_t *pre_ids_all,
const int64_t *input_ids,
const int *seq_lens_encoder,
@@ -50,7 +50,7 @@ void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
int length = pre_ids_all_shape[1];
int length_input_ids = input_ids.shape()[1];
set_value_by_flag_and_id(stop_flags.data<bool>(),
set_value_by_flags_and_idx(stop_flags.data<bool>(),
const_cast<int64_t *>(pre_ids_all.data<int64_t>()),
input_ids.data<int64_t>(),
seq_lens_encoder.data<int>(),