[Metax] adapt to the latest develop (#6282)

This commit is contained in:
xiaozude
2026-01-30 15:21:20 +08:00
committed by GitHub
parent 18ebce9dec
commit 030647521a
14 changed files with 754 additions and 370 deletions
+19 -2
View File
@@ -12,19 +12,28 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <math.h>
#include "helper.h"
#include "paddle/extension.h"
#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
__forceinline__ __device__ float tanh_ptx(float x) {
float y;
asm volatile("tanh.approx.f32 %0, %1;" : "=f"(y) : "f"(x));
return y;
}
#endif
__device__ __forceinline__ float gelu_tanh_func(const float& val) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
const float cdf =
0.5f * (1.0f + tanhf((0.7978845608028654f *
(val + 0.044715f * val * val * val))));
#else
const float cdf =
0.5f * (1.0f + tanh_ptx((0.7978845608028654f *
(val + 0.044715f * val * val * val))));
#endif
return val * cdf;
}
@@ -79,9 +88,16 @@ std::vector<paddle::Tensor> GeluTanh(paddle::Tensor& input) {
DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), scalar_t, {
uint32_t vec_size = 16 / sizeof(scalar_t);
dim3 grid(num_tokens);
dim3 block(std::max(d / vec_size, 1024U));
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
gelu_tanh_kernel<scalar_t><<<grid, block, 0, stream>>>(
output.data<scalar_t>(), input.data<scalar_t>(), d);
#else
cudaLaunchConfig_t config;
config.gridDim = num_tokens;
config.blockDim = std::min(d / vec_size, 1024U);
config.gridDim = grid;
config.blockDim = block;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
@@ -95,6 +111,7 @@ std::vector<paddle::Tensor> GeluTanh(paddle::Tensor& input) {
output.data<scalar_t>(),
input.data<scalar_t>(),
d);
#endif
});
return {output};
@@ -124,7 +124,10 @@ void SpeculateLimitThinkingContentLengthV1(
const int tokens_per_step = next_tokens.shape()[1];
const int eos_token_id_len = eos_token_ids.shape()[0];
speculate_limit_thinking_content_length_kernel_v1<<<1, 1024>>>(
speculate_limit_thinking_content_length_kernel_v1<<<1,
1024,
0,
next_tokens.stream()>>>(
const_cast<int64_t*>(next_tokens.data<int64_t>()),
max_think_lens.data<int>(),
const_cast<int64_t*>(step_idx.data<int64_t>()),
@@ -132,7 +132,10 @@ void SpeculateLimitThinkingContentLengthV2(
const int batch_size = next_tokens.shape()[0];
const int tokens_per_step = next_tokens.shape()[1];
speculate_limit_thinking_content_length_kernel_v2<<<1, 1024>>>(
speculate_limit_thinking_content_length_kernel_v2<<<1,
1024,
0,
next_tokens.stream()>>>(
const_cast<int64_t*>(next_tokens.data<int64_t>()),
max_think_lens.data<int>(),
const_cast<int64_t*>(step_idx.data<int64_t>()),
+22
View File
@@ -1,3 +1,17 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "pybind11/pybind11.h"
namespace py = pybind11;
@@ -35,6 +49,10 @@ void cuda_host_free(uintptr_t ptr) {
check_cuda_error(cudaFreeHost(reinterpret_cast<void*>(ptr)));
}
paddle::Tensor GetStop(paddle::Tensor& not_need_stop);
void SetStop(paddle::Tensor& not_need_stop, bool flag);
PYBIND11_MODULE(fastdeploy_ops, m) {
/**
* alloc_cache_pinned.cc
@@ -49,4 +67,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def(
"cuda_host_free", &cuda_host_free, "Free pinned memory", py::arg("ptr"));
py::register_exception<CudaError>(m, "CudaError");
m.def("get_stop", &GetStop, "get_stop function");
m.def("set_stop", &SetStop, "set_stop function");
}
+14
View File
@@ -1,3 +1,17 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
+2
View File
@@ -643,6 +643,8 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
"gpu_ops/ipc_sent_key_value_cache_by_remote_ptr.cu",
"gpu_ops/unset_data_ipc.cu",
"gpu_ops/swap_cache_batch.cu",
"gpu_ops/gelu_tanh.cu",
"gpu_ops/set_stop.cu",
"metax_ops/moe_dispatch.cu",
"metax_ops/moe_ffn.cu",
"metax_ops/moe_reduce.cu",
@@ -51,10 +51,6 @@ class ErnieRotaryEmbedding:
# shape: [B, S, D]
rot_emb = paddle.concat([freqs.cos(), freqs.sin()], axis=-1)
return rot_emb
elif paddle.is_compiled_with_custom_device("metax_gpu"):
# shape: [B, S, D/2]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32")
emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim // 2))
else:
# shape: [B, S, D/2]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32")
@@ -182,7 +182,7 @@ def apply_speculative_penalty_multi_scores(
"""
apply_speculative_penalty_multi_scores
"""
if current_platform.is_cuda():
if current_platform.is_cuda() or current_platform.is_maca():
from fastdeploy.model_executor.ops.gpu import (
speculate_get_token_penalty_multi_scores,
)
@@ -606,7 +606,7 @@ class SpeculativeSampler(nn.Layer):
def __init__(self, fd_config: FDConfig):
""" """
super().__init__()
if current_platform.is_cuda():
if current_platform.is_cuda() or current_platform.is_maca():
self.forward = self.forward_cuda
elif current_platform.is_xpu():
self.forward = self.forward_xpu
@@ -972,7 +972,7 @@ class MTPSampler(nn.Layer):
def __init__(self, fd_config: FDConfig):
""" """
super().__init__()
if current_platform.is_cuda():
if current_platform.is_cuda() or current_platform.is_maca():
self.forward = self.forward_cuda
elif current_platform.is_xpu():
self.forward = self.forward_xpu
@@ -56,11 +56,23 @@ elif current_platform.is_maca():
limit_thinking_content_length_v1,
limit_thinking_content_length_v2,
save_output,
save_output_topk,
set_stop_value_multi_ends,
speculate_get_output_padding_offset,
speculate_get_padding_offset,
speculate_get_seq_lens_output,
speculate_limit_thinking_content_length_v1,
speculate_limit_thinking_content_length_v2,
speculate_save_output,
speculate_save_output_topk,
speculate_set_stop_value_multi_seqs,
speculate_set_value_by_flags_and_idx,
speculate_step_paddle,
speculate_step_reschedule,
speculate_step_system_cache,
speculate_update,
step_paddle,
step_reschedule,
step_system_cache,
update_inputs,
update_inputs_v1,
@@ -515,7 +527,7 @@ def post_process_specualate(
share_inputs["preempted_idx"],
model_output.mp_rank,
save_each_rank,
envs.ENABLE_V1_KVCACHE_SCHEDULER,
bool(envs.ENABLE_V1_KVCACHE_SCHEDULER),
)
else:
speculate_save_output_topk(
+3 -3
View File
@@ -108,7 +108,7 @@ class MTPProposer(Proposer):
if current_platform.is_xpu():
self._propose = self._propose_xpu
elif current_platform.is_cuda():
elif current_platform.is_cuda() or current_platform.is_maca():
self._propose = self._propose_cuda
else:
raise RuntimeError("Unsupported platform.")
@@ -350,7 +350,7 @@ class MTPProposer(Proposer):
self.model_inputs["decoder_tile_ids_per_batch"] = paddle.zeros_like(
self.target_model_inputs["decoder_tile_ids_per_batch"]
)
if current_platform.is_xpu():
if current_platform.is_xpu() or current_platform.is_maca():
self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like(
self.target_model_inputs["decoder_num_blocks_cpu"]
).cpu()
@@ -1308,7 +1308,7 @@ class MTPProposer(Proposer):
elif current_platform.is_xpu():
paddle.device.xpu.empty_cache()
else:
raise NotImplementedError
paddle.device.empty_cache()
def _get_cache_type(self):
cache_type = None
File diff suppressed because it is too large Load Diff
+29 -4
View File
@@ -17,7 +17,7 @@
import gc
import os
import time
from typing import List, Optional
from typing import Any, Dict, List, Optional
import paddle
from paddle import nn
@@ -25,6 +25,7 @@ from paddle import nn
from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request
from fastdeploy.usage.usage_lib import report_usage_stats
from fastdeploy.utils import get_logger, set_random_seed
from fastdeploy.worker.metax_model_runner import MetaxModelRunner
from fastdeploy.worker.output import ModelRunnerOutput
@@ -61,6 +62,9 @@ class MetaxWorker(WorkerBase):
gc.collect()
paddle.device.empty_cache()
if self.local_rank == 0:
report_usage_stats(self.fd_config)
set_random_seed(self.fd_config.model_config.seed)
# Construct model runner
self.model_runner: MetaxModelRunner = MetaxModelRunner(
@@ -91,7 +95,6 @@ class MetaxWorker(WorkerBase):
by adjusting the `gpu_memory_utilization` parameter.
"""
# temporary fix kvcache size to test
fd_kvache_mem = os.getenv("FD_METAX_KVCACHE_MEM")
if fd_kvache_mem is not None:
return int(float(fd_kvache_mem) * 1024**3)
@@ -188,6 +191,10 @@ class MetaxWorker(WorkerBase):
# accurate cache size
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
def update_weights(self, version: str = None, rsync_config: Dict[str, Any] = None):
"""update weights in place"""
return self.model_runner.update_weights(version, rsync_config)
def execute_model(
self,
model_forward_batch: Optional[List[Request]] = None,
@@ -199,6 +206,7 @@ class MetaxWorker(WorkerBase):
def preprocess_new_task(self, req_dicts: List[Request], num_running_requests: int) -> None:
"""Process new requests and then start the decode loop
TODO(gongshaotian):The scheduler should schedule the handling of prefill,
and workers and modelrunners should not perceive it.
"""
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
@@ -208,11 +216,28 @@ class MetaxWorker(WorkerBase):
def graph_optimize_and_warm_up_model(self) -> None:
"""
Perform the warm-up and the graph optimization
Perform the warm-up and the graph optimization.
Execution modes:
| Mode | Prefill + Mixed | Decode |
|-----------------------------------|--------------------------|--------------------------|
| Dynamic (graph_opt_level=0) | Dynamic | Dynamic + CUDAGraph |
| Static Full Graph (full=True) | Dynamic | Static + CUDAGraph |
| Static Split Graph (full=False) | Static + CUDAGraph | Dynamic + CUDAGraph |
"""
if self.fd_config.graph_opt_config.graph_opt_level >= 1 and not self.model_runner.use_cudagraph:
self.model_runner.sot_warmup()
# Trigger cuda graph capture
if self.fd_config.graph_opt_config.graph_opt_level >= 1:
self.model_runner.vision_encoder_compile()
# Static split graph mode: capture CUDAGraph for prefill/mixed phase
if (
self.fd_config.graph_opt_config.graph_opt_level >= 1
and not self.fd_config.graph_opt_config.full_cuda_graph
):
self.model_runner.capture_model_prefill_and_mixed()
# Capture CUDAGraph for decode phase (all modes)
self.model_runner.capture_model()
def check_health(self) -> bool:
+1 -2
View File
@@ -10,8 +10,7 @@ tqdm
pynvml
uvicorn>=0.38.0
fastapi
# if paddleformers version > 0.3.2, metax triton will be replaced by the newest triton.
paddleformers==0.3.2
paddleformers==0.4.1
redis
etcd3
httpx