mirror of
https://github.com/rkonfj/peerguard.git
synced 2024-08-11 11:00:25 +08:00
vpn: decouple p2p from vpn
This commit is contained in:
parent
3a9c580c4a
commit
1e3908ad94
@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
@ -21,7 +23,6 @@ import (
|
||||
"github.com/rkonfj/peerguard/peermap/network"
|
||||
"github.com/rkonfj/peerguard/peermap/oidc"
|
||||
"github.com/rkonfj/peerguard/vpn"
|
||||
"github.com/rkonfj/peerguard/vpn/link"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@ -55,28 +56,34 @@ func init() {
|
||||
}
|
||||
|
||||
func run(cmd *cobra.Command, args []string) (err error) {
|
||||
discoPortScanCount, err := cmd.Flags().GetInt("disco-port-scan-count")
|
||||
cfg, err := createConfig(cmd)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
discoChallengesRetry, err := cmd.Flags().GetInt("disco-challenges-retry")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
discoChallengesInitialInterval, err := cmd.Flags().GetDuration("disco-challenges-initial-interval")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
discoChallengesBackoffRate, err := cmd.Flags().GetFloat64("disco-challenges-backoff-rate")
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer cancel()
|
||||
return (&P2PVPN{
|
||||
Config: cfg,
|
||||
RoutingTable: vpn.NewRoutingTable(),
|
||||
}).Run(ctx)
|
||||
}
|
||||
|
||||
cfg := vpn.Config{
|
||||
OnRoute: onRoute,
|
||||
ModifyDiscoConfig: func(cfg *disco.DiscoConfig) {
|
||||
cfg.PortScanCount = discoPortScanCount
|
||||
cfg.ChallengesRetry = discoChallengesRetry
|
||||
cfg.ChallengesInitialInterval = discoChallengesInitialInterval
|
||||
cfg.ChallengesBackoffRate = discoChallengesBackoffRate
|
||||
},
|
||||
func createConfig(cmd *cobra.Command) (cfg Config, err error) {
|
||||
cfg.DiscoPortScanCount, err = cmd.Flags().GetInt("disco-port-scan-count")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cfg.DiscoChallengesRetry, err = cmd.Flags().GetInt("disco-challenges-retry")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cfg.DiscoChallengesInitialInterval, err = cmd.Flags().GetDuration("disco-challenges-initial-interval")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cfg.DiscoChallengesBackoffRate, err = cmd.Flags().GetFloat64("disco-challenges-backoff-rate")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cfg.IPv4, err = cmd.Flags().GetString("ipv4")
|
||||
if err != nil {
|
||||
@ -86,99 +93,179 @@ func run(cmd *cobra.Command, args []string) (err error) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
cfg.AllowedIPs, err = cmd.Flags().GetStringSlice("allowed-ip")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
cfg.Peers, err = cmd.Flags().GetStringSlice("peer")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
cfg.MTU, err = cmd.Flags().GetInt("mtu")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
cfg.TunName, err = cmd.Flags().GetString("tun")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cfg.AllowedIPs, err = cmd.Flags().GetStringSlice("allowed-ip")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cfg.Peers, err = cmd.Flags().GetStringSlice("peer")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cfg.PrivateKey, err = cmd.Flags().GetString("key")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
server, err := cmd.Flags().GetString("server")
|
||||
cfg.SecretFile, err = cmd.Flags().GetString("secret-file")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
tunName, err := cmd.Flags().GetString("tun")
|
||||
cfg.Server, err = cmd.Flags().GetString("server")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
secretFile, err := cmd.Flags().GetString("secret-file")
|
||||
type Config struct {
|
||||
vpn.Config
|
||||
DiscoPortScanCount int
|
||||
DiscoChallengesRetry int
|
||||
DiscoChallengesInitialInterval time.Duration
|
||||
DiscoChallengesBackoffRate float64
|
||||
TunName string
|
||||
AllowedIPs []string
|
||||
Peers []string
|
||||
PrivateKey string
|
||||
SecretFile string
|
||||
Server string
|
||||
}
|
||||
|
||||
type P2PVPN struct {
|
||||
Config Config
|
||||
RoutingTable *vpn.SimpleRoutingTable
|
||||
}
|
||||
|
||||
func (v *P2PVPN) Run(ctx context.Context) error {
|
||||
c, err := v.listenPacketConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(secretFile) == 0 {
|
||||
return vpn.New(v.RoutingTable, c, v.Config.Config).RunTun(ctx, v.Config.TunName)
|
||||
}
|
||||
|
||||
func (v *P2PVPN) listenPacketConn() (c net.PacketConn, err error) {
|
||||
disco.SetModifyDiscoConfig(func(cfg *disco.DiscoConfig) {
|
||||
cfg.PortScanCount = v.Config.DiscoPortScanCount
|
||||
cfg.ChallengesRetry = v.Config.DiscoChallengesRetry
|
||||
cfg.ChallengesInitialInterval = v.Config.DiscoChallengesInitialInterval
|
||||
cfg.ChallengesBackoffRate = v.Config.DiscoChallengesBackoffRate
|
||||
})
|
||||
disco.SetIgnoredLocalInterfaceNamePrefixs("pg", "wg", "veth", "docker", "nerdctl", "tailscale")
|
||||
disco.AddIgnoredLocalCIDRs(v.Config.AllowedIPs...)
|
||||
p2pOptions := []p2p.Option{
|
||||
p2p.PeerMeta("allowedIPs", v.Config.AllowedIPs),
|
||||
p2p.ListenPeerUp(v.addPeer),
|
||||
}
|
||||
|
||||
if len(v.Config.Peers) > 0 {
|
||||
p2pOptions = append(p2pOptions, p2p.PeerSilenceMode())
|
||||
}
|
||||
|
||||
for _, peerURL := range v.Config.Peers {
|
||||
pgPeer, err := url.Parse(peerURL)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if pgPeer.Scheme != "pg" {
|
||||
return nil, fmt.Errorf("unsupport scheme %s", pgPeer.Scheme)
|
||||
}
|
||||
extra := make(map[string]any)
|
||||
for k, v := range pgPeer.Query() {
|
||||
extra[k] = v[0]
|
||||
}
|
||||
v.addPeer(peer.ID(pgPeer.Host), peer.Metadata{
|
||||
Alias1: pgPeer.Query().Get("alias1"),
|
||||
Alias2: pgPeer.Query().Get("alias2"),
|
||||
Extra: extra,
|
||||
})
|
||||
}
|
||||
|
||||
if v.Config.IPv4 != "" {
|
||||
ipv4, err := netip.ParsePrefix(v.Config.IPv4)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
disco.AddIgnoredLocalCIDRs(v.Config.IPv4)
|
||||
p2pOptions = append(p2pOptions, p2p.PeerAlias1(ipv4.Addr().String()))
|
||||
}
|
||||
|
||||
if v.Config.IPv6 != "" {
|
||||
ipv6, err := netip.ParsePrefix(v.Config.IPv6)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
disco.AddIgnoredLocalCIDRs(v.Config.IPv6)
|
||||
p2pOptions = append(p2pOptions, p2p.PeerAlias2(ipv6.Addr().String()))
|
||||
}
|
||||
|
||||
if v.Config.PrivateKey != "" {
|
||||
p2pOptions = append(p2pOptions, p2p.ListenPeerCurve25519(v.Config.PrivateKey))
|
||||
} else {
|
||||
p2pOptions = append(p2pOptions, p2p.ListenPeerSecure())
|
||||
}
|
||||
|
||||
secretStore, err := v.loginIfNecessary()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
peermapURL, err := url.Parse(v.Config.Server)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
peermap, err := peermap.New(peermapURL, secretStore)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return p2p.ListenPacket(peermap, p2pOptions...)
|
||||
}
|
||||
|
||||
func (v *P2PVPN) addPeer(pi peer.ID, m peer.Metadata) {
|
||||
v.RoutingTable.AddPeer(m.Alias1, m.Alias2, pi)
|
||||
allowedIPs := m.Extra["allowedIPs"]
|
||||
if allowedIPs == nil {
|
||||
return
|
||||
}
|
||||
for _, allowIP := range allowedIPs.([]any) {
|
||||
_, cidr, err := net.ParseCIDR(allowIP.(string))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if cidr.IP.To4() != nil {
|
||||
v.RoutingTable.AddRoute4(cidr, m.Alias1, v.Config.TunName)
|
||||
} else {
|
||||
v.RoutingTable.AddRoute6(cidr, m.Alias2, v.Config.TunName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (v *P2PVPN) loginIfNecessary() (peer.SecretStore, error) {
|
||||
if len(v.Config.SecretFile) == 0 {
|
||||
currentUser, err := user.Current()
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
secretFile = filepath.Join(currentUser.HomeDir, ".peerguard_network_secret.json")
|
||||
v.Config.SecretFile = filepath.Join(currentUser.HomeDir, ".peerguard_network_secret.json")
|
||||
}
|
||||
|
||||
secretStore, err := loginIfNecessary(server, secretFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peermapURL, _ := url.Parse(server)
|
||||
cfg.Peermap, err = peermap.New(peermapURL, secretStore)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer cancel()
|
||||
return vpn.New(cfg).RunTun(ctx, tunName)
|
||||
}
|
||||
|
||||
func onRoute(route vpn.Route) {
|
||||
if len(route.OldDst) > 0 {
|
||||
for _, cidr := range route.OldDst {
|
||||
err := link.DelRoute(route.Device, cidr, route.Via)
|
||||
if err != nil {
|
||||
slog.Error("DelRoute error", "detail", err, "to", cidr, "via", route.Via)
|
||||
} else {
|
||||
slog.Info("DelRoute", "to", cidr, "via", route.Via)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(route.NewDst) > 0 {
|
||||
for _, cidr := range route.NewDst {
|
||||
err := link.AddRoute(route.Device, cidr, route.Via)
|
||||
if err != nil {
|
||||
slog.Error("AddRoute error", "detail", err, "to", cidr, "via", route.Via)
|
||||
} else {
|
||||
slog.Info("AddRoute", "to", cidr, "via", route.Via)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func loginIfNecessary(peermap, secretFile string) (peer.SecretStore, error) {
|
||||
store := p2p.FileSecretStore(secretFile)
|
||||
store := p2p.FileSecretStore(v.Config.SecretFile)
|
||||
newFileStore := func() (peer.SecretStore, error) {
|
||||
joined, err := requestNetworkSecret(peermap)
|
||||
joined, err := v.requestNetworkSecret()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request network secret failed: %w", err)
|
||||
}
|
||||
return store, store.UpdateNetworkSecret(joined)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(secretFile); os.IsNotExist(err) {
|
||||
if _, err := os.Stat(v.Config.SecretFile); os.IsNotExist(err) {
|
||||
return newFileStore()
|
||||
}
|
||||
secret, err := store.NetworkSecret()
|
||||
@ -191,7 +278,7 @@ func loginIfNecessary(peermap, secretFile string) (peer.SecretStore, error) {
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func requestNetworkSecret(peermap string) (peer.NetworkSecret, error) {
|
||||
func (v *P2PVPN) requestNetworkSecret() (peer.NetworkSecret, error) {
|
||||
prompt := promptui.Select{
|
||||
Label: "Select OpenID Connect Provider",
|
||||
Items: []string{oidc.ProviderGoogle, oidc.ProviderGithub},
|
||||
@ -206,7 +293,7 @@ func requestNetworkSecret(peermap string) (peer.NetworkSecret, error) {
|
||||
if err != nil {
|
||||
return peer.NetworkSecret{}, err
|
||||
}
|
||||
join, err := network.JoinOIDC(provider, peermap)
|
||||
join, err := network.JoinOIDC(provider, v.Config.Server)
|
||||
if err != nil {
|
||||
slog.Error("JoinNetwork failed", "err", err)
|
||||
return peer.NetworkSecret{}, err
|
||||
|
@ -3,15 +3,9 @@ package link
|
||||
import (
|
||||
"net"
|
||||
"os/exec"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
func SetupLink(device tun.Device, cidr string) error {
|
||||
ifName, err := device.Name()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
func SetupLink(ifName, cidr string) error {
|
||||
ip, ipnet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -27,21 +21,17 @@ func SetupLink(device tun.Device, cidr string) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return AddRoute(device, ipnet, nil)
|
||||
return AddRoute(ifName, ipnet, nil)
|
||||
}
|
||||
|
||||
func AddRoute(device tun.Device, to *net.IPNet, _ net.IP) error {
|
||||
ifName, err := device.Name()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
func AddRoute(ifName string, to *net.IPNet, _ net.IP) error {
|
||||
if to.IP.To4() == nil { // ipv6
|
||||
return exec.Command("route", "-qn", "add", "-inet6", to.String(), "-iface", ifName).Run()
|
||||
}
|
||||
return exec.Command("route", "-qn", "add", "-inet", to.String(), "-iface", ifName).Run()
|
||||
}
|
||||
|
||||
func DelRoute(_ tun.Device, to *net.IPNet, _ net.IP) error {
|
||||
func DelRoute(_ string, to *net.IPNet, _ net.IP) error {
|
||||
if to.IP.To4() == nil { // ipv6
|
||||
return exec.Command("route", "-qn", "delete", "-inet6", to.String()).Run()
|
||||
}
|
||||
|
@ -4,21 +4,19 @@ package link
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
func SetupLink(tun.Device, string) error {
|
||||
func SetupLink(string, string) error {
|
||||
// noop
|
||||
return nil
|
||||
}
|
||||
|
||||
func AddRoute(tun.Device, *net.IPNet, net.IP) error {
|
||||
func AddRoute(string, *net.IPNet, net.IP) error {
|
||||
// noop
|
||||
return nil
|
||||
}
|
||||
|
||||
func DelRoute(tun.Device, *net.IPNet, net.IP) error {
|
||||
func DelRoute(string, *net.IPNet, net.IP) error {
|
||||
// noop
|
||||
return nil
|
||||
}
|
||||
|
@ -6,14 +6,9 @@ import (
|
||||
"net"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
func SetupLink(device tun.Device, cidr string) error {
|
||||
ifName, err := device.Name()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
func SetupLink(ifName, cidr string) error {
|
||||
link, err := netlink.LinkByName(ifName)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -39,14 +34,14 @@ func SetupLink(device tun.Device, cidr string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func AddRoute(device tun.Device, to *net.IPNet, via net.IP) error {
|
||||
func AddRoute(_ string, to *net.IPNet, via net.IP) error {
|
||||
return netlink.RouteAdd(&netlink.Route{
|
||||
Dst: to,
|
||||
Gw: via,
|
||||
})
|
||||
}
|
||||
|
||||
func DelRoute(device tun.Device, to *net.IPNet, via net.IP) error {
|
||||
func DelRoute(_ string, to *net.IPNet, via net.IP) error {
|
||||
return netlink.RouteDel(&netlink.Route{
|
||||
Dst: to,
|
||||
Gw: via,
|
||||
|
@ -6,19 +6,13 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os/exec"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
func SetupLink(device tun.Device, cidr string) error {
|
||||
func SetupLink(ifName, cidr string) error {
|
||||
ip, ipnet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ifName, err := device.Name()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ip.To4() == nil { // ipv6
|
||||
info.IPv6 = ip.String()
|
||||
return exec.Command("netsh", "interface", "ipv6", "add", "address", ifName, cidr).Run()
|
||||
@ -29,11 +23,7 @@ func SetupLink(device tun.Device, cidr string) error {
|
||||
return exec.Command("netsh", "interface", "ipv4", "set", "address", ifName, "static", ip.String(), addrMask).Run()
|
||||
}
|
||||
|
||||
func AddRoute(device tun.Device, to *net.IPNet, via net.IP) error {
|
||||
ifName, err := device.Name()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
func AddRoute(ifName string, to *net.IPNet, via net.IP) error {
|
||||
if via.To4() == nil { // ipv6
|
||||
return exec.Command("netsh", "interface", "ipv6", "add", "route", to.String(), ifName, via.String()).Run()
|
||||
}
|
||||
@ -42,11 +32,7 @@ func AddRoute(device tun.Device, to *net.IPNet, via net.IP) error {
|
||||
return exec.Command("route", "add", to.IP.String(), "mask", addrMask, via.String()).Run()
|
||||
}
|
||||
|
||||
func DelRoute(device tun.Device, to *net.IPNet, via net.IP) error {
|
||||
ifName, err := device.Name()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
func DelRoute(ifName string, to *net.IPNet, via net.IP) error {
|
||||
if via.To4() == nil { // ipv6
|
||||
return exec.Command("netsh", "interface", "ipv6", "delete", "route", to.String(), ifName, via.String()).Run()
|
||||
}
|
||||
|
109
vpn/router.go
Normal file
109
vpn/router.go
Normal file
@ -0,0 +1,109 @@
|
||||
package vpn
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/rkonfj/peerguard/lru"
|
||||
"github.com/rkonfj/peerguard/vpn/link"
|
||||
)
|
||||
|
||||
type RoutingTable interface {
|
||||
GetPeer(ip string) (net.Addr, bool)
|
||||
}
|
||||
|
||||
var _ RoutingTable = (*SimpleRoutingTable)(nil)
|
||||
|
||||
type SimpleRoutingTable struct {
|
||||
ipv6 *lru.Cache[string, []*net.IPNet]
|
||||
ipv4 *lru.Cache[string, []*net.IPNet]
|
||||
peers *lru.Cache[string, net.Addr]
|
||||
peersMutex sync.RWMutex
|
||||
}
|
||||
|
||||
func NewRoutingTable() *SimpleRoutingTable {
|
||||
return &SimpleRoutingTable{
|
||||
ipv6: lru.New[string, []*net.IPNet](256),
|
||||
ipv4: lru.New[string, []*net.IPNet](256),
|
||||
peers: lru.New[string, net.Addr](1024),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *SimpleRoutingTable) GetPeer(ip string) (net.Addr, bool) {
|
||||
r.peersMutex.RLock()
|
||||
defer r.peersMutex.RUnlock()
|
||||
peerID, ok := r.peers.Get(ip)
|
||||
if ok {
|
||||
return peerID, true
|
||||
}
|
||||
dstIP := net.ParseIP(ip)
|
||||
if dstIP.To4() != nil {
|
||||
k, _, _ := r.ipv4.Find(func(k string, v []*net.IPNet) bool {
|
||||
for _, cidr := range v {
|
||||
if cidr.Contains(dstIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
return r.peers.Get(k)
|
||||
}
|
||||
k, _, _ := r.ipv6.Find(func(k string, v []*net.IPNet) bool {
|
||||
for _, cidr := range v {
|
||||
if cidr.Contains(dstIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
return r.peers.Get(k)
|
||||
}
|
||||
|
||||
func (r *SimpleRoutingTable) AddPeer(ipv4, ipv6 string, peer net.Addr) {
|
||||
r.peersMutex.Lock()
|
||||
defer r.peersMutex.Unlock()
|
||||
r.peers.Put(ipv4, peer)
|
||||
r.peers.Put(ipv6, peer)
|
||||
}
|
||||
|
||||
func (r *SimpleRoutingTable) AddRoute4(cidr *net.IPNet, viaIP, ifName string) {
|
||||
r.peersMutex.Lock()
|
||||
defer r.peersMutex.Unlock()
|
||||
cidrs, _ := r.ipv4.Get(viaIP)
|
||||
cidrs = append(cidrs, cidr)
|
||||
r.ipv4.Put(viaIP, cidrs)
|
||||
r.updateRoute(ifName, net.ParseIP(viaIP), cidrs[:len(cidrs)-1], cidrs)
|
||||
}
|
||||
|
||||
func (r *SimpleRoutingTable) AddRoute6(cidr *net.IPNet, viaIP, ifName string) {
|
||||
r.peersMutex.Lock()
|
||||
defer r.peersMutex.Unlock()
|
||||
cidrs, _ := r.ipv6.Get(viaIP)
|
||||
cidrs = append(cidrs, cidr)
|
||||
r.ipv6.Put(viaIP, cidrs)
|
||||
r.updateRoute(ifName, net.ParseIP(viaIP), cidrs[:len(cidrs)-1], cidrs)
|
||||
}
|
||||
|
||||
func (r *SimpleRoutingTable) DelRoute(cidr *net.IPNet, viaIP, ifName string) {
|
||||
link.DelRoute(ifName, cidr, net.ParseIP(viaIP))
|
||||
}
|
||||
|
||||
func (r *SimpleRoutingTable) updateRoute(ifName string, viaIP net.IP, oldTo []*net.IPNet, cidrs []*net.IPNet) {
|
||||
for _, cidr := range oldTo {
|
||||
err := link.DelRoute(ifName, cidr, viaIP)
|
||||
if err != nil {
|
||||
slog.Error("DelRoute error", "detail", err, "to", cidr, "via", viaIP)
|
||||
} else {
|
||||
slog.Info("DelRoute", "to", cidr, "via", viaIP)
|
||||
}
|
||||
}
|
||||
for _, cidr := range cidrs {
|
||||
err := link.AddRoute(ifName, cidr, viaIP)
|
||||
if err != nil {
|
||||
slog.Error("AddRoute error", "detail", err, "to", cidr, "via", viaIP)
|
||||
} else {
|
||||
slog.Info("AddRoute", "to", cidr, "via", viaIP)
|
||||
}
|
||||
}
|
||||
}
|
206
vpn/vpn.go
206
vpn/vpn.go
@ -7,17 +7,11 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/rkonfj/peerguard/disco"
|
||||
"github.com/rkonfj/peerguard/lru"
|
||||
"github.com/rkonfj/peerguard/p2p"
|
||||
"github.com/rkonfj/peerguard/peer"
|
||||
"github.com/rkonfj/peerguard/peer/peermap"
|
||||
"github.com/rkonfj/peerguard/vpn/link"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
@ -28,49 +22,31 @@ const (
|
||||
IPPacketOffset = 16
|
||||
)
|
||||
|
||||
type Route struct {
|
||||
Device tun.Device
|
||||
OldDst []*net.IPNet
|
||||
NewDst []*net.IPNet
|
||||
Via net.IP
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
MTU int
|
||||
IPv4 string
|
||||
IPv6 string
|
||||
AllowedIPs []string
|
||||
Peers []string
|
||||
Peermap *peermap.Peermap
|
||||
PrivateKey string
|
||||
OnRoute func(route Route)
|
||||
ModifyDiscoConfig func(cfg *disco.DiscoConfig)
|
||||
InboundHandlers []InboundHandler
|
||||
OutboundHandlers []OutboundHandler
|
||||
MTU int
|
||||
IPv4 string
|
||||
IPv6 string
|
||||
InboundHandlers []InboundHandler
|
||||
OutboundHandlers []OutboundHandler
|
||||
}
|
||||
|
||||
type VPN struct {
|
||||
cfg Config
|
||||
outbound chan []byte
|
||||
inbound chan []byte
|
||||
newBuf func() []byte
|
||||
|
||||
ipv6Routes *lru.Cache[string, []*net.IPNet]
|
||||
ipv4Routes *lru.Cache[string, []*net.IPNet]
|
||||
peers *lru.Cache[string, peer.ID]
|
||||
peersMutex sync.RWMutex
|
||||
routingTable RoutingTable
|
||||
packetConn net.PacketConn
|
||||
cfg Config
|
||||
outbound chan []byte
|
||||
inbound chan []byte
|
||||
newBuf func() []byte
|
||||
}
|
||||
|
||||
func New(cfg Config) *VPN {
|
||||
disco.SetModifyDiscoConfig(cfg.ModifyDiscoConfig)
|
||||
func New(routingTable RoutingTable, packetConn net.PacketConn, cfg Config) *VPN {
|
||||
return &VPN{
|
||||
cfg: cfg,
|
||||
outbound: make(chan []byte, 512),
|
||||
inbound: make(chan []byte, 512),
|
||||
newBuf: func() []byte { return make([]byte, cfg.MTU+IPPacketOffset+40) },
|
||||
ipv6Routes: lru.New[string, []*net.IPNet](256),
|
||||
ipv4Routes: lru.New[string, []*net.IPNet](256),
|
||||
peers: lru.New[string, peer.ID](1024),
|
||||
routingTable: routingTable,
|
||||
packetConn: packetConn,
|
||||
cfg: cfg,
|
||||
outbound: make(chan []byte, 512),
|
||||
inbound: make(chan []byte, 512),
|
||||
newBuf: func() []byte { return make([]byte, cfg.MTU+IPPacketOffset+40) },
|
||||
}
|
||||
}
|
||||
|
||||
@ -80,165 +56,31 @@ func (vpn *VPN) RunTun(ctx context.Context, tunName string) error {
|
||||
return fmt.Errorf("create tun device (%s) failed: %w", tunName, err)
|
||||
}
|
||||
if vpn.cfg.IPv4 != "" {
|
||||
link.SetupLink(device, vpn.cfg.IPv4)
|
||||
link.SetupLink(tunName, vpn.cfg.IPv4)
|
||||
}
|
||||
if vpn.cfg.IPv6 != "" {
|
||||
link.SetupLink(device, vpn.cfg.IPv6)
|
||||
link.SetupLink(tunName, vpn.cfg.IPv6)
|
||||
}
|
||||
return vpn.run(ctx, device)
|
||||
}
|
||||
|
||||
func (vpn *VPN) run(ctx context.Context, device tun.Device) error {
|
||||
disco.SetIgnoredLocalInterfaceNamePrefixs("pg", "wg", "veth", "docker", "nerdctl", "tailscale")
|
||||
disco.AddIgnoredLocalCIDRs(vpn.cfg.AllowedIPs...)
|
||||
p2pOptions := []p2p.Option{
|
||||
p2p.PeerMeta("allowedIPs", vpn.cfg.AllowedIPs),
|
||||
p2p.ListenPeerUp(func(pi peer.ID, m peer.Metadata) { vpn.setPeer(device, pi, m) }),
|
||||
}
|
||||
|
||||
if len(vpn.cfg.Peers) > 0 {
|
||||
p2pOptions = append(p2pOptions, p2p.PeerSilenceMode())
|
||||
}
|
||||
|
||||
for _, peerURL := range vpn.cfg.Peers {
|
||||
pgPeer, err := url.Parse(peerURL)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if pgPeer.Scheme != "pg" {
|
||||
return fmt.Errorf("unsupport scheme %s", pgPeer.Scheme)
|
||||
}
|
||||
extra := make(map[string]any)
|
||||
for k, v := range pgPeer.Query() {
|
||||
extra[k] = v[0]
|
||||
}
|
||||
vpn.setPeer(device, peer.ID(pgPeer.Host), peer.Metadata{
|
||||
Alias1: pgPeer.Query().Get("alias1"),
|
||||
Alias2: pgPeer.Query().Get("alias2"),
|
||||
Extra: extra,
|
||||
})
|
||||
}
|
||||
|
||||
if vpn.cfg.IPv4 != "" {
|
||||
ipv4, err := netip.ParsePrefix(vpn.cfg.IPv4)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
disco.AddIgnoredLocalCIDRs(vpn.cfg.IPv4)
|
||||
p2pOptions = append(p2pOptions, p2p.PeerAlias1(ipv4.Addr().String()))
|
||||
}
|
||||
|
||||
if vpn.cfg.IPv6 != "" {
|
||||
ipv6, err := netip.ParsePrefix(vpn.cfg.IPv6)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
disco.AddIgnoredLocalCIDRs(vpn.cfg.IPv6)
|
||||
p2pOptions = append(p2pOptions, p2p.PeerAlias2(ipv6.Addr().String()))
|
||||
}
|
||||
|
||||
if vpn.cfg.PrivateKey != "" {
|
||||
p2pOptions = append(p2pOptions, p2p.ListenPeerCurve25519(vpn.cfg.PrivateKey))
|
||||
} else {
|
||||
p2pOptions = append(p2pOptions, p2p.ListenPeerSecure())
|
||||
}
|
||||
|
||||
packetConn, err := p2p.ListenPacket(vpn.cfg.Peermap, p2pOptions...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(4)
|
||||
go vpn.runTunReadEventLoop(&wg, device)
|
||||
go vpn.runTunWriteEventLoop(&wg, device)
|
||||
go vpn.runPacketConnReadEventLoop(&wg, packetConn)
|
||||
go vpn.runPacketConnWriteEventLoop(&wg, packetConn)
|
||||
go vpn.runPacketConnReadEventLoop(&wg, vpn.packetConn)
|
||||
go vpn.runPacketConnWriteEventLoop(&wg, vpn.packetConn)
|
||||
|
||||
<-ctx.Done()
|
||||
close(vpn.inbound)
|
||||
close(vpn.outbound)
|
||||
device.Close()
|
||||
packetConn.Close()
|
||||
vpn.packetConn.Close()
|
||||
wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (vpn *VPN) setPeer(device tun.Device, peer peer.ID, metadata peer.Metadata) {
|
||||
vpn.peersMutex.Lock()
|
||||
defer vpn.peersMutex.Unlock()
|
||||
vpn.peers.Put(metadata.Alias1, peer)
|
||||
vpn.peers.Put(metadata.Alias2, peer)
|
||||
var allowedIPv4s, allowedIPv6s []*net.IPNet
|
||||
if allowedIPs := metadata.Extra["allowedIPs"]; allowedIPs != nil {
|
||||
for _, allowIP := range allowedIPs.([]any) {
|
||||
_, cidr, err := net.ParseCIDR(allowIP.(string))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if cidr.IP.To4() != nil {
|
||||
allowedIPv4s = append(allowedIPv4s, cidr)
|
||||
} else {
|
||||
allowedIPv6s = append(allowedIPv6s, cidr)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(allowedIPv4s) > 0 {
|
||||
oldTo, _ := vpn.ipv4Routes.Get(metadata.Alias1)
|
||||
vpn.ipv4Routes.Put(metadata.Alias1, allowedIPv4s)
|
||||
if onRoute := vpn.cfg.OnRoute; onRoute != nil {
|
||||
onRoute(Route{
|
||||
Device: device,
|
||||
OldDst: oldTo,
|
||||
NewDst: allowedIPv4s,
|
||||
Via: net.ParseIP(metadata.Alias1),
|
||||
})
|
||||
}
|
||||
}
|
||||
if len(allowedIPv6s) > 0 {
|
||||
oldTo, _ := vpn.ipv6Routes.Get(metadata.Alias2)
|
||||
vpn.ipv6Routes.Put(metadata.Alias2, allowedIPv6s)
|
||||
if onRoute := vpn.cfg.OnRoute; onRoute != nil {
|
||||
onRoute(Route{
|
||||
Device: device,
|
||||
OldDst: oldTo,
|
||||
NewDst: allowedIPv6s,
|
||||
Via: net.ParseIP(metadata.Alias2),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (vpn *VPN) getPeer(ip string) (peer.ID, bool) {
|
||||
vpn.peersMutex.RLock()
|
||||
defer vpn.peersMutex.RUnlock()
|
||||
peerID, ok := vpn.peers.Get(ip)
|
||||
if ok {
|
||||
return peerID, true
|
||||
}
|
||||
dstIP := net.ParseIP(ip)
|
||||
if dstIP.To4() != nil {
|
||||
k, _, _ := vpn.ipv4Routes.Find(func(k string, v []*net.IPNet) bool {
|
||||
for _, cidr := range v {
|
||||
if cidr.Contains(dstIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
return vpn.peers.Get(k)
|
||||
}
|
||||
k, _, _ := vpn.ipv6Routes.Find(func(k string, v []*net.IPNet) bool {
|
||||
for _, cidr := range v {
|
||||
if cidr.Contains(dstIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
return vpn.peers.Get(k)
|
||||
}
|
||||
|
||||
func (vpn *VPN) runTunReadEventLoop(wg *sync.WaitGroup, device tun.Device) {
|
||||
defer wg.Done()
|
||||
|
||||
@ -315,7 +157,7 @@ func (vpn *VPN) runPacketConnWriteEventLoop(wg *sync.WaitGroup, packetConn net.P
|
||||
slog.Log(context.Background(), -10, "DropMulticastIP", "dst", dstIP)
|
||||
return
|
||||
}
|
||||
if peer, ok := vpn.getPeer(dstIP.String()); ok {
|
||||
if peer, ok := vpn.routingTable.GetPeer(dstIP.String()); ok {
|
||||
_, err := packetConn.WriteTo(packet[IPPacketOffset:], peer)
|
||||
if err != nil {
|
||||
slog.Error("WriteTo peer failed", "peer", peer, "detail", err)
|
||||
|
Loading…
Reference in New Issue
Block a user