diff --git a/Dockerfile_ai b/Dockerfile_ai index 7c1d4b9..f33d44f 100644 --- a/Dockerfile_ai +++ b/Dockerfile_ai @@ -43,8 +43,14 @@ RUN pip3 config set global.index-url https://mirrors.aliyun.com/pypi/simple/ \ # 安装 Python 依赖 # --break-system-packages: Debian Trixie 使用 PEP 668 保护系统 Python +# TFLite 支持:优先尝试 tflite-runtime,失败则用 ai-edge-litert(Google 新包名) COPY ./analysis/requirements.txt /tmp/requirements.txt RUN pip3 install --no-cache-dir --break-system-packages -r /tmp/requirements.txt \ + && (pip3 install --no-cache-dir --break-system-packages \ + -i https://pypi.org/simple/ tflite-runtime 2>/dev/null \ + || pip3 install --no-cache-dir --break-system-packages \ + -i https://pypi.org/simple/ ai-edge-litert 2>/dev/null \ + || echo "TFLite runtime not available, will use ONNX only") \ && rm /tmp/requirements.txt \ && rm -rf /root/.cache/pip @@ -54,7 +60,9 @@ COPY --from=mwader/static-ffmpeg:6.1 /ffmpeg /usr/local/bin/ffmpeg COPY --from=mwader/static-ffmpeg:6.1 /ffprobe /usr/local/bin/ffprobe # 复制应用文件 -ADD ./build/linux_amd64/bin ./gowvp +# 使用 TARGETARCH 自动识别目标架构(amd64 或 arm64) +ARG TARGETARCH +ADD ./build/linux_${TARGETARCH}/bin ./gowvp ADD ./www ./www ADD ./analysis ./analysis diff --git a/Dockerfile_zlm b/Dockerfile_zlm index a3aa208..0874cf9 100644 --- a/Dockerfile_zlm +++ b/Dockerfile_zlm @@ -18,7 +18,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ COPY --from=zlmediakit/zlmediakit:master /opt/media/bin /opt/media/bin COPY --from=mwader/static-ffmpeg:6.1 /ffmpeg /usr/local/bin/ffmpeg -ADD ./build/linux_amd64/bin ./gowvp +ARG TARGETARCH +ADD ./build/linux_${TARGETARCH}/bin ./gowvp ADD ./www ./www RUN mkdir -p configs diff --git a/README_CN.md b/README_CN.md index 8a470f8..5210dd6 100644 --- a/README_CN.md +++ b/README_CN.md @@ -135,11 +135,13 @@ postgres 和 mysql 的格式即: > 如何关闭 AI? -ai 默认是开启状态,1 秒检测 5 帧 +ai 默认是开启状态,1 秒检测 1 帧 可以在 `configs/config.toml` 中修改 `disabledAI = true` 关闭 ai 检测 +> 国标设备在线,通道离线? +属于 ipc 的问题,请检查 ipc 后台注册的 平台 sip_id 和 域是否与 gowvp/owl 一致。 ## 文档 @@ -302,6 +304,8 @@ services: [@joestarzxh](https://github.com/joestarzxh) [@oldweipro](https://github.com/oldweipro) [@beixiaocai](https://github.com/beixiaocai) +[@chencanfggz](https://github.com/chencanfggz) +[@zhangxuan1340](https://github.com/zhangxuan1340) ## 许可证 diff --git a/analysis/Makefile b/analysis/Makefile index e69de29..c743b4a 100644 --- a/analysis/Makefile +++ b/analysis/Makefile @@ -0,0 +1,19 @@ + +## ai/init: 安装 AI 分析模块依赖(本地开发用) +ai/init: + @echo "安装基础依赖..." + @pip install -r analysis/requirements.txt + @echo "" + @echo "TFLite 支持需要额外安装(二选一):" + @echo " Linux: make ai/init/tflite" + @echo " macOS/Windows: make ai/init/tensorflow" + @echo "" + @echo "如果使用 .tflite 模型,请根据系统选择安装" + +## ai/init/tflite: 安装 TFLite 支持(Linux 轻量版) +ai/init/tflite: + @conda install tflite-runtime + +## ai/init/tensorflow: 安装 TensorFlow(macOS/Windows 完整版) +ai/init/tensorflow: + @ pip install tensorflow -i https://pypi.org/simple \ No newline at end of file diff --git a/analysis/detect.py b/analysis/detect.py index fa45c91..47452b3 100644 --- a/analysis/detect.py +++ b/analysis/detect.py @@ -1,11 +1,11 @@ import logging +import os import time +from abc import ABC, abstractmethod from typing import Any import numpy as np import cv2 -import onnxruntime as ort - slog = logging.getLogger("Detector") @@ -94,55 +94,77 @@ COCO_LABELS = [ ] -class ObjectDetector: +class ModelBackend(ABC): + """ + 模型推理后端抽象接口 + 不同模型格式(ONNX、TFLite)需实现此接口,确保上层调用逻辑统一 + """ - def __init__(self, model_path: str = "yolo11n.onnx"): - self.model_path = model_path - self.session: ort.InferenceSession | None = None + @abstractmethod + def load(self, model_path: str) -> bool: + """加载模型文件,返回是否成功""" + pass + + @abstractmethod + def is_ready(self) -> bool: + """检查模型是否已加载并可用""" + pass + + @abstractmethod + def get_input_shape(self) -> tuple: + """获取模型输入形状,用于预处理""" + pass + + @abstractmethod + def infer(self, input_tensor: np.ndarray) -> np.ndarray: + """执行推理,返回原始输出""" + pass + + +class ONNXBackend(ModelBackend): + """ + ONNX Runtime 推理后端 + 使用 onnxruntime 库加载和执行 ONNX 格式模型 + """ + + def __init__(self): + self.session = None self.input_name: str = "" self.input_shape: tuple = (1, 3, 640, 640) self._is_ready = False - self.names: dict[int, str] = {i: name for i, name in enumerate(COCO_LABELS)} - def load_model(self) -> bool: - """加载 ONNX 模型并初始化推理会话""" + def load(self, model_path: str) -> bool: try: - slog.info(f"加载 ONNX 模型: {self.model_path} ...") + import onnxruntime as ort + + slog.info(f"加载 ONNX 模型: {model_path} ...") start_time = time.time() - # 配置 ONNX Runtime 会话选项 sess_options = ort.SessionOptions() sess_options.graph_optimization_level = ( ort.GraphOptimizationLevel.ORT_ENABLE_ALL ) - # 限制线程数,避免在容器中占用过多 CPU sess_options.intra_op_num_threads = 4 sess_options.inter_op_num_threads = 2 - # 优先使用 CPU 执行提供程序 providers = ["CPUExecutionProvider"] - self.session = ort.InferenceSession( - self.model_path, sess_options=sess_options, providers=providers + model_path, sess_options=sess_options, providers=providers ) - # 获取输入信息 input_info = self.session.get_inputs()[0] self.input_name = input_info.name self.input_shape = tuple(input_info.shape) - # 预热模型 - dummy_img = np.zeros((640, 640, 3), dtype=np.uint8) - self._preprocess(dummy_img) - dummy_input = self._preprocess(dummy_img) - self.session.run(None, {self.input_name: dummy_input}) - elapsed = time.time() - start_time slog.info( f"ONNX 模型加载完成 (耗时: {elapsed:.2f}s, 输入形状: {self.input_shape})" ) self._is_ready = True return True + except ImportError: + slog.error("未安装 onnxruntime,无法加载 ONNX 模型") + return False except Exception as e: slog.error(f"加载 ONNX 模型失败: {e}") return False @@ -150,29 +172,335 @@ class ObjectDetector: def is_ready(self) -> bool: return self._is_ready and self.session is not None + def get_input_shape(self) -> tuple: + return self.input_shape + + def infer(self, input_tensor: np.ndarray) -> np.ndarray: + if not self.session: + raise RuntimeError("ONNX 模型未加载") + outputs = self.session.run(None, {self.input_name: input_tensor}) + return np.asarray(outputs[0]) + + +class TFLiteBackend(ModelBackend): + """ + TensorFlow Lite 推理后端 + 支持两种模型格式: + 1. YOLO 格式:单输出张量 (1, 84, 8400) + 2. SSD 格式:多输出张量(boxes, classes, scores, num_detections) + """ + + def __init__(self): + self.interpreter: Any = None + self.input_details: list[dict[str, Any]] = [] + self.output_details: list[dict[str, Any]] = [] + self.input_shape: tuple = (1, 640, 640, 3) + self._is_ready = False + self._is_nhwc = True + self._is_ssd_format = False # 区分 SSD 和 YOLO 格式 + self._input_quantization: tuple[float, int] = (1.0, 0) # scale, zero_point + + def load(self, model_path: str) -> bool: + try: + Interpreter = None + try: + from tflite_runtime.interpreter import Interpreter # type: ignore + except ImportError: + try: + from ai_edge_litert.interpreter import Interpreter # type: ignore + except ImportError: + try: + import tensorflow as tf + + Interpreter = tf.lite.Interpreter + except ImportError: + pass + + if Interpreter is None: + raise ImportError("未找到 tflite_runtime、ai_edge_litert 或 tensorflow") + + slog.info(f"加载 TFLite 模型: {model_path} ...") + start_time = time.time() + + self.interpreter = Interpreter(model_path=model_path) + self.interpreter.allocate_tensors() + + self.input_details = self.interpreter.get_input_details() + self.output_details = self.interpreter.get_output_details() + + input_shape = self.input_details[0]["shape"] + self.input_shape = tuple(input_shape) + + if len(self.input_shape) == 4: + self._is_nhwc = self.input_shape[3] == 3 + + # 获取输入量化参数(用于 uint8 量化模型) + quant_params = self.input_details[0].get("quantization_parameters", {}) + scales = quant_params.get("scales", np.array([1.0])) + zero_points = quant_params.get("zero_points", np.array([0])) + if len(scales) > 0 and len(zero_points) > 0: + self._input_quantization = (float(scales[0]), int(zero_points[0])) + + # 检测模型格式:SSD 模型通常有4个输出(boxes, classes, scores, num) + # 且输出名称包含 "TFLite_Detection_PostProcess" + self._is_ssd_format = len(self.output_details) >= 3 and any( + "Detection" in d.get("name", "") for d in self.output_details + ) + + elapsed = time.time() - start_time + format_name = "SSD" if self._is_ssd_format else "YOLO" + slog.info( + f"TFLite 模型加载完成 (耗时: {elapsed:.2f}s, 输入: {self.input_shape}, " + f"格式: {format_name}, 量化: {self._input_quantization})" + ) + self._is_ready = True + return True + except ImportError: + slog.error("未安装 tflite_runtime 或 tensorflow,无法加载 TFLite 模型") + return False + except Exception as e: + slog.error(f"加载 TFLite 模型失败: {e}") + return False + + def is_ready(self) -> bool: + return self._is_ready and self.interpreter is not None + + def get_input_shape(self) -> tuple: + return self.input_shape + + def is_nhwc(self) -> bool: + """返回模型是否使用 NHWC 格式""" + return self._is_nhwc + + def is_ssd_format(self) -> bool: + """返回是否为 SSD 格式(多输出张量)""" + return self._is_ssd_format + + def get_input_quantization(self) -> tuple[float, int]: + """返回输入量化参数 (scale, zero_point)""" + return self._input_quantization + + def get_input_dtype(self) -> np.dtype: + """返回模型期望的输入数据类型""" + return self.input_details[0]["dtype"] + + def infer(self, input_tensor: np.ndarray) -> np.ndarray: + """执行推理,返回第一个输出张量(用于 YOLO 格式)""" + if not self.interpreter or len(self.input_details) == 0: + raise RuntimeError("TFLite 模型未加载") + + input_dtype = self.input_details[0]["dtype"] + if input_tensor.dtype != input_dtype: + input_tensor = input_tensor.astype(input_dtype) + + self.interpreter.set_tensor(self.input_details[0]["index"], input_tensor) + self.interpreter.invoke() + + output = self.interpreter.get_tensor(self.output_details[0]["index"]) + return np.asarray(output) + + def infer_ssd( + self, input_tensor: np.ndarray + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, int]: + """ + 执行 SSD 格式推理,返回解析后的检测结果 + SSD 输出格式(已内置后处理): + - boxes: (1, num_boxes, 4) 归一化坐标 [y_min, x_min, y_max, x_max] + - classes: (1, num_boxes) 类别 ID + - scores: (1, num_boxes) 置信度分数 + - num_detections: 有效检测数量 + """ + if not self.interpreter or len(self.input_details) == 0: + raise RuntimeError("TFLite 模型未加载") + + input_dtype = self.input_details[0]["dtype"] + if input_tensor.dtype != input_dtype: + input_tensor = input_tensor.astype(input_dtype) + + self.interpreter.set_tensor(self.input_details[0]["index"], input_tensor) + self.interpreter.invoke() + + # 按名称或索引获取各输出张量 + boxes = None + classes = None + scores = None + num_detections = 0 + + for detail in self.output_details: + name = detail.get("name", "") + tensor = self.interpreter.get_tensor(detail["index"]) + + if "boxes" in name.lower() or ( + detail["shape"][-1] == 4 and len(detail["shape"]) == 3 + ): + boxes = np.asarray(tensor) + elif "class" in name.lower() or ( + len(detail["shape"]) == 2 + and detail["shape"][1] > 1 + and boxes is not None + ): + classes = np.asarray(tensor) + elif "score" in name.lower() or ":2" in name: + scores = np.asarray(tensor) + elif "num" in name.lower() or ( + len(detail["shape"]) == 1 and detail["shape"][0] == 1 + ): + num_detections = int(tensor[0]) + + # 兜底处理:按输出顺序分配 + if boxes is None or classes is None or scores is None: + outputs = [ + self.interpreter.get_tensor(d["index"]) for d in self.output_details + ] + if len(outputs) >= 4: + boxes = np.asarray(outputs[0]) + classes = np.asarray(outputs[1]) + scores = np.asarray(outputs[2]) + num_detections = int(outputs[3][0]) if outputs[3].size > 0 else 0 + + if boxes is None: + boxes = np.array([]) + if classes is None: + classes = np.array([]) + if scores is None: + scores = np.array([]) + + return boxes, classes, scores, num_detections + + +def get_model_type(model_path: str) -> str: + """根据模型文件后缀判断模型类型""" + ext = os.path.splitext(model_path)[1].lower() + return "tflite" if ext == ".tflite" else "onnx" + + +def create_backend(model_type: str) -> ModelBackend: + """ + 根据模型类型创建对应的推理后端 + """ + if model_type == "tflite": + return TFLiteBackend() + else: + return ONNXBackend() + + +class ObjectDetector: + """ + 目标检测器 - 支持多种模型格式(ONNX、TFLite) + 通过统一的 ModelBackend 接口实现模型无关的检测逻辑 + """ + + def __init__(self, model_path: str): + self.model_path = model_path + self.model_type = get_model_type(model_path) + self.backend: ModelBackend | None = None + self.input_shape: tuple = (1, 3, 640, 640) + self._is_ready = False + self.names: dict[int, str] = {i: name for i, name in enumerate(COCO_LABELS)} + + def load_model(self) -> bool: + """加载模型并初始化推理后端""" + try: + start_time = time.time() + + # 创建对应类型的后端 + self.backend = create_backend(self.model_type) + + # 加载模型 + if not self.backend.load(self.model_path): + return False + + self.input_shape = self.backend.get_input_shape() + + # 预热模型 + self._warmup() + + elapsed = time.time() - start_time + slog.info(f"模型预热完成 (总耗时: {elapsed:.2f}s)") + self._is_ready = True + return True + except Exception as e: + slog.error(f"加载模型失败: {e}") + return False + + def _warmup(self) -> None: + """预热模型,减少首次推理延迟""" + if not self.backend: + return + + dummy_img = np.zeros((640, 640, 3), dtype=np.uint8) + dummy_input = self._preprocess(dummy_img) + self.backend.infer(dummy_input) + slog.info("模型预热完成") + + def is_ready(self) -> bool: + return self._is_ready and self.backend is not None and self.backend.is_ready() + + def _get_target_size(self) -> int: + """获取模型期望的输入尺寸""" + # NCHW: (1, 3, H, W) -> shape[2] + # NHWC: (1, H, W, 3) -> shape[1] + if self._is_nhwc_format(): + return int(self.input_shape[1]) + return int(self.input_shape[2]) + + def _is_nhwc_format(self) -> bool: + """判断当前后端是否使用 NHWC 格式""" + if isinstance(self.backend, TFLiteBackend): + return self.backend.is_nhwc() + return False + + def _is_ssd_format(self) -> bool: + """判断当前后端是否为 SSD 格式""" + if isinstance(self.backend, TFLiteBackend): + return self.backend.is_ssd_format() + return False + def _preprocess(self, image: np.ndarray) -> np.ndarray: """ 预处理图像:调整大小、归一化、转换格式 - YOLO 期望输入格式: NCHW, float32, 归一化到 [0, 1] + 根据后端类型自动选择 NCHW 或 NHWC 格式,并处理量化输入 """ - # 保持宽高比的 letterbox 缩放 - target_size = self.input_shape[2] # 640 + target_size = self._get_target_size() h, w = image.shape[:2] + + # SSD 模型使用直接缩放(不保持宽高比) + # YOLO 模型使用 letterbox 缩放(保持宽高比) + if self._is_ssd_format(): + resized = cv2.resize( + image, (target_size, target_size), interpolation=cv2.INTER_LINEAR + ) + rgb = resized[:, :, ::-1] # BGR -> RGB + + # 检查是否需要量化为 uint8 + if isinstance(self.backend, TFLiteBackend): + input_dtype = self.backend.get_input_dtype() + if input_dtype == np.uint8: + # 直接返回 uint8 格式 + return np.expand_dims(rgb.astype(np.uint8), axis=0) + + # float32 格式 + rgb = rgb.astype(np.float32) / 255.0 + return np.expand_dims(rgb, axis=0) + + # YOLO letterbox 预处理 scale = min(target_size / h, target_size / w) new_h, new_w = int(h * scale), int(w * scale) - # 缩放图像 resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR) - # 创建正方形画布并居中放置图像 canvas = np.full((target_size, target_size, 3), 114, dtype=np.uint8) top = (target_size - new_h) // 2 left = (target_size - new_w) // 2 canvas[top : top + new_h, left : left + new_w] = resized - # BGR -> RGB, HWC -> CHW, 归一化 - blob = canvas[:, :, ::-1].transpose(2, 0, 1).astype(np.float32) / 255.0 - return np.expand_dims(blob, axis=0) + rgb = canvas[:, :, ::-1].astype(np.float32) / 255.0 + + if self._is_nhwc_format(): + return np.expand_dims(rgb, axis=0) + else: + blob = rgb.transpose(2, 0, 1) + return np.expand_dims(blob, axis=0) def _postprocess( self, @@ -214,7 +542,7 @@ class ObjectDetector: # 缩放坐标到原始图像尺寸 orig_h, orig_w = original_shape - target_size = self.input_shape[2] + target_size = self._get_target_size() scale = min(target_size / orig_h, target_size / orig_w) pad_h = (target_size - orig_h * scale) / 2 pad_w = (target_size - orig_w * scale) / 2 @@ -341,18 +669,17 @@ class ObjectDetector: self, image: np.ndarray, threshold: float, label_filter: list[str] | None = None ) -> list[dict[str, Any]]: """对单张图像执行检测""" - if not self.session: + if not self.backend or not self.backend.is_ready(): return [] # 预处理 input_tensor = self._preprocess(image) # 推理 - outputs = self.session.run(None, {self.input_name: input_tensor}) + output = self.backend.infer(input_tensor) # 后处理 original_shape = image.shape[:2] - output: np.ndarray = np.asarray(outputs[0]) return self._postprocess(output, original_shape, threshold, label_filter) diff --git a/analysis/detect_test.py b/analysis/detect_test.py new file mode 100644 index 0000000..fe42ddc --- /dev/null +++ b/analysis/detect_test.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +""" +测试脚本 - 用于诊断 TFLite 和 ONNX 模型的检测问题 +使用方法: python detect_test.py [model_path] +""" + +import sys +import os +import cv2 +import numpy as np + +# 添加当前目录到路径 +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from detect import ObjectDetector, TFLiteBackend, ONNXBackend + + +def draw_detections(image: np.ndarray, detections: list) -> np.ndarray: + """在图像上绘制检测结果""" + img = image.copy() + for det in detections: + box = det["box"] + label = det["label"] + conf = det["confidence"] + + x1, y1 = box["x_min"], box["y_min"] + x2, y2 = box["x_max"], box["y_max"] + + # 绘制边框 + cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2) + + # 绘制标签 + text = f"{label}: {conf:.2%}" + (tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1) + cv2.rectangle(img, (x1, y1 - th - 10), (x1 + tw, y1), (0, 255, 0), -1) + cv2.putText( + img, text, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 1 + ) + + return img + + +def test_model(model_path: str, image_path: str): + """测试单个模型""" + print(f"\n{'='*60}") + print(f"测试模型: {model_path}") + print(f"测试图片: {image_path}") + print(f"{'='*60}") + + # 读取图像 + image = cv2.imread(image_path) + if image is None: + print(f"错误: 无法读取图像 {image_path}") + return + + print(f"图像尺寸: {image.shape}") + + # 创建检测器 + detector = ObjectDetector(model_path) + if not detector.load_model(): + print("错误: 模型加载失败") + return + + print(f"模型类型: {detector.model_type}") + print(f"输入形状: {detector.input_shape}") + + # 检查后端类型 + if isinstance(detector.backend, TFLiteBackend): + print(f"TFLite 格式: {'NHWC' if detector.backend.is_nhwc() else 'NCHW'}") + + # 调试: 查看输入输出详情 + print(f"输入详情: {detector.backend.input_details}") + print(f"输出详情: {detector.backend.output_details}") + + # 执行检测 + print("\n执行检测...") + + # 调试: 手动执行预处理并检查 + input_tensor = detector._preprocess(image) + print(f"预处理后张量形状: {input_tensor.shape}") + print(f"预处理后张量范围: [{input_tensor.min():.4f}, {input_tensor.max():.4f}]") + print(f"预处理后张量类型: {input_tensor.dtype}") + + # 执行推理 + if detector.backend is None: + print("错误: 后端未初始化") + return + output = detector.backend.infer(input_tensor) + print(f"原始输出形状: {output.shape}") + print(f"原始输出范围: [{output.min():.4f}, {output.max():.4f}]") + + # 检查输出格式 + if len(output.shape) == 3: + print( + f"输出维度分析: batch={output.shape[0]}, dim1={output.shape[1]}, dim2={output.shape[2]}" + ) + # 如果是 (1, 8400, 84) 格式,不需要转置 + # 如果是 (1, 84, 8400) 格式,需要转置 + if output.shape[1] == 84 and output.shape[2] == 8400: + print("输出格式: (1, 84, 8400) - YOLO 标准格式") + elif output.shape[1] == 8400 and output.shape[2] == 84: + print("输出格式: (1, 8400, 84) - 需要调整后处理逻辑") + + # 使用标准检测流程 + detections, inference_time = detector.detect(image, threshold=0.25) + + print(f"\n检测结果 (阈值=0.25):") + print(f"推理耗时: {inference_time:.2f} ms") + print(f"检测到 {len(detections)} 个目标:") + + for i, det in enumerate(detections): + print(f" {i+1}. {det['label']}: {det['confidence']:.2%} at {det['box']}") + # 检查置信度是否异常 + if det["confidence"] > 1.0: + print(f" ⚠️ 警告: 置信度超过100%! 原始值: {det['confidence']}") + + # 保存结果图像 + model_name = os.path.splitext(os.path.basename(model_path))[0] + output_path = f"/Users/xugo/Desktop/gowvp/gb28181/analysis/result_{model_name}.jpg" + result_img = draw_detections(image, detections) + cv2.imwrite(output_path, result_img) + print(f"\n结果图像已保存: {output_path}") + + +def main(): + # 默认图像路径 + image_path = "/Users/xugo/Desktop/gowvp/gb28181/out.png" + + # 命令行参数 + if len(sys.argv) > 1: + image_path = sys.argv[1] + + # 模型路径 + tflite_model = "/Users/xugo/Desktop/gowvp/gb28181/configs/owl.tflite" + onnx_model = "/Users/xugo/Desktop/gowvp/gb28181/analysis/owl.onnx" + + if len(sys.argv) > 2: + # 只测试指定模型 + test_model(sys.argv[2], image_path) + else: + # 测试所有可用模型 + if os.path.exists(onnx_model): + test_model(onnx_model, image_path) + + if os.path.exists(tflite_model): + test_model(tflite_model, image_path) + + +if __name__ == "__main__": + main() diff --git a/analysis/main.py b/analysis/main.py index 288c489..98ee87c 100644 --- a/analysis/main.py +++ b/analysis/main.py @@ -25,6 +25,14 @@ from detect import MotionDetector, ObjectDetector from frame_capture import FrameCapture import cv2 +# 模型文件搜索候选路径(按优先级排序) +MODEL_SEARCH_PATHS = [ + ("../configs/owl.tflite", "tflite"), + ("../configs/owl.onnx", "onnx"), + ("./owl.tflite", "tflite"), + ("./owl.onnx", "onnx"), +] + # 导入生成的 proto 代码 # 这些模块必须存在才能启动 gRPC 服务 import analysis_pb2 @@ -480,10 +488,31 @@ def serve(port, model_path): server.stop(0) +def discover_model(model_arg: str) -> str: + """ + 自动发现可用模型文件 + 优先级:../configs/owl.tflite > ../configs/owl.onnx > ./owl.tflite > ./owl.onnx > 命令行参数 + """ + script_dir = os.path.dirname(os.path.abspath(__file__)) + + for rel_path, _ in MODEL_SEARCH_PATHS: + full_path = os.path.normpath(os.path.join(script_dir, rel_path)) + if os.path.exists(full_path): + slog.info(f"发现模型文件: {full_path}") + return full_path + + # 回退到命令行参数指定的模型 + if os.path.isabs(model_arg): + return model_arg + + # 相对路径基于脚本目录解析 + return os.path.normpath(os.path.join(script_dir, model_arg)) + + def main(): parser = argparse.ArgumentParser() parser.add_argument("--port", type=int, default=50051) - parser.add_argument("--model", type=str, default="yolo11n.onnx") + parser.add_argument("--model", type=str, default="owl.onnx") parser.add_argument( "--callback-url", type=str, @@ -503,11 +532,14 @@ def main(): GLOBAL_CONFIG["callback_url"] = args.callback_url GLOBAL_CONFIG["callback_secret"] = args.callback_secret + # 自动发现模型文件 + model_path = discover_model(args.model) + slog.debug( - f"log level: {args.log_level}, model: {args.model}, callback url: {args.callback_url}, callback secret: {args.callback_secret}" + f"log level: {args.log_level}, model: {model_path}, callback url: {args.callback_url}, callback secret: {args.callback_secret}" ) - serve(args.port, args.model) + serve(args.port, model_path) if __name__ == "__main__": diff --git a/analysis/requirements.txt b/analysis/requirements.txt index c66a054..900adb5 100644 --- a/analysis/requirements.txt +++ b/analysis/requirements.txt @@ -1,4 +1,4 @@ -# ONNX Runtime - 替代 PyTorch,大幅减小镜像体积 +# ONNX Runtime - 用于加载 .onnx 模型 onnxruntime>=1.17.0 # gRPC 框架 diff --git a/changelog b/changelog index c66b37d..e69474c 100644 --- a/changelog +++ b/changelog @@ -1,3 +1,11 @@ +# 2026-01-11 + +修复 webrtc 播放地址 +支持 ai 检测,支持 onnx 和 tflite +支持自定义模型,写入 configs/owl.onnx 或 configs/owl.tflite +支持修改国标信息免重启生效 +修复 rtmp port 偶先值为 0 + # 2025-06-14 重构国标播放逻辑, play 接口仅返回播放地址,通过 webhook 来通知拉流 可能会导致首播慢一些 diff --git a/internal/core/sms/driver_zlm.go b/internal/core/sms/driver_zlm.go index 42b0d2d..5d6e00d 100644 --- a/internal/core/sms/driver_zlm.go +++ b/internal/core/sms/driver_zlm.go @@ -20,9 +20,9 @@ func (d *ZLMDriver) GetStreamLiveAddr(ctx context.Context, ms *MediaServer, http var out StreamLiveAddr out.Label = "ZLM" wsPrefix := strings.Replace(strings.Replace(httpPrefix, "https", "wss", 1), "http", "ws", 1) - out.WSFLV = fmt.Sprintf("%s/proxy/sms/%s.live.flv", wsPrefix, stream) - out.HTTPFLV = fmt.Sprintf("%s/proxy/sms/%s.live.flv", httpPrefix, stream) - out.HLS = fmt.Sprintf("%s/proxy/sms/%s/hls.fmp4.m3u8", httpPrefix, stream) + out.WSFLV = fmt.Sprintf("%s/proxy/sms/%s/%s.live.flv", wsPrefix, app, stream) + out.HTTPFLV = fmt.Sprintf("%s/proxy/sms/%s/%s.live.flv", httpPrefix, app, stream) + out.HLS = fmt.Sprintf("%s/proxy/sms/%s/%s/hls.fmp4.m3u8", httpPrefix, app, stream) rtcPrefix := strings.Replace(strings.Replace(httpPrefix, "https", "webrtc", 1), "http", "webrtc", 1) out.WebRTC = fmt.Sprintf("%s/proxy/sms/index/api/webrtc?app=%s&stream=%s&type=play", rtcPrefix, app, stream) out.RTMP = fmt.Sprintf("rtmp://%s:%d/%s", host, ms.Ports.RTMP, stream) @@ -81,9 +81,18 @@ func (d *ZLMDriver) Connect(ctx context.Context, ms *MediaServer) error { func (d *ZLMDriver) Setup(ctx context.Context, ms *MediaServer, webhookURL string) error { engine := d.withConfig(ms) + // 拼接 IP 但是不要空格 + ips := make([]string, 0, 2) + for _, ip := range []string{ms.SDPIP, ms.IP} { + if ip != "" { + ips = append(ips, ip) + } + } + _ = ips // 构造配置请求 req := zlm.SetServerConfigRequest{ - RtcExternIP: new(ms.IP), + RtcExternIP: new(strings.Join(ips, ",")), + GeneralMediaServerID: new(ms.ID), HookEnable: new("1"), HookOnFlowReport: new(""), diff --git a/internal/web/api/ai_webhook.go b/internal/web/api/ai_webhook.go index 59ee3ae..55ef8db 100644 --- a/internal/web/api/ai_webhook.go +++ b/internal/web/api/ai_webhook.go @@ -136,7 +136,8 @@ func (a AIWebhookAPI) onEvents(c *gin.Context, in *AIDetectionInput) (AIWebhookO Score: float32(det.Confidence), Zones: string(zonesJSON), ImagePath: imagePath, - Model: "yolo11n", + // TODO: 模型名称可以根据模型自定义 + Model: "default", } if _, err := a.eventCore.AddEvent(ctx, eventInput); err != nil { @@ -259,7 +260,7 @@ func (a *AIWebhookAPI) StartAIDetection(ctx context.Context, ch *ipc.Channel, rt CameraId: ch.ID, CameraName: ch.Name, RtspUrl: rtspURL, - DetectFps: 5, + DetectFps: 1, Labels: labels, Threshold: 0.75, RoiPoints: roiPoints, diff --git a/internal/web/api/api.go b/internal/web/api/api.go index 3ee0131..d7efa34 100644 --- a/internal/web/api/api.go +++ b/internal/web/api/api.go @@ -62,10 +62,11 @@ func setupRouter(r *gin.Engine, uc *Usecase) { AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"}, AllowHeaders: []string{ "Accept", "Content-Length", "Content-Type", "Range", "Accept-Language", - "Origin", "Authorization", + "Origin", "Authorization", "Referer", "User-Agent", "Accept-Encoding", "Cache-Control", "Pragma", "X-Requested-With", "Sec-Fetch-Mode", "Sec-Fetch-Site", "Sec-Fetch-Dest", + "Sec-Ch-Ua", "Sec-Ch-Ua-Mobile", "Sec-Ch-Ua-Platform", "Dnt", "X-Forwarded-For", "X-Forwarded-Proto", "X-Forwarded-Host", "X-Real-IP", "X-Request-ID", "X-Request-Start", "X-Request-Time", }, @@ -105,6 +106,7 @@ func setupRouter(r *gin.Engine, uc *Usecase) { registerPushAPI(r, uc.MediaAPI, auth) registerGB28181(r, uc.GB28181API, auth) registerProxy(r, uc.ProxyAPI, auth) + uc.ConfigAPI.uc = uc registerConfig(r, uc.ConfigAPI, auth) registerSms(r, uc.SMSAPI, auth) RegisterUser(r, uc.UserAPI, auth) @@ -327,6 +329,7 @@ func (uc *Usecase) proxySMS(c *gin.Context) { fullAddr, _ := url.Parse(addr) c.Request.URL.Path = "" proxy := httputil.NewSingleHostReverseProxy(fullAddr) + proxy.Director = func(req *http.Request) { // 设置请求的URL req.URL.Scheme = "http" diff --git a/internal/web/api/config.go b/internal/web/api/config.go index cf10865..fbeab9b 100755 --- a/internal/web/api/config.go +++ b/internal/web/api/config.go @@ -18,6 +18,7 @@ import ( type ConfigAPI struct { configCore config.Core conf *conf.Bootstrap + uc *Usecase } func NewConfigAPI(db *gorm.DB, conf *conf.Bootstrap) ConfigAPI { @@ -83,5 +84,7 @@ func (a ConfigAPI) editSIP(_ *gin.Context, in *conf.SIP) (gin.H, error) { if err := conf.WriteConfig(a.conf, a.conf.ConfigPath); err != nil { return nil, reason.ErrServer.SetMsg(err.Error()) } + a.uc.SipServer.SetConfig() + return gin.H{"msg": "ok"}, nil } diff --git a/internal/web/api/ipc.go b/internal/web/api/ipc.go index f4da1e8..5daa76b 100755 --- a/internal/web/api/ipc.go +++ b/internal/web/api/ipc.go @@ -281,7 +281,7 @@ func (a IPCAPI) play(c *gin.Context, _ *struct{}) (*playOutput, error) { host = h } - item := a.uc.SMSAPI.smsCore.GetStreamLiveAddr(svr, prefix, host, app, stream) + item := a.uc.SMSAPI.smsCore.GetStreamLiveAddr(svr, prefix, host, app, appStream) out := playOutput{ App: app, Stream: appStream, diff --git a/internal/web/api/push.go b/internal/web/api/push.go index 954b86d..76e509c 100755 --- a/internal/web/api/push.go +++ b/internal/web/api/push.go @@ -43,13 +43,13 @@ func (a PushAPI) findStreamPush(c *gin.Context, in *push.FindStreamPushInput) (* return nil, err } - cacheFn := hook.UseCache(func(s string) (*sms.MediaServer, error) { + cacheFn := func(s string) (*sms.MediaServer, error) { v, err := a.smsCore.GetMediaServer(c.Request.Context(), s) if err != nil { slog.ErrorContext(c.Request.Context(), "GetMediaServer", "err", err) } return v, err - }) + } out := make([]*push.FindStreamPushOutputItem, len(items)) for i, item := range items { @@ -58,8 +58,12 @@ func (a PushAPI) findStreamPush(c *gin.Context, in *push.FindStreamPushInput) (* if mediaID == "" { mediaID = sms.DefaultMediaServerID } - if svr, _, _ := cacheFn(mediaID); svr != nil { - addr := fmt.Sprintf("rtmp://%s:%d/%s/%s", web.GetHost(c.Request), svr.Ports.RTMP, item.App, item.Stream) + if svr, _ := cacheFn(mediaID); svr != nil { + port := svr.Ports.RTMP + if port == 0 { + port = 1935 + } + addr := fmt.Sprintf("rtmp://%s:%d/%s/%s", web.GetHost(c.Request), port, item.App, item.Stream) if !item.IsAuthDisabled { addr += fmt.Sprintf("?sign=%s", hook.MD5(a.conf.Server.RTMPSecret)) } diff --git a/pkg/gbs/server.go b/pkg/gbs/server.go index cbf017e..e06615e 100644 --- a/pkg/gbs/server.go +++ b/pkg/gbs/server.go @@ -49,7 +49,7 @@ func NewServer(cfg *conf.Bootstrap, store ipc.Adapter, sc sms.Core) (*Server, fu iip := ip.InternalIP() uri, _ := sip.ParseSipURI(fmt.Sprintf("sip:%s@%s:%d", cfg.Sip.ID, iip, cfg.Sip.Port)) from := sip.Address{ - DisplayName: sip.String{Str: "gowvp"}, + DisplayName: sip.String{Str: "gowvp/owl"}, URI: &uri, Params: sip.NewParams(), } @@ -87,6 +87,20 @@ func NewServer(cfg *conf.Bootstrap, store ipc.Adapter, sc sms.Core) (*Server, fu return &c, c.Close } +// SetConfig 热更新 SIP 配置,用于配置变更时更新 from 地址而无需重启服务 +func (s *Server) SetConfig() { + cfg := s.gb.cfg + iip := ip.InternalIP() + uri, _ := sip.ParseSipURI(fmt.Sprintf("sip:%s@%s:%d", cfg.ID, iip, cfg.Port)) + from := sip.Address{ + DisplayName: sip.String{Str: "gowvp/owl"}, + URI: &uri, + Params: sip.NewParams(), + } + s.fromAddress = from + s.Server.SetFrom(&from) +} + // startTickerCheck 定时检查离线,通过心跳超时判断设备是否离线 func (s *Server) startTickerCheck() { conc.Timer(context.Background(), 60*time.Second, time.Second, func() { diff --git a/pkg/gbs/sip/server.go b/pkg/gbs/sip/server.go index e3a4e91..3d560e7 100644 --- a/pkg/gbs/sip/server.go +++ b/pkg/gbs/sip/server.go @@ -54,6 +54,11 @@ func NewServer(form *Address) *Server { return srv } +// SetFrom 热更新 SIP 源地址配置,用于配置变更时无需重启服务 +func (s *Server) SetFrom(from *Address) { + *s.from = *from +} + func (s *Server) addRoute(method string, handler ...HandlerFunc) { s.route.Store(strings.ToUpper(method), handler) } diff --git a/pkg/zlm/config.go b/pkg/zlm/config.go index be9965c..0bba479 100644 --- a/pkg/zlm/config.go +++ b/pkg/zlm/config.go @@ -310,42 +310,43 @@ type SetServerConfigRequest struct { RtcStartBitrate *string `json:"rtc.start_bitrate,omitempty"` RtcTCPPort *string `json:"rtc.tcpPort,omitempty"` RtcTimeoutSec *string `json:"rtc.timeoutSec,omitempty"` - RtmpDirectProxy *string `json:"rtmp.directProxy,omitempty"` - RtmpEnhanced *string `json:"rtmp.enhanced,omitempty"` - RtmpHandshakeSecond *string `json:"rtmp.handshakeSecond,omitempty"` - RtmpKeepAliveSecond *string `json:"rtmp.keepAliveSecond,omitempty"` - RtmpPort *string `json:"rtmp.port,omitempty"` - RtmpSslport *string `json:"rtmp.sslport,omitempty"` - RtpAudioMtuSize *string `json:"rtp.audioMtuSize,omitempty"` - RtpH264StapA *string `json:"rtp.h264_stap_a,omitempty"` - RtpLowLatency *string `json:"rtp.lowLatency,omitempty"` - RtpRtpMaxSize *string `json:"rtp.rtpMaxSize,omitempty"` - RtpVideoMtuSize *string `json:"rtp.videoMtuSize,omitempty"` - RtpProxyDumpDir *string `json:"rtp_proxy.dumpDir,omitempty"` - RtpProxyGopCache *string `json:"rtp_proxy.gop_cache,omitempty"` - RtpProxyH264Pt *string `json:"rtp_proxy.h264_pt,omitempty"` - RtpProxyH265Pt *string `json:"rtp_proxy.h265_pt,omitempty"` - RtpProxyOpusPt *string `json:"rtp_proxy.opus_pt,omitempty"` - RtpProxyPort *string `json:"rtp_proxy.port,omitempty"` - RtpProxyPortRange *string `json:"rtp_proxy.port_range,omitempty"` - RtpProxyPsPt *string `json:"rtp_proxy.ps_pt,omitempty"` - RtpProxyRtpG711DurMs *string `json:"rtp_proxy.rtp_g711_dur_ms,omitempty"` - RtpProxyTimeoutSec *string `json:"rtp_proxy.timeoutSec,omitempty"` - RtpProxyUDPRecvSocketBuffer *string `json:"rtp_proxy.udp_recv_socket_buffer,omitempty"` - RtspAuthBasic *string `json:"rtsp.authBasic,omitempty"` - RtspDirectProxy *string `json:"rtsp.directProxy,omitempty"` - RtspHandshakeSecond *string `json:"rtsp.handshakeSecond,omitempty"` - RtspKeepAliveSecond *string `json:"rtsp.keepAliveSecond,omitempty"` - RtspLowLatency *string `json:"rtsp.lowLatency,omitempty"` - RtspPort *string `json:"rtsp.port,omitempty"` - RtspRtpTransportType *string `json:"rtsp.rtpTransportType,omitempty"` - RtspSslport *string `json:"rtsp.sslport,omitempty"` - ShellMaxReqSize *string `json:"shell.maxReqSize,omitempty"` - ShellPort *string `json:"shell.port,omitempty"` - SrtLatencyMul *string `json:"srt.latencyMul,omitempty"` - SrtPktBufSize *string `json:"srt.pktBufSize,omitempty"` - SrtPort *string `json:"srt.port,omitempty"` - SrtTimeoutSec *string `json:"srt.timeoutSec,omitempty"` + // RTCEnableTurn *string `json:"rtc.enableTurn,omitempty"` + RtmpDirectProxy *string `json:"rtmp.directProxy,omitempty"` + RtmpEnhanced *string `json:"rtmp.enhanced,omitempty"` + RtmpHandshakeSecond *string `json:"rtmp.handshakeSecond,omitempty"` + RtmpKeepAliveSecond *string `json:"rtmp.keepAliveSecond,omitempty"` + RtmpPort *string `json:"rtmp.port,omitempty"` + RtmpSslport *string `json:"rtmp.sslport,omitempty"` + RtpAudioMtuSize *string `json:"rtp.audioMtuSize,omitempty"` + RtpH264StapA *string `json:"rtp.h264_stap_a,omitempty"` + RtpLowLatency *string `json:"rtp.lowLatency,omitempty"` + RtpRtpMaxSize *string `json:"rtp.rtpMaxSize,omitempty"` + RtpVideoMtuSize *string `json:"rtp.videoMtuSize,omitempty"` + RtpProxyDumpDir *string `json:"rtp_proxy.dumpDir,omitempty"` + RtpProxyGopCache *string `json:"rtp_proxy.gop_cache,omitempty"` + RtpProxyH264Pt *string `json:"rtp_proxy.h264_pt,omitempty"` + RtpProxyH265Pt *string `json:"rtp_proxy.h265_pt,omitempty"` + RtpProxyOpusPt *string `json:"rtp_proxy.opus_pt,omitempty"` + RtpProxyPort *string `json:"rtp_proxy.port,omitempty"` + RtpProxyPortRange *string `json:"rtp_proxy.port_range,omitempty"` + RtpProxyPsPt *string `json:"rtp_proxy.ps_pt,omitempty"` + RtpProxyRtpG711DurMs *string `json:"rtp_proxy.rtp_g711_dur_ms,omitempty"` + RtpProxyTimeoutSec *string `json:"rtp_proxy.timeoutSec,omitempty"` + RtpProxyUDPRecvSocketBuffer *string `json:"rtp_proxy.udp_recv_socket_buffer,omitempty"` + RtspAuthBasic *string `json:"rtsp.authBasic,omitempty"` + RtspDirectProxy *string `json:"rtsp.directProxy,omitempty"` + RtspHandshakeSecond *string `json:"rtsp.handshakeSecond,omitempty"` + RtspKeepAliveSecond *string `json:"rtsp.keepAliveSecond,omitempty"` + RtspLowLatency *string `json:"rtsp.lowLatency,omitempty"` + RtspPort *string `json:"rtsp.port,omitempty"` + RtspRtpTransportType *string `json:"rtsp.rtpTransportType,omitempty"` + RtspSslport *string `json:"rtsp.sslport,omitempty"` + ShellMaxReqSize *string `json:"shell.maxReqSize,omitempty"` + ShellPort *string `json:"shell.port,omitempty"` + SrtLatencyMul *string `json:"srt.latencyMul,omitempty"` + SrtPktBufSize *string `json:"srt.pktBufSize,omitempty"` + SrtPort *string `json:"srt.port,omitempty"` + SrtTimeoutSec *string `json:"srt.timeoutSec,omitempty"` } func (e *Engine) GetServerConfig() (*GetServerConfigResponse, error) {