mirror of
https://github.com/gowvp/gb28181.git
synced 2026-04-22 15:07:10 +08:00
support tflite
This commit is contained in:
+9
-1
@@ -43,8 +43,14 @@ RUN pip3 config set global.index-url https://mirrors.aliyun.com/pypi/simple/ \
|
|||||||
|
|
||||||
# 安装 Python 依赖
|
# 安装 Python 依赖
|
||||||
# --break-system-packages: Debian Trixie 使用 PEP 668 保护系统 Python
|
# --break-system-packages: Debian Trixie 使用 PEP 668 保护系统 Python
|
||||||
|
# TFLite 支持:优先尝试 tflite-runtime,失败则用 ai-edge-litert(Google 新包名)
|
||||||
COPY ./analysis/requirements.txt /tmp/requirements.txt
|
COPY ./analysis/requirements.txt /tmp/requirements.txt
|
||||||
RUN pip3 install --no-cache-dir --break-system-packages -r /tmp/requirements.txt \
|
RUN pip3 install --no-cache-dir --break-system-packages -r /tmp/requirements.txt \
|
||||||
|
&& (pip3 install --no-cache-dir --break-system-packages \
|
||||||
|
-i https://pypi.org/simple/ tflite-runtime 2>/dev/null \
|
||||||
|
|| pip3 install --no-cache-dir --break-system-packages \
|
||||||
|
-i https://pypi.org/simple/ ai-edge-litert 2>/dev/null \
|
||||||
|
|| echo "TFLite runtime not available, will use ONNX only") \
|
||||||
&& rm /tmp/requirements.txt \
|
&& rm /tmp/requirements.txt \
|
||||||
&& rm -rf /root/.cache/pip
|
&& rm -rf /root/.cache/pip
|
||||||
|
|
||||||
@@ -54,7 +60,9 @@ COPY --from=mwader/static-ffmpeg:6.1 /ffmpeg /usr/local/bin/ffmpeg
|
|||||||
COPY --from=mwader/static-ffmpeg:6.1 /ffprobe /usr/local/bin/ffprobe
|
COPY --from=mwader/static-ffmpeg:6.1 /ffprobe /usr/local/bin/ffprobe
|
||||||
|
|
||||||
# 复制应用文件
|
# 复制应用文件
|
||||||
ADD ./build/linux_amd64/bin ./gowvp
|
# 使用 TARGETARCH 自动识别目标架构(amd64 或 arm64)
|
||||||
|
ARG TARGETARCH
|
||||||
|
ADD ./build/linux_${TARGETARCH}/bin ./gowvp
|
||||||
ADD ./www ./www
|
ADD ./www ./www
|
||||||
ADD ./analysis ./analysis
|
ADD ./analysis ./analysis
|
||||||
|
|
||||||
|
|||||||
+2
-1
@@ -18,7 +18,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||||||
COPY --from=zlmediakit/zlmediakit:master /opt/media/bin /opt/media/bin
|
COPY --from=zlmediakit/zlmediakit:master /opt/media/bin /opt/media/bin
|
||||||
COPY --from=mwader/static-ffmpeg:6.1 /ffmpeg /usr/local/bin/ffmpeg
|
COPY --from=mwader/static-ffmpeg:6.1 /ffmpeg /usr/local/bin/ffmpeg
|
||||||
|
|
||||||
ADD ./build/linux_amd64/bin ./gowvp
|
ARG TARGETARCH
|
||||||
|
ADD ./build/linux_${TARGETARCH}/bin ./gowvp
|
||||||
ADD ./www ./www
|
ADD ./www ./www
|
||||||
|
|
||||||
RUN mkdir -p configs
|
RUN mkdir -p configs
|
||||||
|
|||||||
+5
-1
@@ -135,11 +135,13 @@ postgres 和 mysql 的格式即:
|
|||||||
|
|
||||||
> 如何关闭 AI?
|
> 如何关闭 AI?
|
||||||
|
|
||||||
ai 默认是开启状态,1 秒检测 5 帧
|
ai 默认是开启状态,1 秒检测 1 帧
|
||||||
|
|
||||||
可以在 `configs/config.toml` 中修改 `disabledAI = true` 关闭 ai 检测
|
可以在 `configs/config.toml` 中修改 `disabledAI = true` 关闭 ai 检测
|
||||||
|
|
||||||
|
> 国标设备在线,通道离线?
|
||||||
|
|
||||||
|
属于 ipc 的问题,请检查 ipc 后台注册的 平台 sip_id 和 域是否与 gowvp/owl 一致。
|
||||||
|
|
||||||
|
|
||||||
## 文档
|
## 文档
|
||||||
@@ -302,6 +304,8 @@ services:
|
|||||||
[@joestarzxh](https://github.com/joestarzxh)
|
[@joestarzxh](https://github.com/joestarzxh)
|
||||||
[@oldweipro](https://github.com/oldweipro)
|
[@oldweipro](https://github.com/oldweipro)
|
||||||
[@beixiaocai](https://github.com/beixiaocai)
|
[@beixiaocai](https://github.com/beixiaocai)
|
||||||
|
[@chencanfggz](https://github.com/chencanfggz)
|
||||||
|
[@zhangxuan1340](https://github.com/zhangxuan1340)
|
||||||
|
|
||||||
|
|
||||||
## 许可证
|
## 许可证
|
||||||
|
|||||||
@@ -0,0 +1,19 @@
|
|||||||
|
|
||||||
|
## ai/init: 安装 AI 分析模块依赖(本地开发用)
|
||||||
|
ai/init:
|
||||||
|
@echo "安装基础依赖..."
|
||||||
|
@pip install -r analysis/requirements.txt
|
||||||
|
@echo ""
|
||||||
|
@echo "TFLite 支持需要额外安装(二选一):"
|
||||||
|
@echo " Linux: make ai/init/tflite"
|
||||||
|
@echo " macOS/Windows: make ai/init/tensorflow"
|
||||||
|
@echo ""
|
||||||
|
@echo "如果使用 .tflite 模型,请根据系统选择安装"
|
||||||
|
|
||||||
|
## ai/init/tflite: 安装 TFLite 支持(Linux 轻量版)
|
||||||
|
ai/init/tflite:
|
||||||
|
@conda install tflite-runtime
|
||||||
|
|
||||||
|
## ai/init/tensorflow: 安装 TensorFlow(macOS/Windows 完整版)
|
||||||
|
ai/init/tensorflow:
|
||||||
|
@ pip install tensorflow -i https://pypi.org/simple
|
||||||
+361
-34
@@ -1,11 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
import onnxruntime as ort
|
|
||||||
|
|
||||||
|
|
||||||
slog = logging.getLogger("Detector")
|
slog = logging.getLogger("Detector")
|
||||||
|
|
||||||
@@ -94,55 +94,77 @@ COCO_LABELS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class ObjectDetector:
|
class ModelBackend(ABC):
|
||||||
|
"""
|
||||||
|
模型推理后端抽象接口
|
||||||
|
不同模型格式(ONNX、TFLite)需实现此接口,确保上层调用逻辑统一
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, model_path: str = "yolo11n.onnx"):
|
@abstractmethod
|
||||||
self.model_path = model_path
|
def load(self, model_path: str) -> bool:
|
||||||
self.session: ort.InferenceSession | None = None
|
"""加载模型文件,返回是否成功"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def is_ready(self) -> bool:
|
||||||
|
"""检查模型是否已加载并可用"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_input_shape(self) -> tuple:
|
||||||
|
"""获取模型输入形状,用于预处理"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def infer(self, input_tensor: np.ndarray) -> np.ndarray:
|
||||||
|
"""执行推理,返回原始输出"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ONNXBackend(ModelBackend):
|
||||||
|
"""
|
||||||
|
ONNX Runtime 推理后端
|
||||||
|
使用 onnxruntime 库加载和执行 ONNX 格式模型
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.session = None
|
||||||
self.input_name: str = ""
|
self.input_name: str = ""
|
||||||
self.input_shape: tuple = (1, 3, 640, 640)
|
self.input_shape: tuple = (1, 3, 640, 640)
|
||||||
self._is_ready = False
|
self._is_ready = False
|
||||||
self.names: dict[int, str] = {i: name for i, name in enumerate(COCO_LABELS)}
|
|
||||||
|
|
||||||
def load_model(self) -> bool:
|
def load(self, model_path: str) -> bool:
|
||||||
"""加载 ONNX 模型并初始化推理会话"""
|
|
||||||
try:
|
try:
|
||||||
slog.info(f"加载 ONNX 模型: {self.model_path} ...")
|
import onnxruntime as ort
|
||||||
|
|
||||||
|
slog.info(f"加载 ONNX 模型: {model_path} ...")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# 配置 ONNX Runtime 会话选项
|
|
||||||
sess_options = ort.SessionOptions()
|
sess_options = ort.SessionOptions()
|
||||||
sess_options.graph_optimization_level = (
|
sess_options.graph_optimization_level = (
|
||||||
ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||||
)
|
)
|
||||||
# 限制线程数,避免在容器中占用过多 CPU
|
|
||||||
sess_options.intra_op_num_threads = 4
|
sess_options.intra_op_num_threads = 4
|
||||||
sess_options.inter_op_num_threads = 2
|
sess_options.inter_op_num_threads = 2
|
||||||
|
|
||||||
# 优先使用 CPU 执行提供程序
|
|
||||||
providers = ["CPUExecutionProvider"]
|
providers = ["CPUExecutionProvider"]
|
||||||
|
|
||||||
self.session = ort.InferenceSession(
|
self.session = ort.InferenceSession(
|
||||||
self.model_path, sess_options=sess_options, providers=providers
|
model_path, sess_options=sess_options, providers=providers
|
||||||
)
|
)
|
||||||
|
|
||||||
# 获取输入信息
|
|
||||||
input_info = self.session.get_inputs()[0]
|
input_info = self.session.get_inputs()[0]
|
||||||
self.input_name = input_info.name
|
self.input_name = input_info.name
|
||||||
self.input_shape = tuple(input_info.shape)
|
self.input_shape = tuple(input_info.shape)
|
||||||
|
|
||||||
# 预热模型
|
|
||||||
dummy_img = np.zeros((640, 640, 3), dtype=np.uint8)
|
|
||||||
self._preprocess(dummy_img)
|
|
||||||
dummy_input = self._preprocess(dummy_img)
|
|
||||||
self.session.run(None, {self.input_name: dummy_input})
|
|
||||||
|
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
slog.info(
|
slog.info(
|
||||||
f"ONNX 模型加载完成 (耗时: {elapsed:.2f}s, 输入形状: {self.input_shape})"
|
f"ONNX 模型加载完成 (耗时: {elapsed:.2f}s, 输入形状: {self.input_shape})"
|
||||||
)
|
)
|
||||||
self._is_ready = True
|
self._is_ready = True
|
||||||
return True
|
return True
|
||||||
|
except ImportError:
|
||||||
|
slog.error("未安装 onnxruntime,无法加载 ONNX 模型")
|
||||||
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
slog.error(f"加载 ONNX 模型失败: {e}")
|
slog.error(f"加载 ONNX 模型失败: {e}")
|
||||||
return False
|
return False
|
||||||
@@ -150,29 +172,335 @@ class ObjectDetector:
|
|||||||
def is_ready(self) -> bool:
|
def is_ready(self) -> bool:
|
||||||
return self._is_ready and self.session is not None
|
return self._is_ready and self.session is not None
|
||||||
|
|
||||||
|
def get_input_shape(self) -> tuple:
|
||||||
|
return self.input_shape
|
||||||
|
|
||||||
|
def infer(self, input_tensor: np.ndarray) -> np.ndarray:
|
||||||
|
if not self.session:
|
||||||
|
raise RuntimeError("ONNX 模型未加载")
|
||||||
|
outputs = self.session.run(None, {self.input_name: input_tensor})
|
||||||
|
return np.asarray(outputs[0])
|
||||||
|
|
||||||
|
|
||||||
|
class TFLiteBackend(ModelBackend):
|
||||||
|
"""
|
||||||
|
TensorFlow Lite 推理后端
|
||||||
|
支持两种模型格式:
|
||||||
|
1. YOLO 格式:单输出张量 (1, 84, 8400)
|
||||||
|
2. SSD 格式:多输出张量(boxes, classes, scores, num_detections)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.interpreter: Any = None
|
||||||
|
self.input_details: list[dict[str, Any]] = []
|
||||||
|
self.output_details: list[dict[str, Any]] = []
|
||||||
|
self.input_shape: tuple = (1, 640, 640, 3)
|
||||||
|
self._is_ready = False
|
||||||
|
self._is_nhwc = True
|
||||||
|
self._is_ssd_format = False # 区分 SSD 和 YOLO 格式
|
||||||
|
self._input_quantization: tuple[float, int] = (1.0, 0) # scale, zero_point
|
||||||
|
|
||||||
|
def load(self, model_path: str) -> bool:
|
||||||
|
try:
|
||||||
|
Interpreter = None
|
||||||
|
try:
|
||||||
|
from tflite_runtime.interpreter import Interpreter # type: ignore
|
||||||
|
except ImportError:
|
||||||
|
try:
|
||||||
|
from ai_edge_litert.interpreter import Interpreter # type: ignore
|
||||||
|
except ImportError:
|
||||||
|
try:
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
Interpreter = tf.lite.Interpreter
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if Interpreter is None:
|
||||||
|
raise ImportError("未找到 tflite_runtime、ai_edge_litert 或 tensorflow")
|
||||||
|
|
||||||
|
slog.info(f"加载 TFLite 模型: {model_path} ...")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
self.interpreter = Interpreter(model_path=model_path)
|
||||||
|
self.interpreter.allocate_tensors()
|
||||||
|
|
||||||
|
self.input_details = self.interpreter.get_input_details()
|
||||||
|
self.output_details = self.interpreter.get_output_details()
|
||||||
|
|
||||||
|
input_shape = self.input_details[0]["shape"]
|
||||||
|
self.input_shape = tuple(input_shape)
|
||||||
|
|
||||||
|
if len(self.input_shape) == 4:
|
||||||
|
self._is_nhwc = self.input_shape[3] == 3
|
||||||
|
|
||||||
|
# 获取输入量化参数(用于 uint8 量化模型)
|
||||||
|
quant_params = self.input_details[0].get("quantization_parameters", {})
|
||||||
|
scales = quant_params.get("scales", np.array([1.0]))
|
||||||
|
zero_points = quant_params.get("zero_points", np.array([0]))
|
||||||
|
if len(scales) > 0 and len(zero_points) > 0:
|
||||||
|
self._input_quantization = (float(scales[0]), int(zero_points[0]))
|
||||||
|
|
||||||
|
# 检测模型格式:SSD 模型通常有4个输出(boxes, classes, scores, num)
|
||||||
|
# 且输出名称包含 "TFLite_Detection_PostProcess"
|
||||||
|
self._is_ssd_format = len(self.output_details) >= 3 and any(
|
||||||
|
"Detection" in d.get("name", "") for d in self.output_details
|
||||||
|
)
|
||||||
|
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
format_name = "SSD" if self._is_ssd_format else "YOLO"
|
||||||
|
slog.info(
|
||||||
|
f"TFLite 模型加载完成 (耗时: {elapsed:.2f}s, 输入: {self.input_shape}, "
|
||||||
|
f"格式: {format_name}, 量化: {self._input_quantization})"
|
||||||
|
)
|
||||||
|
self._is_ready = True
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
slog.error("未安装 tflite_runtime 或 tensorflow,无法加载 TFLite 模型")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
slog.error(f"加载 TFLite 模型失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def is_ready(self) -> bool:
|
||||||
|
return self._is_ready and self.interpreter is not None
|
||||||
|
|
||||||
|
def get_input_shape(self) -> tuple:
|
||||||
|
return self.input_shape
|
||||||
|
|
||||||
|
def is_nhwc(self) -> bool:
|
||||||
|
"""返回模型是否使用 NHWC 格式"""
|
||||||
|
return self._is_nhwc
|
||||||
|
|
||||||
|
def is_ssd_format(self) -> bool:
|
||||||
|
"""返回是否为 SSD 格式(多输出张量)"""
|
||||||
|
return self._is_ssd_format
|
||||||
|
|
||||||
|
def get_input_quantization(self) -> tuple[float, int]:
|
||||||
|
"""返回输入量化参数 (scale, zero_point)"""
|
||||||
|
return self._input_quantization
|
||||||
|
|
||||||
|
def get_input_dtype(self) -> np.dtype:
|
||||||
|
"""返回模型期望的输入数据类型"""
|
||||||
|
return self.input_details[0]["dtype"]
|
||||||
|
|
||||||
|
def infer(self, input_tensor: np.ndarray) -> np.ndarray:
|
||||||
|
"""执行推理,返回第一个输出张量(用于 YOLO 格式)"""
|
||||||
|
if not self.interpreter or len(self.input_details) == 0:
|
||||||
|
raise RuntimeError("TFLite 模型未加载")
|
||||||
|
|
||||||
|
input_dtype = self.input_details[0]["dtype"]
|
||||||
|
if input_tensor.dtype != input_dtype:
|
||||||
|
input_tensor = input_tensor.astype(input_dtype)
|
||||||
|
|
||||||
|
self.interpreter.set_tensor(self.input_details[0]["index"], input_tensor)
|
||||||
|
self.interpreter.invoke()
|
||||||
|
|
||||||
|
output = self.interpreter.get_tensor(self.output_details[0]["index"])
|
||||||
|
return np.asarray(output)
|
||||||
|
|
||||||
|
def infer_ssd(
|
||||||
|
self, input_tensor: np.ndarray
|
||||||
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray, int]:
|
||||||
|
"""
|
||||||
|
执行 SSD 格式推理,返回解析后的检测结果
|
||||||
|
SSD 输出格式(已内置后处理):
|
||||||
|
- boxes: (1, num_boxes, 4) 归一化坐标 [y_min, x_min, y_max, x_max]
|
||||||
|
- classes: (1, num_boxes) 类别 ID
|
||||||
|
- scores: (1, num_boxes) 置信度分数
|
||||||
|
- num_detections: 有效检测数量
|
||||||
|
"""
|
||||||
|
if not self.interpreter or len(self.input_details) == 0:
|
||||||
|
raise RuntimeError("TFLite 模型未加载")
|
||||||
|
|
||||||
|
input_dtype = self.input_details[0]["dtype"]
|
||||||
|
if input_tensor.dtype != input_dtype:
|
||||||
|
input_tensor = input_tensor.astype(input_dtype)
|
||||||
|
|
||||||
|
self.interpreter.set_tensor(self.input_details[0]["index"], input_tensor)
|
||||||
|
self.interpreter.invoke()
|
||||||
|
|
||||||
|
# 按名称或索引获取各输出张量
|
||||||
|
boxes = None
|
||||||
|
classes = None
|
||||||
|
scores = None
|
||||||
|
num_detections = 0
|
||||||
|
|
||||||
|
for detail in self.output_details:
|
||||||
|
name = detail.get("name", "")
|
||||||
|
tensor = self.interpreter.get_tensor(detail["index"])
|
||||||
|
|
||||||
|
if "boxes" in name.lower() or (
|
||||||
|
detail["shape"][-1] == 4 and len(detail["shape"]) == 3
|
||||||
|
):
|
||||||
|
boxes = np.asarray(tensor)
|
||||||
|
elif "class" in name.lower() or (
|
||||||
|
len(detail["shape"]) == 2
|
||||||
|
and detail["shape"][1] > 1
|
||||||
|
and boxes is not None
|
||||||
|
):
|
||||||
|
classes = np.asarray(tensor)
|
||||||
|
elif "score" in name.lower() or ":2" in name:
|
||||||
|
scores = np.asarray(tensor)
|
||||||
|
elif "num" in name.lower() or (
|
||||||
|
len(detail["shape"]) == 1 and detail["shape"][0] == 1
|
||||||
|
):
|
||||||
|
num_detections = int(tensor[0])
|
||||||
|
|
||||||
|
# 兜底处理:按输出顺序分配
|
||||||
|
if boxes is None or classes is None or scores is None:
|
||||||
|
outputs = [
|
||||||
|
self.interpreter.get_tensor(d["index"]) for d in self.output_details
|
||||||
|
]
|
||||||
|
if len(outputs) >= 4:
|
||||||
|
boxes = np.asarray(outputs[0])
|
||||||
|
classes = np.asarray(outputs[1])
|
||||||
|
scores = np.asarray(outputs[2])
|
||||||
|
num_detections = int(outputs[3][0]) if outputs[3].size > 0 else 0
|
||||||
|
|
||||||
|
if boxes is None:
|
||||||
|
boxes = np.array([])
|
||||||
|
if classes is None:
|
||||||
|
classes = np.array([])
|
||||||
|
if scores is None:
|
||||||
|
scores = np.array([])
|
||||||
|
|
||||||
|
return boxes, classes, scores, num_detections
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_type(model_path: str) -> str:
|
||||||
|
"""根据模型文件后缀判断模型类型"""
|
||||||
|
ext = os.path.splitext(model_path)[1].lower()
|
||||||
|
return "tflite" if ext == ".tflite" else "onnx"
|
||||||
|
|
||||||
|
|
||||||
|
def create_backend(model_type: str) -> ModelBackend:
|
||||||
|
"""
|
||||||
|
根据模型类型创建对应的推理后端
|
||||||
|
"""
|
||||||
|
if model_type == "tflite":
|
||||||
|
return TFLiteBackend()
|
||||||
|
else:
|
||||||
|
return ONNXBackend()
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectDetector:
|
||||||
|
"""
|
||||||
|
目标检测器 - 支持多种模型格式(ONNX、TFLite)
|
||||||
|
通过统一的 ModelBackend 接口实现模型无关的检测逻辑
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_path: str):
|
||||||
|
self.model_path = model_path
|
||||||
|
self.model_type = get_model_type(model_path)
|
||||||
|
self.backend: ModelBackend | None = None
|
||||||
|
self.input_shape: tuple = (1, 3, 640, 640)
|
||||||
|
self._is_ready = False
|
||||||
|
self.names: dict[int, str] = {i: name for i, name in enumerate(COCO_LABELS)}
|
||||||
|
|
||||||
|
def load_model(self) -> bool:
|
||||||
|
"""加载模型并初始化推理后端"""
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 创建对应类型的后端
|
||||||
|
self.backend = create_backend(self.model_type)
|
||||||
|
|
||||||
|
# 加载模型
|
||||||
|
if not self.backend.load(self.model_path):
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.input_shape = self.backend.get_input_shape()
|
||||||
|
|
||||||
|
# 预热模型
|
||||||
|
self._warmup()
|
||||||
|
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
slog.info(f"模型预热完成 (总耗时: {elapsed:.2f}s)")
|
||||||
|
self._is_ready = True
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
slog.error(f"加载模型失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _warmup(self) -> None:
|
||||||
|
"""预热模型,减少首次推理延迟"""
|
||||||
|
if not self.backend:
|
||||||
|
return
|
||||||
|
|
||||||
|
dummy_img = np.zeros((640, 640, 3), dtype=np.uint8)
|
||||||
|
dummy_input = self._preprocess(dummy_img)
|
||||||
|
self.backend.infer(dummy_input)
|
||||||
|
slog.info("模型预热完成")
|
||||||
|
|
||||||
|
def is_ready(self) -> bool:
|
||||||
|
return self._is_ready and self.backend is not None and self.backend.is_ready()
|
||||||
|
|
||||||
|
def _get_target_size(self) -> int:
|
||||||
|
"""获取模型期望的输入尺寸"""
|
||||||
|
# NCHW: (1, 3, H, W) -> shape[2]
|
||||||
|
# NHWC: (1, H, W, 3) -> shape[1]
|
||||||
|
if self._is_nhwc_format():
|
||||||
|
return int(self.input_shape[1])
|
||||||
|
return int(self.input_shape[2])
|
||||||
|
|
||||||
|
def _is_nhwc_format(self) -> bool:
|
||||||
|
"""判断当前后端是否使用 NHWC 格式"""
|
||||||
|
if isinstance(self.backend, TFLiteBackend):
|
||||||
|
return self.backend.is_nhwc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _is_ssd_format(self) -> bool:
|
||||||
|
"""判断当前后端是否为 SSD 格式"""
|
||||||
|
if isinstance(self.backend, TFLiteBackend):
|
||||||
|
return self.backend.is_ssd_format()
|
||||||
|
return False
|
||||||
|
|
||||||
def _preprocess(self, image: np.ndarray) -> np.ndarray:
|
def _preprocess(self, image: np.ndarray) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
预处理图像:调整大小、归一化、转换格式
|
预处理图像:调整大小、归一化、转换格式
|
||||||
YOLO 期望输入格式: NCHW, float32, 归一化到 [0, 1]
|
根据后端类型自动选择 NCHW 或 NHWC 格式,并处理量化输入
|
||||||
"""
|
"""
|
||||||
# 保持宽高比的 letterbox 缩放
|
target_size = self._get_target_size()
|
||||||
target_size = self.input_shape[2] # 640
|
|
||||||
h, w = image.shape[:2]
|
h, w = image.shape[:2]
|
||||||
|
|
||||||
|
# SSD 模型使用直接缩放(不保持宽高比)
|
||||||
|
# YOLO 模型使用 letterbox 缩放(保持宽高比)
|
||||||
|
if self._is_ssd_format():
|
||||||
|
resized = cv2.resize(
|
||||||
|
image, (target_size, target_size), interpolation=cv2.INTER_LINEAR
|
||||||
|
)
|
||||||
|
rgb = resized[:, :, ::-1] # BGR -> RGB
|
||||||
|
|
||||||
|
# 检查是否需要量化为 uint8
|
||||||
|
if isinstance(self.backend, TFLiteBackend):
|
||||||
|
input_dtype = self.backend.get_input_dtype()
|
||||||
|
if input_dtype == np.uint8:
|
||||||
|
# 直接返回 uint8 格式
|
||||||
|
return np.expand_dims(rgb.astype(np.uint8), axis=0)
|
||||||
|
|
||||||
|
# float32 格式
|
||||||
|
rgb = rgb.astype(np.float32) / 255.0
|
||||||
|
return np.expand_dims(rgb, axis=0)
|
||||||
|
|
||||||
|
# YOLO letterbox 预处理
|
||||||
scale = min(target_size / h, target_size / w)
|
scale = min(target_size / h, target_size / w)
|
||||||
new_h, new_w = int(h * scale), int(w * scale)
|
new_h, new_w = int(h * scale), int(w * scale)
|
||||||
|
|
||||||
# 缩放图像
|
|
||||||
resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
# 创建正方形画布并居中放置图像
|
|
||||||
canvas = np.full((target_size, target_size, 3), 114, dtype=np.uint8)
|
canvas = np.full((target_size, target_size, 3), 114, dtype=np.uint8)
|
||||||
top = (target_size - new_h) // 2
|
top = (target_size - new_h) // 2
|
||||||
left = (target_size - new_w) // 2
|
left = (target_size - new_w) // 2
|
||||||
canvas[top : top + new_h, left : left + new_w] = resized
|
canvas[top : top + new_h, left : left + new_w] = resized
|
||||||
|
|
||||||
# BGR -> RGB, HWC -> CHW, 归一化
|
rgb = canvas[:, :, ::-1].astype(np.float32) / 255.0
|
||||||
blob = canvas[:, :, ::-1].transpose(2, 0, 1).astype(np.float32) / 255.0
|
|
||||||
return np.expand_dims(blob, axis=0)
|
if self._is_nhwc_format():
|
||||||
|
return np.expand_dims(rgb, axis=0)
|
||||||
|
else:
|
||||||
|
blob = rgb.transpose(2, 0, 1)
|
||||||
|
return np.expand_dims(blob, axis=0)
|
||||||
|
|
||||||
def _postprocess(
|
def _postprocess(
|
||||||
self,
|
self,
|
||||||
@@ -214,7 +542,7 @@ class ObjectDetector:
|
|||||||
|
|
||||||
# 缩放坐标到原始图像尺寸
|
# 缩放坐标到原始图像尺寸
|
||||||
orig_h, orig_w = original_shape
|
orig_h, orig_w = original_shape
|
||||||
target_size = self.input_shape[2]
|
target_size = self._get_target_size()
|
||||||
scale = min(target_size / orig_h, target_size / orig_w)
|
scale = min(target_size / orig_h, target_size / orig_w)
|
||||||
pad_h = (target_size - orig_h * scale) / 2
|
pad_h = (target_size - orig_h * scale) / 2
|
||||||
pad_w = (target_size - orig_w * scale) / 2
|
pad_w = (target_size - orig_w * scale) / 2
|
||||||
@@ -341,18 +669,17 @@ class ObjectDetector:
|
|||||||
self, image: np.ndarray, threshold: float, label_filter: list[str] | None = None
|
self, image: np.ndarray, threshold: float, label_filter: list[str] | None = None
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""对单张图像执行检测"""
|
"""对单张图像执行检测"""
|
||||||
if not self.session:
|
if not self.backend or not self.backend.is_ready():
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 预处理
|
# 预处理
|
||||||
input_tensor = self._preprocess(image)
|
input_tensor = self._preprocess(image)
|
||||||
|
|
||||||
# 推理
|
# 推理
|
||||||
outputs = self.session.run(None, {self.input_name: input_tensor})
|
output = self.backend.infer(input_tensor)
|
||||||
|
|
||||||
# 后处理
|
# 后处理
|
||||||
original_shape = image.shape[:2]
|
original_shape = image.shape[:2]
|
||||||
output: np.ndarray = np.asarray(outputs[0])
|
|
||||||
return self._postprocess(output, original_shape, threshold, label_filter)
|
return self._postprocess(output, original_shape, threshold, label_filter)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,150 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
测试脚本 - 用于诊断 TFLite 和 ONNX 模型的检测问题
|
||||||
|
使用方法: python detect_test.py <image_path> [model_path]
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# 添加当前目录到路径
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
from detect import ObjectDetector, TFLiteBackend, ONNXBackend
|
||||||
|
|
||||||
|
|
||||||
|
def draw_detections(image: np.ndarray, detections: list) -> np.ndarray:
|
||||||
|
"""在图像上绘制检测结果"""
|
||||||
|
img = image.copy()
|
||||||
|
for det in detections:
|
||||||
|
box = det["box"]
|
||||||
|
label = det["label"]
|
||||||
|
conf = det["confidence"]
|
||||||
|
|
||||||
|
x1, y1 = box["x_min"], box["y_min"]
|
||||||
|
x2, y2 = box["x_max"], box["y_max"]
|
||||||
|
|
||||||
|
# 绘制边框
|
||||||
|
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||||
|
|
||||||
|
# 绘制标签
|
||||||
|
text = f"{label}: {conf:.2%}"
|
||||||
|
(tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1)
|
||||||
|
cv2.rectangle(img, (x1, y1 - th - 10), (x1 + tw, y1), (0, 255, 0), -1)
|
||||||
|
cv2.putText(
|
||||||
|
img, text, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 1
|
||||||
|
)
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def test_model(model_path: str, image_path: str):
|
||||||
|
"""测试单个模型"""
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"测试模型: {model_path}")
|
||||||
|
print(f"测试图片: {image_path}")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
|
||||||
|
# 读取图像
|
||||||
|
image = cv2.imread(image_path)
|
||||||
|
if image is None:
|
||||||
|
print(f"错误: 无法读取图像 {image_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"图像尺寸: {image.shape}")
|
||||||
|
|
||||||
|
# 创建检测器
|
||||||
|
detector = ObjectDetector(model_path)
|
||||||
|
if not detector.load_model():
|
||||||
|
print("错误: 模型加载失败")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"模型类型: {detector.model_type}")
|
||||||
|
print(f"输入形状: {detector.input_shape}")
|
||||||
|
|
||||||
|
# 检查后端类型
|
||||||
|
if isinstance(detector.backend, TFLiteBackend):
|
||||||
|
print(f"TFLite 格式: {'NHWC' if detector.backend.is_nhwc() else 'NCHW'}")
|
||||||
|
|
||||||
|
# 调试: 查看输入输出详情
|
||||||
|
print(f"输入详情: {detector.backend.input_details}")
|
||||||
|
print(f"输出详情: {detector.backend.output_details}")
|
||||||
|
|
||||||
|
# 执行检测
|
||||||
|
print("\n执行检测...")
|
||||||
|
|
||||||
|
# 调试: 手动执行预处理并检查
|
||||||
|
input_tensor = detector._preprocess(image)
|
||||||
|
print(f"预处理后张量形状: {input_tensor.shape}")
|
||||||
|
print(f"预处理后张量范围: [{input_tensor.min():.4f}, {input_tensor.max():.4f}]")
|
||||||
|
print(f"预处理后张量类型: {input_tensor.dtype}")
|
||||||
|
|
||||||
|
# 执行推理
|
||||||
|
if detector.backend is None:
|
||||||
|
print("错误: 后端未初始化")
|
||||||
|
return
|
||||||
|
output = detector.backend.infer(input_tensor)
|
||||||
|
print(f"原始输出形状: {output.shape}")
|
||||||
|
print(f"原始输出范围: [{output.min():.4f}, {output.max():.4f}]")
|
||||||
|
|
||||||
|
# 检查输出格式
|
||||||
|
if len(output.shape) == 3:
|
||||||
|
print(
|
||||||
|
f"输出维度分析: batch={output.shape[0]}, dim1={output.shape[1]}, dim2={output.shape[2]}"
|
||||||
|
)
|
||||||
|
# 如果是 (1, 8400, 84) 格式,不需要转置
|
||||||
|
# 如果是 (1, 84, 8400) 格式,需要转置
|
||||||
|
if output.shape[1] == 84 and output.shape[2] == 8400:
|
||||||
|
print("输出格式: (1, 84, 8400) - YOLO 标准格式")
|
||||||
|
elif output.shape[1] == 8400 and output.shape[2] == 84:
|
||||||
|
print("输出格式: (1, 8400, 84) - 需要调整后处理逻辑")
|
||||||
|
|
||||||
|
# 使用标准检测流程
|
||||||
|
detections, inference_time = detector.detect(image, threshold=0.25)
|
||||||
|
|
||||||
|
print(f"\n检测结果 (阈值=0.25):")
|
||||||
|
print(f"推理耗时: {inference_time:.2f} ms")
|
||||||
|
print(f"检测到 {len(detections)} 个目标:")
|
||||||
|
|
||||||
|
for i, det in enumerate(detections):
|
||||||
|
print(f" {i+1}. {det['label']}: {det['confidence']:.2%} at {det['box']}")
|
||||||
|
# 检查置信度是否异常
|
||||||
|
if det["confidence"] > 1.0:
|
||||||
|
print(f" ⚠️ 警告: 置信度超过100%! 原始值: {det['confidence']}")
|
||||||
|
|
||||||
|
# 保存结果图像
|
||||||
|
model_name = os.path.splitext(os.path.basename(model_path))[0]
|
||||||
|
output_path = f"/Users/xugo/Desktop/gowvp/gb28181/analysis/result_{model_name}.jpg"
|
||||||
|
result_img = draw_detections(image, detections)
|
||||||
|
cv2.imwrite(output_path, result_img)
|
||||||
|
print(f"\n结果图像已保存: {output_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 默认图像路径
|
||||||
|
image_path = "/Users/xugo/Desktop/gowvp/gb28181/out.png"
|
||||||
|
|
||||||
|
# 命令行参数
|
||||||
|
if len(sys.argv) > 1:
|
||||||
|
image_path = sys.argv[1]
|
||||||
|
|
||||||
|
# 模型路径
|
||||||
|
tflite_model = "/Users/xugo/Desktop/gowvp/gb28181/configs/owl.tflite"
|
||||||
|
onnx_model = "/Users/xugo/Desktop/gowvp/gb28181/analysis/owl.onnx"
|
||||||
|
|
||||||
|
if len(sys.argv) > 2:
|
||||||
|
# 只测试指定模型
|
||||||
|
test_model(sys.argv[2], image_path)
|
||||||
|
else:
|
||||||
|
# 测试所有可用模型
|
||||||
|
if os.path.exists(onnx_model):
|
||||||
|
test_model(onnx_model, image_path)
|
||||||
|
|
||||||
|
if os.path.exists(tflite_model):
|
||||||
|
test_model(tflite_model, image_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
+35
-3
@@ -25,6 +25,14 @@ from detect import MotionDetector, ObjectDetector
|
|||||||
from frame_capture import FrameCapture
|
from frame_capture import FrameCapture
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
|
# 模型文件搜索候选路径(按优先级排序)
|
||||||
|
MODEL_SEARCH_PATHS = [
|
||||||
|
("../configs/owl.tflite", "tflite"),
|
||||||
|
("../configs/owl.onnx", "onnx"),
|
||||||
|
("./owl.tflite", "tflite"),
|
||||||
|
("./owl.onnx", "onnx"),
|
||||||
|
]
|
||||||
|
|
||||||
# 导入生成的 proto 代码
|
# 导入生成的 proto 代码
|
||||||
# 这些模块必须存在才能启动 gRPC 服务
|
# 这些模块必须存在才能启动 gRPC 服务
|
||||||
import analysis_pb2
|
import analysis_pb2
|
||||||
@@ -480,10 +488,31 @@ def serve(port, model_path):
|
|||||||
server.stop(0)
|
server.stop(0)
|
||||||
|
|
||||||
|
|
||||||
|
def discover_model(model_arg: str) -> str:
|
||||||
|
"""
|
||||||
|
自动发现可用模型文件
|
||||||
|
优先级:../configs/owl.tflite > ../configs/owl.onnx > ./owl.tflite > ./owl.onnx > 命令行参数
|
||||||
|
"""
|
||||||
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
for rel_path, _ in MODEL_SEARCH_PATHS:
|
||||||
|
full_path = os.path.normpath(os.path.join(script_dir, rel_path))
|
||||||
|
if os.path.exists(full_path):
|
||||||
|
slog.info(f"发现模型文件: {full_path}")
|
||||||
|
return full_path
|
||||||
|
|
||||||
|
# 回退到命令行参数指定的模型
|
||||||
|
if os.path.isabs(model_arg):
|
||||||
|
return model_arg
|
||||||
|
|
||||||
|
# 相对路径基于脚本目录解析
|
||||||
|
return os.path.normpath(os.path.join(script_dir, model_arg))
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--port", type=int, default=50051)
|
parser.add_argument("--port", type=int, default=50051)
|
||||||
parser.add_argument("--model", type=str, default="yolo11n.onnx")
|
parser.add_argument("--model", type=str, default="owl.onnx")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--callback-url",
|
"--callback-url",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -503,11 +532,14 @@ def main():
|
|||||||
GLOBAL_CONFIG["callback_url"] = args.callback_url
|
GLOBAL_CONFIG["callback_url"] = args.callback_url
|
||||||
GLOBAL_CONFIG["callback_secret"] = args.callback_secret
|
GLOBAL_CONFIG["callback_secret"] = args.callback_secret
|
||||||
|
|
||||||
|
# 自动发现模型文件
|
||||||
|
model_path = discover_model(args.model)
|
||||||
|
|
||||||
slog.debug(
|
slog.debug(
|
||||||
f"log level: {args.log_level}, model: {args.model}, callback url: {args.callback_url}, callback secret: {args.callback_secret}"
|
f"log level: {args.log_level}, model: {model_path}, callback url: {args.callback_url}, callback secret: {args.callback_secret}"
|
||||||
)
|
)
|
||||||
|
|
||||||
serve(args.port, args.model)
|
serve(args.port, model_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
# ONNX Runtime - 替代 PyTorch,大幅减小镜像体积
|
# ONNX Runtime - 用于加载 .onnx 模型
|
||||||
onnxruntime>=1.17.0
|
onnxruntime>=1.17.0
|
||||||
|
|
||||||
# gRPC 框架
|
# gRPC 框架
|
||||||
|
|||||||
@@ -1,3 +1,11 @@
|
|||||||
|
# 2026-01-11
|
||||||
|
|
||||||
|
修复 webrtc 播放地址
|
||||||
|
支持 ai 检测,支持 onnx 和 tflite
|
||||||
|
支持自定义模型,写入 configs/owl.onnx 或 configs/owl.tflite
|
||||||
|
支持修改国标信息免重启生效
|
||||||
|
修复 rtmp port 偶先值为 0
|
||||||
|
|
||||||
# 2025-06-14
|
# 2025-06-14
|
||||||
重构国标播放逻辑, play 接口仅返回播放地址,通过 webhook 来通知拉流
|
重构国标播放逻辑, play 接口仅返回播放地址,通过 webhook 来通知拉流
|
||||||
可能会导致首播慢一些
|
可能会导致首播慢一些
|
||||||
|
|||||||
@@ -20,9 +20,9 @@ func (d *ZLMDriver) GetStreamLiveAddr(ctx context.Context, ms *MediaServer, http
|
|||||||
var out StreamLiveAddr
|
var out StreamLiveAddr
|
||||||
out.Label = "ZLM"
|
out.Label = "ZLM"
|
||||||
wsPrefix := strings.Replace(strings.Replace(httpPrefix, "https", "wss", 1), "http", "ws", 1)
|
wsPrefix := strings.Replace(strings.Replace(httpPrefix, "https", "wss", 1), "http", "ws", 1)
|
||||||
out.WSFLV = fmt.Sprintf("%s/proxy/sms/%s.live.flv", wsPrefix, stream)
|
out.WSFLV = fmt.Sprintf("%s/proxy/sms/%s/%s.live.flv", wsPrefix, app, stream)
|
||||||
out.HTTPFLV = fmt.Sprintf("%s/proxy/sms/%s.live.flv", httpPrefix, stream)
|
out.HTTPFLV = fmt.Sprintf("%s/proxy/sms/%s/%s.live.flv", httpPrefix, app, stream)
|
||||||
out.HLS = fmt.Sprintf("%s/proxy/sms/%s/hls.fmp4.m3u8", httpPrefix, stream)
|
out.HLS = fmt.Sprintf("%s/proxy/sms/%s/%s/hls.fmp4.m3u8", httpPrefix, app, stream)
|
||||||
rtcPrefix := strings.Replace(strings.Replace(httpPrefix, "https", "webrtc", 1), "http", "webrtc", 1)
|
rtcPrefix := strings.Replace(strings.Replace(httpPrefix, "https", "webrtc", 1), "http", "webrtc", 1)
|
||||||
out.WebRTC = fmt.Sprintf("%s/proxy/sms/index/api/webrtc?app=%s&stream=%s&type=play", rtcPrefix, app, stream)
|
out.WebRTC = fmt.Sprintf("%s/proxy/sms/index/api/webrtc?app=%s&stream=%s&type=play", rtcPrefix, app, stream)
|
||||||
out.RTMP = fmt.Sprintf("rtmp://%s:%d/%s", host, ms.Ports.RTMP, stream)
|
out.RTMP = fmt.Sprintf("rtmp://%s:%d/%s", host, ms.Ports.RTMP, stream)
|
||||||
@@ -81,9 +81,18 @@ func (d *ZLMDriver) Connect(ctx context.Context, ms *MediaServer) error {
|
|||||||
func (d *ZLMDriver) Setup(ctx context.Context, ms *MediaServer, webhookURL string) error {
|
func (d *ZLMDriver) Setup(ctx context.Context, ms *MediaServer, webhookURL string) error {
|
||||||
engine := d.withConfig(ms)
|
engine := d.withConfig(ms)
|
||||||
|
|
||||||
|
// 拼接 IP 但是不要空格
|
||||||
|
ips := make([]string, 0, 2)
|
||||||
|
for _, ip := range []string{ms.SDPIP, ms.IP} {
|
||||||
|
if ip != "" {
|
||||||
|
ips = append(ips, ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = ips
|
||||||
// 构造配置请求
|
// 构造配置请求
|
||||||
req := zlm.SetServerConfigRequest{
|
req := zlm.SetServerConfigRequest{
|
||||||
RtcExternIP: new(ms.IP),
|
RtcExternIP: new(strings.Join(ips, ",")),
|
||||||
|
|
||||||
GeneralMediaServerID: new(ms.ID),
|
GeneralMediaServerID: new(ms.ID),
|
||||||
HookEnable: new("1"),
|
HookEnable: new("1"),
|
||||||
HookOnFlowReport: new(""),
|
HookOnFlowReport: new(""),
|
||||||
|
|||||||
@@ -136,7 +136,8 @@ func (a AIWebhookAPI) onEvents(c *gin.Context, in *AIDetectionInput) (AIWebhookO
|
|||||||
Score: float32(det.Confidence),
|
Score: float32(det.Confidence),
|
||||||
Zones: string(zonesJSON),
|
Zones: string(zonesJSON),
|
||||||
ImagePath: imagePath,
|
ImagePath: imagePath,
|
||||||
Model: "yolo11n",
|
// TODO: 模型名称可以根据模型自定义
|
||||||
|
Model: "default",
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := a.eventCore.AddEvent(ctx, eventInput); err != nil {
|
if _, err := a.eventCore.AddEvent(ctx, eventInput); err != nil {
|
||||||
@@ -259,7 +260,7 @@ func (a *AIWebhookAPI) StartAIDetection(ctx context.Context, ch *ipc.Channel, rt
|
|||||||
CameraId: ch.ID,
|
CameraId: ch.ID,
|
||||||
CameraName: ch.Name,
|
CameraName: ch.Name,
|
||||||
RtspUrl: rtspURL,
|
RtspUrl: rtspURL,
|
||||||
DetectFps: 5,
|
DetectFps: 1,
|
||||||
Labels: labels,
|
Labels: labels,
|
||||||
Threshold: 0.75,
|
Threshold: 0.75,
|
||||||
RoiPoints: roiPoints,
|
RoiPoints: roiPoints,
|
||||||
|
|||||||
@@ -62,10 +62,11 @@ func setupRouter(r *gin.Engine, uc *Usecase) {
|
|||||||
AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"},
|
AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"},
|
||||||
AllowHeaders: []string{
|
AllowHeaders: []string{
|
||||||
"Accept", "Content-Length", "Content-Type", "Range", "Accept-Language",
|
"Accept", "Content-Length", "Content-Type", "Range", "Accept-Language",
|
||||||
"Origin", "Authorization",
|
"Origin", "Authorization", "Referer", "User-Agent",
|
||||||
"Accept-Encoding",
|
"Accept-Encoding",
|
||||||
"Cache-Control", "Pragma", "X-Requested-With",
|
"Cache-Control", "Pragma", "X-Requested-With",
|
||||||
"Sec-Fetch-Mode", "Sec-Fetch-Site", "Sec-Fetch-Dest",
|
"Sec-Fetch-Mode", "Sec-Fetch-Site", "Sec-Fetch-Dest",
|
||||||
|
"Sec-Ch-Ua", "Sec-Ch-Ua-Mobile", "Sec-Ch-Ua-Platform",
|
||||||
"Dnt", "X-Forwarded-For", "X-Forwarded-Proto", "X-Forwarded-Host",
|
"Dnt", "X-Forwarded-For", "X-Forwarded-Proto", "X-Forwarded-Host",
|
||||||
"X-Real-IP", "X-Request-ID", "X-Request-Start", "X-Request-Time",
|
"X-Real-IP", "X-Request-ID", "X-Request-Start", "X-Request-Time",
|
||||||
},
|
},
|
||||||
@@ -105,6 +106,7 @@ func setupRouter(r *gin.Engine, uc *Usecase) {
|
|||||||
registerPushAPI(r, uc.MediaAPI, auth)
|
registerPushAPI(r, uc.MediaAPI, auth)
|
||||||
registerGB28181(r, uc.GB28181API, auth)
|
registerGB28181(r, uc.GB28181API, auth)
|
||||||
registerProxy(r, uc.ProxyAPI, auth)
|
registerProxy(r, uc.ProxyAPI, auth)
|
||||||
|
uc.ConfigAPI.uc = uc
|
||||||
registerConfig(r, uc.ConfigAPI, auth)
|
registerConfig(r, uc.ConfigAPI, auth)
|
||||||
registerSms(r, uc.SMSAPI, auth)
|
registerSms(r, uc.SMSAPI, auth)
|
||||||
RegisterUser(r, uc.UserAPI, auth)
|
RegisterUser(r, uc.UserAPI, auth)
|
||||||
@@ -327,6 +329,7 @@ func (uc *Usecase) proxySMS(c *gin.Context) {
|
|||||||
fullAddr, _ := url.Parse(addr)
|
fullAddr, _ := url.Parse(addr)
|
||||||
c.Request.URL.Path = ""
|
c.Request.URL.Path = ""
|
||||||
proxy := httputil.NewSingleHostReverseProxy(fullAddr)
|
proxy := httputil.NewSingleHostReverseProxy(fullAddr)
|
||||||
|
|
||||||
proxy.Director = func(req *http.Request) {
|
proxy.Director = func(req *http.Request) {
|
||||||
// 设置请求的URL
|
// 设置请求的URL
|
||||||
req.URL.Scheme = "http"
|
req.URL.Scheme = "http"
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
type ConfigAPI struct {
|
type ConfigAPI struct {
|
||||||
configCore config.Core
|
configCore config.Core
|
||||||
conf *conf.Bootstrap
|
conf *conf.Bootstrap
|
||||||
|
uc *Usecase
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConfigAPI(db *gorm.DB, conf *conf.Bootstrap) ConfigAPI {
|
func NewConfigAPI(db *gorm.DB, conf *conf.Bootstrap) ConfigAPI {
|
||||||
@@ -83,5 +84,7 @@ func (a ConfigAPI) editSIP(_ *gin.Context, in *conf.SIP) (gin.H, error) {
|
|||||||
if err := conf.WriteConfig(a.conf, a.conf.ConfigPath); err != nil {
|
if err := conf.WriteConfig(a.conf, a.conf.ConfigPath); err != nil {
|
||||||
return nil, reason.ErrServer.SetMsg(err.Error())
|
return nil, reason.ErrServer.SetMsg(err.Error())
|
||||||
}
|
}
|
||||||
|
a.uc.SipServer.SetConfig()
|
||||||
|
|
||||||
return gin.H{"msg": "ok"}, nil
|
return gin.H{"msg": "ok"}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -281,7 +281,7 @@ func (a IPCAPI) play(c *gin.Context, _ *struct{}) (*playOutput, error) {
|
|||||||
host = h
|
host = h
|
||||||
}
|
}
|
||||||
|
|
||||||
item := a.uc.SMSAPI.smsCore.GetStreamLiveAddr(svr, prefix, host, app, stream)
|
item := a.uc.SMSAPI.smsCore.GetStreamLiveAddr(svr, prefix, host, app, appStream)
|
||||||
out := playOutput{
|
out := playOutput{
|
||||||
App: app,
|
App: app,
|
||||||
Stream: appStream,
|
Stream: appStream,
|
||||||
|
|||||||
@@ -43,13 +43,13 @@ func (a PushAPI) findStreamPush(c *gin.Context, in *push.FindStreamPushInput) (*
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
cacheFn := hook.UseCache(func(s string) (*sms.MediaServer, error) {
|
cacheFn := func(s string) (*sms.MediaServer, error) {
|
||||||
v, err := a.smsCore.GetMediaServer(c.Request.Context(), s)
|
v, err := a.smsCore.GetMediaServer(c.Request.Context(), s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(c.Request.Context(), "GetMediaServer", "err", err)
|
slog.ErrorContext(c.Request.Context(), "GetMediaServer", "err", err)
|
||||||
}
|
}
|
||||||
return v, err
|
return v, err
|
||||||
})
|
}
|
||||||
|
|
||||||
out := make([]*push.FindStreamPushOutputItem, len(items))
|
out := make([]*push.FindStreamPushOutputItem, len(items))
|
||||||
for i, item := range items {
|
for i, item := range items {
|
||||||
@@ -58,8 +58,12 @@ func (a PushAPI) findStreamPush(c *gin.Context, in *push.FindStreamPushInput) (*
|
|||||||
if mediaID == "" {
|
if mediaID == "" {
|
||||||
mediaID = sms.DefaultMediaServerID
|
mediaID = sms.DefaultMediaServerID
|
||||||
}
|
}
|
||||||
if svr, _, _ := cacheFn(mediaID); svr != nil {
|
if svr, _ := cacheFn(mediaID); svr != nil {
|
||||||
addr := fmt.Sprintf("rtmp://%s:%d/%s/%s", web.GetHost(c.Request), svr.Ports.RTMP, item.App, item.Stream)
|
port := svr.Ports.RTMP
|
||||||
|
if port == 0 {
|
||||||
|
port = 1935
|
||||||
|
}
|
||||||
|
addr := fmt.Sprintf("rtmp://%s:%d/%s/%s", web.GetHost(c.Request), port, item.App, item.Stream)
|
||||||
if !item.IsAuthDisabled {
|
if !item.IsAuthDisabled {
|
||||||
addr += fmt.Sprintf("?sign=%s", hook.MD5(a.conf.Server.RTMPSecret))
|
addr += fmt.Sprintf("?sign=%s", hook.MD5(a.conf.Server.RTMPSecret))
|
||||||
}
|
}
|
||||||
|
|||||||
+15
-1
@@ -49,7 +49,7 @@ func NewServer(cfg *conf.Bootstrap, store ipc.Adapter, sc sms.Core) (*Server, fu
|
|||||||
iip := ip.InternalIP()
|
iip := ip.InternalIP()
|
||||||
uri, _ := sip.ParseSipURI(fmt.Sprintf("sip:%s@%s:%d", cfg.Sip.ID, iip, cfg.Sip.Port))
|
uri, _ := sip.ParseSipURI(fmt.Sprintf("sip:%s@%s:%d", cfg.Sip.ID, iip, cfg.Sip.Port))
|
||||||
from := sip.Address{
|
from := sip.Address{
|
||||||
DisplayName: sip.String{Str: "gowvp"},
|
DisplayName: sip.String{Str: "gowvp/owl"},
|
||||||
URI: &uri,
|
URI: &uri,
|
||||||
Params: sip.NewParams(),
|
Params: sip.NewParams(),
|
||||||
}
|
}
|
||||||
@@ -87,6 +87,20 @@ func NewServer(cfg *conf.Bootstrap, store ipc.Adapter, sc sms.Core) (*Server, fu
|
|||||||
return &c, c.Close
|
return &c, c.Close
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetConfig 热更新 SIP 配置,用于配置变更时更新 from 地址而无需重启服务
|
||||||
|
func (s *Server) SetConfig() {
|
||||||
|
cfg := s.gb.cfg
|
||||||
|
iip := ip.InternalIP()
|
||||||
|
uri, _ := sip.ParseSipURI(fmt.Sprintf("sip:%s@%s:%d", cfg.ID, iip, cfg.Port))
|
||||||
|
from := sip.Address{
|
||||||
|
DisplayName: sip.String{Str: "gowvp/owl"},
|
||||||
|
URI: &uri,
|
||||||
|
Params: sip.NewParams(),
|
||||||
|
}
|
||||||
|
s.fromAddress = from
|
||||||
|
s.Server.SetFrom(&from)
|
||||||
|
}
|
||||||
|
|
||||||
// startTickerCheck 定时检查离线,通过心跳超时判断设备是否离线
|
// startTickerCheck 定时检查离线,通过心跳超时判断设备是否离线
|
||||||
func (s *Server) startTickerCheck() {
|
func (s *Server) startTickerCheck() {
|
||||||
conc.Timer(context.Background(), 60*time.Second, time.Second, func() {
|
conc.Timer(context.Background(), 60*time.Second, time.Second, func() {
|
||||||
|
|||||||
@@ -54,6 +54,11 @@ func NewServer(form *Address) *Server {
|
|||||||
return srv
|
return srv
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetFrom 热更新 SIP 源地址配置,用于配置变更时无需重启服务
|
||||||
|
func (s *Server) SetFrom(from *Address) {
|
||||||
|
*s.from = *from
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) addRoute(method string, handler ...HandlerFunc) {
|
func (s *Server) addRoute(method string, handler ...HandlerFunc) {
|
||||||
s.route.Store(strings.ToUpper(method), handler)
|
s.route.Store(strings.ToUpper(method), handler)
|
||||||
}
|
}
|
||||||
|
|||||||
+37
-36
@@ -310,42 +310,43 @@ type SetServerConfigRequest struct {
|
|||||||
RtcStartBitrate *string `json:"rtc.start_bitrate,omitempty"`
|
RtcStartBitrate *string `json:"rtc.start_bitrate,omitempty"`
|
||||||
RtcTCPPort *string `json:"rtc.tcpPort,omitempty"`
|
RtcTCPPort *string `json:"rtc.tcpPort,omitempty"`
|
||||||
RtcTimeoutSec *string `json:"rtc.timeoutSec,omitempty"`
|
RtcTimeoutSec *string `json:"rtc.timeoutSec,omitempty"`
|
||||||
RtmpDirectProxy *string `json:"rtmp.directProxy,omitempty"`
|
// RTCEnableTurn *string `json:"rtc.enableTurn,omitempty"`
|
||||||
RtmpEnhanced *string `json:"rtmp.enhanced,omitempty"`
|
RtmpDirectProxy *string `json:"rtmp.directProxy,omitempty"`
|
||||||
RtmpHandshakeSecond *string `json:"rtmp.handshakeSecond,omitempty"`
|
RtmpEnhanced *string `json:"rtmp.enhanced,omitempty"`
|
||||||
RtmpKeepAliveSecond *string `json:"rtmp.keepAliveSecond,omitempty"`
|
RtmpHandshakeSecond *string `json:"rtmp.handshakeSecond,omitempty"`
|
||||||
RtmpPort *string `json:"rtmp.port,omitempty"`
|
RtmpKeepAliveSecond *string `json:"rtmp.keepAliveSecond,omitempty"`
|
||||||
RtmpSslport *string `json:"rtmp.sslport,omitempty"`
|
RtmpPort *string `json:"rtmp.port,omitempty"`
|
||||||
RtpAudioMtuSize *string `json:"rtp.audioMtuSize,omitempty"`
|
RtmpSslport *string `json:"rtmp.sslport,omitempty"`
|
||||||
RtpH264StapA *string `json:"rtp.h264_stap_a,omitempty"`
|
RtpAudioMtuSize *string `json:"rtp.audioMtuSize,omitempty"`
|
||||||
RtpLowLatency *string `json:"rtp.lowLatency,omitempty"`
|
RtpH264StapA *string `json:"rtp.h264_stap_a,omitempty"`
|
||||||
RtpRtpMaxSize *string `json:"rtp.rtpMaxSize,omitempty"`
|
RtpLowLatency *string `json:"rtp.lowLatency,omitempty"`
|
||||||
RtpVideoMtuSize *string `json:"rtp.videoMtuSize,omitempty"`
|
RtpRtpMaxSize *string `json:"rtp.rtpMaxSize,omitempty"`
|
||||||
RtpProxyDumpDir *string `json:"rtp_proxy.dumpDir,omitempty"`
|
RtpVideoMtuSize *string `json:"rtp.videoMtuSize,omitempty"`
|
||||||
RtpProxyGopCache *string `json:"rtp_proxy.gop_cache,omitempty"`
|
RtpProxyDumpDir *string `json:"rtp_proxy.dumpDir,omitempty"`
|
||||||
RtpProxyH264Pt *string `json:"rtp_proxy.h264_pt,omitempty"`
|
RtpProxyGopCache *string `json:"rtp_proxy.gop_cache,omitempty"`
|
||||||
RtpProxyH265Pt *string `json:"rtp_proxy.h265_pt,omitempty"`
|
RtpProxyH264Pt *string `json:"rtp_proxy.h264_pt,omitempty"`
|
||||||
RtpProxyOpusPt *string `json:"rtp_proxy.opus_pt,omitempty"`
|
RtpProxyH265Pt *string `json:"rtp_proxy.h265_pt,omitempty"`
|
||||||
RtpProxyPort *string `json:"rtp_proxy.port,omitempty"`
|
RtpProxyOpusPt *string `json:"rtp_proxy.opus_pt,omitempty"`
|
||||||
RtpProxyPortRange *string `json:"rtp_proxy.port_range,omitempty"`
|
RtpProxyPort *string `json:"rtp_proxy.port,omitempty"`
|
||||||
RtpProxyPsPt *string `json:"rtp_proxy.ps_pt,omitempty"`
|
RtpProxyPortRange *string `json:"rtp_proxy.port_range,omitempty"`
|
||||||
RtpProxyRtpG711DurMs *string `json:"rtp_proxy.rtp_g711_dur_ms,omitempty"`
|
RtpProxyPsPt *string `json:"rtp_proxy.ps_pt,omitempty"`
|
||||||
RtpProxyTimeoutSec *string `json:"rtp_proxy.timeoutSec,omitempty"`
|
RtpProxyRtpG711DurMs *string `json:"rtp_proxy.rtp_g711_dur_ms,omitempty"`
|
||||||
RtpProxyUDPRecvSocketBuffer *string `json:"rtp_proxy.udp_recv_socket_buffer,omitempty"`
|
RtpProxyTimeoutSec *string `json:"rtp_proxy.timeoutSec,omitempty"`
|
||||||
RtspAuthBasic *string `json:"rtsp.authBasic,omitempty"`
|
RtpProxyUDPRecvSocketBuffer *string `json:"rtp_proxy.udp_recv_socket_buffer,omitempty"`
|
||||||
RtspDirectProxy *string `json:"rtsp.directProxy,omitempty"`
|
RtspAuthBasic *string `json:"rtsp.authBasic,omitempty"`
|
||||||
RtspHandshakeSecond *string `json:"rtsp.handshakeSecond,omitempty"`
|
RtspDirectProxy *string `json:"rtsp.directProxy,omitempty"`
|
||||||
RtspKeepAliveSecond *string `json:"rtsp.keepAliveSecond,omitempty"`
|
RtspHandshakeSecond *string `json:"rtsp.handshakeSecond,omitempty"`
|
||||||
RtspLowLatency *string `json:"rtsp.lowLatency,omitempty"`
|
RtspKeepAliveSecond *string `json:"rtsp.keepAliveSecond,omitempty"`
|
||||||
RtspPort *string `json:"rtsp.port,omitempty"`
|
RtspLowLatency *string `json:"rtsp.lowLatency,omitempty"`
|
||||||
RtspRtpTransportType *string `json:"rtsp.rtpTransportType,omitempty"`
|
RtspPort *string `json:"rtsp.port,omitempty"`
|
||||||
RtspSslport *string `json:"rtsp.sslport,omitempty"`
|
RtspRtpTransportType *string `json:"rtsp.rtpTransportType,omitempty"`
|
||||||
ShellMaxReqSize *string `json:"shell.maxReqSize,omitempty"`
|
RtspSslport *string `json:"rtsp.sslport,omitempty"`
|
||||||
ShellPort *string `json:"shell.port,omitempty"`
|
ShellMaxReqSize *string `json:"shell.maxReqSize,omitempty"`
|
||||||
SrtLatencyMul *string `json:"srt.latencyMul,omitempty"`
|
ShellPort *string `json:"shell.port,omitempty"`
|
||||||
SrtPktBufSize *string `json:"srt.pktBufSize,omitempty"`
|
SrtLatencyMul *string `json:"srt.latencyMul,omitempty"`
|
||||||
SrtPort *string `json:"srt.port,omitempty"`
|
SrtPktBufSize *string `json:"srt.pktBufSize,omitempty"`
|
||||||
SrtTimeoutSec *string `json:"srt.timeoutSec,omitempty"`
|
SrtPort *string `json:"srt.port,omitempty"`
|
||||||
|
SrtTimeoutSec *string `json:"srt.timeoutSec,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) GetServerConfig() (*GetServerConfigResponse, error) {
|
func (e *Engine) GetServerConfig() (*GetServerConfigResponse, error) {
|
||||||
|
|||||||
Reference in New Issue
Block a user