mirror of
https://github.com/csznet/goForward.git
synced 2026-04-22 16:17:31 +08:00
94 lines
2.3 KiB
Go
94 lines
2.3 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"flag"
|
|
"log"
|
|
"os"
|
|
"os/signal"
|
|
"strings"
|
|
"sync"
|
|
"syscall"
|
|
|
|
"csz.net/goForward/conf"
|
|
"csz.net/goForward/forward"
|
|
"csz.net/goForward/sql"
|
|
"csz.net/goForward/web"
|
|
)
|
|
|
|
func main() {
|
|
go web.Run()
|
|
if conf.TcpTimeout < 5 {
|
|
conf.TcpTimeout = 5
|
|
}
|
|
|
|
forwardList := sql.GetAction()
|
|
if len(forwardList) == 0 {
|
|
// 添加测试数据
|
|
testData := conf.ConnectionStats{
|
|
LocalPort: conf.WebPort,
|
|
RemotePort: conf.WebPort,
|
|
RemoteAddr: "127.0.0.1",
|
|
Protocol: "udp",
|
|
}
|
|
sql.AddForward(testData)
|
|
forwardList = sql.GetForwardList()
|
|
}
|
|
var largeStats forward.LargeConnectionStats
|
|
largeStats.Connections = make([]*forward.ConnectionStats, len(forwardList))
|
|
for i := range forwardList {
|
|
connectionStats := &forward.ConnectionStats{
|
|
ConnectionStats: conf.ConnectionStats{
|
|
Id: forwardList[i].Id,
|
|
Protocol: forwardList[i].Protocol,
|
|
LocalPort: forwardList[i].LocalPort,
|
|
RemotePort: forwardList[i].RemotePort,
|
|
RemoteAddr: sql.NormalizeAddr(forwardList[i].RemoteAddr), // 兼容旧数据中可能存在的方括号
|
|
TotalBytes: forwardList[i].TotalBytes,
|
|
},
|
|
TotalBytesOld: forwardList[i].TotalBytes,
|
|
TotalBytesLock: sync.Mutex{},
|
|
}
|
|
largeStats.Connections[i] = connectionStats
|
|
}
|
|
|
|
// 监听系统退出信号,优雅关闭所有转发
|
|
go func() {
|
|
sigCh := make(chan os.Signal, 1)
|
|
signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT)
|
|
<-sigCh
|
|
log.Println("收到终止信号,正在关闭所有转发...")
|
|
conf.StopFuncs.Range(func(key, value interface{}) bool {
|
|
if cancelFn, ok := value.(context.CancelFunc); ok {
|
|
cancelFn()
|
|
}
|
|
return true
|
|
})
|
|
}()
|
|
|
|
// 设置 WaitGroup 计数为连接数
|
|
conf.Wg.Add(len(largeStats.Connections))
|
|
// 并发执行多个转发
|
|
for _, stats := range largeStats.Connections {
|
|
go func(s *forward.ConnectionStats) {
|
|
forward.Run(s)
|
|
conf.Wg.Done()
|
|
}(stats)
|
|
}
|
|
conf.Wg.Wait()
|
|
sql.Close()
|
|
}
|
|
|
|
func init() {
|
|
flag.StringVar(&conf.WebPort, "port", "8889", "Web Port")
|
|
flag.StringVar(&conf.Db, "db", "goForward.db", "Db Path")
|
|
flag.StringVar(&conf.WebIP, "ip", "0.0.0.0", "Web IP")
|
|
flag.StringVar(&conf.WebPass, "pass", "", "Web Password")
|
|
flag.IntVar(&conf.TcpTimeout, "tt", 60, "Tcp Timeout")
|
|
flag.Parse()
|
|
if !strings.HasSuffix(conf.Db, ".db") {
|
|
conf.Db += ".db"
|
|
}
|
|
sql.Once()
|
|
}
|