mirror of
https://github.com/gowvp/gb28181.git
synced 2026-04-22 23:17:19 +08:00
510 lines
17 KiB
Python
510 lines
17 KiB
Python
import os
|
||
import signal
|
||
|
||
# 解决 macOS 上 OpenMP 库冲突问题,必须在导入 torch/cv2 等库之前设置
|
||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||
|
||
import argparse
|
||
import base64
|
||
from concurrent import futures
|
||
from concurrent.futures import thread
|
||
import logging
|
||
import queue
|
||
import sys
|
||
import threading
|
||
import time
|
||
from typing import Any
|
||
import requests
|
||
|
||
import grpc
|
||
from torch.export.exported_program import PassType
|
||
|
||
# 添加当前目录到 path 以支持直接运行
|
||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||
|
||
import logger
|
||
from detect import MotionDetector, ObjectDetector
|
||
from frame_capture import FrameCapture
|
||
import cv2
|
||
|
||
# 导入生成的 proto 代码
|
||
try:
|
||
import analysis_pb2
|
||
import analysis_pb2_grpc
|
||
except ImportError:
|
||
pass
|
||
|
||
|
||
slog = logging.getLogger("AI")
|
||
|
||
# 全局配置
|
||
GLOBAL_CONFIG = {
|
||
"callback_url": "",
|
||
"callback_secret": "",
|
||
}
|
||
|
||
# 保存父进程 PID,用于检测父进程是否退出
|
||
_PARENT_PID = os.getppid()
|
||
|
||
|
||
def _watch_parent_process():
|
||
"""
|
||
监控父进程是否存活。当 Go 父进程退出后,Python 子进程应该自动退出,
|
||
避免成为孤儿进程持续占用端口和资源。
|
||
"""
|
||
while True:
|
||
time.sleep(3)
|
||
# 检查父进程是否还存在
|
||
# 如果父进程退出,当前进程的 ppid 会变成 1 (init/launchd) 或其他进程
|
||
current_ppid = os.getppid()
|
||
if current_ppid != _PARENT_PID:
|
||
slog.warning(
|
||
f"父进程已退出 (原 PID: {_PARENT_PID}, 当前 PPID: {current_ppid}),Python 进程退出"
|
||
)
|
||
os._exit(0)
|
||
|
||
|
||
class CameraTask:
|
||
def __init__(
|
||
self,
|
||
camera_id: str,
|
||
rtsp_url: str,
|
||
config: dict[str, Any],
|
||
detector: ObjectDetector,
|
||
motion_detector: MotionDetector,
|
||
) -> None:
|
||
self.camera_id = camera_id
|
||
self.rtsp_url = rtsp_url
|
||
self.config = config
|
||
self.detector = detector
|
||
self.motion_detector = motion_detector
|
||
|
||
self.status = "initializing"
|
||
self.frames_processed = 0
|
||
self.retry_count = 0
|
||
self.last_error = ""
|
||
self._stop_event = threading.Event()
|
||
self._thread: threading.Thread | None = None
|
||
|
||
self.frame_queue = queue.Queue(maxsize=1)
|
||
self.capture = FrameCapture(
|
||
rtsp_url, self.frame_queue, config.get("detect_fps", 5)
|
||
)
|
||
|
||
def start(self):
|
||
self.status = "running"
|
||
self.capture.start()
|
||
self._stop_event.clear()
|
||
self._thread = threading.Thread(target=self._analysis_loop, daemon=True)
|
||
self._thread.start()
|
||
slog.info(f"CameraTask started for {self.camera_id}")
|
||
|
||
def stop(self):
|
||
self.status = "stopping"
|
||
self._stop_event.set()
|
||
self.capture.stop()
|
||
if self._thread:
|
||
self._thread.join(timeout=2)
|
||
slog.info(f"CameraTask stopped for {self.camera_id}")
|
||
|
||
def _analysis_loop(self):
|
||
error_streak = 0
|
||
retry_limit = int(self.config.get("retry_limit", 10))
|
||
|
||
while not self._stop_event.is_set():
|
||
try:
|
||
try:
|
||
frame = self.frame_queue.get(timeout=2.0)
|
||
except queue.Empty:
|
||
slog.debug("CameraTask frame queue empty, skipping")
|
||
continue
|
||
|
||
error_streak = 0
|
||
self.frames_processed += 1
|
||
|
||
roi_points = self.config.get("roi_points")
|
||
motion_boxes, has_motion = self.motion_detector.detect(
|
||
frame, self.camera_id, roi_points
|
||
)
|
||
|
||
if not has_motion:
|
||
continue
|
||
|
||
try:
|
||
labels = self.config.get("labels")
|
||
if labels and isinstance(labels, list):
|
||
safe_labels = [str(l) for l in labels]
|
||
else:
|
||
safe_labels = None
|
||
|
||
detections, _ = self.detector.detect(
|
||
frame,
|
||
threshold=self.config.get("threshold", 0.5),
|
||
label_filter=safe_labels,
|
||
# 暂时只支持全图检测,未来优化可以只检测 motion_boxes 区域
|
||
regions=None,
|
||
)
|
||
except Exception as e:
|
||
slog.error(f"CameraTask labels error: {e}")
|
||
continue
|
||
|
||
if not detections:
|
||
continue
|
||
self._send_detection_callback(detections, frame)
|
||
except Exception as e:
|
||
slog.error(f"CameraTask analysis loop error: {e}")
|
||
error_streak += 1
|
||
self.last_error = str(e)
|
||
if error_streak >= retry_limit:
|
||
self.status = "error"
|
||
self._send_stopped_callback("error", self.last_error)
|
||
self.capture.stop()
|
||
break
|
||
# 防止 cpu 在异常里空转
|
||
time.sleep(1)
|
||
|
||
def _send_detection_callback(self, detections, frame):
|
||
timestamp = int(time.time() * 1000)
|
||
draw_frame = frame.copy()
|
||
for det in detections:
|
||
box = det["box"]
|
||
label = f"{det['label']} {det['confidence']:.2f}"
|
||
|
||
# 坐标
|
||
p1 = (box["x_min"], box["y_min"])
|
||
p2 = (box["x_max"], box["y_max"])
|
||
|
||
# 画矩形框 (红色,线宽2)
|
||
cv2.rectangle(draw_frame, p1, p2, (0, 0, 255), 2)
|
||
|
||
# 画文字背景条,防止文字看不清
|
||
t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0]
|
||
p2_text = (p1[0] + t_size[0], p1[1] - t_size[1] - 3)
|
||
cv2.rectangle(draw_frame, p1, p2_text, (0, 0, 255), -1) # -1 表示实心填充
|
||
|
||
# 画文字 (白色)
|
||
cv2.putText(
|
||
draw_frame,
|
||
label,
|
||
(p1[0], p1[1] - 2),
|
||
cv2.FONT_HERSHEY_SIMPLEX,
|
||
0.5,
|
||
(255, 255, 255),
|
||
1,
|
||
)
|
||
success, buffer = cv2.imencode(".jpg", draw_frame)
|
||
snapshot_b64 = ""
|
||
if success:
|
||
snapshot_b64 = base64.b64encode(buffer).decode("utf-8")
|
||
|
||
payload = {
|
||
"camera_id": self.camera_id,
|
||
"timestamp": timestamp,
|
||
"detections": detections,
|
||
"snapshot": snapshot_b64,
|
||
"snapshot_width": frame.shape[1],
|
||
"snapshot_height": frame.shape[0],
|
||
}
|
||
|
||
send_callback(self.config, "/events", payload)
|
||
|
||
def _send_stopped_callback(self, reason, message):
|
||
payload = {
|
||
"camera_id": self.camera_id,
|
||
"timestamp": int(time.time() * 1000),
|
||
"reason": reason,
|
||
"message": message,
|
||
}
|
||
send_callback(self.config, "/stopped", payload)
|
||
|
||
|
||
class HealthServicer(analysis_pb2_grpc.HealthServicer):
|
||
def __init__(self, servicer):
|
||
self._servicer = servicer
|
||
|
||
def Check(self, request, context):
|
||
if not self._servicer.is_ready:
|
||
return analysis_pb2.HealthCheckResponse(
|
||
status=analysis_pb2.HealthCheckResponse.NOT_SERVING
|
||
)
|
||
return analysis_pb2.HealthCheckResponse(
|
||
status=analysis_pb2.HealthCheckResponse.SERVING
|
||
)
|
||
|
||
|
||
class AnalysisServiceServicer(analysis_pb2_grpc.AnalysisServiceServicer):
|
||
def __init__(self, model_path):
|
||
self._camera_tasks: dict[str, CameraTask] = {}
|
||
self._lock = threading.Lock()
|
||
self._is_ready = False
|
||
self._start_time = time.time()
|
||
|
||
self.object_detector = ObjectDetector(model_path)
|
||
self.motion_detector = MotionDetector()
|
||
|
||
def is_ready(self) -> bool:
|
||
return self._is_ready
|
||
|
||
def initialize(self):
|
||
slog.info("AnalysisService initializing...")
|
||
success = self.object_detector.load_model()
|
||
self._is_ready = success
|
||
|
||
if not success:
|
||
slog.error("AnalysisService initialization failed")
|
||
return
|
||
slog.info("AnalysisService initialized")
|
||
threading.Thread(target=send_started_callback).start()
|
||
|
||
def StartCamera(self, request, context):
|
||
if not self._is_ready:
|
||
context.set_details("model loadding")
|
||
context.set_code(grpc.StatusCode.UNAVAILABLE)
|
||
return analysis_pb2.StartCameraResponse(
|
||
success=False, message="model loadding"
|
||
)
|
||
camera_id = request.camera_id
|
||
with self._lock:
|
||
if camera_id in self._camera_tasks:
|
||
slog.info(
|
||
f"Camera {camera_id} already exists, status: {self._camera_tasks[camera_id].status}"
|
||
)
|
||
return analysis_pb2.StartCameraResponse(
|
||
success=True, message="任务已运行"
|
||
)
|
||
cb_url = request.callback_url or GLOBAL_CONFIG["callback_url"]
|
||
cb_secret = request.callback_secret or GLOBAL_CONFIG["callback_secret"]
|
||
if not cb_url:
|
||
context.set_details("callback url is required")
|
||
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
||
return analysis_pb2.StartCameraResponse(
|
||
success=False, message="callback url is required"
|
||
)
|
||
config = {
|
||
"detect_fps": request.detect_fps,
|
||
"labels": list(request.labels),
|
||
"threshold": request.threshold,
|
||
"roi_points": list(request.roi_points),
|
||
"retry_limit": request.retry_limit,
|
||
"callback_url": cb_url,
|
||
"callback_secret": cb_secret,
|
||
}
|
||
|
||
task = CameraTask(
|
||
camera_id,
|
||
rtsp_url=request.rtsp_url,
|
||
config=config,
|
||
detector=self.object_detector,
|
||
motion_detector=self.motion_detector,
|
||
)
|
||
task.start()
|
||
self._camera_tasks[camera_id] = task
|
||
|
||
timeout = 5.0
|
||
start = time.time()
|
||
w, h, fps = 0, 0, 0.0
|
||
while time.time() - start < timeout:
|
||
w, h, fps = task.capture.get_stream_info()
|
||
if w > 0:
|
||
break
|
||
time.sleep(0.5)
|
||
return analysis_pb2.StartCameraResponse(
|
||
success=True,
|
||
message="任务已启动",
|
||
source_width=w,
|
||
source_height=h,
|
||
source_fps=fps,
|
||
)
|
||
|
||
def StopCamera(self, request, context):
|
||
camera_id = request.camera_id
|
||
with self._lock:
|
||
if camera_id not in self._camera_tasks:
|
||
return analysis_pb2.StopCameraResponse(
|
||
success=False, message="Camera not found"
|
||
)
|
||
|
||
task = self._camera_tasks.pop(camera_id)
|
||
task.stop()
|
||
return analysis_pb2.StopCameraResponse(success=True, message="任务已停止")
|
||
|
||
def GetStatus(self, request, context):
|
||
response = analysis_pb2.StatusResponse()
|
||
response.is_ready = self._is_ready
|
||
response.stats.active_streams = len(self._camera_tasks)
|
||
response.stats.uptime_seconds = int(time.time() - self._start_time)
|
||
|
||
with self._lock:
|
||
for cid, task in self._camera_tasks.items():
|
||
cam_status = analysis_pb2.CameraStatus(
|
||
camera_id=cid,
|
||
status=task.status,
|
||
frames_processed=task.frames_processed,
|
||
retry_count=task.retry_count,
|
||
last_error=task.last_error,
|
||
)
|
||
response.cameras.append(cam_status)
|
||
return response
|
||
|
||
|
||
def send_callback(config: dict, path: str, payload: dict):
|
||
"""
|
||
发送回调到指定路径,路径会拼接到 callback_url 后面。
|
||
例如: callback_url=http://127.0.0.1:15123, path=/events
|
||
最终请求: POST http://127.0.0.1:15123/events
|
||
"""
|
||
url = config.get("callback_url", "")
|
||
secret = config.get("callback_secret", "")
|
||
if not url:
|
||
return
|
||
|
||
full_url = url.rstrip("/") + path
|
||
headers = {"Content-Type": "application/json"}
|
||
if secret:
|
||
headers["Authorization"] = secret
|
||
|
||
try:
|
||
threading.Thread(
|
||
target=requests.post,
|
||
args=(full_url,),
|
||
kwargs={
|
||
"json": payload,
|
||
"headers": headers,
|
||
"timeout": 5.0,
|
||
},
|
||
).start()
|
||
except Exception as e:
|
||
slog.error(f"Failed to send callback to {path}: {e}")
|
||
|
||
|
||
def send_started_callback():
|
||
"""
|
||
向 Go 服务发送启动通知,用于确认 Python 进程与 Go 服务的连接是否正常。
|
||
如果 Go 服务返回 404,说明回调接口不存在,Python 进程应该退出,避免成为孤儿进程。
|
||
"""
|
||
url = GLOBAL_CONFIG["callback_url"]
|
||
secret = GLOBAL_CONFIG["callback_secret"]
|
||
if not url:
|
||
return
|
||
|
||
full_url = url.rstrip("/") + "/started"
|
||
headers = {"Content-Type": "application/json"}
|
||
if secret:
|
||
headers["Authorization"] = secret
|
||
|
||
payload = {
|
||
"timestamp": int(time.time() * 1000),
|
||
"message": "AI Analysis Service Started",
|
||
}
|
||
|
||
max_retries = 3
|
||
retry_interval = 2
|
||
|
||
for attempt in range(1, max_retries + 1):
|
||
slog.info(f"Sending started callback (attempt {attempt}/{max_retries})...")
|
||
try:
|
||
resp = requests.post(full_url, json=payload, headers=headers, timeout=5)
|
||
if resp.status_code == 404 and attempt == max_retries - 1:
|
||
slog.error(f"回调接口返回 404,Go 服务可能已停止,退出 Python 进程")
|
||
os._exit(1)
|
||
if resp.ok:
|
||
slog.info("启动通知发送成功")
|
||
return
|
||
slog.warning(f"启动通知返回非成功状态: {resp.status_code}")
|
||
except requests.exceptions.ConnectionError as e:
|
||
slog.warning(f"发送启动通知失败 (连接错误): {e}")
|
||
except Exception as e:
|
||
slog.error(f"发送启动通知失败: {e}")
|
||
|
||
if attempt < max_retries:
|
||
time.sleep(retry_interval)
|
||
|
||
slog.error(f"启动通知发送失败,已重试 {max_retries} 次")
|
||
|
||
|
||
def send_keepalive_callback(stats: dict):
|
||
"""
|
||
发送心跳回调,用于定期向 Go 服务报告 AI 服务状态。
|
||
"""
|
||
url = GLOBAL_CONFIG["callback_url"]
|
||
secret = GLOBAL_CONFIG["callback_secret"]
|
||
if not url:
|
||
return
|
||
|
||
full_url = url.rstrip("/") + "/keepalive"
|
||
headers = {"Content-Type": "application/json"}
|
||
if secret:
|
||
headers["Authorization"] = secret
|
||
|
||
payload = {
|
||
"timestamp": int(time.time() * 1000),
|
||
"stats": stats,
|
||
"message": "Service running normally",
|
||
}
|
||
|
||
try:
|
||
requests.post(full_url, json=payload, headers=headers, timeout=5)
|
||
except Exception as e:
|
||
slog.debug(f"Failed to send keepalive callback: {e}")
|
||
|
||
|
||
def serve(port, model_path):
|
||
if "analysis_pb2_grpc" not in sys.modules:
|
||
slog.error("Proto 代码未加载,退出。")
|
||
return
|
||
|
||
# 启动父进程监控线程,确保 Go 退出时 Python 也退出
|
||
threading.Thread(target=_watch_parent_process, daemon=True).start()
|
||
|
||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=20))
|
||
servicer = AnalysisServiceServicer(model_path)
|
||
analysis_pb2_grpc.add_AnalysisServiceServicer_to_server(servicer, server)
|
||
|
||
health_servicer = HealthServicer(servicer)
|
||
analysis_pb2_grpc.add_HealthServicer_to_server(health_servicer, server)
|
||
|
||
server.add_insecure_port(f"[::]:{port}")
|
||
server.start()
|
||
slog.info(f"AnalysisService started: 0.0.0.0:{port}")
|
||
|
||
threading.Thread(target=servicer.initialize).start()
|
||
|
||
try:
|
||
server.wait_for_termination()
|
||
except KeyboardInterrupt:
|
||
server.stop(0)
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument("--port", type=int, default=50051)
|
||
parser.add_argument("--model", type=str, default="yolo11n.pt")
|
||
parser.add_argument(
|
||
"--callback-url",
|
||
type=str,
|
||
default="http://127.0.0.1:15123",
|
||
help="回调基础URL,各回调路由会自动拼接",
|
||
)
|
||
parser.add_argument("--callback-secret", type=str, default="", help="回调秘钥")
|
||
parser.add_argument(
|
||
"--log-level",
|
||
type=str,
|
||
default="INFO",
|
||
help="日志级别 (DEBUG/INFO/ERROR)",
|
||
)
|
||
args = parser.parse_args()
|
||
logger.setup_logging(level_str=args.log_level)
|
||
|
||
GLOBAL_CONFIG["callback_url"] = args.callback_url
|
||
GLOBAL_CONFIG["callback_secret"] = args.callback_secret
|
||
|
||
slog.debug(
|
||
f"log level: {args.log_level}, model: {args.model}, callback url: {args.callback_url}, callback secret: {args.callback_secret}"
|
||
)
|
||
|
||
serve(args.port, args.model)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|