[XPU] fix bug and teporary fix for rope 3d (#7465)

This commit is contained in:
RuohengMa
2026-04-20 09:51:27 +08:00
committed by GitHub
parent b2aca6c550
commit cf5bc5e510
2 changed files with 5 additions and 3 deletions
@@ -239,7 +239,8 @@ void split_kvcache_encoder(api::Context* xpu_ctx,
bool use_neox_rotary_style) {
int ret;
int64_t real_kv_num_heads = (kv_num_heads == -1) ? q_num_heads : kv_num_heads;
if (FLAGS_encoder_splice) {
// TODO: spliced split kvcache should support rope3d
if (FLAGS_encoder_splice && !rope_3d) {
if (rope_3d) {
PD_THROW("split_kvcache_encoder does not support rope_3d == true!");
}
@@ -545,7 +546,8 @@ void split_kvcache_decoder(api::Context* xpu_ctx,
bool use_neox_rotary_style) {
int64_t real_kv_num_heads = (kv_num_heads == -1) ? q_num_heads : kv_num_heads;
int ret;
if (FLAGS_decoder_splice) {
// TODO: spliced split kvcache should support rope3d
if (FLAGS_decoder_splice && !rope_3d) {
// not yet supported
if (rope_3d) {
PD_THROW("split_kvcache_decoder does not support rope_3d == true!");
@@ -157,7 +157,7 @@ std::vector<paddle::Tensor> GetInferParam(
}
// for vsl_rotary_embedding_gptj of cudagraph mode
int prev_val = 0;
for (int i = 0; i < bsz; i++) {
for (int i = 0; i < bsz + 1; i++) {
if (decoder_seq_lod_vec[i] > prev_val) {
prev_val = decoder_seq_lod_vec[i];
} else if (decoder_seq_lod_vec[i] < prev_val) {