vpn: decouple p2p from vpn

This commit is contained in:
rkonfj 2024-05-25 15:07:22 +08:00
parent 3a9c580c4a
commit 1e3908ad94
7 changed files with 317 additions and 310 deletions

View File

@ -4,6 +4,8 @@ import (
"context"
"fmt"
"log/slog"
"net"
"net/netip"
"net/url"
"os"
"os/signal"
@ -21,7 +23,6 @@ import (
"github.com/rkonfj/peerguard/peermap/network"
"github.com/rkonfj/peerguard/peermap/oidc"
"github.com/rkonfj/peerguard/vpn"
"github.com/rkonfj/peerguard/vpn/link"
"github.com/spf13/cobra"
)
@ -55,28 +56,34 @@ func init() {
}
func run(cmd *cobra.Command, args []string) (err error) {
discoPortScanCount, err := cmd.Flags().GetInt("disco-port-scan-count")
cfg, err := createConfig(cmd)
if err != nil {
return
}
discoChallengesRetry, err := cmd.Flags().GetInt("disco-challenges-retry")
if err != nil {
return
}
discoChallengesInitialInterval, err := cmd.Flags().GetDuration("disco-challenges-initial-interval")
if err != nil {
return
}
discoChallengesBackoffRate, err := cmd.Flags().GetFloat64("disco-challenges-backoff-rate")
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer cancel()
return (&P2PVPN{
Config: cfg,
RoutingTable: vpn.NewRoutingTable(),
}).Run(ctx)
}
cfg := vpn.Config{
OnRoute: onRoute,
ModifyDiscoConfig: func(cfg *disco.DiscoConfig) {
cfg.PortScanCount = discoPortScanCount
cfg.ChallengesRetry = discoChallengesRetry
cfg.ChallengesInitialInterval = discoChallengesInitialInterval
cfg.ChallengesBackoffRate = discoChallengesBackoffRate
},
func createConfig(cmd *cobra.Command) (cfg Config, err error) {
cfg.DiscoPortScanCount, err = cmd.Flags().GetInt("disco-port-scan-count")
if err != nil {
return
}
cfg.DiscoChallengesRetry, err = cmd.Flags().GetInt("disco-challenges-retry")
if err != nil {
return
}
cfg.DiscoChallengesInitialInterval, err = cmd.Flags().GetDuration("disco-challenges-initial-interval")
if err != nil {
return
}
cfg.DiscoChallengesBackoffRate, err = cmd.Flags().GetFloat64("disco-challenges-backoff-rate")
if err != nil {
return
}
cfg.IPv4, err = cmd.Flags().GetString("ipv4")
if err != nil {
@ -86,99 +93,179 @@ func run(cmd *cobra.Command, args []string) (err error) {
if err != nil {
return
}
cfg.AllowedIPs, err = cmd.Flags().GetStringSlice("allowed-ip")
if err != nil {
return
}
cfg.Peers, err = cmd.Flags().GetStringSlice("peer")
if err != nil {
return
}
cfg.MTU, err = cmd.Flags().GetInt("mtu")
if err != nil {
return
}
cfg.TunName, err = cmd.Flags().GetString("tun")
if err != nil {
return
}
cfg.AllowedIPs, err = cmd.Flags().GetStringSlice("allowed-ip")
if err != nil {
return
}
cfg.Peers, err = cmd.Flags().GetStringSlice("peer")
if err != nil {
return
}
cfg.PrivateKey, err = cmd.Flags().GetString("key")
if err != nil {
return
}
server, err := cmd.Flags().GetString("server")
cfg.SecretFile, err = cmd.Flags().GetString("secret-file")
if err != nil {
return
}
tunName, err := cmd.Flags().GetString("tun")
cfg.Server, err = cmd.Flags().GetString("server")
if err != nil {
return
}
return
}
secretFile, err := cmd.Flags().GetString("secret-file")
type Config struct {
vpn.Config
DiscoPortScanCount int
DiscoChallengesRetry int
DiscoChallengesInitialInterval time.Duration
DiscoChallengesBackoffRate float64
TunName string
AllowedIPs []string
Peers []string
PrivateKey string
SecretFile string
Server string
}
type P2PVPN struct {
Config Config
RoutingTable *vpn.SimpleRoutingTable
}
func (v *P2PVPN) Run(ctx context.Context) error {
c, err := v.listenPacketConn()
if err != nil {
return err
}
if len(secretFile) == 0 {
return vpn.New(v.RoutingTable, c, v.Config.Config).RunTun(ctx, v.Config.TunName)
}
func (v *P2PVPN) listenPacketConn() (c net.PacketConn, err error) {
disco.SetModifyDiscoConfig(func(cfg *disco.DiscoConfig) {
cfg.PortScanCount = v.Config.DiscoPortScanCount
cfg.ChallengesRetry = v.Config.DiscoChallengesRetry
cfg.ChallengesInitialInterval = v.Config.DiscoChallengesInitialInterval
cfg.ChallengesBackoffRate = v.Config.DiscoChallengesBackoffRate
})
disco.SetIgnoredLocalInterfaceNamePrefixs("pg", "wg", "veth", "docker", "nerdctl", "tailscale")
disco.AddIgnoredLocalCIDRs(v.Config.AllowedIPs...)
p2pOptions := []p2p.Option{
p2p.PeerMeta("allowedIPs", v.Config.AllowedIPs),
p2p.ListenPeerUp(v.addPeer),
}
if len(v.Config.Peers) > 0 {
p2pOptions = append(p2pOptions, p2p.PeerSilenceMode())
}
for _, peerURL := range v.Config.Peers {
pgPeer, err := url.Parse(peerURL)
if err != nil {
continue
}
if pgPeer.Scheme != "pg" {
return nil, fmt.Errorf("unsupport scheme %s", pgPeer.Scheme)
}
extra := make(map[string]any)
for k, v := range pgPeer.Query() {
extra[k] = v[0]
}
v.addPeer(peer.ID(pgPeer.Host), peer.Metadata{
Alias1: pgPeer.Query().Get("alias1"),
Alias2: pgPeer.Query().Get("alias2"),
Extra: extra,
})
}
if v.Config.IPv4 != "" {
ipv4, err := netip.ParsePrefix(v.Config.IPv4)
if err != nil {
return nil, err
}
disco.AddIgnoredLocalCIDRs(v.Config.IPv4)
p2pOptions = append(p2pOptions, p2p.PeerAlias1(ipv4.Addr().String()))
}
if v.Config.IPv6 != "" {
ipv6, err := netip.ParsePrefix(v.Config.IPv6)
if err != nil {
return nil, err
}
disco.AddIgnoredLocalCIDRs(v.Config.IPv6)
p2pOptions = append(p2pOptions, p2p.PeerAlias2(ipv6.Addr().String()))
}
if v.Config.PrivateKey != "" {
p2pOptions = append(p2pOptions, p2p.ListenPeerCurve25519(v.Config.PrivateKey))
} else {
p2pOptions = append(p2pOptions, p2p.ListenPeerSecure())
}
secretStore, err := v.loginIfNecessary()
if err != nil {
return
}
peermapURL, err := url.Parse(v.Config.Server)
if err != nil {
return
}
peermap, err := peermap.New(peermapURL, secretStore)
if err != nil {
return
}
return p2p.ListenPacket(peermap, p2pOptions...)
}
func (v *P2PVPN) addPeer(pi peer.ID, m peer.Metadata) {
v.RoutingTable.AddPeer(m.Alias1, m.Alias2, pi)
allowedIPs := m.Extra["allowedIPs"]
if allowedIPs == nil {
return
}
for _, allowIP := range allowedIPs.([]any) {
_, cidr, err := net.ParseCIDR(allowIP.(string))
if err != nil {
continue
}
if cidr.IP.To4() != nil {
v.RoutingTable.AddRoute4(cidr, m.Alias1, v.Config.TunName)
} else {
v.RoutingTable.AddRoute6(cidr, m.Alias2, v.Config.TunName)
}
}
}
func (v *P2PVPN) loginIfNecessary() (peer.SecretStore, error) {
if len(v.Config.SecretFile) == 0 {
currentUser, err := user.Current()
if err != nil {
return err
return nil, err
}
secretFile = filepath.Join(currentUser.HomeDir, ".peerguard_network_secret.json")
v.Config.SecretFile = filepath.Join(currentUser.HomeDir, ".peerguard_network_secret.json")
}
secretStore, err := loginIfNecessary(server, secretFile)
if err != nil {
return err
}
peermapURL, _ := url.Parse(server)
cfg.Peermap, err = peermap.New(peermapURL, secretStore)
if err != nil {
return err
}
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer cancel()
return vpn.New(cfg).RunTun(ctx, tunName)
}
func onRoute(route vpn.Route) {
if len(route.OldDst) > 0 {
for _, cidr := range route.OldDst {
err := link.DelRoute(route.Device, cidr, route.Via)
if err != nil {
slog.Error("DelRoute error", "detail", err, "to", cidr, "via", route.Via)
} else {
slog.Info("DelRoute", "to", cidr, "via", route.Via)
}
}
}
if len(route.NewDst) > 0 {
for _, cidr := range route.NewDst {
err := link.AddRoute(route.Device, cidr, route.Via)
if err != nil {
slog.Error("AddRoute error", "detail", err, "to", cidr, "via", route.Via)
} else {
slog.Info("AddRoute", "to", cidr, "via", route.Via)
}
}
}
}
func loginIfNecessary(peermap, secretFile string) (peer.SecretStore, error) {
store := p2p.FileSecretStore(secretFile)
store := p2p.FileSecretStore(v.Config.SecretFile)
newFileStore := func() (peer.SecretStore, error) {
joined, err := requestNetworkSecret(peermap)
joined, err := v.requestNetworkSecret()
if err != nil {
return nil, fmt.Errorf("request network secret failed: %w", err)
}
return store, store.UpdateNetworkSecret(joined)
}
if _, err := os.Stat(secretFile); os.IsNotExist(err) {
if _, err := os.Stat(v.Config.SecretFile); os.IsNotExist(err) {
return newFileStore()
}
secret, err := store.NetworkSecret()
@ -191,7 +278,7 @@ func loginIfNecessary(peermap, secretFile string) (peer.SecretStore, error) {
return store, nil
}
func requestNetworkSecret(peermap string) (peer.NetworkSecret, error) {
func (v *P2PVPN) requestNetworkSecret() (peer.NetworkSecret, error) {
prompt := promptui.Select{
Label: "Select OpenID Connect Provider",
Items: []string{oidc.ProviderGoogle, oidc.ProviderGithub},
@ -206,7 +293,7 @@ func requestNetworkSecret(peermap string) (peer.NetworkSecret, error) {
if err != nil {
return peer.NetworkSecret{}, err
}
join, err := network.JoinOIDC(provider, peermap)
join, err := network.JoinOIDC(provider, v.Config.Server)
if err != nil {
slog.Error("JoinNetwork failed", "err", err)
return peer.NetworkSecret{}, err

View File

@ -3,15 +3,9 @@ package link
import (
"net"
"os/exec"
"golang.zx2c4.com/wireguard/tun"
)
func SetupLink(device tun.Device, cidr string) error {
ifName, err := device.Name()
if err != nil {
return err
}
func SetupLink(ifName, cidr string) error {
ip, ipnet, err := net.ParseCIDR(cidr)
if err != nil {
return err
@ -27,21 +21,17 @@ func SetupLink(device tun.Device, cidr string) error {
return err
}
}
return AddRoute(device, ipnet, nil)
return AddRoute(ifName, ipnet, nil)
}
func AddRoute(device tun.Device, to *net.IPNet, _ net.IP) error {
ifName, err := device.Name()
if err != nil {
return err
}
func AddRoute(ifName string, to *net.IPNet, _ net.IP) error {
if to.IP.To4() == nil { // ipv6
return exec.Command("route", "-qn", "add", "-inet6", to.String(), "-iface", ifName).Run()
}
return exec.Command("route", "-qn", "add", "-inet", to.String(), "-iface", ifName).Run()
}
func DelRoute(_ tun.Device, to *net.IPNet, _ net.IP) error {
func DelRoute(_ string, to *net.IPNet, _ net.IP) error {
if to.IP.To4() == nil { // ipv6
return exec.Command("route", "-qn", "delete", "-inet6", to.String()).Run()
}

View File

@ -4,21 +4,19 @@ package link
import (
"net"
"golang.zx2c4.com/wireguard/tun"
)
func SetupLink(tun.Device, string) error {
func SetupLink(string, string) error {
// noop
return nil
}
func AddRoute(tun.Device, *net.IPNet, net.IP) error {
func AddRoute(string, *net.IPNet, net.IP) error {
// noop
return nil
}
func DelRoute(tun.Device, *net.IPNet, net.IP) error {
func DelRoute(string, *net.IPNet, net.IP) error {
// noop
return nil
}

View File

@ -6,14 +6,9 @@ import (
"net"
"github.com/vishvananda/netlink"
"golang.zx2c4.com/wireguard/tun"
)
func SetupLink(device tun.Device, cidr string) error {
ifName, err := device.Name()
if err != nil {
return err
}
func SetupLink(ifName, cidr string) error {
link, err := netlink.LinkByName(ifName)
if err != nil {
return err
@ -39,14 +34,14 @@ func SetupLink(device tun.Device, cidr string) error {
return nil
}
func AddRoute(device tun.Device, to *net.IPNet, via net.IP) error {
func AddRoute(_ string, to *net.IPNet, via net.IP) error {
return netlink.RouteAdd(&netlink.Route{
Dst: to,
Gw: via,
})
}
func DelRoute(device tun.Device, to *net.IPNet, via net.IP) error {
func DelRoute(_ string, to *net.IPNet, via net.IP) error {
return netlink.RouteDel(&netlink.Route{
Dst: to,
Gw: via,

View File

@ -6,19 +6,13 @@ import (
"fmt"
"net"
"os/exec"
"golang.zx2c4.com/wireguard/tun"
)
func SetupLink(device tun.Device, cidr string) error {
func SetupLink(ifName, cidr string) error {
ip, ipnet, err := net.ParseCIDR(cidr)
if err != nil {
return err
}
ifName, err := device.Name()
if err != nil {
return err
}
if ip.To4() == nil { // ipv6
info.IPv6 = ip.String()
return exec.Command("netsh", "interface", "ipv6", "add", "address", ifName, cidr).Run()
@ -29,11 +23,7 @@ func SetupLink(device tun.Device, cidr string) error {
return exec.Command("netsh", "interface", "ipv4", "set", "address", ifName, "static", ip.String(), addrMask).Run()
}
func AddRoute(device tun.Device, to *net.IPNet, via net.IP) error {
ifName, err := device.Name()
if err != nil {
return err
}
func AddRoute(ifName string, to *net.IPNet, via net.IP) error {
if via.To4() == nil { // ipv6
return exec.Command("netsh", "interface", "ipv6", "add", "route", to.String(), ifName, via.String()).Run()
}
@ -42,11 +32,7 @@ func AddRoute(device tun.Device, to *net.IPNet, via net.IP) error {
return exec.Command("route", "add", to.IP.String(), "mask", addrMask, via.String()).Run()
}
func DelRoute(device tun.Device, to *net.IPNet, via net.IP) error {
ifName, err := device.Name()
if err != nil {
return err
}
func DelRoute(ifName string, to *net.IPNet, via net.IP) error {
if via.To4() == nil { // ipv6
return exec.Command("netsh", "interface", "ipv6", "delete", "route", to.String(), ifName, via.String()).Run()
}

109
vpn/router.go Normal file
View File

@ -0,0 +1,109 @@
package vpn
import (
"log/slog"
"net"
"sync"
"github.com/rkonfj/peerguard/lru"
"github.com/rkonfj/peerguard/vpn/link"
)
type RoutingTable interface {
GetPeer(ip string) (net.Addr, bool)
}
var _ RoutingTable = (*SimpleRoutingTable)(nil)
type SimpleRoutingTable struct {
ipv6 *lru.Cache[string, []*net.IPNet]
ipv4 *lru.Cache[string, []*net.IPNet]
peers *lru.Cache[string, net.Addr]
peersMutex sync.RWMutex
}
func NewRoutingTable() *SimpleRoutingTable {
return &SimpleRoutingTable{
ipv6: lru.New[string, []*net.IPNet](256),
ipv4: lru.New[string, []*net.IPNet](256),
peers: lru.New[string, net.Addr](1024),
}
}
func (r *SimpleRoutingTable) GetPeer(ip string) (net.Addr, bool) {
r.peersMutex.RLock()
defer r.peersMutex.RUnlock()
peerID, ok := r.peers.Get(ip)
if ok {
return peerID, true
}
dstIP := net.ParseIP(ip)
if dstIP.To4() != nil {
k, _, _ := r.ipv4.Find(func(k string, v []*net.IPNet) bool {
for _, cidr := range v {
if cidr.Contains(dstIP) {
return true
}
}
return false
})
return r.peers.Get(k)
}
k, _, _ := r.ipv6.Find(func(k string, v []*net.IPNet) bool {
for _, cidr := range v {
if cidr.Contains(dstIP) {
return true
}
}
return false
})
return r.peers.Get(k)
}
func (r *SimpleRoutingTable) AddPeer(ipv4, ipv6 string, peer net.Addr) {
r.peersMutex.Lock()
defer r.peersMutex.Unlock()
r.peers.Put(ipv4, peer)
r.peers.Put(ipv6, peer)
}
func (r *SimpleRoutingTable) AddRoute4(cidr *net.IPNet, viaIP, ifName string) {
r.peersMutex.Lock()
defer r.peersMutex.Unlock()
cidrs, _ := r.ipv4.Get(viaIP)
cidrs = append(cidrs, cidr)
r.ipv4.Put(viaIP, cidrs)
r.updateRoute(ifName, net.ParseIP(viaIP), cidrs[:len(cidrs)-1], cidrs)
}
func (r *SimpleRoutingTable) AddRoute6(cidr *net.IPNet, viaIP, ifName string) {
r.peersMutex.Lock()
defer r.peersMutex.Unlock()
cidrs, _ := r.ipv6.Get(viaIP)
cidrs = append(cidrs, cidr)
r.ipv6.Put(viaIP, cidrs)
r.updateRoute(ifName, net.ParseIP(viaIP), cidrs[:len(cidrs)-1], cidrs)
}
func (r *SimpleRoutingTable) DelRoute(cidr *net.IPNet, viaIP, ifName string) {
link.DelRoute(ifName, cidr, net.ParseIP(viaIP))
}
func (r *SimpleRoutingTable) updateRoute(ifName string, viaIP net.IP, oldTo []*net.IPNet, cidrs []*net.IPNet) {
for _, cidr := range oldTo {
err := link.DelRoute(ifName, cidr, viaIP)
if err != nil {
slog.Error("DelRoute error", "detail", err, "to", cidr, "via", viaIP)
} else {
slog.Info("DelRoute", "to", cidr, "via", viaIP)
}
}
for _, cidr := range cidrs {
err := link.AddRoute(ifName, cidr, viaIP)
if err != nil {
slog.Error("AddRoute error", "detail", err, "to", cidr, "via", viaIP)
} else {
slog.Info("AddRoute", "to", cidr, "via", viaIP)
}
}
}

View File

@ -7,17 +7,11 @@ import (
"fmt"
"log/slog"
"net"
"net/netip"
"net/url"
"os"
"strings"
"sync"
"github.com/rkonfj/peerguard/disco"
"github.com/rkonfj/peerguard/lru"
"github.com/rkonfj/peerguard/p2p"
"github.com/rkonfj/peerguard/peer"
"github.com/rkonfj/peerguard/peer/peermap"
"github.com/rkonfj/peerguard/vpn/link"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
@ -28,49 +22,31 @@ const (
IPPacketOffset = 16
)
type Route struct {
Device tun.Device
OldDst []*net.IPNet
NewDst []*net.IPNet
Via net.IP
}
type Config struct {
MTU int
IPv4 string
IPv6 string
AllowedIPs []string
Peers []string
Peermap *peermap.Peermap
PrivateKey string
OnRoute func(route Route)
ModifyDiscoConfig func(cfg *disco.DiscoConfig)
InboundHandlers []InboundHandler
OutboundHandlers []OutboundHandler
MTU int
IPv4 string
IPv6 string
InboundHandlers []InboundHandler
OutboundHandlers []OutboundHandler
}
type VPN struct {
cfg Config
outbound chan []byte
inbound chan []byte
newBuf func() []byte
ipv6Routes *lru.Cache[string, []*net.IPNet]
ipv4Routes *lru.Cache[string, []*net.IPNet]
peers *lru.Cache[string, peer.ID]
peersMutex sync.RWMutex
routingTable RoutingTable
packetConn net.PacketConn
cfg Config
outbound chan []byte
inbound chan []byte
newBuf func() []byte
}
func New(cfg Config) *VPN {
disco.SetModifyDiscoConfig(cfg.ModifyDiscoConfig)
func New(routingTable RoutingTable, packetConn net.PacketConn, cfg Config) *VPN {
return &VPN{
cfg: cfg,
outbound: make(chan []byte, 512),
inbound: make(chan []byte, 512),
newBuf: func() []byte { return make([]byte, cfg.MTU+IPPacketOffset+40) },
ipv6Routes: lru.New[string, []*net.IPNet](256),
ipv4Routes: lru.New[string, []*net.IPNet](256),
peers: lru.New[string, peer.ID](1024),
routingTable: routingTable,
packetConn: packetConn,
cfg: cfg,
outbound: make(chan []byte, 512),
inbound: make(chan []byte, 512),
newBuf: func() []byte { return make([]byte, cfg.MTU+IPPacketOffset+40) },
}
}
@ -80,165 +56,31 @@ func (vpn *VPN) RunTun(ctx context.Context, tunName string) error {
return fmt.Errorf("create tun device (%s) failed: %w", tunName, err)
}
if vpn.cfg.IPv4 != "" {
link.SetupLink(device, vpn.cfg.IPv4)
link.SetupLink(tunName, vpn.cfg.IPv4)
}
if vpn.cfg.IPv6 != "" {
link.SetupLink(device, vpn.cfg.IPv6)
link.SetupLink(tunName, vpn.cfg.IPv6)
}
return vpn.run(ctx, device)
}
func (vpn *VPN) run(ctx context.Context, device tun.Device) error {
disco.SetIgnoredLocalInterfaceNamePrefixs("pg", "wg", "veth", "docker", "nerdctl", "tailscale")
disco.AddIgnoredLocalCIDRs(vpn.cfg.AllowedIPs...)
p2pOptions := []p2p.Option{
p2p.PeerMeta("allowedIPs", vpn.cfg.AllowedIPs),
p2p.ListenPeerUp(func(pi peer.ID, m peer.Metadata) { vpn.setPeer(device, pi, m) }),
}
if len(vpn.cfg.Peers) > 0 {
p2pOptions = append(p2pOptions, p2p.PeerSilenceMode())
}
for _, peerURL := range vpn.cfg.Peers {
pgPeer, err := url.Parse(peerURL)
if err != nil {
continue
}
if pgPeer.Scheme != "pg" {
return fmt.Errorf("unsupport scheme %s", pgPeer.Scheme)
}
extra := make(map[string]any)
for k, v := range pgPeer.Query() {
extra[k] = v[0]
}
vpn.setPeer(device, peer.ID(pgPeer.Host), peer.Metadata{
Alias1: pgPeer.Query().Get("alias1"),
Alias2: pgPeer.Query().Get("alias2"),
Extra: extra,
})
}
if vpn.cfg.IPv4 != "" {
ipv4, err := netip.ParsePrefix(vpn.cfg.IPv4)
if err != nil {
return err
}
disco.AddIgnoredLocalCIDRs(vpn.cfg.IPv4)
p2pOptions = append(p2pOptions, p2p.PeerAlias1(ipv4.Addr().String()))
}
if vpn.cfg.IPv6 != "" {
ipv6, err := netip.ParsePrefix(vpn.cfg.IPv6)
if err != nil {
return err
}
disco.AddIgnoredLocalCIDRs(vpn.cfg.IPv6)
p2pOptions = append(p2pOptions, p2p.PeerAlias2(ipv6.Addr().String()))
}
if vpn.cfg.PrivateKey != "" {
p2pOptions = append(p2pOptions, p2p.ListenPeerCurve25519(vpn.cfg.PrivateKey))
} else {
p2pOptions = append(p2pOptions, p2p.ListenPeerSecure())
}
packetConn, err := p2p.ListenPacket(vpn.cfg.Peermap, p2pOptions...)
if err != nil {
return err
}
var wg sync.WaitGroup
wg.Add(4)
go vpn.runTunReadEventLoop(&wg, device)
go vpn.runTunWriteEventLoop(&wg, device)
go vpn.runPacketConnReadEventLoop(&wg, packetConn)
go vpn.runPacketConnWriteEventLoop(&wg, packetConn)
go vpn.runPacketConnReadEventLoop(&wg, vpn.packetConn)
go vpn.runPacketConnWriteEventLoop(&wg, vpn.packetConn)
<-ctx.Done()
close(vpn.inbound)
close(vpn.outbound)
device.Close()
packetConn.Close()
vpn.packetConn.Close()
wg.Wait()
return nil
}
func (vpn *VPN) setPeer(device tun.Device, peer peer.ID, metadata peer.Metadata) {
vpn.peersMutex.Lock()
defer vpn.peersMutex.Unlock()
vpn.peers.Put(metadata.Alias1, peer)
vpn.peers.Put(metadata.Alias2, peer)
var allowedIPv4s, allowedIPv6s []*net.IPNet
if allowedIPs := metadata.Extra["allowedIPs"]; allowedIPs != nil {
for _, allowIP := range allowedIPs.([]any) {
_, cidr, err := net.ParseCIDR(allowIP.(string))
if err != nil {
continue
}
if cidr.IP.To4() != nil {
allowedIPv4s = append(allowedIPv4s, cidr)
} else {
allowedIPv6s = append(allowedIPv6s, cidr)
}
}
}
if len(allowedIPv4s) > 0 {
oldTo, _ := vpn.ipv4Routes.Get(metadata.Alias1)
vpn.ipv4Routes.Put(metadata.Alias1, allowedIPv4s)
if onRoute := vpn.cfg.OnRoute; onRoute != nil {
onRoute(Route{
Device: device,
OldDst: oldTo,
NewDst: allowedIPv4s,
Via: net.ParseIP(metadata.Alias1),
})
}
}
if len(allowedIPv6s) > 0 {
oldTo, _ := vpn.ipv6Routes.Get(metadata.Alias2)
vpn.ipv6Routes.Put(metadata.Alias2, allowedIPv6s)
if onRoute := vpn.cfg.OnRoute; onRoute != nil {
onRoute(Route{
Device: device,
OldDst: oldTo,
NewDst: allowedIPv6s,
Via: net.ParseIP(metadata.Alias2),
})
}
}
}
func (vpn *VPN) getPeer(ip string) (peer.ID, bool) {
vpn.peersMutex.RLock()
defer vpn.peersMutex.RUnlock()
peerID, ok := vpn.peers.Get(ip)
if ok {
return peerID, true
}
dstIP := net.ParseIP(ip)
if dstIP.To4() != nil {
k, _, _ := vpn.ipv4Routes.Find(func(k string, v []*net.IPNet) bool {
for _, cidr := range v {
if cidr.Contains(dstIP) {
return true
}
}
return false
})
return vpn.peers.Get(k)
}
k, _, _ := vpn.ipv6Routes.Find(func(k string, v []*net.IPNet) bool {
for _, cidr := range v {
if cidr.Contains(dstIP) {
return true
}
}
return false
})
return vpn.peers.Get(k)
}
func (vpn *VPN) runTunReadEventLoop(wg *sync.WaitGroup, device tun.Device) {
defer wg.Done()
@ -315,7 +157,7 @@ func (vpn *VPN) runPacketConnWriteEventLoop(wg *sync.WaitGroup, packetConn net.P
slog.Log(context.Background(), -10, "DropMulticastIP", "dst", dstIP)
return
}
if peer, ok := vpn.getPeer(dstIP.String()); ok {
if peer, ok := vpn.routingTable.GetPeer(dstIP.String()); ok {
_, err := packetConn.WriteTo(packet[IPPacketOffset:], peer)
if err != nil {
slog.Error("WriteTo peer failed", "peer", peer, "detail", err)