refactor: optimize wireguard memory usage

This commit is contained in:
VaalaCat
2026-01-01 15:36:20 +00:00
parent fa555b17f5
commit 10325bcb7e
9 changed files with 181 additions and 36 deletions
+3 -1
View File
@@ -57,5 +57,7 @@ var (
rpc.NewClientsManager,
NewAutoJoin, // provide final config
fx.Annotate(NewPatchedConfig, fx.ResultTags(`name:"argsPatchedConfig"`)),
))
),
fx.Invoke(runProfiler),
)
)
+71
View File
@@ -0,0 +1,71 @@
package shared
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"net/http/pprof"
"time"
"github.com/VaalaCat/frp-panel/conf"
"github.com/VaalaCat/frp-panel/services/app"
"github.com/VaalaCat/frp-panel/utils/logger"
"go.uber.org/fx"
)
func runProfiler(param struct {
fx.In
Lc fx.Lifecycle
Cfg conf.Config
Ctx *app.Context
}) {
if !param.Cfg.Debug.ProfilerEnabled {
return
}
if !param.Cfg.IsDebug {
logger.Logger(param.Ctx).Warn("profiler is enabled but IS_DEBUG=false, make sure you understand the risk")
}
addr := fmt.Sprintf(":%d", param.Cfg.Debug.ProfilerPort)
mux := http.NewServeMux()
mux.HandleFunc("/debug/pprof/", pprof.Index)
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
srv := &http.Server{
Addr: addr,
Handler: mux,
ReadHeaderTimeout: 5 * time.Second,
}
param.Lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
lis, err := net.Listen("tcp", addr)
if err != nil {
return err
}
logger.Logger(param.Ctx).Infof("profiler http server started: http://%s/debug/pprof/", addr)
go func() {
if err := srv.Serve(lis); err != nil && !errors.Is(err, http.ErrServerClosed) {
logger.Logger(param.Ctx).WithError(err).Warn("profiler http server stopped unexpectedly")
}
}()
return nil
},
OnStop: func(ctx context.Context) error {
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
return srv.Shutdown(shutdownCtx)
},
})
}
+5 -1
View File
@@ -67,7 +67,11 @@ type Config struct {
} `env-prefix:"FEATURES_" env-description:"features config"`
} `env-prefix:"CLIENT_"`
IsDebug bool `env:"IS_DEBUG" env-default:"false" env-description:"is debug mode"`
Logger struct {
Debug struct {
ProfilerEnabled bool `env:"PROFILER_ENABLED" env-default:"false" env-description:"enable profiler"`
ProfilerPort int `env:"PROFILER_PORT" env-default:"6961" env-description:"profiler port"`
} `env-prefix:"DEBUG_"`
Logger struct {
DefaultLoggerLevel string `env:"DEFAULT_LOGGER_LEVEL" env-default:"info" env-description:"frp-panel internal default logger level"`
FRPLoggerLevel string `env:"FRP_LOGGER_LEVEL" env-default:"info" env-description:"frp logger level"`
} `env-prefix:"LOGGER_"`
+3
View File
@@ -42,3 +42,6 @@
| string | `DB_DSN` | `data.db` | 数据库 DSN,默认使用sqlite3,数据默认存储在可执行文件同目录下,对于 sqlite 是路径,其他数据库为 DSN,参见 [MySQL DSN](https://github.com/go-sql-driver/mysql#dsn-data-source-name) |
| string | `CLIENT_ID` | - | 客户端 ID |
| string | `CLIENT_SECRET` | - | 客户端密钥 |
| bool | `IS_DEBUG` | `false` | 是否开启调试模式(影响日志/部分组件行为) |
| bool | `DEBUG_PROFILER_ENABLED` | `false` | 是否开启 profiler(pprof) HTTP 服务(默认仅监听 127.0.0.1 |
| int | `DEBUG_PROFILER_PORT` | `6961` | profiler(pprof) HTTP 服务端口 |
+3
View File
@@ -45,3 +45,6 @@ The application loads configuration in the following order:
| string | `DB_DSN` | `data.db` | Database DSN. For `sqlite3`, this is a file path (default in working directory). For other databases, use DSN. |
| string | `CLIENT_ID` | – | Client ID |
| string | `CLIENT_SECRET` | – | Client secret |
| bool | `IS_DEBUG` | `false` | Enable debug mode (affects logging / some components behavior) |
| bool | `DEBUG_PROFILER_ENABLED` | `false` | Enable profiler (pprof) HTTP server (by default listens on 127.0.0.1 only) |
| int | `DEBUG_PROFILER_PORT` | `6961` | Profiler (pprof) HTTP port |
+10 -1
View File
@@ -20,6 +20,10 @@ type MultiBind struct {
endpointPool sync.Pool
}
const (
maxPooledEndpointSliceCap = 1024
)
func NewMultiBind(logger *logrus.Entry, trans ...*Transport) *MultiBind {
if logger == nil {
logger = logrus.NewEntry(logrus.New())
@@ -173,7 +177,12 @@ func (m *MultiBind) recvWrapper(trans *Transport, fns conn.ReceiveFunc) conn.Rec
for i := range tmpEps {
tmpEps[i] = nil
}
tmpEps = tmpEps[:0]
// 避免把超大切片放回 pool
if cap(tmpEps) > maxPooledEndpointSliceCap {
tmpEps = make([]conn.Endpoint, 0, 128)
} else {
tmpEps = tmpEps[:0]
}
*tmpEpsPtr = tmpEps
m.endpointPool.Put(tmpEpsPtr)
+5 -22
View File
@@ -1,41 +1,24 @@
package multibind
import (
"github.com/VaalaCat/frp-panel/utils"
"github.com/samber/lo"
"golang.zx2c4.com/wireguard/conn"
)
type Transport struct {
bind conn.Bind
name string
endpoints utils.SyncMap[conn.Endpoint, *MultiEndpoint]
bind conn.Bind
name string
}
func (t *Transport) loadOrNewEndpoint(inner conn.Endpoint) conn.Endpoint {
if lo.IsNil(inner) {
return &MultiEndpoint{
trans: t,
inner: inner,
}
}
if cached, ok := t.endpoints.Load(inner); ok {
return cached
}
newEndpoint := &MultiEndpoint{
return &MultiEndpoint{
trans: t,
inner: inner,
}
t.endpoints.Store(inner, newEndpoint)
return newEndpoint
}
func NewTransport(bind conn.Bind, name string) *Transport {
return &Transport{
bind: bind,
name: name,
endpoints: utils.SyncMap[conn.Endpoint, *MultiEndpoint]{},
bind: bind,
name: name,
}
}
+77 -10
View File
@@ -2,6 +2,7 @@ package ws
import (
"fmt"
"io"
"net"
"net/http"
"net/url"
@@ -22,9 +23,11 @@ var (
const (
defaultRegisterChanSize = 128
defaultIncomingChanSize = 2048
defaultBatchSize = 128 // 批量处理大小
wsReadBufferSize = 65536
wsWriteBufferSize = 65536
defaultBatchSize = 128 // 批量处理大小
wsReadBufferSize = 64 * 1024 // 64KiB
wsWriteBufferSize = 64 * 1024 // 64KiB
wsMaxMessageSize = 4 * 1024 * 1024 // 4MiB
maxPooledPayloadCap = 64 * 1024 // 64KiB
)
type WSBind struct {
@@ -145,35 +148,90 @@ func (w *WSBind) Send(bufs [][]byte, ep conn.Endpoint) error {
return err
}
writer, err := conn.NextWriter(websocket.BinaryMessage)
if err != nil {
// 需要限制单条 message 的大小。
// 超大 message 可能导致对端有大分配
var (
writer io.WriteCloser
msgBytes int
openWriter = func() error { // 打开 writer
wr, openErr := conn.NextWriter(websocket.BinaryMessage)
if openErr != nil {
return openErr
}
writer = wr
msgBytes = 0
return nil
}
flushWriter = func() error { // 关闭 writer
if writer == nil {
return nil
}
closeErr := writer.Close()
writer = nil
msgBytes = 0
return closeErr
}
)
if err = openWriter(); err != nil {
conn.Close()
return fmt.Errorf("ws get writer error: %w", err)
}
// 批量写包
// TLV分割
// 保证最大包大小不超过wsMaxMessageSize
// 如果超过,分多次TLV写入
for _, buf := range bufs {
if len(buf) == 0 {
continue
}
// TLV 长度字段为 2 字节,单包最大 65535
// 如果超长,给wg-go报错,他不应该传这么长的包
if len(buf) > 0xFFFF {
_ = flushWriter()
conn.Close()
return fmt.Errorf("ws packet too large: %d > 65535", len(buf))
}
// 高低位拼接
need := 2 + len(buf)
if need > wsMaxMessageSize {
_ = flushWriter()
conn.Close()
return fmt.Errorf("ws message too large for single packet: need=%d limit=%d", need, wsMaxMessageSize)
}
// 若追加后超过单条 message 上限,则先 flush,再开启新 message
if msgBytes > 0 && msgBytes+need > wsMaxMessageSize {
if err = flushWriter(); err != nil {
conn.Close()
return fmt.Errorf("ws flush error: %w", err)
}
if err = openWriter(); err != nil {
conn.Close()
return fmt.Errorf("ws get writer error: %w", err)
}
}
// 写 TLV 长度
lenBuf := [2]byte{byte(len(buf) >> 8), byte(len(buf))}
if _, err = writer.Write(lenBuf[:]); err != nil {
_ = flushWriter()
conn.Close()
return fmt.Errorf("ws write length error: %w", err)
}
// 写入包内容
// 写 TLV 内容
if _, err = writer.Write(buf); err != nil {
_ = flushWriter()
conn.Close()
return fmt.Errorf("ws write data error: %w", err)
}
msgBytes += need
}
// 一次性写入
if err = writer.Close(); err != nil {
// flush 最后一条 message
if err = flushWriter(); err != nil {
conn.Close()
return fmt.Errorf("ws flush error: %w", err)
}
@@ -205,6 +263,9 @@ func (w *WSBind) HandleHTTP(writer http.ResponseWriter, r *http.Request) error {
return fmt.Errorf("ws upgrade error: %w", err)
}
// 限制单条消息大小,避免 ReadMessage 触发 io.ReadAll 的超大分配
conn.SetReadLimit(wsMaxMessageSize)
// 禁用写入截止时间,避免在高负载下超时
conn.SetWriteDeadline(time.Time{})
@@ -297,6 +358,12 @@ func (w *WSBind) recvFunc(packets [][]byte, sizes []int, eps []conn.Endpoint) (i
sizes[idx] = len(payload)
}
eps[idx] = p.endpoint
// 避免把大 buffer 放回 sync.Pool
if cap(p.payload) > maxPooledPayloadCap {
p.payload = make([]byte, 0, 2048)
} else {
p.payload = p.payload[:0]
}
// 归还 packet 到对象池
w.packetPool.Put(p)
}
+3
View File
@@ -78,6 +78,9 @@ func (w *WSConn) readLoop(ctx *app.Context) {
return
}
// 限制单条消息大小,避免 gorilla/websocket.ReadMessage/io.ReadAll 触发超大分配
conn.SetReadLimit(wsMaxMessageSize)
for {
// data 是 TLV 格式
msgType, data, err := conn.ReadMessage()