[XPU ]fix text_image_gather_scatter in cudagraph mode(#6049)

This commit is contained in:
RuohengMa
2026-01-23 19:48:43 +08:00
committed by GitHub
parent 20074d301f
commit 976203cf60
@@ -60,9 +60,11 @@ static __device__ inline void text_image_gather(
if (token_type == 0) {
text_image_input = text_input;
text_image_index = text_index;
} else {
} else if (token_type == 1) {
text_image_input = image_input;
text_image_index = image_index;
} else {
continue;
}
GM2LM(text_image_index + i, &text_image_token_idx, sizeof(int));
int input_offset = i * hidden_size;
@@ -132,9 +134,11 @@ static __device__ inline void text_image_scatter(
if (token_type == 0) {
text_image_input = text_input;
text_image_index = text_index;
} else {
} else if (token_type == 1) {
text_image_input = image_input;
text_image_index = image_index;
} else {
continue;
}
GM2LM(text_image_index + i, &text_image_token_idx, sizeof(int));
int input_offset = i * hidden_size;