support tflite

This commit is contained in:
xugo
2026-01-13 12:17:26 +08:00
parent 674e7dfb23
commit 6715c07be3
18 changed files with 679 additions and 90 deletions
+9 -1
View File
@@ -43,8 +43,14 @@ RUN pip3 config set global.index-url https://mirrors.aliyun.com/pypi/simple/ \
# 安装 Python 依赖
# --break-system-packages: Debian Trixie 使用 PEP 668 保护系统 Python
# TFLite 支持:优先尝试 tflite-runtime,失败则用 ai-edge-litertGoogle 新包名)
COPY ./analysis/requirements.txt /tmp/requirements.txt
RUN pip3 install --no-cache-dir --break-system-packages -r /tmp/requirements.txt \
&& (pip3 install --no-cache-dir --break-system-packages \
-i https://pypi.org/simple/ tflite-runtime 2>/dev/null \
|| pip3 install --no-cache-dir --break-system-packages \
-i https://pypi.org/simple/ ai-edge-litert 2>/dev/null \
|| echo "TFLite runtime not available, will use ONNX only") \
&& rm /tmp/requirements.txt \
&& rm -rf /root/.cache/pip
@@ -54,7 +60,9 @@ COPY --from=mwader/static-ffmpeg:6.1 /ffmpeg /usr/local/bin/ffmpeg
COPY --from=mwader/static-ffmpeg:6.1 /ffprobe /usr/local/bin/ffprobe
# 复制应用文件
ADD ./build/linux_amd64/bin ./gowvp
# 使用 TARGETARCH 自动识别目标架构(amd64 或 arm64
ARG TARGETARCH
ADD ./build/linux_${TARGETARCH}/bin ./gowvp
ADD ./www ./www
ADD ./analysis ./analysis
+2 -1
View File
@@ -18,7 +18,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
COPY --from=zlmediakit/zlmediakit:master /opt/media/bin /opt/media/bin
COPY --from=mwader/static-ffmpeg:6.1 /ffmpeg /usr/local/bin/ffmpeg
ADD ./build/linux_amd64/bin ./gowvp
ARG TARGETARCH
ADD ./build/linux_${TARGETARCH}/bin ./gowvp
ADD ./www ./www
RUN mkdir -p configs
+5 -1
View File
@@ -135,11 +135,13 @@ postgres 和 mysql 的格式即:
> 如何关闭 AI?
ai 默认是开启状态,1 秒检测 5
ai 默认是开启状态,1 秒检测 1
可以在 `configs/config.toml` 中修改 `disabledAI = true` 关闭 ai 检测
> 国标设备在线,通道离线?
属于 ipc 的问题,请检查 ipc 后台注册的 平台 sip_id 和 域是否与 gowvp/owl 一致。
## 文档
@@ -302,6 +304,8 @@ services:
[@joestarzxh](https://github.com/joestarzxh)
[@oldweipro](https://github.com/oldweipro)
[@beixiaocai](https://github.com/beixiaocai)
[@chencanfggz](https://github.com/chencanfggz)
[@zhangxuan1340](https://github.com/zhangxuan1340)
## 许可证
+19
View File
@@ -0,0 +1,19 @@
## ai/init: 安装 AI 分析模块依赖(本地开发用)
ai/init:
@echo "安装基础依赖..."
@pip install -r analysis/requirements.txt
@echo ""
@echo "TFLite 支持需要额外安装(二选一):"
@echo " Linux: make ai/init/tflite"
@echo " macOS/Windows: make ai/init/tensorflow"
@echo ""
@echo "如果使用 .tflite 模型,请根据系统选择安装"
## ai/init/tflite: 安装 TFLite 支持(Linux 轻量版)
ai/init/tflite:
@conda install tflite-runtime
## ai/init/tensorflow: 安装 TensorFlowmacOS/Windows 完整版)
ai/init/tensorflow:
@ pip install tensorflow -i https://pypi.org/simple
+361 -34
View File
@@ -1,11 +1,11 @@
import logging
import os
import time
from abc import ABC, abstractmethod
from typing import Any
import numpy as np
import cv2
import onnxruntime as ort
slog = logging.getLogger("Detector")
@@ -94,55 +94,77 @@ COCO_LABELS = [
]
class ObjectDetector:
class ModelBackend(ABC):
"""
模型推理后端抽象接口
不同模型格式(ONNX、TFLite)需实现此接口,确保上层调用逻辑统一
"""
def __init__(self, model_path: str = "yolo11n.onnx"):
self.model_path = model_path
self.session: ort.InferenceSession | None = None
@abstractmethod
def load(self, model_path: str) -> bool:
"""加载模型文件,返回是否成功"""
pass
@abstractmethod
def is_ready(self) -> bool:
"""检查模型是否已加载并可用"""
pass
@abstractmethod
def get_input_shape(self) -> tuple:
"""获取模型输入形状,用于预处理"""
pass
@abstractmethod
def infer(self, input_tensor: np.ndarray) -> np.ndarray:
"""执行推理,返回原始输出"""
pass
class ONNXBackend(ModelBackend):
"""
ONNX Runtime 推理后端
使用 onnxruntime 库加载和执行 ONNX 格式模型
"""
def __init__(self):
self.session = None
self.input_name: str = ""
self.input_shape: tuple = (1, 3, 640, 640)
self._is_ready = False
self.names: dict[int, str] = {i: name for i, name in enumerate(COCO_LABELS)}
def load_model(self) -> bool:
"""加载 ONNX 模型并初始化推理会话"""
def load(self, model_path: str) -> bool:
try:
slog.info(f"加载 ONNX 模型: {self.model_path} ...")
import onnxruntime as ort
slog.info(f"加载 ONNX 模型: {model_path} ...")
start_time = time.time()
# 配置 ONNX Runtime 会话选项
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = (
ort.GraphOptimizationLevel.ORT_ENABLE_ALL
)
# 限制线程数,避免在容器中占用过多 CPU
sess_options.intra_op_num_threads = 4
sess_options.inter_op_num_threads = 2
# 优先使用 CPU 执行提供程序
providers = ["CPUExecutionProvider"]
self.session = ort.InferenceSession(
self.model_path, sess_options=sess_options, providers=providers
model_path, sess_options=sess_options, providers=providers
)
# 获取输入信息
input_info = self.session.get_inputs()[0]
self.input_name = input_info.name
self.input_shape = tuple(input_info.shape)
# 预热模型
dummy_img = np.zeros((640, 640, 3), dtype=np.uint8)
self._preprocess(dummy_img)
dummy_input = self._preprocess(dummy_img)
self.session.run(None, {self.input_name: dummy_input})
elapsed = time.time() - start_time
slog.info(
f"ONNX 模型加载完成 (耗时: {elapsed:.2f}s, 输入形状: {self.input_shape})"
)
self._is_ready = True
return True
except ImportError:
slog.error("未安装 onnxruntime,无法加载 ONNX 模型")
return False
except Exception as e:
slog.error(f"加载 ONNX 模型失败: {e}")
return False
@@ -150,29 +172,335 @@ class ObjectDetector:
def is_ready(self) -> bool:
return self._is_ready and self.session is not None
def get_input_shape(self) -> tuple:
return self.input_shape
def infer(self, input_tensor: np.ndarray) -> np.ndarray:
if not self.session:
raise RuntimeError("ONNX 模型未加载")
outputs = self.session.run(None, {self.input_name: input_tensor})
return np.asarray(outputs[0])
class TFLiteBackend(ModelBackend):
"""
TensorFlow Lite 推理后端
支持两种模型格式:
1. YOLO 格式:单输出张量 (1, 84, 8400)
2. SSD 格式:多输出张量(boxes, classes, scores, num_detections
"""
def __init__(self):
self.interpreter: Any = None
self.input_details: list[dict[str, Any]] = []
self.output_details: list[dict[str, Any]] = []
self.input_shape: tuple = (1, 640, 640, 3)
self._is_ready = False
self._is_nhwc = True
self._is_ssd_format = False # 区分 SSD 和 YOLO 格式
self._input_quantization: tuple[float, int] = (1.0, 0) # scale, zero_point
def load(self, model_path: str) -> bool:
try:
Interpreter = None
try:
from tflite_runtime.interpreter import Interpreter # type: ignore
except ImportError:
try:
from ai_edge_litert.interpreter import Interpreter # type: ignore
except ImportError:
try:
import tensorflow as tf
Interpreter = tf.lite.Interpreter
except ImportError:
pass
if Interpreter is None:
raise ImportError("未找到 tflite_runtime、ai_edge_litert 或 tensorflow")
slog.info(f"加载 TFLite 模型: {model_path} ...")
start_time = time.time()
self.interpreter = Interpreter(model_path=model_path)
self.interpreter.allocate_tensors()
self.input_details = self.interpreter.get_input_details()
self.output_details = self.interpreter.get_output_details()
input_shape = self.input_details[0]["shape"]
self.input_shape = tuple(input_shape)
if len(self.input_shape) == 4:
self._is_nhwc = self.input_shape[3] == 3
# 获取输入量化参数(用于 uint8 量化模型)
quant_params = self.input_details[0].get("quantization_parameters", {})
scales = quant_params.get("scales", np.array([1.0]))
zero_points = quant_params.get("zero_points", np.array([0]))
if len(scales) > 0 and len(zero_points) > 0:
self._input_quantization = (float(scales[0]), int(zero_points[0]))
# 检测模型格式:SSD 模型通常有4个输出(boxes, classes, scores, num
# 且输出名称包含 "TFLite_Detection_PostProcess"
self._is_ssd_format = len(self.output_details) >= 3 and any(
"Detection" in d.get("name", "") for d in self.output_details
)
elapsed = time.time() - start_time
format_name = "SSD" if self._is_ssd_format else "YOLO"
slog.info(
f"TFLite 模型加载完成 (耗时: {elapsed:.2f}s, 输入: {self.input_shape}, "
f"格式: {format_name}, 量化: {self._input_quantization})"
)
self._is_ready = True
return True
except ImportError:
slog.error("未安装 tflite_runtime 或 tensorflow,无法加载 TFLite 模型")
return False
except Exception as e:
slog.error(f"加载 TFLite 模型失败: {e}")
return False
def is_ready(self) -> bool:
return self._is_ready and self.interpreter is not None
def get_input_shape(self) -> tuple:
return self.input_shape
def is_nhwc(self) -> bool:
"""返回模型是否使用 NHWC 格式"""
return self._is_nhwc
def is_ssd_format(self) -> bool:
"""返回是否为 SSD 格式(多输出张量)"""
return self._is_ssd_format
def get_input_quantization(self) -> tuple[float, int]:
"""返回输入量化参数 (scale, zero_point)"""
return self._input_quantization
def get_input_dtype(self) -> np.dtype:
"""返回模型期望的输入数据类型"""
return self.input_details[0]["dtype"]
def infer(self, input_tensor: np.ndarray) -> np.ndarray:
"""执行推理,返回第一个输出张量(用于 YOLO 格式)"""
if not self.interpreter or len(self.input_details) == 0:
raise RuntimeError("TFLite 模型未加载")
input_dtype = self.input_details[0]["dtype"]
if input_tensor.dtype != input_dtype:
input_tensor = input_tensor.astype(input_dtype)
self.interpreter.set_tensor(self.input_details[0]["index"], input_tensor)
self.interpreter.invoke()
output = self.interpreter.get_tensor(self.output_details[0]["index"])
return np.asarray(output)
def infer_ssd(
self, input_tensor: np.ndarray
) -> tuple[np.ndarray, np.ndarray, np.ndarray, int]:
"""
执行 SSD 格式推理,返回解析后的检测结果
SSD 输出格式(已内置后处理):
- boxes: (1, num_boxes, 4) 归一化坐标 [y_min, x_min, y_max, x_max]
- classes: (1, num_boxes) 类别 ID
- scores: (1, num_boxes) 置信度分数
- num_detections: 有效检测数量
"""
if not self.interpreter or len(self.input_details) == 0:
raise RuntimeError("TFLite 模型未加载")
input_dtype = self.input_details[0]["dtype"]
if input_tensor.dtype != input_dtype:
input_tensor = input_tensor.astype(input_dtype)
self.interpreter.set_tensor(self.input_details[0]["index"], input_tensor)
self.interpreter.invoke()
# 按名称或索引获取各输出张量
boxes = None
classes = None
scores = None
num_detections = 0
for detail in self.output_details:
name = detail.get("name", "")
tensor = self.interpreter.get_tensor(detail["index"])
if "boxes" in name.lower() or (
detail["shape"][-1] == 4 and len(detail["shape"]) == 3
):
boxes = np.asarray(tensor)
elif "class" in name.lower() or (
len(detail["shape"]) == 2
and detail["shape"][1] > 1
and boxes is not None
):
classes = np.asarray(tensor)
elif "score" in name.lower() or ":2" in name:
scores = np.asarray(tensor)
elif "num" in name.lower() or (
len(detail["shape"]) == 1 and detail["shape"][0] == 1
):
num_detections = int(tensor[0])
# 兜底处理:按输出顺序分配
if boxes is None or classes is None or scores is None:
outputs = [
self.interpreter.get_tensor(d["index"]) for d in self.output_details
]
if len(outputs) >= 4:
boxes = np.asarray(outputs[0])
classes = np.asarray(outputs[1])
scores = np.asarray(outputs[2])
num_detections = int(outputs[3][0]) if outputs[3].size > 0 else 0
if boxes is None:
boxes = np.array([])
if classes is None:
classes = np.array([])
if scores is None:
scores = np.array([])
return boxes, classes, scores, num_detections
def get_model_type(model_path: str) -> str:
"""根据模型文件后缀判断模型类型"""
ext = os.path.splitext(model_path)[1].lower()
return "tflite" if ext == ".tflite" else "onnx"
def create_backend(model_type: str) -> ModelBackend:
"""
根据模型类型创建对应的推理后端
"""
if model_type == "tflite":
return TFLiteBackend()
else:
return ONNXBackend()
class ObjectDetector:
"""
目标检测器 - 支持多种模型格式(ONNX、TFLite)
通过统一的 ModelBackend 接口实现模型无关的检测逻辑
"""
def __init__(self, model_path: str):
self.model_path = model_path
self.model_type = get_model_type(model_path)
self.backend: ModelBackend | None = None
self.input_shape: tuple = (1, 3, 640, 640)
self._is_ready = False
self.names: dict[int, str] = {i: name for i, name in enumerate(COCO_LABELS)}
def load_model(self) -> bool:
"""加载模型并初始化推理后端"""
try:
start_time = time.time()
# 创建对应类型的后端
self.backend = create_backend(self.model_type)
# 加载模型
if not self.backend.load(self.model_path):
return False
self.input_shape = self.backend.get_input_shape()
# 预热模型
self._warmup()
elapsed = time.time() - start_time
slog.info(f"模型预热完成 (总耗时: {elapsed:.2f}s)")
self._is_ready = True
return True
except Exception as e:
slog.error(f"加载模型失败: {e}")
return False
def _warmup(self) -> None:
"""预热模型,减少首次推理延迟"""
if not self.backend:
return
dummy_img = np.zeros((640, 640, 3), dtype=np.uint8)
dummy_input = self._preprocess(dummy_img)
self.backend.infer(dummy_input)
slog.info("模型预热完成")
def is_ready(self) -> bool:
return self._is_ready and self.backend is not None and self.backend.is_ready()
def _get_target_size(self) -> int:
"""获取模型期望的输入尺寸"""
# NCHW: (1, 3, H, W) -> shape[2]
# NHWC: (1, H, W, 3) -> shape[1]
if self._is_nhwc_format():
return int(self.input_shape[1])
return int(self.input_shape[2])
def _is_nhwc_format(self) -> bool:
"""判断当前后端是否使用 NHWC 格式"""
if isinstance(self.backend, TFLiteBackend):
return self.backend.is_nhwc()
return False
def _is_ssd_format(self) -> bool:
"""判断当前后端是否为 SSD 格式"""
if isinstance(self.backend, TFLiteBackend):
return self.backend.is_ssd_format()
return False
def _preprocess(self, image: np.ndarray) -> np.ndarray:
"""
预处理图像:调整大小、归一化、转换格式
YOLO 期望输入格式: NCHW, float32, 归一化到 [0, 1]
根据后端类型自动选择 NCHW 或 NHWC 格式,并处理量化输入
"""
# 保持宽高比的 letterbox 缩放
target_size = self.input_shape[2] # 640
target_size = self._get_target_size()
h, w = image.shape[:2]
# SSD 模型使用直接缩放(不保持宽高比)
# YOLO 模型使用 letterbox 缩放(保持宽高比)
if self._is_ssd_format():
resized = cv2.resize(
image, (target_size, target_size), interpolation=cv2.INTER_LINEAR
)
rgb = resized[:, :, ::-1] # BGR -> RGB
# 检查是否需要量化为 uint8
if isinstance(self.backend, TFLiteBackend):
input_dtype = self.backend.get_input_dtype()
if input_dtype == np.uint8:
# 直接返回 uint8 格式
return np.expand_dims(rgb.astype(np.uint8), axis=0)
# float32 格式
rgb = rgb.astype(np.float32) / 255.0
return np.expand_dims(rgb, axis=0)
# YOLO letterbox 预处理
scale = min(target_size / h, target_size / w)
new_h, new_w = int(h * scale), int(w * scale)
# 缩放图像
resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
# 创建正方形画布并居中放置图像
canvas = np.full((target_size, target_size, 3), 114, dtype=np.uint8)
top = (target_size - new_h) // 2
left = (target_size - new_w) // 2
canvas[top : top + new_h, left : left + new_w] = resized
# BGR -> RGB, HWC -> CHW, 归一化
blob = canvas[:, :, ::-1].transpose(2, 0, 1).astype(np.float32) / 255.0
return np.expand_dims(blob, axis=0)
rgb = canvas[:, :, ::-1].astype(np.float32) / 255.0
if self._is_nhwc_format():
return np.expand_dims(rgb, axis=0)
else:
blob = rgb.transpose(2, 0, 1)
return np.expand_dims(blob, axis=0)
def _postprocess(
self,
@@ -214,7 +542,7 @@ class ObjectDetector:
# 缩放坐标到原始图像尺寸
orig_h, orig_w = original_shape
target_size = self.input_shape[2]
target_size = self._get_target_size()
scale = min(target_size / orig_h, target_size / orig_w)
pad_h = (target_size - orig_h * scale) / 2
pad_w = (target_size - orig_w * scale) / 2
@@ -341,18 +669,17 @@ class ObjectDetector:
self, image: np.ndarray, threshold: float, label_filter: list[str] | None = None
) -> list[dict[str, Any]]:
"""对单张图像执行检测"""
if not self.session:
if not self.backend or not self.backend.is_ready():
return []
# 预处理
input_tensor = self._preprocess(image)
# 推理
outputs = self.session.run(None, {self.input_name: input_tensor})
output = self.backend.infer(input_tensor)
# 后处理
original_shape = image.shape[:2]
output: np.ndarray = np.asarray(outputs[0])
return self._postprocess(output, original_shape, threshold, label_filter)
+150
View File
@@ -0,0 +1,150 @@
#!/usr/bin/env python3
"""
测试脚本 - 用于诊断 TFLite 和 ONNX 模型的检测问题
使用方法: python detect_test.py <image_path> [model_path]
"""
import sys
import os
import cv2
import numpy as np
# 添加当前目录到路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from detect import ObjectDetector, TFLiteBackend, ONNXBackend
def draw_detections(image: np.ndarray, detections: list) -> np.ndarray:
"""在图像上绘制检测结果"""
img = image.copy()
for det in detections:
box = det["box"]
label = det["label"]
conf = det["confidence"]
x1, y1 = box["x_min"], box["y_min"]
x2, y2 = box["x_max"], box["y_max"]
# 绘制边框
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
# 绘制标签
text = f"{label}: {conf:.2%}"
(tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1)
cv2.rectangle(img, (x1, y1 - th - 10), (x1 + tw, y1), (0, 255, 0), -1)
cv2.putText(
img, text, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 1
)
return img
def test_model(model_path: str, image_path: str):
"""测试单个模型"""
print(f"\n{'='*60}")
print(f"测试模型: {model_path}")
print(f"测试图片: {image_path}")
print(f"{'='*60}")
# 读取图像
image = cv2.imread(image_path)
if image is None:
print(f"错误: 无法读取图像 {image_path}")
return
print(f"图像尺寸: {image.shape}")
# 创建检测器
detector = ObjectDetector(model_path)
if not detector.load_model():
print("错误: 模型加载失败")
return
print(f"模型类型: {detector.model_type}")
print(f"输入形状: {detector.input_shape}")
# 检查后端类型
if isinstance(detector.backend, TFLiteBackend):
print(f"TFLite 格式: {'NHWC' if detector.backend.is_nhwc() else 'NCHW'}")
# 调试: 查看输入输出详情
print(f"输入详情: {detector.backend.input_details}")
print(f"输出详情: {detector.backend.output_details}")
# 执行检测
print("\n执行检测...")
# 调试: 手动执行预处理并检查
input_tensor = detector._preprocess(image)
print(f"预处理后张量形状: {input_tensor.shape}")
print(f"预处理后张量范围: [{input_tensor.min():.4f}, {input_tensor.max():.4f}]")
print(f"预处理后张量类型: {input_tensor.dtype}")
# 执行推理
if detector.backend is None:
print("错误: 后端未初始化")
return
output = detector.backend.infer(input_tensor)
print(f"原始输出形状: {output.shape}")
print(f"原始输出范围: [{output.min():.4f}, {output.max():.4f}]")
# 检查输出格式
if len(output.shape) == 3:
print(
f"输出维度分析: batch={output.shape[0]}, dim1={output.shape[1]}, dim2={output.shape[2]}"
)
# 如果是 (1, 8400, 84) 格式,不需要转置
# 如果是 (1, 84, 8400) 格式,需要转置
if output.shape[1] == 84 and output.shape[2] == 8400:
print("输出格式: (1, 84, 8400) - YOLO 标准格式")
elif output.shape[1] == 8400 and output.shape[2] == 84:
print("输出格式: (1, 8400, 84) - 需要调整后处理逻辑")
# 使用标准检测流程
detections, inference_time = detector.detect(image, threshold=0.25)
print(f"\n检测结果 (阈值=0.25):")
print(f"推理耗时: {inference_time:.2f} ms")
print(f"检测到 {len(detections)} 个目标:")
for i, det in enumerate(detections):
print(f" {i+1}. {det['label']}: {det['confidence']:.2%} at {det['box']}")
# 检查置信度是否异常
if det["confidence"] > 1.0:
print(f" ⚠️ 警告: 置信度超过100%! 原始值: {det['confidence']}")
# 保存结果图像
model_name = os.path.splitext(os.path.basename(model_path))[0]
output_path = f"/Users/xugo/Desktop/gowvp/gb28181/analysis/result_{model_name}.jpg"
result_img = draw_detections(image, detections)
cv2.imwrite(output_path, result_img)
print(f"\n结果图像已保存: {output_path}")
def main():
# 默认图像路径
image_path = "/Users/xugo/Desktop/gowvp/gb28181/out.png"
# 命令行参数
if len(sys.argv) > 1:
image_path = sys.argv[1]
# 模型路径
tflite_model = "/Users/xugo/Desktop/gowvp/gb28181/configs/owl.tflite"
onnx_model = "/Users/xugo/Desktop/gowvp/gb28181/analysis/owl.onnx"
if len(sys.argv) > 2:
# 只测试指定模型
test_model(sys.argv[2], image_path)
else:
# 测试所有可用模型
if os.path.exists(onnx_model):
test_model(onnx_model, image_path)
if os.path.exists(tflite_model):
test_model(tflite_model, image_path)
if __name__ == "__main__":
main()
+35 -3
View File
@@ -25,6 +25,14 @@ from detect import MotionDetector, ObjectDetector
from frame_capture import FrameCapture
import cv2
# 模型文件搜索候选路径(按优先级排序)
MODEL_SEARCH_PATHS = [
("../configs/owl.tflite", "tflite"),
("../configs/owl.onnx", "onnx"),
("./owl.tflite", "tflite"),
("./owl.onnx", "onnx"),
]
# 导入生成的 proto 代码
# 这些模块必须存在才能启动 gRPC 服务
import analysis_pb2
@@ -480,10 +488,31 @@ def serve(port, model_path):
server.stop(0)
def discover_model(model_arg: str) -> str:
"""
自动发现可用模型文件
优先级:../configs/owl.tflite > ../configs/owl.onnx > ./owl.tflite > ./owl.onnx > 命令行参数
"""
script_dir = os.path.dirname(os.path.abspath(__file__))
for rel_path, _ in MODEL_SEARCH_PATHS:
full_path = os.path.normpath(os.path.join(script_dir, rel_path))
if os.path.exists(full_path):
slog.info(f"发现模型文件: {full_path}")
return full_path
# 回退到命令行参数指定的模型
if os.path.isabs(model_arg):
return model_arg
# 相对路径基于脚本目录解析
return os.path.normpath(os.path.join(script_dir, model_arg))
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=50051)
parser.add_argument("--model", type=str, default="yolo11n.onnx")
parser.add_argument("--model", type=str, default="owl.onnx")
parser.add_argument(
"--callback-url",
type=str,
@@ -503,11 +532,14 @@ def main():
GLOBAL_CONFIG["callback_url"] = args.callback_url
GLOBAL_CONFIG["callback_secret"] = args.callback_secret
# 自动发现模型文件
model_path = discover_model(args.model)
slog.debug(
f"log level: {args.log_level}, model: {args.model}, callback url: {args.callback_url}, callback secret: {args.callback_secret}"
f"log level: {args.log_level}, model: {model_path}, callback url: {args.callback_url}, callback secret: {args.callback_secret}"
)
serve(args.port, args.model)
serve(args.port, model_path)
if __name__ == "__main__":
+1 -1
View File
@@ -1,4 +1,4 @@
# ONNX Runtime - 替代 PyTorch,大幅减小镜像体积
# ONNX Runtime - 用于加载 .onnx 模型
onnxruntime>=1.17.0
# gRPC 框架
+8
View File
@@ -1,3 +1,11 @@
# 2026-01-11
修复 webrtc 播放地址
支持 ai 检测,支持 onnx 和 tflite
支持自定义模型,写入 configs/owl.onnx 或 configs/owl.tflite
支持修改国标信息免重启生效
修复 rtmp port 偶先值为 0
# 2025-06-14
重构国标播放逻辑, play 接口仅返回播放地址,通过 webhook 来通知拉流
可能会导致首播慢一些
+13 -4
View File
@@ -20,9 +20,9 @@ func (d *ZLMDriver) GetStreamLiveAddr(ctx context.Context, ms *MediaServer, http
var out StreamLiveAddr
out.Label = "ZLM"
wsPrefix := strings.Replace(strings.Replace(httpPrefix, "https", "wss", 1), "http", "ws", 1)
out.WSFLV = fmt.Sprintf("%s/proxy/sms/%s.live.flv", wsPrefix, stream)
out.HTTPFLV = fmt.Sprintf("%s/proxy/sms/%s.live.flv", httpPrefix, stream)
out.HLS = fmt.Sprintf("%s/proxy/sms/%s/hls.fmp4.m3u8", httpPrefix, stream)
out.WSFLV = fmt.Sprintf("%s/proxy/sms/%s/%s.live.flv", wsPrefix, app, stream)
out.HTTPFLV = fmt.Sprintf("%s/proxy/sms/%s/%s.live.flv", httpPrefix, app, stream)
out.HLS = fmt.Sprintf("%s/proxy/sms/%s/%s/hls.fmp4.m3u8", httpPrefix, app, stream)
rtcPrefix := strings.Replace(strings.Replace(httpPrefix, "https", "webrtc", 1), "http", "webrtc", 1)
out.WebRTC = fmt.Sprintf("%s/proxy/sms/index/api/webrtc?app=%s&stream=%s&type=play", rtcPrefix, app, stream)
out.RTMP = fmt.Sprintf("rtmp://%s:%d/%s", host, ms.Ports.RTMP, stream)
@@ -81,9 +81,18 @@ func (d *ZLMDriver) Connect(ctx context.Context, ms *MediaServer) error {
func (d *ZLMDriver) Setup(ctx context.Context, ms *MediaServer, webhookURL string) error {
engine := d.withConfig(ms)
// 拼接 IP 但是不要空格
ips := make([]string, 0, 2)
for _, ip := range []string{ms.SDPIP, ms.IP} {
if ip != "" {
ips = append(ips, ip)
}
}
_ = ips
// 构造配置请求
req := zlm.SetServerConfigRequest{
RtcExternIP: new(ms.IP),
RtcExternIP: new(strings.Join(ips, ",")),
GeneralMediaServerID: new(ms.ID),
HookEnable: new("1"),
HookOnFlowReport: new(""),
+3 -2
View File
@@ -136,7 +136,8 @@ func (a AIWebhookAPI) onEvents(c *gin.Context, in *AIDetectionInput) (AIWebhookO
Score: float32(det.Confidence),
Zones: string(zonesJSON),
ImagePath: imagePath,
Model: "yolo11n",
// TODO: 模型名称可以根据模型自定义
Model: "default",
}
if _, err := a.eventCore.AddEvent(ctx, eventInput); err != nil {
@@ -259,7 +260,7 @@ func (a *AIWebhookAPI) StartAIDetection(ctx context.Context, ch *ipc.Channel, rt
CameraId: ch.ID,
CameraName: ch.Name,
RtspUrl: rtspURL,
DetectFps: 5,
DetectFps: 1,
Labels: labels,
Threshold: 0.75,
RoiPoints: roiPoints,
+4 -1
View File
@@ -62,10 +62,11 @@ func setupRouter(r *gin.Engine, uc *Usecase) {
AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"},
AllowHeaders: []string{
"Accept", "Content-Length", "Content-Type", "Range", "Accept-Language",
"Origin", "Authorization",
"Origin", "Authorization", "Referer", "User-Agent",
"Accept-Encoding",
"Cache-Control", "Pragma", "X-Requested-With",
"Sec-Fetch-Mode", "Sec-Fetch-Site", "Sec-Fetch-Dest",
"Sec-Ch-Ua", "Sec-Ch-Ua-Mobile", "Sec-Ch-Ua-Platform",
"Dnt", "X-Forwarded-For", "X-Forwarded-Proto", "X-Forwarded-Host",
"X-Real-IP", "X-Request-ID", "X-Request-Start", "X-Request-Time",
},
@@ -105,6 +106,7 @@ func setupRouter(r *gin.Engine, uc *Usecase) {
registerPushAPI(r, uc.MediaAPI, auth)
registerGB28181(r, uc.GB28181API, auth)
registerProxy(r, uc.ProxyAPI, auth)
uc.ConfigAPI.uc = uc
registerConfig(r, uc.ConfigAPI, auth)
registerSms(r, uc.SMSAPI, auth)
RegisterUser(r, uc.UserAPI, auth)
@@ -327,6 +329,7 @@ func (uc *Usecase) proxySMS(c *gin.Context) {
fullAddr, _ := url.Parse(addr)
c.Request.URL.Path = ""
proxy := httputil.NewSingleHostReverseProxy(fullAddr)
proxy.Director = func(req *http.Request) {
// 设置请求的URL
req.URL.Scheme = "http"
+3
View File
@@ -18,6 +18,7 @@ import (
type ConfigAPI struct {
configCore config.Core
conf *conf.Bootstrap
uc *Usecase
}
func NewConfigAPI(db *gorm.DB, conf *conf.Bootstrap) ConfigAPI {
@@ -83,5 +84,7 @@ func (a ConfigAPI) editSIP(_ *gin.Context, in *conf.SIP) (gin.H, error) {
if err := conf.WriteConfig(a.conf, a.conf.ConfigPath); err != nil {
return nil, reason.ErrServer.SetMsg(err.Error())
}
a.uc.SipServer.SetConfig()
return gin.H{"msg": "ok"}, nil
}
+1 -1
View File
@@ -281,7 +281,7 @@ func (a IPCAPI) play(c *gin.Context, _ *struct{}) (*playOutput, error) {
host = h
}
item := a.uc.SMSAPI.smsCore.GetStreamLiveAddr(svr, prefix, host, app, stream)
item := a.uc.SMSAPI.smsCore.GetStreamLiveAddr(svr, prefix, host, app, appStream)
out := playOutput{
App: app,
Stream: appStream,
+8 -4
View File
@@ -43,13 +43,13 @@ func (a PushAPI) findStreamPush(c *gin.Context, in *push.FindStreamPushInput) (*
return nil, err
}
cacheFn := hook.UseCache(func(s string) (*sms.MediaServer, error) {
cacheFn := func(s string) (*sms.MediaServer, error) {
v, err := a.smsCore.GetMediaServer(c.Request.Context(), s)
if err != nil {
slog.ErrorContext(c.Request.Context(), "GetMediaServer", "err", err)
}
return v, err
})
}
out := make([]*push.FindStreamPushOutputItem, len(items))
for i, item := range items {
@@ -58,8 +58,12 @@ func (a PushAPI) findStreamPush(c *gin.Context, in *push.FindStreamPushInput) (*
if mediaID == "" {
mediaID = sms.DefaultMediaServerID
}
if svr, _, _ := cacheFn(mediaID); svr != nil {
addr := fmt.Sprintf("rtmp://%s:%d/%s/%s", web.GetHost(c.Request), svr.Ports.RTMP, item.App, item.Stream)
if svr, _ := cacheFn(mediaID); svr != nil {
port := svr.Ports.RTMP
if port == 0 {
port = 1935
}
addr := fmt.Sprintf("rtmp://%s:%d/%s/%s", web.GetHost(c.Request), port, item.App, item.Stream)
if !item.IsAuthDisabled {
addr += fmt.Sprintf("?sign=%s", hook.MD5(a.conf.Server.RTMPSecret))
}
+15 -1
View File
@@ -49,7 +49,7 @@ func NewServer(cfg *conf.Bootstrap, store ipc.Adapter, sc sms.Core) (*Server, fu
iip := ip.InternalIP()
uri, _ := sip.ParseSipURI(fmt.Sprintf("sip:%s@%s:%d", cfg.Sip.ID, iip, cfg.Sip.Port))
from := sip.Address{
DisplayName: sip.String{Str: "gowvp"},
DisplayName: sip.String{Str: "gowvp/owl"},
URI: &uri,
Params: sip.NewParams(),
}
@@ -87,6 +87,20 @@ func NewServer(cfg *conf.Bootstrap, store ipc.Adapter, sc sms.Core) (*Server, fu
return &c, c.Close
}
// SetConfig 热更新 SIP 配置,用于配置变更时更新 from 地址而无需重启服务
func (s *Server) SetConfig() {
cfg := s.gb.cfg
iip := ip.InternalIP()
uri, _ := sip.ParseSipURI(fmt.Sprintf("sip:%s@%s:%d", cfg.ID, iip, cfg.Port))
from := sip.Address{
DisplayName: sip.String{Str: "gowvp/owl"},
URI: &uri,
Params: sip.NewParams(),
}
s.fromAddress = from
s.Server.SetFrom(&from)
}
// startTickerCheck 定时检查离线,通过心跳超时判断设备是否离线
func (s *Server) startTickerCheck() {
conc.Timer(context.Background(), 60*time.Second, time.Second, func() {
+5
View File
@@ -54,6 +54,11 @@ func NewServer(form *Address) *Server {
return srv
}
// SetFrom 热更新 SIP 源地址配置,用于配置变更时无需重启服务
func (s *Server) SetFrom(from *Address) {
*s.from = *from
}
func (s *Server) addRoute(method string, handler ...HandlerFunc) {
s.route.Store(strings.ToUpper(method), handler)
}
+37 -36
View File
@@ -310,42 +310,43 @@ type SetServerConfigRequest struct {
RtcStartBitrate *string `json:"rtc.start_bitrate,omitempty"`
RtcTCPPort *string `json:"rtc.tcpPort,omitempty"`
RtcTimeoutSec *string `json:"rtc.timeoutSec,omitempty"`
RtmpDirectProxy *string `json:"rtmp.directProxy,omitempty"`
RtmpEnhanced *string `json:"rtmp.enhanced,omitempty"`
RtmpHandshakeSecond *string `json:"rtmp.handshakeSecond,omitempty"`
RtmpKeepAliveSecond *string `json:"rtmp.keepAliveSecond,omitempty"`
RtmpPort *string `json:"rtmp.port,omitempty"`
RtmpSslport *string `json:"rtmp.sslport,omitempty"`
RtpAudioMtuSize *string `json:"rtp.audioMtuSize,omitempty"`
RtpH264StapA *string `json:"rtp.h264_stap_a,omitempty"`
RtpLowLatency *string `json:"rtp.lowLatency,omitempty"`
RtpRtpMaxSize *string `json:"rtp.rtpMaxSize,omitempty"`
RtpVideoMtuSize *string `json:"rtp.videoMtuSize,omitempty"`
RtpProxyDumpDir *string `json:"rtp_proxy.dumpDir,omitempty"`
RtpProxyGopCache *string `json:"rtp_proxy.gop_cache,omitempty"`
RtpProxyH264Pt *string `json:"rtp_proxy.h264_pt,omitempty"`
RtpProxyH265Pt *string `json:"rtp_proxy.h265_pt,omitempty"`
RtpProxyOpusPt *string `json:"rtp_proxy.opus_pt,omitempty"`
RtpProxyPort *string `json:"rtp_proxy.port,omitempty"`
RtpProxyPortRange *string `json:"rtp_proxy.port_range,omitempty"`
RtpProxyPsPt *string `json:"rtp_proxy.ps_pt,omitempty"`
RtpProxyRtpG711DurMs *string `json:"rtp_proxy.rtp_g711_dur_ms,omitempty"`
RtpProxyTimeoutSec *string `json:"rtp_proxy.timeoutSec,omitempty"`
RtpProxyUDPRecvSocketBuffer *string `json:"rtp_proxy.udp_recv_socket_buffer,omitempty"`
RtspAuthBasic *string `json:"rtsp.authBasic,omitempty"`
RtspDirectProxy *string `json:"rtsp.directProxy,omitempty"`
RtspHandshakeSecond *string `json:"rtsp.handshakeSecond,omitempty"`
RtspKeepAliveSecond *string `json:"rtsp.keepAliveSecond,omitempty"`
RtspLowLatency *string `json:"rtsp.lowLatency,omitempty"`
RtspPort *string `json:"rtsp.port,omitempty"`
RtspRtpTransportType *string `json:"rtsp.rtpTransportType,omitempty"`
RtspSslport *string `json:"rtsp.sslport,omitempty"`
ShellMaxReqSize *string `json:"shell.maxReqSize,omitempty"`
ShellPort *string `json:"shell.port,omitempty"`
SrtLatencyMul *string `json:"srt.latencyMul,omitempty"`
SrtPktBufSize *string `json:"srt.pktBufSize,omitempty"`
SrtPort *string `json:"srt.port,omitempty"`
SrtTimeoutSec *string `json:"srt.timeoutSec,omitempty"`
// RTCEnableTurn *string `json:"rtc.enableTurn,omitempty"`
RtmpDirectProxy *string `json:"rtmp.directProxy,omitempty"`
RtmpEnhanced *string `json:"rtmp.enhanced,omitempty"`
RtmpHandshakeSecond *string `json:"rtmp.handshakeSecond,omitempty"`
RtmpKeepAliveSecond *string `json:"rtmp.keepAliveSecond,omitempty"`
RtmpPort *string `json:"rtmp.port,omitempty"`
RtmpSslport *string `json:"rtmp.sslport,omitempty"`
RtpAudioMtuSize *string `json:"rtp.audioMtuSize,omitempty"`
RtpH264StapA *string `json:"rtp.h264_stap_a,omitempty"`
RtpLowLatency *string `json:"rtp.lowLatency,omitempty"`
RtpRtpMaxSize *string `json:"rtp.rtpMaxSize,omitempty"`
RtpVideoMtuSize *string `json:"rtp.videoMtuSize,omitempty"`
RtpProxyDumpDir *string `json:"rtp_proxy.dumpDir,omitempty"`
RtpProxyGopCache *string `json:"rtp_proxy.gop_cache,omitempty"`
RtpProxyH264Pt *string `json:"rtp_proxy.h264_pt,omitempty"`
RtpProxyH265Pt *string `json:"rtp_proxy.h265_pt,omitempty"`
RtpProxyOpusPt *string `json:"rtp_proxy.opus_pt,omitempty"`
RtpProxyPort *string `json:"rtp_proxy.port,omitempty"`
RtpProxyPortRange *string `json:"rtp_proxy.port_range,omitempty"`
RtpProxyPsPt *string `json:"rtp_proxy.ps_pt,omitempty"`
RtpProxyRtpG711DurMs *string `json:"rtp_proxy.rtp_g711_dur_ms,omitempty"`
RtpProxyTimeoutSec *string `json:"rtp_proxy.timeoutSec,omitempty"`
RtpProxyUDPRecvSocketBuffer *string `json:"rtp_proxy.udp_recv_socket_buffer,omitempty"`
RtspAuthBasic *string `json:"rtsp.authBasic,omitempty"`
RtspDirectProxy *string `json:"rtsp.directProxy,omitempty"`
RtspHandshakeSecond *string `json:"rtsp.handshakeSecond,omitempty"`
RtspKeepAliveSecond *string `json:"rtsp.keepAliveSecond,omitempty"`
RtspLowLatency *string `json:"rtsp.lowLatency,omitempty"`
RtspPort *string `json:"rtsp.port,omitempty"`
RtspRtpTransportType *string `json:"rtsp.rtpTransportType,omitempty"`
RtspSslport *string `json:"rtsp.sslport,omitempty"`
ShellMaxReqSize *string `json:"shell.maxReqSize,omitempty"`
ShellPort *string `json:"shell.port,omitempty"`
SrtLatencyMul *string `json:"srt.latencyMul,omitempty"`
SrtPktBufSize *string `json:"srt.pktBufSize,omitempty"`
SrtPort *string `json:"srt.port,omitempty"`
SrtTimeoutSec *string `json:"srt.timeoutSec,omitempty"`
}
func (e *Engine) GetServerConfig() (*GetServerConfigResponse, error) {