diff --git a/config/args.go b/config/args.go index bae798f..763c197 100644 --- a/config/args.go +++ b/config/args.go @@ -31,6 +31,7 @@ var ( Arg_stun_svr_port int Arg_local_proxy_addr string Arg_local_forward_tcp_addrs string + Arg_local_forward_udp_addrs string ) func Help() { @@ -52,7 +53,8 @@ func Help() { Arg_tun_remote = flag.Bool("remote", false, "启动Remote端") flag.StringVar(&Arg_local_proxy_addr, "proxy", "", "Local端代理转发监听地址, 例如: 0.0.0.0:1080") - flag.StringVar(&Arg_local_forward_tcp_addrs, "forward", "", "Local端端口转发地址, 多个用逗号间隔, 例如: 0.0.0.0:22@127.0.0.1:22,0.0.0.0:80@127.0.0.1:80") + flag.StringVar(&Arg_local_forward_tcp_addrs, "forward_tcp", "", "Local端TCP转发地址, 多个用逗号间隔, 例如: 0.0.0.0:22@127.0.0.1:22,0.0.0.0:80@127.0.0.1:80") + flag.StringVar(&Arg_local_forward_udp_addrs, "forward_udp", "", "Local端UDP转发地址, 多个用逗号间隔, 例如: 0.0.0.0:5353@127.0.0.1:53") flag.StringVar(&Arg_tun_key, "key", "", "自定义, 必须客户端和服务端一致。建议: {name}_{YYYYMMDDHHMM}, 例如: kony_202412140928") flag.IntVar(&Arg_p2p_timeout, "time_out", 15, "最大连接超时, 单位: 秒") diff --git a/pro/local.go b/pro/local.go index 38c3ba6..642e0fe 100644 --- a/pro/local.go +++ b/pro/local.go @@ -20,7 +20,7 @@ var ( m_tun_active *tun.TunActive m_tun_passive *tun.TunPassive g_netstack_started = false - m_forward_clients []*proxy.ForwardClient + m_forward_clients []proxy.ForwardRunner ) // handleState0_RegisterSession 处理 State 0: 注册会话并等待 Remote 端认领 @@ -243,7 +243,7 @@ func RunLocal() error { useForward := proxy.CheckForwardArgs() if useForward { - m_forward_clients, err = proxy.NewForwardClients() + m_forward_clients, err = proxy.NewForwardRunners() if err != nil { return err } diff --git a/proxy/proxy_l.go b/proxy/proxy_l.go index 7744429..03c5e49 100644 --- a/proxy/proxy_l.go +++ b/proxy/proxy_l.go @@ -5,6 +5,7 @@ import ( "encoding/binary" go2pool "go2/pool" "goodlink/config" + "io" "log" "net" "strconv" @@ -19,10 +20,62 @@ type ForwardRule struct { ListenAddr string RemoteIP net.IP RemotePort uint16 + Proto byte // 0x00 TCP, 0x01 UDP(与 Remote process_stream 一致) } var ForwardRules []ForwardRule +func appendForwardRuleEntries(csv string, proto byte) bool { + if csv == "" { + return true + } + entries := strings.Split(csv, ",") + for _, entry := range entries { + entry = strings.TrimSpace(entry) + if entry == "" { + continue + } + parts := strings.SplitN(entry, "@", 2) + if len(parts) != 2 { + log.Printf("[proxy] 转发地址格式错误(需要 listenHost:listenPort@remoteHost:remotePort): %s", entry) + ForwardRules = nil + return false + } + listenHost, listenPort, err := net.SplitHostPort(parts[0]) + if err != nil { + log.Printf("[proxy] 转发监听地址解析失败: %s, %v", parts[0], err) + ForwardRules = nil + return false + } + listenAddr := net.JoinHostPort(listenHost, listenPort) + remoteHost, remotePortStr, err := net.SplitHostPort(parts[1]) + if err != nil { + log.Printf("[proxy] 转发目标地址解析失败: %s, %v", parts[1], err) + ForwardRules = nil + return false + } + remoteIP := net.ParseIP(remoteHost) + if remoteIP == nil { + log.Printf("[proxy] 转发目标IP解析失败: %s", remoteHost) + ForwardRules = nil + return false + } + remotePort, err := strconv.Atoi(remotePortStr) + if err != nil || remotePort <= 0 || remotePort > 65535 { + log.Printf("[proxy] 转发目标端口无效: %s", remotePortStr) + ForwardRules = nil + return false + } + ForwardRules = append(ForwardRules, ForwardRule{ + ListenAddr: listenAddr, + RemoteIP: remoteIP.To4(), + RemotePort: uint16(remotePort), + Proto: proto, + }) + } + return true +} + func CheckForwardArgs() bool { ForwardRules = nil @@ -31,59 +84,28 @@ func CheckForwardArgs() bool { ListenAddr: config.Arg_local_proxy_addr, RemoteIP: net.IPv4(127, 0, 0, 1), RemotePort: PROXY_PORT, + Proto: 0x00, }) } - if config.Arg_local_forward_tcp_addrs != "" { - entries := strings.Split(config.Arg_local_forward_tcp_addrs, ",") - for _, entry := range entries { - entry = strings.TrimSpace(entry) - if entry == "" { - continue - } - // 格式: listenHost:listenPort@remoteHost:remotePort - parts := strings.SplitN(entry, "@", 2) - if len(parts) != 2 { - log.Printf("[proxy] 转发地址格式错误(需要 listenHost:listenPort@remoteHost:remotePort): %s", entry) - ForwardRules = nil - return false - } - listenHost, listenPort, err := net.SplitHostPort(parts[0]) - if err != nil { - log.Printf("[proxy] 转发监听地址解析失败: %s, %v", parts[0], err) - ForwardRules = nil - return false - } - listenAddr := net.JoinHostPort(listenHost, listenPort) - remoteHost, remotePortStr, err := net.SplitHostPort(parts[1]) - if err != nil { - log.Printf("[proxy] 转发目标地址解析失败: %s, %v", parts[1], err) - ForwardRules = nil - return false - } - remoteIP := net.ParseIP(remoteHost) - if remoteIP == nil { - log.Printf("[proxy] 转发目标IP解析失败: %s", remoteHost) - ForwardRules = nil - return false - } - remotePort, err := strconv.Atoi(remotePortStr) - if err != nil || remotePort <= 0 || remotePort > 65535 { - log.Printf("[proxy] 转发目标端口无效: %s", remotePortStr) - ForwardRules = nil - return false - } - ForwardRules = append(ForwardRules, ForwardRule{ - ListenAddr: listenAddr, - RemoteIP: remoteIP.To4(), - RemotePort: uint16(remotePort), - }) - } + if !appendForwardRuleEntries(config.Arg_local_forward_tcp_addrs, 0x00) { + return false + } + if !appendForwardRuleEntries(config.Arg_local_forward_udp_addrs, 0x01) { + return false } return len(ForwardRules) > 0 } +// ForwardRunner 本地转发监听器(TCP 或 UDP),隧道重连时通过 SetQuicConn/ClearQuicConn 热替换 QUIC。 +type ForwardRunner interface { + SetQuicConn(conn *quic.Conn) + ClearQuicConn() + Serve() + Close() +} + // ForwardClient 管理 TCP 监听和 QUIC 隧道转发。 // listener 只创建一次,隧道重连时通过 SetQuicConn/ClearQuicConn 热替换 QUIC 连接。 type ForwardClient struct { @@ -94,24 +116,66 @@ type ForwardClient struct { remotePort uint16 } -func NewForwardClients() ([]*ForwardClient, error) { - clients := make([]*ForwardClient, 0, len(ForwardRules)) +// ForwardUDPClient 管理 UDP 监听:每个入站数据报对应一条 QUIC 流。 +type ForwardUDPClient struct { + pc *net.UDPConn + mu sync.RWMutex + quicConn *quic.Conn + remoteIP net.IP + remotePort uint16 +} + +type udpWriteBack struct { + pc *net.UDPConn + addr net.Addr +} + +func (u *udpWriteBack) Write(b []byte) (int, error) { + return u.pc.WriteTo(b, u.addr) +} + +func NewForwardRunners() ([]ForwardRunner, error) { + runners := make([]ForwardRunner, 0, len(ForwardRules)) for _, rule := range ForwardRules { - ln, err := net.Listen("tcp", rule.ListenAddr) - if err != nil { - for _, c := range clients { - c.Close() + switch rule.Proto { + case 0x01: + udpAddr, err := net.ResolveUDPAddr("udp4", rule.ListenAddr) + if err != nil { + for _, r := range runners { + r.Close() + } + return nil, err } - return nil, err + pc, err := net.ListenUDP("udp4", udpAddr) + if err != nil { + for _, r := range runners { + r.Close() + } + return nil, err + } + log.Printf("[proxy] UDP转发监听: %s -> %s:%d", rule.ListenAddr, rule.RemoteIP, rule.RemotePort) + runners = append(runners, &ForwardUDPClient{ + pc: pc, + remoteIP: rule.RemoteIP, + remotePort: rule.RemotePort, + }) + default: + ln, err := net.Listen("tcp", rule.ListenAddr) + if err != nil { + for _, r := range runners { + r.Close() + } + return nil, err + } + log.Printf("[proxy] TCP转发监听: %s -> %s:%d", rule.ListenAddr, rule.RemoteIP, rule.RemotePort) + runners = append(runners, &ForwardClient{ + listener: ln, + remoteIP: rule.RemoteIP, + remotePort: rule.RemotePort, + }) } - log.Printf("[proxy] TCP转发监听: %s -> %s:%d", rule.ListenAddr, rule.RemoteIP, rule.RemotePort) - clients = append(clients, &ForwardClient{ - listener: ln, - remoteIP: rule.RemoteIP, - remotePort: rule.RemotePort, - }) } - return clients, nil + return runners, nil } func (p *ForwardClient) SetQuicConn(conn *quic.Conn) { @@ -187,3 +251,90 @@ func (p *ForwardClient) Close() { p.listener.Close() } } + +func (p *ForwardUDPClient) SetQuicConn(conn *quic.Conn) { + p.mu.Lock() + defer p.mu.Unlock() + p.quicConn = conn +} + +func (p *ForwardUDPClient) ClearQuicConn() { + p.mu.Lock() + defer p.mu.Unlock() + p.quicConn = nil +} + +func (p *ForwardUDPClient) getQuicConn() *quic.Conn { + p.mu.RLock() + defer p.mu.RUnlock() + return p.quicConn +} + +func (p *ForwardUDPClient) Serve() { + buf := make([]byte, 65535) + for { + n, clientAddr, err := p.pc.ReadFrom(buf) + if err != nil { + log.Printf("[proxy] UDP转发监听异常: %v", err) + return + } + if n == 0 { + continue + } + payload := make([]byte, n) + copy(payload, buf[:n]) + quicConn := p.getQuicConn() + if quicConn == nil { + log.Println("[proxy] 隧道未就绪,丢弃UDP数据报") + continue + } + go p.handleDatagram(payload, clientAddr, quicConn) + } +} + +func (p *ForwardUDPClient) handleDatagram(payload []byte, clientAddr net.Addr, quicConn *quic.Conn) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + stream, err := quicConn.OpenStreamSync(ctx) + if err != nil { + log.Println("[proxy] UDP转发打开QUIC流失败:", err) + return + } + + ioBuf := go2pool.Malloc(HEAD_LEN) + ioBuf[0] = 0x01 // UDP协议标识 + copy(ioBuf[1:5], p.remoteIP.To4()) + binary.BigEndian.PutUint16(ioBuf[5:HEAD_LEN], p.remotePort) + if _, err := stream.Write(ioBuf[:HEAD_LEN]); err != nil { + go2pool.Free(ioBuf) + log.Println("[proxy] UDP转发写入头部失败:", err) + stream.CancelRead(0) + stream.Close() + return + } + go2pool.Free(ioBuf) + + if _, err := stream.Write(payload); err != nil { + log.Println("[proxy] UDP转发写入载荷失败:", err) + stream.CancelRead(0) + stream.Close() + return + } + + wb := &udpWriteBack{pc: p.pc, addr: clientAddr} + cpBuf := go2pool.Malloc(32 * 1024) + defer go2pool.Free(cpBuf) + _, err = io.CopyBuffer(wb, stream, cpBuf) + if err != nil && err != io.EOF { + log.Println("[proxy] UDP转发读流失败:", err) + } + stream.CancelRead(0) + stream.Close() +} + +func (p *ForwardUDPClient) Close() { + if p.pc != nil { + p.pc.Close() + } +}