Implement OTA functionality for version checking

This commit is contained in:
xugo
2026-01-09 12:33:42 +08:00
parent 7153059e5b
commit 674e7dfb23
9 changed files with 619 additions and 8 deletions
+2 -1
View File
@@ -42,4 +42,5 @@ cover/
*.jpg
__pycache__/
*.pt
*.onnx
*.onnx
*.remember/
+1 -1
View File
@@ -59,7 +59,7 @@ def setup_logging(level_str: str = "INFO", retention_days: int = 3):
)
# 设置后缀格式,例如 app.log.2023-12-31
file_handler.suffix = "%Y-%m-%d"
file_handler.suffix = "%Y-%m-%d.log"
file_handler.setFormatter(formatter)
root_logger.addHandler(file_handler)
+15 -5
View File
@@ -2,10 +2,18 @@
Debug = false
# rtmp 推流秘钥
RTMPSecret = '123'
# 登录用户名
Username = 'admin'
# 登录密码
Password = 'admin'
# ai 分析服务
[Server.AI]
# 是否禁用 ai 分析服务
Disabled = false
# 保留天数
RetainDays = 0
# 对外提供的服务,建议由 nginx 代理
[Server.HTTP]
# http 端口
@@ -13,7 +21,7 @@
# 请求超时时间
Timeout = '1m0s'
# jwt 秘钥,空串时,每次启动程序将随机赋值
JwtSecret = ''
JwtSecret = '6caOiETMs8SPWNHgEKA1Jhmn9wxpjAj9'
[Server.HTTP.PProf]
# 是否启用 pprof, 建议设置为 true
@@ -22,7 +30,7 @@
AccessIps = ['::1', '127.0.0.1']
[Data]
# 数据库支持 sqlite/postgres/mysql 使用 sqlite 时 dsn 应当填写文件存储路径
# 数据库支持 sqlite/postgres/mysql, 使用 sqlite 时 dsn 应当填写文件存储路径
# postgres://postgres:123456@127.0.0.1:5432/gb28181?sslmode=disable
# mysql://root:123456@127.0.0.1:5432/gb28181?sslmode=disable
[Data.Database]
@@ -61,9 +69,11 @@
HTTPPort = 8080
# 媒体服务器密钥
Secret = 'jvRqCAzEg7AszBi4gm1cfhwXpmnVmJMG'
# 媒体服务器类型 zlm/lalmax
Type = 'zlm'
# 用于流媒体 webhook 回调
WebHookIP = '192.168.10.10'
WebHookIP = '192.168.1.3'
# 媒体服务器 RTP 端口范围
RTPPortRange = '20000-20100'
# 媒体服务器 SDP IP
SDPIP = '192.168.10.10'
SDPIP = '192.168.1.3'
+1 -1
View File
@@ -3,7 +3,7 @@ services:
# 如果拉不到 docker hub 镜像,也可以尝试
# registry.cn-shanghai.aliyuncs.com/ixugo/homenvr:latest
image: gospace/gowvp:latest
restart: unless-stopped
restart: always
# linux 解开下行注释,并将 ports 全部注释
# network_mode: host
ports:
+106
View File
@@ -19,6 +19,7 @@ import (
"github.com/gin-contrib/gzip"
"github.com/gin-gonic/gin"
"github.com/gowvp/owl/internal/core/sms"
"github.com/gowvp/owl/pkg/ota"
"github.com/gowvp/owl/plugin/stat"
"github.com/gowvp/owl/plugin/stat/statapi"
"github.com/ixugo/goddd/domain/version/versionapi"
@@ -94,6 +95,8 @@ func setupRouter(r *gin.Engine, uc *Usecase) {
auth := web.AuthMiddleware(uc.Conf.Server.HTTP.JwtSecret)
r.GET("/health", web.WrapH(uc.getHealth))
r.GET("/app/metrics/api", web.WrapH(uc.getMetricsAPI))
r.GET("/app/version/check", web.WrapH(uc.checkVersion))
r.POST("/app/upgrade", auth, uc.upgradeApp)
versionapi.Register(r, uc.Version, auth)
statapi.Register(r)
@@ -202,6 +205,109 @@ func sortExpvarMap(data *expvar.Map, top int) []KV {
return kvs[:idx]
}
const repoName = "gowvp/owl"
type checkVersionOutput struct {
HasNewVersion bool `json:"has_new_version"`
CurrentVersion string `json:"current_version"`
NewVersion string `json:"new_version"`
Description string `json:"description"`
}
// checkVersion 检查是否有新版本
// 通过 GitHub API 获取最新 release 信息,与当前版本比较
func (uc *Usecase) checkVersion(_ *gin.Context, _ *struct{}) (checkVersionOutput, error) {
currentVersion := uc.Conf.BuildVersion
newVersion, body, err := ota.GetLastVersion(repoName)
if err != nil {
return checkVersionOutput{}, err
}
hasNew := compareVersion(currentVersion, newVersion) < 0
return checkVersionOutput{
HasNewVersion: hasNew,
CurrentVersion: currentVersion,
NewVersion: newVersion,
Description: body,
}, nil
}
// compareVersion 比较两个版本号
// 返回值: -1 表示 v1 < v2, 0 表示相等, 1 表示 v1 > v2
func compareVersion(v1, v2 string) int {
v1 = strings.TrimPrefix(v1, "v")
v2 = strings.TrimPrefix(v2, "v")
parts1 := strings.Split(v1, ".")
parts2 := strings.Split(v2, ".")
maxLen := len(parts1)
if len(parts2) > maxLen {
maxLen = len(parts2)
}
for i := 0; i < maxLen; i++ {
var n1, n2 int
if i < len(parts1) {
fmt.Sscanf(parts1[i], "%d", &n1)
}
if i < len(parts2) {
fmt.Sscanf(parts2[i], "%d", &n2)
}
if n1 < n2 {
return -1
}
if n1 > n2 {
return 1
}
}
return 0
}
// upgradeApp 执行应用升级
// 通过 SSE 返回下载进度,下载完成后由回调决定如何升级
func (uc *Usecase) upgradeApp(c *gin.Context) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, gin.H{"msg": "不支持 SSE"})
return
}
sendEvent := func(event, data string) {
fmt.Fprintf(c.Writer, "event: %s\ndata: %s\n\n", event, data)
flusher.Flush()
}
sendEvent("start", `{"msg":"开始下载升级包"}`)
filename := "linux_amd64"
if runtime.GOARCH == "arm64" {
filename = "linux_arm64"
}
o := ota.NewOTA(repoName, filename)
o.SetProgressCallback(func(current, total int64) {
percent := 0
if total > 0 {
percent = int(current * 100 / total)
}
sendEvent("progress", fmt.Sprintf(`{"current":%d,"total":%d,"percent":%d}`, current, total, percent))
})
if err := o.Download().Error(); err != nil {
sendEvent("error", fmt.Sprintf(`{"msg":"%s"}`, err.Error()))
return
}
sendEvent("complete", `{"msg":"下载完成,请手动重启服务"}`)
}
func (uc *Usecase) proxySMS(c *gin.Context) {
defer func() {
_ = recover()
+306
View File
@@ -0,0 +1,306 @@
package ota
const linuxTarPath = "upgrade.tar.gz"
// var _ Upgrader = &LinuxOTA{}
// type LinuxOTA struct {
// err error
// OnProgress func(current, total int64)
// }
// // Download implements Upgrader.
// func (l *LinuxOTA) Download(link string) Upgrader {
// if l.err != nil {
// return l
// }
// resp, err := http.Get(linuxPackage)
// if err != nil {
// l.err = err
// return l
// }
// defer resp.Body.Close()
// _ = os.RemoveAll(filepath.Join(system.Getwd(), linuxTarPath))
// f, err := os.OpenFile(filepath.Join(system.Getwd(), linuxTarPath), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
// if err != nil {
// l.err = err
// return l
// }
// defer f.Close()
// p := NewProgressReader(resp.ContentLength, resp.Body, l.OnProgress)
// defer p.Close()
// _, err = io.Copy(f, p)
// if err != nil {
// l.err = err
// }
// return l
// }
// // Unzip implements Upgrader.
// func (l *LinuxOTA) Unzip() Upgrader {
// if l.err != nil {
// return l
// }
// // 清理旧的升级目录
// upgradeDir := filepath.Join(system.Getwd(), "upgrade")
// _ = os.RemoveAll(upgradeDir)
// // 打开 tar.gz 文件
// file, err := os.Open(filepath.Join(system.Getwd(), linuxTarPath))
// if err != nil {
// l.err = err
// return l
// }
// defer file.Close()
// // 创建 gzip reader
// gzr, err := gzip.NewReader(file)
// if err != nil {
// l.err = err
// return l
// }
// defer gzr.Close()
// // 创建 tar reader
// tr := tar.NewReader(gzr)
// // 找到顶层目录名称
// var topLevelDir string
// for {
// header, err := tr.Next()
// if err == io.EOF {
// break
// }
// if err != nil {
// l.err = err
// return l
// }
// // 获取第一级目录名
// parts := strings.Split(header.Name, "/")
// if len(parts) > 0 && parts[0] != "" {
// topLevelDir = parts[0]
// break
// }
// }
// // 重新打开文件进行解压
// file.Close()
// file, err = os.Open(filepath.Join(system.Getwd(), linuxTarPath))
// if err != nil {
// l.err = err
// return l
// }
// defer file.Close()
// gzr, err = gzip.NewReader(file)
// if err != nil {
// l.err = err
// return l
// }
// defer gzr.Close()
// tr = tar.NewReader(gzr)
// // 解压所有文件
// for {
// header, err := tr.Next()
// if err == io.EOF {
// break
// }
// if err != nil {
// l.err = err
// return l
// }
// if err := l.extractFile(tr, header, upgradeDir, topLevelDir); err != nil {
// l.err = err
// return l
// }
// }
// return l
// }
// // Backup implements Upgrader.
// func (l *LinuxOTA) Backup() Upgrader {
// if l.err != nil {
// return l
// }
// execName := os.Args[0]
// backupName := execName + ".bak"
// if err := os.RemoveAll(backupName); err != nil {
// l.err = err
// return l
// }
// if err := os.Rename(execName, backupName); err != nil {
// l.err = err
// }
// return l
// }
// // Replace implements Upgrader.
// func (l *LinuxOTA) Replace() Upgrader {
// if l.err != nil {
// return l
// }
// upgradeDir := filepath.Join(system.Getwd(), "upgrade")
// currentDir := system.Getwd()
// // 获取当前可执行文件名
// execName := filepath.Base(os.Args[0])
// // 替换可执行文件
// newExecPath := filepath.Join(upgradeDir, execName)
// currentExecPath := filepath.Join(currentDir, execName)
// if _, err := os.Stat(newExecPath); err == nil {
// if err := l.copyFile(newExecPath, currentExecPath); err != nil {
// l.err = fmt.Errorf("替换可执行文件失败: %w", err)
// return l
// }
// // 设置可执行权限
// if err := os.Chmod(currentExecPath, 0o755); err != nil {
// l.err = fmt.Errorf("设置可执行权限失败: %w", err)
// return l
// }
// }
// // 替换 www 目录
// newWwwPath := filepath.Join(upgradeDir, "www")
// currentWwwPath := filepath.Join(currentDir, "www")
// if _, err := os.Stat(newWwwPath); err == nil {
// // 备份现有 www 目录
// backupWwwPath := filepath.Join(currentDir, "www.bak")
// _ = os.RemoveAll(backupWwwPath)
// if _, err := os.Stat(currentWwwPath); err == nil {
// if err := os.Rename(currentWwwPath, backupWwwPath); err != nil {
// l.err = fmt.Errorf("备份 www 目录失败: %w", err)
// return l
// }
// }
// // 复制新的 www 目录
// if err := l.copyDir(newWwwPath, currentWwwPath); err != nil {
// // 恢复备份
// _ = os.Rename(backupWwwPath, currentWwwPath)
// l.err = fmt.Errorf("替换 www 目录失败: %w", err)
// return l
// }
// // 删除备份
// _ = os.RemoveAll(backupWwwPath)
// }
// // 保留升级目录,下次升级的时候删除
// // _ = os.RemoveAll(upgradeDir)
// // 清理升级临时文件
// _ = os.RemoveAll(filepath.Join(currentDir, linuxTarPath))
// return l
// }
// // Error implements Upgrader.
// func (l *LinuxOTA) Error() error {
// return l.err
// }
// // extractFile 解压单个文件,跳过顶层目录
// func (l *LinuxOTA) extractFile(tr *tar.Reader, header *tar.Header, destDir, topLevelDir string) error {
// // 跳过顶层目录
// relativePath := header.Name
// if topLevelDir != "" && strings.HasPrefix(relativePath, topLevelDir+"/") {
// relativePath = strings.TrimPrefix(relativePath, topLevelDir+"/")
// }
// // 如果是顶层目录本身,跳过
// if relativePath == "" || relativePath == topLevelDir {
// return nil
// }
// target := filepath.Join(destDir, relativePath)
// switch header.Typeflag {
// case tar.TypeDir:
// if err := os.MkdirAll(target, os.FileMode(header.Mode)); err != nil {
// return err
// }
// case tar.TypeReg:
// if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
// return err
// }
// f, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(header.Mode))
// if err != nil {
// return err
// }
// defer f.Close()
// _, err = io.Copy(f, tr)
// if err != nil {
// return err
// }
// }
// return nil
// }
// // copyFile 复制文件
// func (l *LinuxOTA) copyFile(src, dst string) error {
// sourceFile, err := os.Open(src)
// if err != nil {
// return err
// }
// defer sourceFile.Close()
// destFile, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
// if err != nil {
// return err
// }
// defer destFile.Close()
// _, err = io.Copy(destFile, sourceFile)
// return err
// }
// // copyDir 递归复制目录
// func (l *LinuxOTA) copyDir(src, dst string) error {
// srcInfo, err := os.Stat(src)
// if err != nil {
// return err
// }
// if err := os.MkdirAll(dst, srcInfo.Mode()); err != nil {
// return err
// }
// entries, err := os.ReadDir(src)
// if err != nil {
// return err
// }
// for _, entry := range entries {
// srcPath := filepath.Join(src, entry.Name())
// dstPath := filepath.Join(dst, entry.Name())
// if entry.IsDir() {
// if err := l.copyDir(srcPath, dstPath); err != nil {
// return err
// }
// } else {
// if err := l.copyFile(srcPath, dstPath); err != nil {
// return err
// }
// }
// }
// return nil
// }
+122
View File
@@ -0,0 +1,122 @@
package ota
import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"
)
const (
linuxPackage = `/releases/latest/download/`
LastVersionURL = `https://api.github.com/repos/%s/releases/latest`
)
// ReleaseInfo GitHub Release 信息
type ReleaseInfo struct {
TagName string `json:"tag_name"`
Body string `json:"body"`
}
// OTA 提供版本检查和下载功能的结构体
// OTA 只负责下载,不关心后续的解压、备份、替换等操作
type OTA struct {
repoName string
filename string
err error
onProgress func(current, total int64)
}
// NewOTA 创建 OTA 实例
// repoName: GitHub 仓库名,如 "gowvp/owl",也支持 "github.com/gowvp/owl" 格式
// filename: 下载的文件名
func NewOTA(repoName, filename string) *OTA {
return &OTA{
repoName: cleanRepoName(repoName),
filename: filename,
}
}
// SetProgressCallback 设置下载进度回调
func (o *OTA) SetProgressCallback(callback func(current, total int64)) *OTA {
o.onProgress = callback
return o
}
// GetLastVersion 从 GitHub API 获取最新版本信息
// 返回 tag_name, body(release notes), error
func (o *OTA) GetLastVersion() (string, string, error) {
return GetLastVersion(o.repoName)
}
// Download 下载升级包到指定路径
func (o *OTA) Download() *OTA {
if o.err != nil {
return o
}
// link := o.getDownloadLink()
// linuxOTA := &LinuxOTA{OnProgress: o.onProgress}
// linuxOTA.Download(link)
// o.err = linuxOTA.Error()
return o
}
// Error 返回错误
func (o *OTA) Error() error {
return o.err
}
// getDownloadLink 获取下载链接
func (o *OTA) getDownloadLink() string {
repoLink := "https://github.com/" + o.repoName
link, _ := url.JoinPath(repoLink, linuxPackage, o.filename)
return link
}
// cleanRepoName 清理仓库名称,移除前缀
// 支持 "gowvp/owl"、"github.com/gowvp/owl" 等格式
func cleanRepoName(repoName string) string {
repoName = strings.TrimPrefix(repoName, "https://")
repoName = strings.TrimPrefix(repoName, "http://")
repoName = strings.TrimPrefix(repoName, "github.com/")
repoName = strings.TrimPrefix(repoName, "api.github.com/repos/")
return repoName
}
// GetLastVersion 从 GitHub API 获取最新版本信息
// repoName: GitHub 仓库名,如 "gowvp/owl"
// 返回 tag_name, body(release notes), error
func GetLastVersion(repoName string) (string, string, error) {
repoName = cleanRepoName(repoName)
apiURL := fmt.Sprintf(LastVersionURL, repoName)
client := &http.Client{
Timeout: 10 * time.Second,
}
req, err := http.NewRequest(http.MethodGet, apiURL, nil)
if err != nil {
return "", "", fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Accept", "application/vnd.github+json")
resp, err := client.Do(req)
if err != nil {
return "", "", fmt.Errorf("请求失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", "", fmt.Errorf("请求失败,状态码: %d", resp.StatusCode)
}
var release ReleaseInfo
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
return "", "", fmt.Errorf("解析响应失败: %w", err)
}
return release.TagName, release.Body, nil
}
+12
View File
@@ -0,0 +1,12 @@
package ota
import "testing"
func TestGetLastVersion(t *testing.T) {
version, desc, err := GetLastVersion("gowvp/owl")
if err != nil {
t.Fatalf("GetLastVersion() error = %v", err)
}
t.Logf("version = %s", version)
t.Logf("desc = %s", desc)
}
+54
View File
@@ -0,0 +1,54 @@
package ota
import (
"io"
"sync/atomic"
"time"
)
type ProgressReader struct {
Total int64
Current atomic.Int64
io.Reader
OnProgress func(current, total int64)
quit chan struct{}
}
func NewProgressReader(total int64, reader io.Reader, onProgress func(current, total int64)) *ProgressReader {
p := ProgressReader{
Total: total,
Reader: reader,
OnProgress: onProgress,
quit: make(chan struct{}, 1),
}
if onProgress != nil {
go p.Start()
}
return &p
}
func (p *ProgressReader) Close() {
close(p.quit)
}
func (p *ProgressReader) Start() {
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
p.OnProgress(p.Current.Load(), p.Total)
case <-p.quit:
p.OnProgress(p.Current.Load(), p.Total)
return
}
}
}
func (p *ProgressReader) Read(b []byte) (int, error) {
n, err := p.Reader.Read(b)
p.Current.Add(int64(n))
return n, err
}