[Feature] Optim PaddleOCR-VL (#4873)

* [Feature] Optim PaddleOCR-VL

* fix bug
This commit is contained in:
ming1753
2025-11-07 14:56:44 +08:00
committed by GitHub
parent bbe0820555
commit cba185f1fe
12 changed files with 535 additions and 112 deletions
+15
View File
@@ -1059,6 +1059,15 @@ std::vector<paddle::Tensor> UpdateAttnMaskOffsets(
const paddle::Tensor& decode_states, const paddle::Tensor& decode_states,
const paddle::Tensor& mask_rollback); const paddle::Tensor& mask_rollback);
std::vector<paddle::Tensor> FusedNeoxRopeEmbedding(
const paddle::Tensor& qkv,
const paddle::Tensor& cos_emb,
const paddle::Tensor& sin_emb,
const int num_heads,
const int head_dim);
std::vector<paddle::Tensor> GeluTanh(paddle::Tensor& input);
PYBIND11_MODULE(fastdeploy_ops, m) { PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("get_expert_token_num", m.def("get_expert_token_num",
&GetExpertTokenNum, &GetExpertTokenNum,
@@ -1648,4 +1657,10 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("update_attn_mask_offsets", m.def("update_attn_mask_offsets",
&UpdateAttnMaskOffsets, &UpdateAttnMaskOffsets,
"update attention mask"); "update attention mask");
m.def("fused_neox_rope_embedding",
&FusedNeoxRopeEmbedding,
"fused_neox_rope_embedding function");
m.def("gelu_tanh", &GeluTanh, "gelu_tanh function");
} }
@@ -0,0 +1,140 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "paddle/extension.h"
template <typename T, int VecSize = 1>
__global__ void FusedNeoxRopeEmbeddingKernel(const T *__restrict__ qkv,
const float *__restrict__ cos_emb,
const float *__restrict__ sin_emb,
T *__restrict__ q,
T *__restrict__ k,
T *__restrict__ v,
const int64_t elem_cnt,
const int num_head,
const int last_dim) {
using LoadT = AlignedVector<T, VecSize>;
using LoadEmbT = AlignedVector<float, VecSize>;
LoadT left_vec;
LoadT right_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
const int half_lastdim = last_dim / 2;
const int hidden_size = num_head * half_lastdim;
const int full_hidden_size = num_head * last_dim;
const int offset = 3 * hidden_size;
for (int64_t linear_index = global_thread_idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
const int token_idx = linear_index / offset;
const int bias = linear_index % offset;
const int qkv_id = bias / hidden_size;
const int qkv_bias = bias % hidden_size;
const int hi = qkv_bias / half_lastdim;
const int h_bias = qkv_bias % half_lastdim;
const int base_idx_left = token_idx * 3 * full_hidden_size +
qkv_id * full_hidden_size + hi * last_dim +
h_bias;
const int base_idx_right = base_idx_left + half_lastdim;
const int emb_idx = token_idx * last_dim + h_bias;
const int base_split_idx_left =
token_idx * full_hidden_size + hi * last_dim + h_bias;
const int base_split_idx_right = base_split_idx_left + half_lastdim;
// q,k,v output
T *out_p = nullptr;
if (qkv_id == 0) {
out_p = q;
} else if (qkv_id == 1) {
out_p = k;
} else {
out_p = v;
}
Load<T, VecSize>(&qkv[base_idx_left], &left_vec);
Load<T, VecSize>(&qkv[base_idx_right], &right_vec);
// do rope
if (qkv_id < 2) {
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
float input_left = static_cast<float>(left_vec[i]);
float input_right = static_cast<float>(right_vec[i]);
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
left_vec[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
right_vec[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
int cur_idx_1 = base_split_idx_left + i;
int cur_idx_2 = base_split_idx_right + i;
}
}
Store<T, VecSize>(left_vec, &out_p[base_split_idx_left]);
Store<T, VecSize>(right_vec, &out_p[base_split_idx_right]);
}
}
std::vector<paddle::Tensor> FusedNeoxRopeEmbedding(
const paddle::Tensor &qkv,
const paddle::Tensor &cos_emb,
const paddle::Tensor &sin_emb,
const int num_heads,
const int head_dim) {
typedef PDTraits<paddle::DataType::BFLOAT16> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
const auto &qkv_dims = qkv.dims();
const int token_num = qkv_dims.size() == 2 ? qkv_dims[0] : qkv_dims[1];
auto stream = qkv.stream();
paddle::Tensor q = GetEmptyTensor(
{token_num, num_heads, head_dim}, qkv.dtype(), qkv.place());
paddle::Tensor k = GetEmptyTensor(
{token_num, num_heads, head_dim}, qkv.dtype(), qkv.place());
paddle::Tensor v = GetEmptyTensor(
{token_num, num_heads, head_dim}, qkv.dtype(), qkv.place());
int64_t elem_nums = token_num * num_heads * head_dim * 3 / 2;
constexpr int PackSize = 4;
const int pack_num = elem_nums / PackSize;
const int blocksize = 128;
int grid_size = 1;
GetNumBlocks<128>(pack_num, &grid_size);
FusedNeoxRopeEmbeddingKernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<const DataType_ *>(qkv.data<data_t>()),
cos_emb.data<float>(),
sin_emb.data<float>(),
reinterpret_cast<DataType_ *>(q.data<data_t>()),
reinterpret_cast<DataType_ *>(k.data<data_t>()),
reinterpret_cast<DataType_ *>(v.data<data_t>()),
elem_nums,
num_heads,
head_dim);
return {q, k, v};
}
PD_BUILD_STATIC_OP(fused_neox_rope_embedding)
.Inputs({"qkv", "cos_emb", "sin_emb"})
.Outputs({"q", "k", "v"})
.Attrs({"num_heads: int", "head_dim: int"})
.SetKernelFn(PD_KERNEL(FusedNeoxRopeEmbedding));
+106
View File
@@ -0,0 +1,106 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "paddle/extension.h"
__forceinline__ __device__ float tanh_ptx(float x) {
float y;
asm volatile("tanh.approx.f32 %0, %1;" : "=f"(y) : "f"(x));
return y;
}
__device__ __forceinline__ float gelu_tanh_func(const float& val) {
const float cdf =
0.5f * (1.0f + tanh_ptx((0.7978845608028654f *
(val + 0.044715f * val * val * val))));
return val * cdf;
}
template <typename T>
__global__ void gelu_tanh_kernel(T* __restrict__ out,
const T* __restrict__ input,
const int d) {
constexpr uint32_t kVecSize = 16 / sizeof(T);
const int64_t token_idx = blockIdx.x;
const int64_t thread_idx = threadIdx.x;
const int64_t stride = blockDim.x;
const int64_t offset = token_idx * d;
using vec_t = AlignedVector<T, kVecSize>;
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && \
(__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
#pragma unroll 1
for (uint32_t idx = thread_idx; idx < d / kVecSize; idx += stride) {
vec_t x_vec;
Load(input + offset + idx * kVecSize, &x_vec);
#pragma unroll
for (uint32_t i = 0; i < kVecSize; ++i) {
x_vec[i] = static_cast<T>(gelu_tanh_func(static_cast<float>(x_vec[i])));
}
Store(x_vec, out + token_idx * d + idx * kVecSize);
}
const int64_t remaining_offset = d - d % (stride * kVecSize);
// process the remaining elements
#pragma unroll 1
for (int64_t idx = thread_idx; idx < d % (stride * kVecSize); idx += stride) {
float x = static_cast<float>(input[offset + remaining_offset + idx]);
out[token_idx * d + remaining_offset + idx] =
static_cast<T>(gelu_tanh_func(x));
}
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && \
(__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}
std::vector<paddle::Tensor> GeluTanh(paddle::Tensor& input) {
int d = input.dims()[1];
int64_t num_tokens = input.dims()[0];
cudaStream_t stream = input.stream();
paddle::Tensor output =
GetEmptyTensor(input.dims(), input.dtype(), input.place());
DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), scalar_t, {
uint32_t vec_size = 16 / sizeof(scalar_t);
cudaLaunchConfig_t config;
config.gridDim = num_tokens;
config.blockDim = std::min(d / vec_size, 1024U);
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = false;
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config,
gelu_tanh_kernel<scalar_t>,
output.data<scalar_t>(),
input.data<scalar_t>(),
d);
});
return {output};
}
PD_BUILD_STATIC_OP(gelu_tanh)
.Inputs({"input"})
.Outputs({"output"})
.SetKernelFn(PD_KERNEL(GeluTanh));
+2
View File
@@ -306,6 +306,8 @@ elif paddle.is_compiled_with_cuda():
"gpu_ops/limit_thinking_content_length_v1.cu", "gpu_ops/limit_thinking_content_length_v1.cu",
"gpu_ops/limit_thinking_content_length_v2.cu", "gpu_ops/limit_thinking_content_length_v2.cu",
"gpu_ops/update_attn_mask_offsets.cu", "gpu_ops/update_attn_mask_offsets.cu",
"gpu_ops/fused_neox_rope_embedding.cu",
"gpu_ops/gelu_tanh.cu",
] ]
# pd_disaggregation # pd_disaggregation
+17 -17
View File
@@ -5,8 +5,8 @@
## 1. Environment Preparation ## 1. Environment Preparation
### 1.1 Support Status ### 1.1 Support Status
Recommended Hardware Configuration: Recommended Hardware Configuration:
- GPU Memory: 12GB or more - GPU Memory: 8GB or more
- Shared Memory: 2GB or more - Shared Memory: 4GB or more
### 1.2 Install Fastdeploy ### 1.2 Install Fastdeploy
@@ -18,38 +18,38 @@ Installation process reference documentation [FastDeploy GPU Install](../get_sta
```shell ```shell
python -m fastdeploy.entrypoints.openai.api_server \ python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/PaddleOCR-VL \ --model PaddlePaddle/PaddleOCR-VL \
--port 8180 \ --port 8185 \
--metrics-port 8181 \ --metrics-port 8186 \
--engine-worker-queue-port 8182 \ --engine-worker-queue-port 8187 \
--max-model-len 16384 \ --max-model-len 16384 \
--max-num-batched-tokens 16384 \ --max-num-batched-tokens 16384 \
--gpu-memory-utilization 0.9 \ --gpu-memory-utilization 0.8 \
--max-num-seqs 128 --max-num-seqs 256
``` ```
**Example 2:** Deploying a 16K Context Service on a Single RTX 4090 GPU **Example 2:** Deploying a 16K Context Service on a Single RTX 4090 GPU
```shell ```shell
python -m fastdeploy.entrypoints.openai.api_server \ python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/PaddleOCR-VL \ --model PaddlePaddle/PaddleOCR-VL \
--port 8180 \ --port 8185 \
--metrics-port 8181 \ --metrics-port 8186 \
--engine-worker-queue-port 8182 \ --engine-worker-queue-port 8187 \
--max-model-len 16384 \ --max-model-len 16384 \
--max-num-batched-tokens 16384 \ --max-num-batched-tokens 16384 \
--gpu-memory-utilization 0.8 \ --gpu-memory-utilization 0.7 \
--max-num-seqs 196 --max-num-seqs 256
``` ```
**Example 3:** Deploying a 16K Context Service on a Single A100 GPU **Example 3:** Deploying a 16K Context Service on a Single A100 GPU
```shell ```shell
python -m fastdeploy.entrypoints.openai.api_server \ python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/PaddleOCR-VL \ --model PaddlePaddle/PaddleOCR-VL \
--port 8180 \ --port 8185 \
--metrics-port 8181 \ --metrics-port 8186 \
--engine-worker-queue-port 8182 \ --engine-worker-queue-port 8187 \
--max-model-len 16384 \ --max-model-len 16384 \
--max-num-batched-tokens 16384 \ --max-num-batched-tokens 16384 \
--gpu-memory-utilization 0.8 \ --gpu-memory-utilization 0.7 \
--max-num-seqs 256 --max-num-seqs 256
``` ```
@@ -71,7 +71,7 @@ An example is a set of configurations that can run stably while also delivering
> **Available GPU memory ratio during initialization** > **Available GPU memory ratio during initialization**
- **Parameters** `--gpu-memory-utilization` - **Parameters** `--gpu-memory-utilization`
- **Description** Controls the available GPU memory for FastDeploy service initialization. The default value is 0.9, meaning 10% of the memory is reserved for backup. - **Description** Controls the available GPU memory for FastDeploy service initialization. The default value is 0.9, meaning 10% of the memory is reserved for backup.
- **Recommendation** It is recommended to use 0.8. If an "out of memory" error occurs during stress testing, you may attempt to reduce this value. - **Recommendation** It is recommended to use 0.7. If an "out of memory" error occurs during stress testing, you may attempt to reduce this value.
#### 2.2.2 Chunked Prefill #### 2.2.2 Chunked Prefill
- **Parameters** `--max-num-batched-tokens` - **Parameters** `--max-num-batched-tokens`
+16 -16
View File
@@ -5,8 +5,8 @@
## 一、环境准备 ## 一、环境准备
### 1.1 支持情况 ### 1.1 支持情况
推荐硬件配置: 推荐硬件配置:
- 显存:12GB显存及以上 - 显存:8GB显存及以上
- 共享内存:2G及以上 - 共享内存:4G及以上
### 1.2 安装fastdeploy ### 1.2 安装fastdeploy
@@ -18,12 +18,12 @@
```shell ```shell
python -m fastdeploy.entrypoints.openai.api_server \ python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/PaddleOCR-VL \ --model PaddlePaddle/PaddleOCR-VL \
--port 8180 \ --port 8185 \
--metrics-port 8181 \ --metrics-port 8186 \
--engine-worker-queue-port 8182 \ --engine-worker-queue-port 8187 \
--max-model-len 16384 \ --max-model-len 16384 \
--max-num-batched-tokens 16384 \ --max-num-batched-tokens 16384 \
--gpu-memory-utilization 0.9 \ --gpu-memory-utilization 0.8 \
--max-num-seqs 128 --max-num-seqs 128
``` ```
@@ -31,25 +31,25 @@ python -m fastdeploy.entrypoints.openai.api_server \
```shell ```shell
python -m fastdeploy.entrypoints.openai.api_server \ python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/PaddleOCR-VL \ --model PaddlePaddle/PaddleOCR-VL \
--port 8180 \ --port 8185 \
--metrics-port 8181 \ --metrics-port 8186 \
--engine-worker-queue-port 8182 \ --engine-worker-queue-port 8187 \
--max-model-len 16384 \ --max-model-len 16384 \
--max-num-batched-tokens 16384 \ --max-num-batched-tokens 16384 \
--gpu-memory-utilization 0.8 \ --gpu-memory-utilization 0.7 \
--max-num-seqs 196 --max-num-seqs 256
``` ```
**示例3** A100上单卡部署16K上下文的服务 **示例3** A100上单卡部署16K上下文的服务
```shell ```shell
python -m fastdeploy.entrypoints.openai.api_server \ python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/PaddleOCR-VL \ --model PaddlePaddle/PaddleOCR-VL \
--port 8180 \ --port 8185 \
--metrics-port 8181 \ --metrics-port 8186 \
--engine-worker-queue-port 8182 \ --engine-worker-queue-port 8187 \
--max-model-len 16384 \ --max-model-len 16384 \
--max-num-batched-tokens 16384 \ --max-num-batched-tokens 16384 \
--gpu-memory-utilization 0.8 \ --gpu-memory-utilization 0.7 \
--max-num-seqs 256 --max-num-seqs 256
``` ```
@@ -72,7 +72,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
> **初始化时可用的显存比例** > **初始化时可用的显存比例**
- **参数:** `--gpu-memory-utilization` - **参数:** `--gpu-memory-utilization`
- **用处:** 用于控制 FastDeploy 初始化服务的可用显存,默认0.9,即预留10%的显存备用。 - **用处:** 用于控制 FastDeploy 初始化服务的可用显存,默认0.9,即预留10%的显存备用。
- **推荐:** 推荐使用0.8。如果服务压测时提示显存不足,可以尝试调低该值。 - **推荐:** 推荐使用0.7。如果服务压测时提示显存不足,可以尝试调低该值。
#### 2.2.2 Chunked Prefill #### 2.2.2 Chunked Prefill
- **参数:** `--max-num-batched-tokens` - **参数:** `--max-num-batched-tokens`
+2 -2
View File
@@ -197,7 +197,7 @@ curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \
-d '{ -d '{
"messages": [ "messages": [
{"role": "user", "content": [ {"role": "user", "content": [
{"type": "image_url", "image_url": {"url": "https://paddle-model-ecology.bj.bcebos.com/PPOCRVL/dataset/ocr_v5_eval/handwrite_ch_rec_val/中文手写古籍_000054_crop_32.jpg"}}, {"type": "image_url", "image_url": {"url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo.jpg"}},
{"type": "text", "text": "OCR:"} {"type": "text", "text": "OCR:"}
]} ]}
], ],
@@ -216,7 +216,7 @@ response = client.chat.completions.create(
model="default", model="default",
messages=[ messages=[
{"role": "user", "content": [ {"role": "user", "content": [
{"type": "image_url", "image_url": {"url": "https://paddle-model-ecology.bj.bcebos.com/PPOCRVL/dataset/ocr_v5_eval/handwrite_ch_rec_val/中文手写古籍_000054_crop_32.jpg"}}, {"type": "image_url", "image_url": {"url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo.jpg"}},
{"type": "text", "text": "OCR:"} {"type": "text", "text": "OCR:"}
] ]
}, },
@@ -22,7 +22,6 @@ import paddle
import paddle.nn as nn import paddle.nn as nn
from paddleformers.transformers import PretrainedModel from paddleformers.transformers import PretrainedModel
from fastdeploy import envs
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.graph_optimization.decorator import ( from fastdeploy.model_executor.graph_optimization.decorator import (
@@ -136,12 +135,8 @@ class PaddleOCRVLForConditionalGeneration(ModelForCasualLM):
) )
# Persistent buffers for CUDA graphs. # Persistent buffers for CUDA graphs.
if envs.FD_ENABLE_MAX_PREFILL: self._decoder_input_embeddings = paddle.zeros(
max_length = fd_config.scheduler_config.max_num_seqs * fd_config.model_config.max_model_len [fd_config.scheduler_config.max_num_seqs, fd_config.model_config.hidden_size],
else:
max_length = fd_config.model_config.max_model_len
self._input_embeddings = paddle.zeros(
[max_length, fd_config.model_config.hidden_size],
dtype=fd_config.model_config.dtype, dtype=fd_config.model_config.dtype,
) )
@@ -247,12 +242,19 @@ class PaddleOCRVLForConditionalGeneration(ModelForCasualLM):
input_embeddings = self.get_input_embeddings( input_embeddings = self.get_input_embeddings(
ids_remove_padding=ids_remove_padding, image_features=image_features ids_remove_padding=ids_remove_padding, image_features=image_features
) )
self._input_embeddings.copy_(input_embeddings, False)
hidden_states = self.model( if forward_meta.step_use_cudagraph:
input_embeddings=self._input_embeddings, self._decoder_input_embeddings.copy_(input_embeddings, False)
forward_meta=forward_meta,
) hidden_states = self.model(
input_embeddings=self._decoder_input_embeddings,
forward_meta=forward_meta,
)
else:
hidden_states = self.model(
input_embeddings=input_embeddings,
forward_meta=forward_meta,
)
return hidden_states return hidden_states
@@ -21,39 +21,13 @@ import numpy as np
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddleformers.transformers.activations import ACT2FN
from paddleformers.transformers.model_utils import PretrainedModel from paddleformers.transformers.model_utils import PretrainedModel
from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.utils import slice_fn from fastdeploy.model_executor.utils import slice_fn
from .config import PaddleOCRVisionConfig from .config import PaddleOCRVisionConfig
from .siglip_ops import get_activation_fn, neox_rope_embedding
def rotate_half(x):
Dh = x.shape[-1]
x1 = x[..., : Dh // 2]
x2 = x[..., Dh // 2 :]
return paddle.concat([-x2, x1], axis=-1)
def _ensure_cos_sin_dim(cos, sin, dim_needed):
last = cos.shape[-1]
if last == dim_needed:
return cos, sin
elif last * 2 == dim_needed:
cos = paddle.concat([cos, cos], axis=-1)
sin = paddle.concat([sin, sin], axis=-1)
return cos, sin
else:
raise ValueError(f"Unexpected cos/sin last-dim: {last}, expected {dim_needed} or {dim_needed//2}")
def apply_rotary_pos_emb_vision(x, cos, sin):
orig_dtype = x.dtype
x = x.astype("float32")
x_embed = (x * cos) + (rotate_half(x) * sin)
return x_embed.astype(orig_dtype)
class SiglipAttention(nn.Layer): class SiglipAttention(nn.Layer):
@@ -147,29 +121,12 @@ class SiglipAttention(nn.Layer):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
cu_seqlens: Optional[List[paddle.Tensor]] = None, cu_seqlens: Optional[List[paddle.Tensor]] = None,
max_seqlen: Optional[paddle.Tensor] = None, max_seqlen: Optional[paddle.Tensor] = None,
rope_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, # (cos, sin) cos_emb: Optional[paddle.Tensor] = None, # (cos, sin)
sin_emb: Optional[paddle.Tensor] = None, # (cos, sin)
): ):
B, seq_length, D = hidden_states.shape B, seq_length, D = hidden_states.shape
qkv = self.qkv_proj(hidden_states)
qkv = ( q, k, v = neox_rope_embedding(qkv, cos_emb, sin_emb, self.num_heads, self.head_dim)
self.qkv_proj(hidden_states)
.reshape(
[
seq_length,
3,
self.num_heads,
-1,
]
)
.transpose(perm=[1, 0, 2, 3])
)
q, k, v = qkv.unbind(axis=0)
cos, sin = rope_emb
# --------
q = apply_rotary_pos_emb_vision(q, cos, sin)
k = apply_rotary_pos_emb_vision(k, cos, sin)
attn_output = self.flash_attn_func( attn_output = self.flash_attn_func(
q, q,
k, k,
@@ -181,11 +138,9 @@ class SiglipAttention(nn.Layer):
causal=False, causal=False,
**self.flash_attn_kwargs, **self.flash_attn_kwargs,
)[0] )[0]
# --------
attn_output = attn_output.reshape((seq_length, -1)) attn_output = attn_output.reshape((seq_length, -1))
attn_output = self.out_proj(attn_output) attn_output = self.out_proj(attn_output)
return attn_output return attn_output
@@ -327,11 +282,7 @@ class SiglipMLP(nn.Layer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config self.config = config
if config.hidden_act == "gelu_pytorch_tanh": self.activation_fn = get_activation_fn(config.hidden_act)
config.hidden_act = "gelu_new"
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc1.weight.weight_loader = self.weight_loader self.fc1.weight.weight_loader = self.weight_loader
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
@@ -353,7 +304,7 @@ class SiglipMLP(nn.Layer):
def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
hidden_states = self.fc1(hidden_states) hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states) hidden_states = self.activation_fn(hidden_states[0])
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
return hidden_states return hidden_states
@@ -375,7 +326,8 @@ class SiglipEncoderLayer(paddle.nn.Layer):
output_attentions=False, output_attentions=False,
cu_seqlens=None, cu_seqlens=None,
max_seqlen=None, max_seqlen=None,
rope_emb=None, cos_emb=None,
sin_emb=None,
): ):
residual = hidden_states residual = hidden_states
@@ -388,7 +340,8 @@ class SiglipEncoderLayer(paddle.nn.Layer):
output_attentions=output_attentions, output_attentions=output_attentions,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
rope_emb=rope_emb, cos_emb=cos_emb,
sin_emb=sin_emb,
) )
hs_post_attn = residual + x hs_post_attn = residual + x
@@ -545,13 +498,13 @@ class SiglipEncoder(nn.Layer):
rope_emb = rope_emb_max_grid[pids].flatten(1) rope_emb = rope_emb_max_grid[pids].flatten(1)
rope_emb = rope_emb.tile((1, 2)) rope_emb = rope_emb.tile((1, 2))
cos = rope_emb.cos().astype("float32") cos_emb = rope_emb.cos().astype("float32")
sin = rope_emb.sin().astype("float32") sin_emb = rope_emb.sin().astype("float32")
cos = cos.unsqueeze(-2) cos_emb = cos_emb.unsqueeze(-2)
sin = sin.unsqueeze(-2) sin_emb = sin_emb.unsqueeze(-2)
rope_emb = (cos, sin)
else: else:
rope_emb = None cos_emb = None
sin_emb = None
window_indices, cu_seqlens_within_windows = None, None window_indices, cu_seqlens_within_windows = None, None
@@ -588,7 +541,8 @@ class SiglipEncoder(nn.Layer):
output_attentions=output_attentions, output_attentions=output_attentions,
cu_seqlens=attn_cu_seqlens, cu_seqlens=attn_cu_seqlens,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
rope_emb=rope_emb, cos_emb=cos_emb,
sin_emb=sin_emb,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
@@ -0,0 +1,74 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import List
import paddle
from paddleformers.transformers.activations import ACT2FN
from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import fused_neox_rope_embedding, gelu_tanh
def rotate_half(x):
Dh = x.shape[-1]
x1 = x[..., : Dh // 2]
x2 = x[..., Dh // 2 :]
return paddle.concat([-x2, x1], axis=-1)
def apply_rotary_pos_emb_vision(x, cos, sin):
orig_dtype = x.dtype
x = x.astype("float32")
x_embed = (x * cos) + (rotate_half(x) * sin)
return x_embed.astype(orig_dtype)
def native_neox_rope_embedding(qkv, cos, sin, num_heads):
B, seq_length, D = qkv.shape
qkv = qkv.reshape(
[
seq_length,
3,
num_heads,
-1,
]
).transpose(perm=[1, 0, 2, 3])
q, k, v = qkv.unbind(axis=0)
q = apply_rotary_pos_emb_vision(q, cos, sin)
k = apply_rotary_pos_emb_vision(k, cos, sin)
return q, k, v
def neox_rope_embedding(
qkv: paddle.Tensor, cos_emb: paddle.Tensor, sin_emb: paddle.Tensor, num_heads: int, head_dim: int
) -> List[paddle.Tensor]:
if current_platform.is_cuda():
return fused_neox_rope_embedding(qkv, cos_emb, sin_emb, num_heads, head_dim)
else:
return native_neox_rope_embedding(qkv, cos_emb, sin_emb, num_heads)
def get_activation_fn(hidden_act: str):
if hidden_act == "gelu_pytorch_tanh":
if current_platform.is_cuda():
return gelu_tanh
else:
return ACT2FN["gelu_new"]
else:
return ACT2FN[hidden_act]
@@ -0,0 +1,88 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from fastdeploy.model_executor.ops.gpu import fused_neox_rope_embedding
def rotate_half(x):
Dh = x.shape[-1]
x1 = x[..., : Dh // 2]
x2 = x[..., Dh // 2 :]
return paddle.concat([-x2, x1], axis=-1)
def apply_rotary_pos_emb_vision(x, cos, sin):
orig_dtype = x.dtype
x = x.astype("float32")
x_embed = (x * cos) + (rotate_half(x) * sin)
return x_embed.astype(orig_dtype)
class TestFusedNeoxRopeEmbedding(unittest.TestCase):
def setUp(self):
paddle.set_device("gpu")
np.random.seed(42)
def native_neox_rope_embedding(self, qkv, cos, sin, num_heads):
seq_length = qkv.shape[0]
qkv = qkv.reshape(
[
seq_length,
3,
num_heads,
-1,
]
).transpose(perm=[1, 0, 2, 3])
q, k, v = qkv.unbind(axis=0)
q = apply_rotary_pos_emb_vision(q, cos, sin)
k = apply_rotary_pos_emb_vision(k, cos, sin)
return q, k, v
def test_fused_neox_rope_embedding(self):
token_num = 1024
hidden_size = 2048
head_dim = 128
num_heads = hidden_size // head_dim
qkv = paddle.randn([token_num, 3 * hidden_size]).astype("bfloat16")
cos_emb = paddle.rand([token_num, head_dim // 2]).tile((1, 2)).unsqueeze(1)
sin_emb = paddle.rand([token_num, head_dim // 2]).tile((1, 2)).unsqueeze(1)
q, k, v = fused_neox_rope_embedding(qkv, cos_emb, sin_emb, num_heads, head_dim)
q_base, k_base, v_base = self.native_neox_rope_embedding(qkv, cos_emb, sin_emb, num_heads)
np.testing.assert_allclose(
q.cast("float32").numpy(),
q_base.cast("float32").numpy(),
rtol=1e-02,
atol=1e-02,
)
np.testing.assert_allclose(
k.cast("float32").numpy(),
k_base.cast("float32").numpy(),
rtol=1e-02,
atol=1e-02,
)
np.testing.assert_allclose(
v.cast("float32").numpy(),
v_base.cast("float32").numpy(),
rtol=1e-02,
atol=1e-02,
)
if __name__ == "__main__":
unittest.main()
+42
View File
@@ -0,0 +1,42 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddleformers.transformers.activations import ACT2FN
from fastdeploy.model_executor.ops.gpu import gelu_tanh
class TestGeluTanh(unittest.TestCase):
def setUp(self):
paddle.set_device("gpu")
np.random.seed(42)
def test_gelu_tanh(self):
x = paddle.randn(2048, 4096)
y0 = ACT2FN["gelu_new"](x)
y1 = gelu_tanh(x)
np.testing.assert_allclose(
y0.cast("float32").numpy(),
y1.cast("float32").numpy(),
rtol=1e-04,
atol=1e-04,
)
if __name__ == "__main__":
unittest.main()