重构转发管理,优化停止函数存储与信号处理,改进日志输出

This commit is contained in:
ym
2026-02-24 02:10:08 +08:00
parent c18b270a01
commit 619602b280
6 changed files with 249 additions and 183 deletions
+2 -2
View File
@@ -27,8 +27,8 @@ type IpBan struct {
// 全局转发协程等待组
var Wg sync.WaitGroup
// 全局协程通道 未初始化默认为nil
var Ch chan string
// StopFuncs 存储每个转发的停止函数,key 为 localPort+protocol
var StopFuncs sync.Map
// Web管理面板端口
var WebPort string
+160 -146
View File
@@ -3,7 +3,6 @@ package forward
import (
"context"
"encoding/binary"
"fmt"
"io"
"log"
"net"
@@ -15,12 +14,19 @@ import (
"csz.net/goForward/sql"
)
const (
bufSize = 4096
statsInterval = 10 * time.Second
gbSize = uint64(1 << 30) // 1 GiB
tcpTimeStep = 5 // TcpTime 每个统计周期增加的值
)
type ConnectionStats struct {
conf.ConnectionStats
TotalBytesOld uint64 `gorm:"-"`
TotalBytesLock sync.Mutex `gorm:"-"`
TCPConnections []net.Conn `gorm:"-"` // 用于存储 TCP 连接
TcpTime int `gorm:"-"` // TCP无传输时间
TotalBytesLock sync.Mutex `gorm:"-"` // 保护 TotalBytes 相关字段及 TCPConnections
TCPConnections []net.Conn `gorm:"-"`
TcpTime int `gorm:"-"`
}
// 保存多个连接信息
@@ -31,144 +37,144 @@ type LargeConnectionStats struct {
// 复用缓冲区
var bufPool = sync.Pool{
New: func() interface{} {
return make([]byte, 4096)
return make([]byte, bufSize)
},
}
// 开启转发,负责分发具体转发
func Run(stats *ConnectionStats) {
defer releaseResources(stats) // 在函数返回时释放资源
var ctx, cancel = context.WithCancel(context.Background())
var innerWg sync.WaitGroup
defer releaseResources(stats)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// 注册停止函数,外部通过 conf.StopFuncs 触发停止
key := stats.LocalPort + stats.Protocol
conf.StopFuncs.Store(key, cancel)
defer conf.StopFuncs.Delete(key)
var innerWg sync.WaitGroup
innerWg.Add(1)
go func() {
defer innerWg.Done()
stats.printStats(ctx)
innerWg.Done()
}()
fmt.Printf("【%s】监听端口 %s 转发至 %s:%s\n", stats.Protocol, stats.LocalPort, stats.RemoteAddr, stats.RemotePort)
log.Printf("【%s】监听端口 %s 转发至 %s:%s", stats.Protocol, stats.LocalPort, stats.RemoteAddr, stats.RemotePort)
if stats.Protocol == "udp" {
// UDP转发
localAddr, err := net.ResolveUDPAddr("udp", ":"+stats.LocalPort)
if err != nil {
log.Fatalln("解析本地地址时发生错误:", err)
log.Println("解析本地地址时发生错误:", err)
return
}
remoteAddr, err := net.ResolveUDPAddr("udp", stats.RemoteAddr+":"+stats.RemotePort)
remoteAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(stats.RemoteAddr, stats.RemotePort))
if err != nil {
log.Fatalln("解析远程地址时发生错误:", err)
log.Println("解析远程地址时发生错误:", err)
return
}
conn, err := net.ListenUDP("udp", localAddr)
if err != nil {
log.Fatalln("监听时发生错误:", err)
log.Println("监听时发生错误:", err)
return
}
defer conn.Close()
// context 取消时关闭 UDP 连接
go func() {
for {
select {
case stopPort := <-conf.Ch:
if stopPort == stats.LocalPort+stats.Protocol {
fmt.Printf("【%s】停止监听端口 %s\n", stats.Protocol, stats.LocalPort)
conn.Close()
cancel()
return
} else {
conf.Ch <- stopPort
time.Sleep(3 * time.Second)
}
default:
time.Sleep(1 * time.Second)
}
}
<-ctx.Done()
conn.Close()
}()
innerWg.Add(1)
go func() {
defer innerWg.Done()
stats.handleUDPConnection(conn, remoteAddr, ctx)
innerWg.Done()
}()
} else {
// TCP转发
listener, err := net.Listen("tcp", ":"+stats.LocalPort)
if err != nil {
log.Fatalln("监听时发生错误:", err)
log.Println("监听时发生错误:", err)
return
}
defer listener.Close()
// context 取消时关闭监听器及所有活跃连接
go func() {
for {
select {
case stopPort := <-conf.Ch:
fmt.Println("通道信息:" + stopPort)
fmt.Println("当前端口:" + stats.LocalPort)
if stopPort == stats.LocalPort+stats.Protocol {
fmt.Printf("【%s】停止监听端口 %s\n", stats.Protocol, stats.LocalPort)
listener.Close()
cancel()
// 遍历并关闭所有 TCP 连接
for _, conn := range stats.TCPConnections {
conn.Close()
}
return
} else {
conf.Ch <- stopPort
time.Sleep(3 * time.Second)
}
default:
time.Sleep(1 * time.Second)
}
<-ctx.Done()
listener.Close()
// 加锁复制连接列表,释放锁后再关闭,避免长时间持锁
stats.TotalBytesLock.Lock()
conns := make([]net.Conn, len(stats.TCPConnections))
copy(conns, stats.TCPConnections)
stats.TotalBytesLock.Unlock()
for _, c := range conns {
c.Close()
}
}()
for {
clientConn, err := listener.Accept()
if err != nil {
log.Println("【"+stats.LocalPort+"】接受连接时发生错误:", err)
cancel()
log.Printf("【%s】停止接受连接", stats.LocalPort)
break
}
innerWg.Add(1)
go func() {
stats.handleTCPConnection(clientConn, ctx, cancel)
innerWg.Done()
}()
go func(conn net.Conn) {
defer innerWg.Done()
stats.handleTCPConnection(conn)
}(clientConn)
}
}
innerWg.Wait()
}
// TCP转发
func (cs *ConnectionStats) handleTCPConnection(clientConn net.Conn, ctx context.Context, cancel context.CancelFunc) {
func (cs *ConnectionStats) handleTCPConnection(clientConn net.Conn) {
defer clientConn.Close()
remoteConn, err := net.Dial("tcp", cs.RemoteAddr+":"+cs.RemotePort)
remoteConn, err := net.Dial("tcp", net.JoinHostPort(cs.RemoteAddr, cs.RemotePort))
if err != nil {
log.Println("【"+cs.LocalPort+"】连接远程地址时发生错误:", err)
log.Printf("【%s】连接远程地址时发生错误: %v", cs.LocalPort, err)
return
}
defer remoteConn.Close()
cs.TCPConnections = append(cs.TCPConnections, clientConn, remoteConn) // 添加连接到列表
// 加锁后添加连接,避免竞态
cs.TotalBytesLock.Lock()
cs.TCPConnections = append(cs.TCPConnections, clientConn, remoteConn)
cs.TotalBytesLock.Unlock()
var copyWG sync.WaitGroup
copyWG.Add(2)
go func() {
defer copyWG.Done()
if err := cs.copyBytes(clientConn, remoteConn); err != nil {
log.Println("复制字节时发生错误:", err)
cancel() // Assuming `cancel` is the cancel function from the context
}
cs.copyBytes(clientConn, remoteConn)
}()
go func() {
defer copyWG.Done()
if err := cs.copyBytes(remoteConn, clientConn); err != nil {
log.Println("复制字节时发生错误:", err)
cancel() // Assuming `cancel` is the cancel function from the context
}
cs.copyBytes(remoteConn, clientConn)
}()
for {
select {
case <-ctx.Done():
// 如果上级 context 被取消,停止接收新连接
return
default:
copyWG.Wait()
return
copyWG.Wait()
// 连接结束后从跟踪列表中移除
cs.TotalBytesLock.Lock()
cs.removeConns(clientConn, remoteConn)
cs.TotalBytesLock.Unlock()
}
// removeConns 从 TCPConnections 中移除指定连接。调用时须持有 TotalBytesLock。
func (cs *ConnectionStats) removeConns(conns ...net.Conn) {
set := make(map[net.Conn]struct{}, len(conns))
for _, c := range conns {
set[c] = struct{}{}
}
result := cs.TCPConnections[:0]
for _, c := range cs.TCPConnections {
if _, found := set[c]; !found {
result = append(result, c)
}
}
cs.TCPConnections = result
}
// UDP转发
@@ -178,41 +184,46 @@ func (cs *ConnectionStats) handleUDPConnection(localConn *net.UDPConn, remoteAdd
case <-ctx.Done():
return
default:
buf := bufPool.Get().([]byte)
n, _, err := localConn.ReadFromUDP(buf)
if err != nil {
log.Println("【"+cs.LocalPort+"】从源读取时发生错误:", err)
buf, ok := bufPool.Get().([]byte)
if !ok {
log.Println("缓冲区类型断言失败")
return
}
n, _, err := localConn.ReadFromUDP(buf)
if err != nil {
bufPool.Put(buf)
log.Printf("【%s】从源读取时发生错误: %v", cs.LocalPort, err)
return
}
fmt.Printf("收到长度为 %d 的UDP数据包\n", n)
cs.TotalBytesLock.Lock()
cs.TotalBytes += uint64(n)
cs.TotalBytesLock.Unlock()
// 处理消息的边界和错误情况
go func() {
cs.forwardUDPMessage(localConn, remoteAddr, buf[:n])
bufPool.Put(buf)
}()
// WriteToUDP 写内核缓冲区后立即返回,直接调用无需额外 goroutine,
// 避免高流量时频繁 goroutine 创建和 GC 带来的 CPU 压力
data := make([]byte, n+2)
binary.BigEndian.PutUint16(data[:2], uint16(n))
copy(data[2:], buf[:n])
bufPool.Put(buf)
cs.forwardUDPMessage(localConn, remoteAddr, data)
}
}
}
func (cs *ConnectionStats) forwardUDPMessage(localConn *net.UDPConn, remoteAddr *net.UDPAddr, message []byte) {
// 在消息前面添加消息长度信息
length := make([]byte, 2)
binary.BigEndian.PutUint16(length, uint16(len(message)))
// 组合消息长度和实际消息
data := append(length, message...)
_, err := localConn.WriteToUDP(data, remoteAddr)
if err != nil {
log.Println("【"+cs.LocalPort+"】写入目标时发生错误:", err)
func (cs *ConnectionStats) forwardUDPMessage(localConn *net.UDPConn, remoteAddr *net.UDPAddr, data []byte) {
if _, err := localConn.WriteToUDP(data, remoteAddr); err != nil {
log.Printf("【%s】写入目标时发生错误: %v", cs.LocalPort, err)
}
}
func (cs *ConnectionStats) copyBytes(dst, src net.Conn) error {
buf := bufPool.Get().([]byte)
func (cs *ConnectionStats) copyBytes(dst, src net.Conn) {
buf, ok := bufPool.Get().([]byte)
if !ok {
log.Println("缓冲区类型断言失败")
dst.Close()
src.Close()
return
}
defer bufPool.Put(buf)
for {
n, err := src.Read(buf)
@@ -220,88 +231,91 @@ func (cs *ConnectionStats) copyBytes(dst, src net.Conn) error {
cs.TotalBytesLock.Lock()
cs.TotalBytes += uint64(n)
cs.TotalBytesLock.Unlock()
_, err := dst.Write(buf[:n])
if err != nil {
log.Println("【"+cs.LocalPort+"】写入目标时发生错误:", err)
return err
if _, werr := dst.Write(buf[:n]); werr != nil {
log.Printf("【%s】写入目标时发生错误: %v", cs.LocalPort, werr)
break
}
}
if err == io.EOF {
break
}
if err != nil {
log.Println("【"+cs.LocalPort+"】从源读取时发生错误:", err)
if err != io.EOF {
log.Printf("【%s】从源读取时发生错误: %v", cs.LocalPort, err)
}
break
}
}
// 关闭连接
// 关闭双向连接,通知对端 goroutine 退出
dst.Close()
src.Close()
return nil
}
// 定时打印和处理流量变化
func (cs *ConnectionStats) printStats(ctx context.Context) {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop() // 在函数结束时停止定时器
ticker := time.NewTicker(statsInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
// 在锁内只做内存状态更新,捕获需要的值后立即释放
cs.TotalBytesLock.Lock()
if cs.TotalBytes > cs.TotalBytesOld {
hasTraffic := cs.TotalBytes > cs.TotalBytesOld
var totalStr string
var bytesToSave, gbToSave uint64
var needUpdateGb bool
var connCount int
if hasTraffic {
if cs.Protocol == "tcp" {
cs.TcpTime = 0
}
var total string
if cs.TotalBytes > 0 && float64(cs.TotalBytes)/(1024*1024) < 0.5 {
total = strconv.FormatFloat(float64(cs.TotalBytes)/(1024), 'f', 2, 64) + "KB"
if float64(cs.TotalBytes)/(1024*1024) < 0.5 {
totalStr = strconv.FormatFloat(float64(cs.TotalBytes)/1024, 'f', 2, 64) + "KB"
} else {
total = strconv.FormatFloat(float64(cs.TotalBytes)/(1024*1024), 'f', 2, 64) + "MB"
totalStr = strconv.FormatFloat(float64(cs.TotalBytes)/(1024*1024), 'f', 2, 64) + "MB"
}
fmt.Printf("【%s】端口 %s 统计流量: %s\n", cs.Protocol, cs.LocalPort, total)
//统计更换单位
var gb uint64 = 1073741824
if cs.TotalBytes >= gb {
cs.TotalGigabyte = cs.TotalGigabyte + 1
sql.UpdateForwardGb(cs.Id, cs.TotalGigabyte)
cs.TotalBytes = cs.TotalBytes - gb
if cs.TotalBytes >= gbSize {
cs.TotalGigabyte++
cs.TotalBytes -= gbSize
needUpdateGb = true
gbToSave = cs.TotalGigabyte
}
cs.TotalBytesOld = cs.TotalBytes
sql.UpdateForwardBytes(cs.Id, cs.TotalBytes)
fmt.Printf("【%s】端口 %s 当前连接数: %d\n", cs.Protocol, cs.LocalPort, len(cs.TCPConnections))
} else {
if cs.Protocol == "tcp" {
// fmt.Printf("【%s】端口 %s 当前超时秒: %d\n", cs.Protocol, cs.LocalPort, cs.TcpTime)
if cs.TcpTime >= conf.TcpTimeout {
// fmt.Printf("【%s】端口 %s 超时关闭\n", cs.Protocol, cs.LocalPort)
for i := len(cs.TCPConnections) - 1; i >= 0; i-- {
conn := cs.TCPConnections[i]
conn.Close()
// 从连接列表中移除关闭的连接
cs.TCPConnections = append(cs.TCPConnections[:i], cs.TCPConnections[i+1:]...)
}
} else {
cs.TcpTime = cs.TcpTime + 5
bytesToSave = cs.TotalBytes
connCount = len(cs.TCPConnections)
} else if cs.Protocol == "tcp" {
if cs.TcpTime >= conf.TcpTimeout {
for i := len(cs.TCPConnections) - 1; i >= 0; i-- {
cs.TCPConnections[i].Close()
}
cs.TCPConnections = nil
} else {
cs.TcpTime += tcpTimeStep
}
}
cs.TotalBytesLock.Unlock()
//当协程退出时执行
// 锁外执行日志输出和 SQL 写入,避免阻塞 copyBytes 的流量统计
if hasTraffic {
log.Printf("【%s】端口 %s 统计流量: %s", cs.Protocol, cs.LocalPort, totalStr)
if needUpdateGb {
sql.UpdateForwardGb(cs.Id, gbToSave)
}
sql.UpdateForwardBytes(cs.Id, bytesToSave)
log.Printf("【%s】端口 %s 当前连接数: %d", cs.Protocol, cs.LocalPort, connCount)
}
case <-ctx.Done():
return
}
}
}
// 关闭 TCP 连接并从切片中移除
// 关闭所有 TCP 连接并清空切片
func closeTCPConnections(stats *ConnectionStats) {
stats.TotalBytesLock.Lock()
defer stats.TotalBytesLock.Unlock()
for i, conn := range stats.TCPConnections {
for _, conn := range stats.TCPConnections {
conn.Close()
stats.TCPConnections[i] = nil
}
stats.TCPConnections = nil // 清空切片
stats.TCPConnections = nil
}
// 释放资源
+25 -6
View File
@@ -1,9 +1,14 @@
package main
import (
"context"
"flag"
"log"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"csz.net/goForward/conf"
"csz.net/goForward/forward"
@@ -16,11 +21,10 @@ func main() {
if conf.TcpTimeout < 5 {
conf.TcpTimeout = 5
}
// 初始化通道
conf.Ch = make(chan string)
forwardList := sql.GetAction()
if len(forwardList) == 0 {
//添加测试数据
// 添加测试数据
testData := conf.ConnectionStats{
LocalPort: conf.WebPort,
RemotePort: conf.WebPort,
@@ -39,15 +43,29 @@ func main() {
Protocol: forwardList[i].Protocol,
LocalPort: forwardList[i].LocalPort,
RemotePort: forwardList[i].RemotePort,
RemoteAddr: forwardList[i].RemoteAddr,
RemoteAddr: sql.NormalizeAddr(forwardList[i].RemoteAddr), // 兼容旧数据中可能存在的方括号
TotalBytes: forwardList[i].TotalBytes,
},
TotalBytesOld: forwardList[i].TotalBytes,
TotalBytesLock: sync.Mutex{},
}
largeStats.Connections[i] = connectionStats
}
// 监听系统退出信号,优雅关闭所有转发
go func() {
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT)
<-sigCh
log.Println("收到终止信号,正在关闭所有转发...")
conf.StopFuncs.Range(func(key, value interface{}) bool {
if cancelFn, ok := value.(context.CancelFunc); ok {
cancelFn()
}
return true
})
}()
// 设置 WaitGroup 计数为连接数
conf.Wg.Add(len(largeStats.Connections))
// 并发执行多个转发
@@ -58,8 +76,9 @@ func main() {
}(stats)
}
conf.Wg.Wait()
defer close(conf.Ch)
sql.Close()
}
func init() {
flag.StringVar(&conf.WebPort, "port", "8889", "Web Port")
flag.StringVar(&conf.Db, "db", "goForward.db", "Db Path")
+38 -21
View File
@@ -1,7 +1,6 @@
package sql
import (
"fmt"
"log"
"os"
"path/filepath"
@@ -31,16 +30,30 @@ func Once() {
} else {
dbPath = conf.Db
}
fmt.Println("Data:", dbPath)
log.Printf("Data: %s", dbPath)
db, err = gorm.Open(sqlite.Open(dbPath), &gorm.Config{})
if err != nil {
log.Println("连接数据库失败")
return
log.Fatalln("连接数据库失败:", err)
}
db.AutoMigrate(&conf.ConnectionStats{})
db.AutoMigrate(&conf.IpBan{})
}
// Close 关闭数据库连接
func Close() {
if db == nil {
return
}
sqlDB, err := db.DB()
if err != nil {
log.Println("获取数据库连接失败:", err)
return
}
if err := sqlDB.Close(); err != nil {
log.Println("关闭数据库连接失败:", err)
}
}
// 获取转发列表
func GetForwardList() []conf.ConnectionStats {
var res []conf.ConnectionStats
@@ -66,17 +79,17 @@ func GetIpBan() []conf.IpBan {
func UpdateForwardBytes(id int, bytes uint64) bool {
res := db.Model(&conf.ConnectionStats{}).Where("id = ?", id).Update("total_bytes", bytes)
if res.Error != nil {
fmt.Println(res.Error)
log.Println(res.Error)
return false
}
return true
}
// 修改指定转发统计流量(byte)
// 修改指定转发统计流量(gigabyte)
func UpdateForwardGb(id int, gb uint64) bool {
res := db.Model(&conf.ConnectionStats{}).Where("id = ?", id).Update("total_gigabyte", gb)
if res.Error != nil {
fmt.Println(res.Error)
log.Println(res.Error)
return false
}
return true
@@ -86,7 +99,7 @@ func UpdateForwardGb(id int, gb uint64) bool {
func UpdateForwardStatus(id int, status int) bool {
res := db.Model(&conf.ConnectionStats{}).Where("id = ?", id).Update("status", status)
if res.Error != nil {
fmt.Println(res.Error)
log.Println(res.Error)
return false
}
return true
@@ -118,10 +131,20 @@ func rmSpaces(input string) string {
return strings.ReplaceAll(input, " ", "")
}
// NormalizeAddr 去除空格,并去掉 IPv6 地址外层的方括号,统一以裸地址形式存储。
// 连接时配合 net.JoinHostPort 使用,后者会自动为 IPv6 补回括号。
func NormalizeAddr(addr string) string {
addr = rmSpaces(addr)
if strings.HasPrefix(addr, "[") && strings.HasSuffix(addr, "]") {
addr = addr[1 : len(addr)-1]
}
return addr
}
// 增加转发
func AddForward(newForward conf.ConnectionStats) int {
//预处理
newForward.RemoteAddr = rmSpaces(newForward.RemoteAddr)
// 预处理:去空格,并标准化 RemoteAddr(去掉 IPv6 方括号)
newForward.RemoteAddr = NormalizeAddr(newForward.RemoteAddr)
newForward.RemotePort = rmSpaces(newForward.RemotePort)
newForward.LocalPort = rmSpaces(newForward.LocalPort)
if newForward.Protocol != "udp" {
@@ -130,7 +153,7 @@ func AddForward(newForward conf.ConnectionStats) int {
if !FreeForward(newForward.LocalPort, newForward.Protocol) {
return 0
}
//开启事务
// 开启事务
tx := db.Begin()
if tx.Error != nil {
log.Println("开启事务失败")
@@ -138,9 +161,8 @@ func AddForward(newForward conf.ConnectionStats) int {
}
// 在事务中执行插入操作
if err := tx.Create(&newForward).Error; err != nil {
log.Println("插入新转发失败")
log.Println(err)
tx.Rollback() // 回滚事务
log.Println("插入新转发失败:", err)
tx.Rollback()
return 0
}
// 提交事务
@@ -159,7 +181,7 @@ func DelForward(id int) bool {
// 增加错误登录
func AddBan(ip conf.IpBan) bool {
//开启事务
// 开启事务
tx := db.Begin()
if tx.Error != nil {
return false
@@ -167,7 +189,7 @@ func AddBan(ip conf.IpBan) bool {
// 在事务中执行插入操作
if err := tx.Create(&ip).Error; err != nil {
log.Println(err)
tx.Rollback() // 回滚事务
tx.Rollback()
return false
}
// 提交事务
@@ -177,16 +199,11 @@ func AddBan(ip conf.IpBan) bool {
// 检查过去一天内指定IP地址的记录条数是否超过三条
func IpFree(ip string) bool {
// 获取过去一天的时间戳
oneDayAgo := time.Now().Add(-24 * time.Hour).Unix()
// 查询过去一天内指定IP地址的记录条数
var count int64
if err := db.Model(&conf.IpBan{}).Where("ip = ? AND time_stamp > ?", ip, oneDayAgo).Count(&count).Error; err != nil {
log.Println(err)
return false
}
// 如果记录条数超过三条,则返回false;否则返回true
return count < 3
}
+12 -5
View File
@@ -1,6 +1,7 @@
package utils
import (
"context"
"sync"
"csz.net/goForward/conf"
@@ -20,7 +21,7 @@ func AddForward(newF conf.ConnectionStats) bool {
Id: id,
LocalPort: newF.LocalPort,
RemotePort: newF.RemotePort,
RemoteAddr: newF.RemoteAddr,
RemoteAddr: sql.NormalizeAddr(newF.RemoteAddr), // 确保裸 IPv6(无方括号)
Protocol: newF.Protocol,
TotalBytes: 0,
},
@@ -40,7 +41,10 @@ func AddForward(newF conf.ConnectionStats) bool {
// 删除并关闭指定转发
func DelForward(f conf.ConnectionStats) bool {
sql.DelForward(f.Id)
conf.Ch <- f.LocalPort + f.Protocol
key := f.LocalPort + f.Protocol
if cancelFn, ok := conf.StopFuncs.Load(key); ok {
cancelFn.(context.CancelFunc)()
}
return true
}
@@ -57,7 +61,7 @@ func ExStatus(f conf.ConnectionStats) bool {
Id: f.Id,
LocalPort: f.LocalPort,
RemotePort: f.RemotePort,
RemoteAddr: f.RemoteAddr,
RemoteAddr: sql.NormalizeAddr(f.RemoteAddr), // 确保裸 IPv6(无方括号)
Protocol: f.Protocol,
TotalBytes: f.TotalBytes,
},
@@ -71,10 +75,13 @@ func ExStatus(f conf.ConnectionStats) bool {
}()
return true
} else {
conf.Ch <- f.LocalPort + f.Protocol
// 停止转发
key := f.LocalPort + f.Protocol
if cancelFn, ok := conf.StopFuncs.Load(key); ok {
cancelFn.(context.CancelFunc)()
}
return true
}
}
return false
}
+12 -3
View File
@@ -1,7 +1,7 @@
package web
import (
"fmt"
"crypto/rand"
"html/template"
"log"
"net/http"
@@ -17,10 +17,19 @@ import (
"github.com/gin-gonic/gin"
)
// generateSecretKey 生成随机 session 密钥,避免使用硬编码密钥
func generateSecretKey() []byte {
key := make([]byte, 32)
if _, err := rand.Read(key); err != nil {
log.Panicln("生成session密钥失败:", err)
}
return key
}
func Run() {
gin.SetMode(gin.ReleaseMode)
r := gin.Default()
store := cookie.NewStore([]byte("secret"))
store := cookie.NewStore(generateSecretKey())
r.Use(sessions.Sessions("goForward", store))
r.Use(checkCookieMiddleware)
r.SetHTMLTemplate(template.Must(template.New("").Funcs(r.FuncMap).ParseFS(assets.Templates, "templates/*")))
@@ -148,7 +157,7 @@ func Run() {
}
c.Redirect(302, "/")
})
fmt.Println("Web管理面板端口:" + conf.WebPort)
log.Printf("Web管理面板端口: %s", conf.WebPort)
err := r.Run(conf.WebIP + ":" + conf.WebPort)
if err != nil {
log.Panicln(err)