mirror of
https://github.com/csznet/goForward.git
synced 2026-04-22 16:17:31 +08:00
重构转发管理,优化停止函数存储与信号处理,改进日志输出
This commit is contained in:
+2
-2
@@ -27,8 +27,8 @@ type IpBan struct {
|
||||
// 全局转发协程等待组
|
||||
var Wg sync.WaitGroup
|
||||
|
||||
// 全局协程通道 未初始化默认为nil
|
||||
var Ch chan string
|
||||
// StopFuncs 存储每个转发的停止函数,key 为 localPort+protocol
|
||||
var StopFuncs sync.Map
|
||||
|
||||
// Web管理面板端口
|
||||
var WebPort string
|
||||
|
||||
+160
-146
@@ -3,7 +3,6 @@ package forward
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
@@ -15,12 +14,19 @@ import (
|
||||
"csz.net/goForward/sql"
|
||||
)
|
||||
|
||||
const (
|
||||
bufSize = 4096
|
||||
statsInterval = 10 * time.Second
|
||||
gbSize = uint64(1 << 30) // 1 GiB
|
||||
tcpTimeStep = 5 // TcpTime 每个统计周期增加的值
|
||||
)
|
||||
|
||||
type ConnectionStats struct {
|
||||
conf.ConnectionStats
|
||||
TotalBytesOld uint64 `gorm:"-"`
|
||||
TotalBytesLock sync.Mutex `gorm:"-"`
|
||||
TCPConnections []net.Conn `gorm:"-"` // 用于存储 TCP 连接
|
||||
TcpTime int `gorm:"-"` // TCP无传输时间
|
||||
TotalBytesLock sync.Mutex `gorm:"-"` // 保护 TotalBytes 相关字段及 TCPConnections
|
||||
TCPConnections []net.Conn `gorm:"-"`
|
||||
TcpTime int `gorm:"-"`
|
||||
}
|
||||
|
||||
// 保存多个连接信息
|
||||
@@ -31,144 +37,144 @@ type LargeConnectionStats struct {
|
||||
// 复用缓冲区
|
||||
var bufPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, 4096)
|
||||
return make([]byte, bufSize)
|
||||
},
|
||||
}
|
||||
|
||||
// 开启转发,负责分发具体转发
|
||||
func Run(stats *ConnectionStats) {
|
||||
defer releaseResources(stats) // 在函数返回时释放资源
|
||||
var ctx, cancel = context.WithCancel(context.Background())
|
||||
var innerWg sync.WaitGroup
|
||||
defer releaseResources(stats)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// 注册停止函数,外部通过 conf.StopFuncs 触发停止
|
||||
key := stats.LocalPort + stats.Protocol
|
||||
conf.StopFuncs.Store(key, cancel)
|
||||
defer conf.StopFuncs.Delete(key)
|
||||
|
||||
var innerWg sync.WaitGroup
|
||||
innerWg.Add(1)
|
||||
go func() {
|
||||
defer innerWg.Done()
|
||||
stats.printStats(ctx)
|
||||
innerWg.Done()
|
||||
}()
|
||||
fmt.Printf("【%s】监听端口 %s 转发至 %s:%s\n", stats.Protocol, stats.LocalPort, stats.RemoteAddr, stats.RemotePort)
|
||||
|
||||
log.Printf("【%s】监听端口 %s 转发至 %s:%s", stats.Protocol, stats.LocalPort, stats.RemoteAddr, stats.RemotePort)
|
||||
|
||||
if stats.Protocol == "udp" {
|
||||
// UDP转发
|
||||
localAddr, err := net.ResolveUDPAddr("udp", ":"+stats.LocalPort)
|
||||
if err != nil {
|
||||
log.Fatalln("解析本地地址时发生错误:", err)
|
||||
log.Println("解析本地地址时发生错误:", err)
|
||||
return
|
||||
}
|
||||
remoteAddr, err := net.ResolveUDPAddr("udp", stats.RemoteAddr+":"+stats.RemotePort)
|
||||
remoteAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(stats.RemoteAddr, stats.RemotePort))
|
||||
if err != nil {
|
||||
log.Fatalln("解析远程地址时发生错误:", err)
|
||||
log.Println("解析远程地址时发生错误:", err)
|
||||
return
|
||||
}
|
||||
conn, err := net.ListenUDP("udp", localAddr)
|
||||
if err != nil {
|
||||
log.Fatalln("监听时发生错误:", err)
|
||||
log.Println("监听时发生错误:", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// context 取消时关闭 UDP 连接
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case stopPort := <-conf.Ch:
|
||||
if stopPort == stats.LocalPort+stats.Protocol {
|
||||
fmt.Printf("【%s】停止监听端口 %s\n", stats.Protocol, stats.LocalPort)
|
||||
conn.Close()
|
||||
cancel()
|
||||
return
|
||||
} else {
|
||||
conf.Ch <- stopPort
|
||||
time.Sleep(3 * time.Second)
|
||||
}
|
||||
default:
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
}
|
||||
<-ctx.Done()
|
||||
conn.Close()
|
||||
}()
|
||||
|
||||
innerWg.Add(1)
|
||||
go func() {
|
||||
defer innerWg.Done()
|
||||
stats.handleUDPConnection(conn, remoteAddr, ctx)
|
||||
innerWg.Done()
|
||||
}()
|
||||
} else {
|
||||
// TCP转发
|
||||
listener, err := net.Listen("tcp", ":"+stats.LocalPort)
|
||||
if err != nil {
|
||||
log.Fatalln("监听时发生错误:", err)
|
||||
log.Println("监听时发生错误:", err)
|
||||
return
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
// context 取消时关闭监听器及所有活跃连接
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case stopPort := <-conf.Ch:
|
||||
fmt.Println("通道信息:" + stopPort)
|
||||
fmt.Println("当前端口:" + stats.LocalPort)
|
||||
if stopPort == stats.LocalPort+stats.Protocol {
|
||||
fmt.Printf("【%s】停止监听端口 %s\n", stats.Protocol, stats.LocalPort)
|
||||
listener.Close()
|
||||
cancel()
|
||||
// 遍历并关闭所有 TCP 连接
|
||||
for _, conn := range stats.TCPConnections {
|
||||
conn.Close()
|
||||
}
|
||||
return
|
||||
} else {
|
||||
conf.Ch <- stopPort
|
||||
time.Sleep(3 * time.Second)
|
||||
}
|
||||
default:
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
<-ctx.Done()
|
||||
listener.Close()
|
||||
// 加锁复制连接列表,释放锁后再关闭,避免长时间持锁
|
||||
stats.TotalBytesLock.Lock()
|
||||
conns := make([]net.Conn, len(stats.TCPConnections))
|
||||
copy(conns, stats.TCPConnections)
|
||||
stats.TotalBytesLock.Unlock()
|
||||
for _, c := range conns {
|
||||
c.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
clientConn, err := listener.Accept()
|
||||
if err != nil {
|
||||
log.Println("【"+stats.LocalPort+"】接受连接时发生错误:", err)
|
||||
cancel()
|
||||
log.Printf("【%s】停止接受新连接", stats.LocalPort)
|
||||
break
|
||||
}
|
||||
innerWg.Add(1)
|
||||
go func() {
|
||||
stats.handleTCPConnection(clientConn, ctx, cancel)
|
||||
innerWg.Done()
|
||||
}()
|
||||
go func(conn net.Conn) {
|
||||
defer innerWg.Done()
|
||||
stats.handleTCPConnection(conn)
|
||||
}(clientConn)
|
||||
}
|
||||
}
|
||||
innerWg.Wait()
|
||||
}
|
||||
|
||||
// TCP转发
|
||||
func (cs *ConnectionStats) handleTCPConnection(clientConn net.Conn, ctx context.Context, cancel context.CancelFunc) {
|
||||
func (cs *ConnectionStats) handleTCPConnection(clientConn net.Conn) {
|
||||
defer clientConn.Close()
|
||||
remoteConn, err := net.Dial("tcp", cs.RemoteAddr+":"+cs.RemotePort)
|
||||
remoteConn, err := net.Dial("tcp", net.JoinHostPort(cs.RemoteAddr, cs.RemotePort))
|
||||
if err != nil {
|
||||
log.Println("【"+cs.LocalPort+"】连接远程地址时发生错误:", err)
|
||||
log.Printf("【%s】连接远程地址时发生错误: %v", cs.LocalPort, err)
|
||||
return
|
||||
}
|
||||
defer remoteConn.Close()
|
||||
cs.TCPConnections = append(cs.TCPConnections, clientConn, remoteConn) // 添加连接到列表
|
||||
|
||||
// 加锁后添加连接,避免竞态
|
||||
cs.TotalBytesLock.Lock()
|
||||
cs.TCPConnections = append(cs.TCPConnections, clientConn, remoteConn)
|
||||
cs.TotalBytesLock.Unlock()
|
||||
|
||||
var copyWG sync.WaitGroup
|
||||
copyWG.Add(2)
|
||||
go func() {
|
||||
defer copyWG.Done()
|
||||
if err := cs.copyBytes(clientConn, remoteConn); err != nil {
|
||||
log.Println("复制字节时发生错误:", err)
|
||||
cancel() // Assuming `cancel` is the cancel function from the context
|
||||
}
|
||||
cs.copyBytes(clientConn, remoteConn)
|
||||
}()
|
||||
go func() {
|
||||
defer copyWG.Done()
|
||||
if err := cs.copyBytes(remoteConn, clientConn); err != nil {
|
||||
log.Println("复制字节时发生错误:", err)
|
||||
cancel() // Assuming `cancel` is the cancel function from the context
|
||||
}
|
||||
cs.copyBytes(remoteConn, clientConn)
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// 如果上级 context 被取消,停止接收新连接
|
||||
return
|
||||
default:
|
||||
copyWG.Wait()
|
||||
return
|
||||
copyWG.Wait()
|
||||
|
||||
// 连接结束后从跟踪列表中移除
|
||||
cs.TotalBytesLock.Lock()
|
||||
cs.removeConns(clientConn, remoteConn)
|
||||
cs.TotalBytesLock.Unlock()
|
||||
}
|
||||
|
||||
// removeConns 从 TCPConnections 中移除指定连接。调用时须持有 TotalBytesLock。
|
||||
func (cs *ConnectionStats) removeConns(conns ...net.Conn) {
|
||||
set := make(map[net.Conn]struct{}, len(conns))
|
||||
for _, c := range conns {
|
||||
set[c] = struct{}{}
|
||||
}
|
||||
result := cs.TCPConnections[:0]
|
||||
for _, c := range cs.TCPConnections {
|
||||
if _, found := set[c]; !found {
|
||||
result = append(result, c)
|
||||
}
|
||||
}
|
||||
cs.TCPConnections = result
|
||||
}
|
||||
|
||||
// UDP转发
|
||||
@@ -178,41 +184,46 @@ func (cs *ConnectionStats) handleUDPConnection(localConn *net.UDPConn, remoteAdd
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
buf := bufPool.Get().([]byte)
|
||||
n, _, err := localConn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
log.Println("【"+cs.LocalPort+"】从源读取时发生错误:", err)
|
||||
buf, ok := bufPool.Get().([]byte)
|
||||
if !ok {
|
||||
log.Println("缓冲区类型断言失败")
|
||||
return
|
||||
}
|
||||
n, _, err := localConn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
bufPool.Put(buf)
|
||||
log.Printf("【%s】从源读取时发生错误: %v", cs.LocalPort, err)
|
||||
return
|
||||
}
|
||||
fmt.Printf("收到长度为 %d 的UDP数据包\n", n)
|
||||
cs.TotalBytesLock.Lock()
|
||||
cs.TotalBytes += uint64(n)
|
||||
cs.TotalBytesLock.Unlock()
|
||||
|
||||
// 处理消息的边界和错误情况
|
||||
go func() {
|
||||
cs.forwardUDPMessage(localConn, remoteAddr, buf[:n])
|
||||
bufPool.Put(buf)
|
||||
}()
|
||||
// WriteToUDP 写内核缓冲区后立即返回,直接调用无需额外 goroutine,
|
||||
// 避免高流量时频繁 goroutine 创建和 GC 带来的 CPU 压力
|
||||
data := make([]byte, n+2)
|
||||
binary.BigEndian.PutUint16(data[:2], uint16(n))
|
||||
copy(data[2:], buf[:n])
|
||||
bufPool.Put(buf)
|
||||
cs.forwardUDPMessage(localConn, remoteAddr, data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cs *ConnectionStats) forwardUDPMessage(localConn *net.UDPConn, remoteAddr *net.UDPAddr, message []byte) {
|
||||
// 在消息前面添加消息长度信息
|
||||
length := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(length, uint16(len(message)))
|
||||
// 组合消息长度和实际消息
|
||||
data := append(length, message...)
|
||||
_, err := localConn.WriteToUDP(data, remoteAddr)
|
||||
if err != nil {
|
||||
log.Println("【"+cs.LocalPort+"】写入目标时发生错误:", err)
|
||||
func (cs *ConnectionStats) forwardUDPMessage(localConn *net.UDPConn, remoteAddr *net.UDPAddr, data []byte) {
|
||||
if _, err := localConn.WriteToUDP(data, remoteAddr); err != nil {
|
||||
log.Printf("【%s】写入目标时发生错误: %v", cs.LocalPort, err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (cs *ConnectionStats) copyBytes(dst, src net.Conn) error {
|
||||
buf := bufPool.Get().([]byte)
|
||||
func (cs *ConnectionStats) copyBytes(dst, src net.Conn) {
|
||||
buf, ok := bufPool.Get().([]byte)
|
||||
if !ok {
|
||||
log.Println("缓冲区类型断言失败")
|
||||
dst.Close()
|
||||
src.Close()
|
||||
return
|
||||
}
|
||||
defer bufPool.Put(buf)
|
||||
for {
|
||||
n, err := src.Read(buf)
|
||||
@@ -220,88 +231,91 @@ func (cs *ConnectionStats) copyBytes(dst, src net.Conn) error {
|
||||
cs.TotalBytesLock.Lock()
|
||||
cs.TotalBytes += uint64(n)
|
||||
cs.TotalBytesLock.Unlock()
|
||||
_, err := dst.Write(buf[:n])
|
||||
if err != nil {
|
||||
log.Println("【"+cs.LocalPort+"】写入目标时发生错误:", err)
|
||||
return err
|
||||
if _, werr := dst.Write(buf[:n]); werr != nil {
|
||||
log.Printf("【%s】写入目标时发生错误: %v", cs.LocalPort, werr)
|
||||
break
|
||||
}
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
log.Println("【"+cs.LocalPort+"】从源读取时发生错误:", err)
|
||||
if err != io.EOF {
|
||||
log.Printf("【%s】从源读取时发生错误: %v", cs.LocalPort, err)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
// 关闭连接
|
||||
// 关闭双向连接,通知对端 goroutine 退出
|
||||
dst.Close()
|
||||
src.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// 定时打印和处理流量变化
|
||||
func (cs *ConnectionStats) printStats(ctx context.Context) {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop() // 在函数结束时停止定时器
|
||||
ticker := time.NewTicker(statsInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
// 在锁内只做内存状态更新,捕获需要的值后立即释放
|
||||
cs.TotalBytesLock.Lock()
|
||||
if cs.TotalBytes > cs.TotalBytesOld {
|
||||
hasTraffic := cs.TotalBytes > cs.TotalBytesOld
|
||||
var totalStr string
|
||||
var bytesToSave, gbToSave uint64
|
||||
var needUpdateGb bool
|
||||
var connCount int
|
||||
|
||||
if hasTraffic {
|
||||
if cs.Protocol == "tcp" {
|
||||
cs.TcpTime = 0
|
||||
}
|
||||
var total string
|
||||
if cs.TotalBytes > 0 && float64(cs.TotalBytes)/(1024*1024) < 0.5 {
|
||||
total = strconv.FormatFloat(float64(cs.TotalBytes)/(1024), 'f', 2, 64) + "KB"
|
||||
if float64(cs.TotalBytes)/(1024*1024) < 0.5 {
|
||||
totalStr = strconv.FormatFloat(float64(cs.TotalBytes)/1024, 'f', 2, 64) + "KB"
|
||||
} else {
|
||||
total = strconv.FormatFloat(float64(cs.TotalBytes)/(1024*1024), 'f', 2, 64) + "MB"
|
||||
totalStr = strconv.FormatFloat(float64(cs.TotalBytes)/(1024*1024), 'f', 2, 64) + "MB"
|
||||
}
|
||||
fmt.Printf("【%s】端口 %s 统计流量: %s\n", cs.Protocol, cs.LocalPort, total)
|
||||
//统计更换单位
|
||||
var gb uint64 = 1073741824
|
||||
if cs.TotalBytes >= gb {
|
||||
cs.TotalGigabyte = cs.TotalGigabyte + 1
|
||||
sql.UpdateForwardGb(cs.Id, cs.TotalGigabyte)
|
||||
cs.TotalBytes = cs.TotalBytes - gb
|
||||
if cs.TotalBytes >= gbSize {
|
||||
cs.TotalGigabyte++
|
||||
cs.TotalBytes -= gbSize
|
||||
needUpdateGb = true
|
||||
gbToSave = cs.TotalGigabyte
|
||||
}
|
||||
cs.TotalBytesOld = cs.TotalBytes
|
||||
sql.UpdateForwardBytes(cs.Id, cs.TotalBytes)
|
||||
fmt.Printf("【%s】端口 %s 当前连接数: %d\n", cs.Protocol, cs.LocalPort, len(cs.TCPConnections))
|
||||
} else {
|
||||
if cs.Protocol == "tcp" {
|
||||
// fmt.Printf("【%s】端口 %s 当前超时秒: %d\n", cs.Protocol, cs.LocalPort, cs.TcpTime)
|
||||
if cs.TcpTime >= conf.TcpTimeout {
|
||||
// fmt.Printf("【%s】端口 %s 超时关闭\n", cs.Protocol, cs.LocalPort)
|
||||
for i := len(cs.TCPConnections) - 1; i >= 0; i-- {
|
||||
conn := cs.TCPConnections[i]
|
||||
conn.Close()
|
||||
// 从连接列表中移除关闭的连接
|
||||
cs.TCPConnections = append(cs.TCPConnections[:i], cs.TCPConnections[i+1:]...)
|
||||
}
|
||||
} else {
|
||||
cs.TcpTime = cs.TcpTime + 5
|
||||
bytesToSave = cs.TotalBytes
|
||||
connCount = len(cs.TCPConnections)
|
||||
} else if cs.Protocol == "tcp" {
|
||||
if cs.TcpTime >= conf.TcpTimeout {
|
||||
for i := len(cs.TCPConnections) - 1; i >= 0; i-- {
|
||||
cs.TCPConnections[i].Close()
|
||||
}
|
||||
cs.TCPConnections = nil
|
||||
} else {
|
||||
cs.TcpTime += tcpTimeStep
|
||||
}
|
||||
}
|
||||
cs.TotalBytesLock.Unlock()
|
||||
//当协程退出时执行
|
||||
|
||||
// 锁外执行日志输出和 SQL 写入,避免阻塞 copyBytes 的流量统计
|
||||
if hasTraffic {
|
||||
log.Printf("【%s】端口 %s 统计流量: %s", cs.Protocol, cs.LocalPort, totalStr)
|
||||
if needUpdateGb {
|
||||
sql.UpdateForwardGb(cs.Id, gbToSave)
|
||||
}
|
||||
sql.UpdateForwardBytes(cs.Id, bytesToSave)
|
||||
log.Printf("【%s】端口 %s 当前连接数: %d", cs.Protocol, cs.LocalPort, connCount)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 关闭 TCP 连接并从切片中移除
|
||||
// 关闭所有 TCP 连接并清空切片
|
||||
func closeTCPConnections(stats *ConnectionStats) {
|
||||
stats.TotalBytesLock.Lock()
|
||||
defer stats.TotalBytesLock.Unlock()
|
||||
for i, conn := range stats.TCPConnections {
|
||||
for _, conn := range stats.TCPConnections {
|
||||
conn.Close()
|
||||
stats.TCPConnections[i] = nil
|
||||
}
|
||||
stats.TCPConnections = nil // 清空切片
|
||||
stats.TCPConnections = nil
|
||||
}
|
||||
|
||||
// 释放资源
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"csz.net/goForward/conf"
|
||||
"csz.net/goForward/forward"
|
||||
@@ -16,11 +21,10 @@ func main() {
|
||||
if conf.TcpTimeout < 5 {
|
||||
conf.TcpTimeout = 5
|
||||
}
|
||||
// 初始化通道
|
||||
conf.Ch = make(chan string)
|
||||
|
||||
forwardList := sql.GetAction()
|
||||
if len(forwardList) == 0 {
|
||||
//添加测试数据
|
||||
// 添加测试数据
|
||||
testData := conf.ConnectionStats{
|
||||
LocalPort: conf.WebPort,
|
||||
RemotePort: conf.WebPort,
|
||||
@@ -39,15 +43,29 @@ func main() {
|
||||
Protocol: forwardList[i].Protocol,
|
||||
LocalPort: forwardList[i].LocalPort,
|
||||
RemotePort: forwardList[i].RemotePort,
|
||||
RemoteAddr: forwardList[i].RemoteAddr,
|
||||
RemoteAddr: sql.NormalizeAddr(forwardList[i].RemoteAddr), // 兼容旧数据中可能存在的方括号
|
||||
TotalBytes: forwardList[i].TotalBytes,
|
||||
},
|
||||
TotalBytesOld: forwardList[i].TotalBytes,
|
||||
TotalBytesLock: sync.Mutex{},
|
||||
}
|
||||
|
||||
largeStats.Connections[i] = connectionStats
|
||||
}
|
||||
|
||||
// 监听系统退出信号,优雅关闭所有转发
|
||||
go func() {
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT)
|
||||
<-sigCh
|
||||
log.Println("收到终止信号,正在关闭所有转发...")
|
||||
conf.StopFuncs.Range(func(key, value interface{}) bool {
|
||||
if cancelFn, ok := value.(context.CancelFunc); ok {
|
||||
cancelFn()
|
||||
}
|
||||
return true
|
||||
})
|
||||
}()
|
||||
|
||||
// 设置 WaitGroup 计数为连接数
|
||||
conf.Wg.Add(len(largeStats.Connections))
|
||||
// 并发执行多个转发
|
||||
@@ -58,8 +76,9 @@ func main() {
|
||||
}(stats)
|
||||
}
|
||||
conf.Wg.Wait()
|
||||
defer close(conf.Ch)
|
||||
sql.Close()
|
||||
}
|
||||
|
||||
func init() {
|
||||
flag.StringVar(&conf.WebPort, "port", "8889", "Web Port")
|
||||
flag.StringVar(&conf.Db, "db", "goForward.db", "Db Path")
|
||||
|
||||
+38
-21
@@ -1,7 +1,6 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -31,16 +30,30 @@ func Once() {
|
||||
} else {
|
||||
dbPath = conf.Db
|
||||
}
|
||||
fmt.Println("Data:", dbPath)
|
||||
log.Printf("Data: %s", dbPath)
|
||||
db, err = gorm.Open(sqlite.Open(dbPath), &gorm.Config{})
|
||||
if err != nil {
|
||||
log.Println("连接数据库失败")
|
||||
return
|
||||
log.Fatalln("连接数据库失败:", err)
|
||||
}
|
||||
db.AutoMigrate(&conf.ConnectionStats{})
|
||||
db.AutoMigrate(&conf.IpBan{})
|
||||
}
|
||||
|
||||
// Close 关闭数据库连接
|
||||
func Close() {
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
log.Println("获取数据库连接失败:", err)
|
||||
return
|
||||
}
|
||||
if err := sqlDB.Close(); err != nil {
|
||||
log.Println("关闭数据库连接失败:", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取转发列表
|
||||
func GetForwardList() []conf.ConnectionStats {
|
||||
var res []conf.ConnectionStats
|
||||
@@ -66,17 +79,17 @@ func GetIpBan() []conf.IpBan {
|
||||
func UpdateForwardBytes(id int, bytes uint64) bool {
|
||||
res := db.Model(&conf.ConnectionStats{}).Where("id = ?", id).Update("total_bytes", bytes)
|
||||
if res.Error != nil {
|
||||
fmt.Println(res.Error)
|
||||
log.Println(res.Error)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// 修改指定转发统计流量(byte)
|
||||
// 修改指定转发统计流量(gigabyte)
|
||||
func UpdateForwardGb(id int, gb uint64) bool {
|
||||
res := db.Model(&conf.ConnectionStats{}).Where("id = ?", id).Update("total_gigabyte", gb)
|
||||
if res.Error != nil {
|
||||
fmt.Println(res.Error)
|
||||
log.Println(res.Error)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
@@ -86,7 +99,7 @@ func UpdateForwardGb(id int, gb uint64) bool {
|
||||
func UpdateForwardStatus(id int, status int) bool {
|
||||
res := db.Model(&conf.ConnectionStats{}).Where("id = ?", id).Update("status", status)
|
||||
if res.Error != nil {
|
||||
fmt.Println(res.Error)
|
||||
log.Println(res.Error)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
@@ -118,10 +131,20 @@ func rmSpaces(input string) string {
|
||||
return strings.ReplaceAll(input, " ", "")
|
||||
}
|
||||
|
||||
// NormalizeAddr 去除空格,并去掉 IPv6 地址外层的方括号,统一以裸地址形式存储。
|
||||
// 连接时配合 net.JoinHostPort 使用,后者会自动为 IPv6 补回括号。
|
||||
func NormalizeAddr(addr string) string {
|
||||
addr = rmSpaces(addr)
|
||||
if strings.HasPrefix(addr, "[") && strings.HasSuffix(addr, "]") {
|
||||
addr = addr[1 : len(addr)-1]
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
// 增加转发
|
||||
func AddForward(newForward conf.ConnectionStats) int {
|
||||
//预处理
|
||||
newForward.RemoteAddr = rmSpaces(newForward.RemoteAddr)
|
||||
// 预处理:去空格,并标准化 RemoteAddr(去掉 IPv6 方括号)
|
||||
newForward.RemoteAddr = NormalizeAddr(newForward.RemoteAddr)
|
||||
newForward.RemotePort = rmSpaces(newForward.RemotePort)
|
||||
newForward.LocalPort = rmSpaces(newForward.LocalPort)
|
||||
if newForward.Protocol != "udp" {
|
||||
@@ -130,7 +153,7 @@ func AddForward(newForward conf.ConnectionStats) int {
|
||||
if !FreeForward(newForward.LocalPort, newForward.Protocol) {
|
||||
return 0
|
||||
}
|
||||
//开启事务
|
||||
// 开启事务
|
||||
tx := db.Begin()
|
||||
if tx.Error != nil {
|
||||
log.Println("开启事务失败")
|
||||
@@ -138,9 +161,8 @@ func AddForward(newForward conf.ConnectionStats) int {
|
||||
}
|
||||
// 在事务中执行插入操作
|
||||
if err := tx.Create(&newForward).Error; err != nil {
|
||||
log.Println("插入新转发失败")
|
||||
log.Println(err)
|
||||
tx.Rollback() // 回滚事务
|
||||
log.Println("插入新转发失败:", err)
|
||||
tx.Rollback()
|
||||
return 0
|
||||
}
|
||||
// 提交事务
|
||||
@@ -159,7 +181,7 @@ func DelForward(id int) bool {
|
||||
|
||||
// 增加错误登录
|
||||
func AddBan(ip conf.IpBan) bool {
|
||||
//开启事务
|
||||
// 开启事务
|
||||
tx := db.Begin()
|
||||
if tx.Error != nil {
|
||||
return false
|
||||
@@ -167,7 +189,7 @@ func AddBan(ip conf.IpBan) bool {
|
||||
// 在事务中执行插入操作
|
||||
if err := tx.Create(&ip).Error; err != nil {
|
||||
log.Println(err)
|
||||
tx.Rollback() // 回滚事务
|
||||
tx.Rollback()
|
||||
return false
|
||||
}
|
||||
// 提交事务
|
||||
@@ -177,16 +199,11 @@ func AddBan(ip conf.IpBan) bool {
|
||||
|
||||
// 检查过去一天内指定IP地址的记录条数是否超过三条
|
||||
func IpFree(ip string) bool {
|
||||
// 获取过去一天的时间戳
|
||||
oneDayAgo := time.Now().Add(-24 * time.Hour).Unix()
|
||||
|
||||
// 查询过去一天内指定IP地址的记录条数
|
||||
var count int64
|
||||
if err := db.Model(&conf.IpBan{}).Where("ip = ? AND time_stamp > ?", ip, oneDayAgo).Count(&count).Error; err != nil {
|
||||
log.Println(err)
|
||||
return false
|
||||
}
|
||||
|
||||
// 如果记录条数超过三条,则返回false;否则返回true
|
||||
return count < 3
|
||||
}
|
||||
|
||||
+12
-5
@@ -1,6 +1,7 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"csz.net/goForward/conf"
|
||||
@@ -20,7 +21,7 @@ func AddForward(newF conf.ConnectionStats) bool {
|
||||
Id: id,
|
||||
LocalPort: newF.LocalPort,
|
||||
RemotePort: newF.RemotePort,
|
||||
RemoteAddr: newF.RemoteAddr,
|
||||
RemoteAddr: sql.NormalizeAddr(newF.RemoteAddr), // 确保裸 IPv6(无方括号)
|
||||
Protocol: newF.Protocol,
|
||||
TotalBytes: 0,
|
||||
},
|
||||
@@ -40,7 +41,10 @@ func AddForward(newF conf.ConnectionStats) bool {
|
||||
// 删除并关闭指定转发
|
||||
func DelForward(f conf.ConnectionStats) bool {
|
||||
sql.DelForward(f.Id)
|
||||
conf.Ch <- f.LocalPort + f.Protocol
|
||||
key := f.LocalPort + f.Protocol
|
||||
if cancelFn, ok := conf.StopFuncs.Load(key); ok {
|
||||
cancelFn.(context.CancelFunc)()
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -57,7 +61,7 @@ func ExStatus(f conf.ConnectionStats) bool {
|
||||
Id: f.Id,
|
||||
LocalPort: f.LocalPort,
|
||||
RemotePort: f.RemotePort,
|
||||
RemoteAddr: f.RemoteAddr,
|
||||
RemoteAddr: sql.NormalizeAddr(f.RemoteAddr), // 确保裸 IPv6(无方括号)
|
||||
Protocol: f.Protocol,
|
||||
TotalBytes: f.TotalBytes,
|
||||
},
|
||||
@@ -71,10 +75,13 @@ func ExStatus(f conf.ConnectionStats) bool {
|
||||
}()
|
||||
return true
|
||||
} else {
|
||||
conf.Ch <- f.LocalPort + f.Protocol
|
||||
// 停止转发
|
||||
key := f.LocalPort + f.Protocol
|
||||
if cancelFn, ok := conf.StopFuncs.Load(key); ok {
|
||||
cancelFn.(context.CancelFunc)()
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
+12
-3
@@ -1,7 +1,7 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"crypto/rand"
|
||||
"html/template"
|
||||
"log"
|
||||
"net/http"
|
||||
@@ -17,10 +17,19 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// generateSecretKey 生成随机 session 密钥,避免使用硬编码密钥
|
||||
func generateSecretKey() []byte {
|
||||
key := make([]byte, 32)
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
log.Panicln("生成session密钥失败:", err)
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
func Run() {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
r := gin.Default()
|
||||
store := cookie.NewStore([]byte("secret"))
|
||||
store := cookie.NewStore(generateSecretKey())
|
||||
r.Use(sessions.Sessions("goForward", store))
|
||||
r.Use(checkCookieMiddleware)
|
||||
r.SetHTMLTemplate(template.Must(template.New("").Funcs(r.FuncMap).ParseFS(assets.Templates, "templates/*")))
|
||||
@@ -148,7 +157,7 @@ func Run() {
|
||||
}
|
||||
c.Redirect(302, "/")
|
||||
})
|
||||
fmt.Println("Web管理面板端口:" + conf.WebPort)
|
||||
log.Printf("Web管理面板端口: %s", conf.WebPort)
|
||||
err := r.Run(conf.WebIP + ":" + conf.WebPort)
|
||||
if err != nil {
|
||||
log.Panicln(err)
|
||||
|
||||
Reference in New Issue
Block a user