Add AI detection management and synchronization features

This commit is contained in:
xugo
2026-01-08 21:25:44 +08:00
parent 338eb89699
commit 079aca7601
10 changed files with 367 additions and 43 deletions
+23 -4
View File
@@ -19,6 +19,7 @@ class LogPipe(threading.Thread):
self.deque: Deque[str] = deque(maxlen=100)
self.fd_read, self.fd_write = os.pipe()
self.pipe_reader = os.fdopen(self.fd_read)
self._closed = False
self.start()
def fileno(self):
@@ -27,16 +28,34 @@ class LogPipe(threading.Thread):
def run(self):
# 使用 iter() 包装 self.pipe_reader.readline 方法和空字符串""作为哨兵,使其不断读取管道内容。
# iter(self.pipe_reader.readline, "") 会不断调用 readline(),直到返回空字符串(代表 EOF),循环终止。
for line in iter(self.pipe_reader.readline, ""):
self.deque.append(line)
self.pipe_reader.close()
try:
for line in iter(self.pipe_reader.readline, ""):
self.deque.append(line)
except (OSError, ValueError):
# 管道已关闭,忽略错误
pass
finally:
try:
if not self._closed:
self.pipe_reader.close()
except (OSError, ValueError):
pass
def dump(self):
while len(self.deque) > 0:
self.logger.error(self.deque.popleft())
def close(self):
os.close(self.fd_read)
# 先关闭写端,让读端线程收到 EOF 并退出
if self._closed:
return
self._closed = True
try:
os.close(self.fd_write)
except OSError:
pass
# 等待读线程结束
self.join(timeout=1)
class FrameCapture:
+1 -1
View File
@@ -23,7 +23,7 @@ func DefaultConfig() Bootstrap {
},
AI: ServerAI{
Disabled: false,
RetainDays: 10,
RetainDays: 7,
},
},
Data: Data{
+12
View File
@@ -142,3 +142,15 @@ func (c *Core) GetZones(ctx context.Context, channelID string) ([]Zone, error) {
}
return out.Ext.Zones, nil
}
// SetAIEnabled 设置通道的 AI 检测开关状态,同时返回更新后的完整通道信息供调用方使用
func (c *Core) SetAIEnabled(ctx context.Context, channelID string, enabled bool) (*Channel, error) {
var out Channel
if err := c.store.Channel().Edit(ctx, &out, func(b *Channel) error {
b.Ext.EnabledAI = enabled
return nil
}, orm.Where("id=?", channelID)); err != nil {
return nil, reason.ErrDB.Withf(`Edit err[%s]`, err.Error())
}
return &out, nil
}
+1
View File
@@ -39,6 +39,7 @@ type DeviceExt struct {
Name string `json:"name"` // 设备名
GBVersion string `json:"gb_version"` // GB版本
Zones []Zone `json:"zones"` // 区域
EnabledAI bool `json:"enabled_ai"` // 是否启用 AI
}
// Scan implements orm.Scaner.
+142
View File
@@ -1,6 +1,7 @@
package api
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
@@ -14,7 +15,9 @@ import (
"github.com/gowvp/owl/internal/conf"
"github.com/gowvp/owl/internal/core/event"
"github.com/gowvp/owl/internal/core/ipc"
"github.com/gowvp/owl/internal/core/sms"
"github.com/gowvp/owl/internal/rpc"
"github.com/gowvp/owl/protos"
"github.com/ixugo/goddd/pkg/conc"
"github.com/ixugo/goddd/pkg/orm"
"github.com/ixugo/goddd/pkg/system"
@@ -159,6 +162,145 @@ func (a AIWebhookAPI) onStopped(c *gin.Context, in *AIStoppedInput) (AIWebhookOu
return newAIWebhookOutputOK(), nil
}
// StartAISyncLoop 启动 AI 任务同步协程,每 5 分钟检测一次数据库中 enabled_ai 状态与内存 aiTasks 的差异并同步
func (a *AIWebhookAPI) StartAISyncLoop(ctx context.Context, smsCore sms.Core) {
go func() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
a.log.Info("AI sync loop stopped")
return
case <-ticker.C:
a.syncAITasks(ctx, smsCore)
}
}
}()
}
// syncAITasks 同步 AI 任务状态,确保数据库中 enabled_ai=true 的通道正在运行检测,enabled_ai=false 的已停止
func (a *AIWebhookAPI) syncAITasks(ctx context.Context, smsCore sms.Core) {
if a.conf.Server.AI.Disabled || a.ai == nil {
return
}
// 查询所有通道
channels, _, err := a.ipcCore.FindChannel(ctx, &ipc.FindChannelInput{
PagerFilter: web.PagerFilter{Page: 1, Size: 999},
})
if err != nil {
a.log.ErrorContext(ctx, "sync ai tasks: find channels failed", "err", err)
return
}
// 构建数据库中 enabled_ai=true 的通道集合
dbEnabledSet := make(map[string]*ipc.Channel)
for _, ch := range channels {
if ch.Ext.EnabledAI {
dbEnabledSet[ch.ID] = ch
}
}
// 收集内存中正在运行的任务
memoryTasks := make(map[string]struct{})
a.aiTasks.Range(func(key string, _ struct{}) bool {
memoryTasks[key] = struct{}{}
return true
})
// 需要启动的任务:数据库中 enabled 但内存中没有
for channelID, ch := range dbEnabledSet {
if _, exists := memoryTasks[channelID]; !exists {
a.log.Info("sync: starting AI task", "channel_id", channelID)
if err := a.startAITask(ctx, smsCore, ch); err != nil {
a.log.ErrorContext(ctx, "sync: start AI task failed", "channel_id", channelID, "err", err)
}
}
}
// 需要停止的任务:内存中有但数据库中 disabled
for channelID := range memoryTasks {
if _, exists := dbEnabledSet[channelID]; !exists {
a.log.Info("sync: stopping AI task", "channel_id", channelID)
if err := a.stopAITask(ctx, channelID); err != nil {
a.log.ErrorContext(ctx, "sync: stop AI task failed", "channel_id", channelID, "err", err)
}
}
}
}
// startAITask 启动单个通道的 AI 检测任务(内部使用,自动构建 RTSP URL)
func (a *AIWebhookAPI) startAITask(ctx context.Context, smsCore sms.Core, ch *ipc.Channel) error {
svr, err := smsCore.GetMediaServer(ctx, sms.DefaultMediaServerID)
if err != nil {
return fmt.Errorf("get media server: %w", err)
}
rtspURL := fmt.Sprintf("rtsp://127.0.0.1:%d/rtp/%s", svr.Ports.RTSP, ch.ID)
_, err = a.StartAIDetection(ctx, ch, rtspURL)
return err
}
// stopAITask 停止单个通道的 AI 检测任务(内部使用)
func (a *AIWebhookAPI) stopAITask(ctx context.Context, channelID string) error {
return a.StopAIDetection(ctx, channelID)
}
// StartAIDetection 启动 AI 检测任务,供外部调用(如 ipc.go 中的 enableAI
func (a *AIWebhookAPI) StartAIDetection(ctx context.Context, ch *ipc.Channel, rtspURL string) (*protos.StartCameraResponse, error) {
if a.ai == nil {
return nil, fmt.Errorf("AI service not initialized")
}
roiPoints, labels := a.extractZoneConfig(ch)
resp, err := a.ai.StartCamera(ctx, &protos.StartCameraRequest{
CameraId: ch.ID,
CameraName: ch.Name,
RtspUrl: rtspURL,
DetectFps: 5,
Labels: labels,
Threshold: 0.75,
RoiPoints: roiPoints,
RetryLimit: 10,
CallbackUrl: fmt.Sprintf("http://127.0.0.1:%d/ai", a.conf.Server.HTTP.Port),
CallbackSecret: "Basic 1234567890",
})
if err != nil {
return nil, err
}
a.aiTasks.Store(ch.ID, struct{}{})
return resp, nil
}
// StopAIDetection 停止 AI 检测任务,供外部调用(如 ipc.go 中的 disableAI
func (a *AIWebhookAPI) StopAIDetection(ctx context.Context, channelID string) error {
if a.ai == nil {
return nil
}
_, err := a.ai.StopCamera(ctx, &protos.StopCameraRequest{
CameraId: channelID,
})
// 无论是否成功都从内存中删除,避免重复尝试停止已不存在的任务
a.aiTasks.Delete(channelID)
return err
}
// extractZoneConfig 从通道配置中提取区域和标签信息
func (a *AIWebhookAPI) extractZoneConfig(ch *ipc.Channel) (roiPoints []float32, labels []string) {
if len(ch.Ext.Zones) > 0 {
zone := ch.Ext.Zones[0]
roiPoints = zone.Coordinates
labels = zone.Labels
}
if len(labels) == 0 {
labels = []string{"person", "car", "cat", "dog"}
}
return
}
// saveEventSnapshot 将 Base64 编码的快照保存到 configs/events/{cid}/ 目录
// 返回相对路径: cid/年月日时分秒_随机6位.jpg
func saveEventSnapshot(cid string, t orm.Time, snapshotB64 string) (string, error) {
+3
View File
@@ -1,6 +1,7 @@
package api
import (
"context"
"expvar"
"fmt"
"log/slog"
@@ -110,6 +111,8 @@ func setupRouter(r *gin.Engine, uc *Usecase) {
// 注册 AI 分析服务回调接口
registerAIWebhookAPI(r, uc.AIWebhookAPI)
// 启动 AI 任务同步协程,每 5 分钟检测一次数据库与内存状态差异
uc.AIWebhookAPI.StartAISyncLoop(context.Background(), uc.SMSAPI.smsCore)
// TODO: 待补充中间件
RegisterEvent(r, uc.EventAPI)
}
+4 -2
View File
@@ -29,8 +29,10 @@ func NewEventCore(db *gorm.DB, conf *conf.Bootstrap) event.Core {
core := event.NewCore(store)
// 启动定时清理协程
days := max(conf.Server.AI.RetainDays, 1)
days := conf.Server.AI.RetainDays
if days <= 0 {
days = 7
}
go core.StartCleanupWorker(days)
return core
+140 -27
View File
@@ -22,7 +22,6 @@ import (
"github.com/gowvp/owl/internal/core/push"
"github.com/gowvp/owl/internal/core/sms"
"github.com/gowvp/owl/pkg/zlm"
"github.com/gowvp/owl/protos"
"github.com/ixugo/goddd/domain/uniqueid"
"github.com/ixugo/goddd/pkg/hook"
"github.com/ixugo/goddd/pkg/orm"
@@ -107,6 +106,8 @@ func registerGB28181(g gin.IRouter, api IPCAPI, handler ...gin.HandlerFunc) {
group.GET("/:id/snapshot", api.getSnapshot) // 获取图像(所有协议)
group.POST("/:id/zones", web.WrapH(api.addZone)) // 添加区域(所有协议)
group.GET("/:id/zones", web.WrapH(api.getZones)) // 获取区域(所有协议)
group.POST("/:id/ai/enable", web.WrapH(api.enableAI)) // 启用 AI 检测
group.POST("/:id/ai/disable", web.WrapH(api.disableAI)) // 禁用 AI 检测
}
}
@@ -310,33 +311,33 @@ func (a IPCAPI) play(c *gin.Context, _ *struct{}) (*playOutput, error) {
}
break
}
if a.uc.Conf.Server.AI.Disabled || a.uc.AIWebhookAPI.ai == nil {
return
}
// if a.uc.Conf.Server.AI.Disabled || a.uc.AIWebhookAPI.ai == nil {
// return
// }
if _, ok := a.uc.AIWebhookAPI.aiTasks.LoadOrStore(channelID, struct{}{}); !ok {
resp, err := a.uc.AIWebhookAPI.ai.StartCamera(context.Background(), &protos.StartCameraRequest{
CameraId: appStream,
CameraName: appStream,
RtspUrl: rtsp,
DetectFps: 5,
Labels: []string{"person", "car", "cat", "dog"},
Threshold: 0.65,
RetryLimit: 10,
CallbackUrl: fmt.Sprintf("http://127.0.0.1:%d/ai", a.uc.Conf.Server.HTTP.Port),
CallbackSecret: "Basic 1234567890",
})
if err != nil {
slog.Error("start camera", "err", err)
return
}
slog.Debug("start camera", "resp", resp,
"msg", resp.GetMessage(),
"source_width", resp.GetSourceWidth(),
"source_height", resp.GetSourceHeight(),
"source_fps", resp.GetSourceFps(),
)
}
// if _, ok := a.uc.AIWebhookAPI.aiTasks.LoadOrStore(channelID, struct{}{}); !ok {
// resp, err := a.uc.AIWebhookAPI.ai.StartCamera(context.Background(), &protos.StartCameraRequest{
// CameraId: appStream,
// CameraName: appStream,
// RtspUrl: rtsp,
// DetectFps: 5,
// Labels: []string{"person", "car", "cat", "dog"},
// Threshold: 0.65,
// RetryLimit: 10,
// CallbackUrl: fmt.Sprintf("http://127.0.0.1:%d/ai", a.uc.Conf.Server.HTTP.Port),
// CallbackSecret: "Basic 1234567890",
// })
// if err != nil {
// slog.Error("start camera", "err", err)
// return
// }
// slog.Debug("start camera", "resp", resp,
// "msg", resp.GetMessage(),
// "source_width", resp.GetSourceWidth(),
// "source_height", resp.GetSourceHeight(),
// "source_fps", resp.GetSourceFps(),
// )
// }
}()
return &out, nil
}
@@ -460,3 +461,115 @@ type IOWriter struct {
func (w IOWriter) Write(b []byte) (int, error) {
return w.fn(b)
}
var (
ErrAIGlobalDisabled = reason.NewError("ErrAIGlobalDisabled", "AI 功能已在全局配置中禁用")
ErrAIServiceNotReady = reason.NewError("ErrAIServiceNotReady", "AI 服务未初始化或连接失败")
ErrChannelNotSupported = reason.NewError("ErrChannelNotSupported", "不支持的通道类型")
)
// enableAI 启用指定通道的 AI 检测功能,需要先确保全局 AI 服务已启用且连接正常
func (a IPCAPI) enableAI(c *gin.Context, _ *struct{}) (gin.H, error) {
channelID := c.Param("id")
ctx := c.Request.Context()
// 检查全局 AI 配置
if a.uc.Conf.Server.AI.Disabled {
return nil, ErrAIGlobalDisabled
}
if a.uc.AIWebhookAPI.ai == nil {
return nil, ErrAIServiceNotReady
}
// 更新数据库中的 AI 启用状态
channel, err := a.ipc.SetAIEnabled(ctx, channelID, true)
if err != nil {
return nil, err
}
// 构建 RTSP 地址
rtspURL, err := a.buildRTSPURL(ctx, channelID)
if err != nil {
return nil, err
}
// 启动 AI 检测任务
resp, err := a.uc.AIWebhookAPI.StartAIDetection(ctx, channel, rtspURL)
if err != nil {
slog.ErrorContext(ctx, "start camera AI", "err", err)
return nil, reason.ErrUsedLogic.SetMsg("启动 AI 检测失败: " + err.Error())
}
return gin.H{
"enabled": true,
"message": resp.GetMessage(),
"source_width": resp.GetSourceWidth(),
"source_height": resp.GetSourceHeight(),
"source_fps": resp.GetSourceFps(),
}, nil
}
// disableAI 禁用指定通道的 AI 检测功能,会同时停止正在运行的检测任务
func (a IPCAPI) disableAI(c *gin.Context, _ *struct{}) (gin.H, error) {
channelID := c.Param("id")
ctx := c.Request.Context()
// 检查全局 AI 配置
if a.uc.Conf.Server.AI.Disabled {
return nil, ErrAIGlobalDisabled
}
if a.uc.AIWebhookAPI.ai == nil {
return nil, ErrAIServiceNotReady
}
// 更新数据库中的 AI 启用状态
if _, err := a.ipc.SetAIEnabled(ctx, channelID, false); err != nil {
return nil, err
}
// 停止 AI 检测任务
if err := a.uc.AIWebhookAPI.StopAIDetection(ctx, channelID); err != nil {
slog.ErrorContext(ctx, "stop camera AI", "err", err)
}
return gin.H{
"enabled": false,
"message": "AI 检测已停止",
}, nil
}
// buildRTSPURL 根据通道类型构建对应的 RTSP 播放地址
func (a IPCAPI) buildRTSPURL(ctx context.Context, channelID string) (string, error) {
svr, err := a.uc.SMSAPI.smsCore.GetMediaServer(ctx, sms.DefaultMediaServerID)
if err != nil {
return "", err
}
var app, stream string
if bz.IsGB28181(channelID) {
app = "rtp"
stream = channelID
} else if bz.IsOnvif(channelID) {
app = "live"
stream = channelID
} else if bz.IsRTSP(channelID) {
proxy, err := a.uc.ProxyAPI.proxyCore.GetStreamProxy(ctx, channelID)
if err != nil {
return "", err
}
app = "live"
stream = proxy.Stream
} else if bz.IsRTMP(channelID) {
pu, err := a.uc.MediaAPI.pushCore.GetStreamPush(ctx, channelID)
if err != nil {
return "", err
}
app = pu.App
stream = pu.Stream
} else {
return "", ErrChannelNotSupported
}
return fmt.Sprintf("rtsp://%s:%d/%s/%s", "127.0.0.1", svr.Ports.RTSP, app, stream), nil
}
+1 -1
View File
@@ -182,7 +182,7 @@ func (w WebHookAPI) onStreamNotFound(c *gin.Context, in *onStreamNotFoundInput)
protocol, ok := w.protocols[r]
if ok {
if err := protocol.OnStreamNotFound(c.Request.Context(), app, stream); err != nil {
slog.ErrorContext(c.Request.Context(), "webhook onStreamNotFound", "err", err)
slog.InfoContext(c.Request.Context(), "webhook onStreamNotFound", "err", err)
}
}
+40 -8
View File
@@ -3,6 +3,7 @@ package gbs
import (
"context"
"fmt"
"log/slog"
"net"
"net/http"
"net/url"
@@ -11,7 +12,6 @@ import (
"time"
"github.com/gowvp/owl/internal/conf"
"github.com/gowvp/owl/internal/core/bz"
"github.com/gowvp/owl/internal/core/ipc"
"github.com/gowvp/owl/internal/core/sms"
"github.com/gowvp/owl/pkg/gbs/m"
@@ -87,7 +87,7 @@ func NewServer(cfg *conf.Bootstrap, store ipc.Adapter, sc sms.Core) (*Server, fu
return &c, c.Close
}
// startTickerCheck 定时检查离线
// startTickerCheck 定时检查离线,通过心跳超时判断设备是否离线
func (s *Server) startTickerCheck() {
conc.Timer(context.Background(), 60*time.Second, time.Second, func() {
now := time.Now()
@@ -95,19 +95,51 @@ func (s *Server) startTickerCheck() {
if !dev.IsOnline {
return true
}
if !bz.IsGB28181(key) {
if len(key) < 18 {
return true
}
timeout := time.Duration(dev.keepaliveTimeout) * time.Duration(dev.keepaliveInterval) * time.Second
if timeout <= 0 {
timeout = 3 * 60 * time.Second
// 计算超时时间:心跳间隔 * 超时次数
// 默认心跳间隔 60s,超时次数 3 次,即 3 分钟无心跳判定离线
interval := dev.keepaliveInterval
if interval == 0 {
interval = 60
}
timeoutCount := dev.keepaliveTimeout
if timeoutCount == 0 {
timeoutCount = 3
}
timeout := time.Duration(interval) * time.Duration(timeoutCount) * time.Second
// 跳过未收到过心跳的设备(LastKeepaliveAt 为零值),这类设备依赖注册超时处理
if dev.LastKeepaliveAt.IsZero() {
// 如果注册时间也超过了超时时间,则判定离线
if !dev.LastRegisterAt.IsZero() && now.Sub(dev.LastRegisterAt) >= timeout {
if err := s.gb.logout(key, func(d *ipc.Device) error {
d.IsOnline = false
return nil
}); err != nil {
slog.Error("logout device failed", "device_id", key, "err", err)
}
}
return true
}
// 心跳超时或连接丢失,判定设备离线
if sub := now.Sub(dev.LastKeepaliveAt); sub >= timeout || dev.conn == nil {
s.gb.logout(key, func(d *ipc.Device) error {
slog.Info("device offline detected",
"device_id", key,
"last_keepalive", dev.LastKeepaliveAt,
"timeout", timeout,
"elapsed", sub,
"conn_nil", dev.conn == nil,
)
if err := s.gb.logout(key, func(d *ipc.Device) error {
d.IsOnline = false
return nil
})
}); err != nil {
slog.Error("logout device failed", "device_id", key, "err", err)
}
}
return true
})