diff --git a/disco/udp.go b/disco/udp.go index d6a3160..5a6dfc7 100644 --- a/disco/udp.go +++ b/disco/udp.go @@ -6,12 +6,14 @@ import ( "fmt" "log/slog" "net" + "net/netip" "strings" "sync" "time" cmap "github.com/orcaman/concurrent-map/v2" "github.com/rkonfj/peerguard/peer" + "github.com/rkonfj/peerguard/upnp" "tailscale.com/net/stun" ) @@ -41,6 +43,33 @@ func (c *UDPConn) UDPAddrSends() <-chan *PeerUDPAddrEvent { } func (c *UDPConn) GenerateLocalAddrsSends(peerID peer.PeerID, stunServers []string) { + // UPnP + go func() { + nat, err := upnp.Discover() + if err != nil { + slog.Debug("UPnP is disabled", "err", err) + return + } + externalIP, err := nat.GetExternalAddress() + if err != nil { + slog.Debug("UPnP is disabled", "err", err) + return + } + udpPort := int(netip.MustParseAddrPort(c.UDPConn.LocalAddr().String()).Port()) + + for i := 0; i < 5; i++ { + mappedPort, err := nat.AddPortMapping("udp", udpPort+i, udpPort, "peerguard", 24*3600) + if err != nil { + continue + } + c.udpAddrSends <- &PeerUDPAddrEvent{ + PeerID: peerID, + Addr: &net.UDPAddr{IP: externalIP, Port: mappedPort}, + } + return + } + }() + // LAN for _, addr := range c.localAddrs { uaddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { @@ -52,6 +81,7 @@ func (c *UDPConn) GenerateLocalAddrsSends(peerID peer.PeerID, stunServers []stri Addr: uaddr, } } + // WAN time.AfterFunc(time.Second, func() { if ctx, ok := c.findPeer(peerID); !ok || !ctx.IPv4Ready() { c.requestSTUN(peerID, stunServers) diff --git a/upnp/upnp.go b/upnp/upnp.go new file mode 100644 index 0000000..0efe74e --- /dev/null +++ b/upnp/upnp.go @@ -0,0 +1,368 @@ +package upnp + +// Copyright (c) 2024 rkonfj@gmail.com +// Origin source: https://github.com/jackpal/Taipei-Torrent/blob/6808fdfe24b4db505476ef48b9e288a5e7398f77/torrent/upnp.go +// Copyright (c) 2024 Jack Palevich (https://github.com/jackpal/Taipei-Torrent/blob/6808fdfe24b4db505476ef48b9e288a5e7398f77/LICENSE) +// + +import ( + "bytes" + "encoding/xml" + "errors" + "io" + "net" + "net/http" + "strconv" + "strings" + "time" +) + +// UPnPNAT Just enough UPnP to be able to forward ports +type UPnPNAT struct { + serviceURL string + ourIP string + urnDomain string +} + +func Discover() (nat *UPnPNAT, err error) { + ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900") + if err != nil { + return + } + conn, err := net.ListenPacket("udp4", ":0") + if err != nil { + return + } + socket := conn.(*net.UDPConn) + defer socket.Close() + + err = socket.SetDeadline(time.Now().Add(3 * time.Second)) + if err != nil { + return + } + + st := "InternetGatewayDevice:1" + + buf := bytes.NewBufferString( + "M-SEARCH * HTTP/1.1\r\n" + + "HOST: 239.255.255.250:1900\r\n" + + "ST: ssdp:all\r\n" + + "MAN: \"ssdp:discover\"\r\n" + + "MX: 2\r\n\r\n") + message := buf.Bytes() + answerBytes := make([]byte, 1024) + for i := 0; i < 3; i++ { + _, err = socket.WriteToUDP(message, ssdp) + if err != nil { + return + } + var n int + for { + n, _, err = socket.ReadFromUDP(answerBytes) + if err != nil { + break + } + answer := string(answerBytes[0:n]) + if !strings.Contains(answer, st) { + continue + } + // HTTP header field names are case-insensitive. + // http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2 + locString := "\r\nlocation:" + answer = strings.ToLower(answer) + locIndex := strings.Index(answer, locString) + if locIndex < 0 { + continue + } + loc := answer[locIndex+len(locString):] + endIndex := strings.Index(loc, "\r\n") + if endIndex < 0 { + continue + } + locURL := strings.TrimSpace(loc[0:endIndex]) + var serviceURL, urnDomain string + serviceURL, urnDomain, err = getServiceURL(locURL) + if err != nil { + return + } + var ourIP net.IP + ourIP, err = localIPv4() + if err != nil { + return + } + nat = &UPnPNAT{serviceURL: serviceURL, ourIP: ourIP.String(), urnDomain: urnDomain} + return + } + } + err = errors.New("upnp port discovery failed") + return +} + +type Envelope struct { + XMLName xml.Name `xml:"http://schemas.xmlsoap.org/soap/envelope/ Envelope"` + Soap *SoapBody +} +type SoapBody struct { + XMLName xml.Name `xml:"http://schemas.xmlsoap.org/soap/envelope/ Body"` + ExternalIP *ExternalIPAddressResponse +} + +type ExternalIPAddressResponse struct { + XMLName xml.Name `xml:"GetExternalIPAddressResponse"` + IPAddress string `xml:"NewExternalIPAddress"` +} + +type ExternalIPAddress struct { + XMLName xml.Name `xml:"NewExternalIPAddress"` + IP string +} + +type Service struct { + ServiceType string `xml:"serviceType"` + ControlURL string `xml:"controlURL"` +} + +type DeviceList struct { + Device []Device `xml:"device"` +} + +type ServiceList struct { + Service []Service `xml:"service"` +} + +type Device struct { + XMLName xml.Name `xml:"device"` + DeviceType string `xml:"deviceType"` + DeviceList DeviceList `xml:"deviceList"` + ServiceList ServiceList `xml:"serviceList"` +} + +type Root struct { + Device Device +} + +func getChildDevice(d *Device, deviceType string) *Device { + dl := d.DeviceList.Device + for i := 0; i < len(dl); i++ { + if strings.Contains(dl[i].DeviceType, deviceType) { + return &dl[i] + } + } + return nil +} + +func getChildService(d *Device, serviceType string) *Service { + sl := d.ServiceList.Service + for i := 0; i < len(sl); i++ { + if strings.Contains(sl[i].ServiceType, serviceType) { + return &sl[i] + } + } + return nil +} + +func localIPv4() (net.IP, error) { + tt, err := net.Interfaces() + if err != nil { + return nil, err + } + for _, t := range tt { + aa, err := t.Addrs() + if err != nil { + return nil, err + } + for _, a := range aa { + ipnet, ok := a.(*net.IPNet) + if !ok { + continue + } + v4 := ipnet.IP.To4() + if v4 == nil || v4[0] == 127 { // loopback address + continue + } + return v4, nil + } + } + return nil, errors.New("cannot find local IP address") +} + +func getServiceURL(rootURL string) (url, urnDomain string, err error) { + r, err := http.Get(rootURL) + if err != nil { + return + } + defer r.Body.Close() + if r.StatusCode >= 400 { + err = errors.New(r.Status) + return + } + var root Root + err = xml.NewDecoder(r.Body).Decode(&root) + if err != nil { + return + } + a := &root.Device + if !strings.Contains(a.DeviceType, "InternetGatewayDevice:1") { + err = errors.New("no InternetGatewayDevice") + return + } + b := getChildDevice(a, "WANDevice:1") + if b == nil { + err = errors.New("no WANDevice") + return + } + c := getChildDevice(b, "WANConnectionDevice:1") + if c == nil { + err = errors.New("no WANConnectionDevice") + return + } + d := getChildService(c, "WANIPConnection:1") + if d == nil { + // Some routers don't follow the UPnP spec, and put WanIPConnection under WanDevice, + // instead of under WanConnectionDevice + d = getChildService(b, "WANIPConnection:1") + + if d == nil { + err = errors.New("no WANIPConnection") + return + } + } + // Extract the domain name, which isn't always 'schemas-upnp-org' + urnDomain = strings.Split(d.ServiceType, ":")[1] + url = combineURL(rootURL, d.ControlURL) + return +} + +func combineURL(rootURL, subURL string) string { + protocolEnd := "://" + protoEndIndex := strings.Index(rootURL, protocolEnd) + a := rootURL[protoEndIndex+len(protocolEnd):] + rootIndex := strings.Index(a, "/") + return rootURL[0:protoEndIndex+len(protocolEnd)+rootIndex] + subURL +} + +func soapRequest(url, function, message, domain string) (r *http.Response, err error) { + fullMessage := "" + + "\r\n" + + "" + message + "" + + req, err := http.NewRequest("POST", url, strings.NewReader(fullMessage)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "text/xml ; charset=\"utf-8\"") + req.Header.Set("User-Agent", "Darwin/10.0.0, UPnP/1.0, MiniUPnPc/1.3") + //req.Header.Set("Transfer-Encoding", "chunked") + req.Header.Set("SOAPAction", "\"urn:"+domain+":service:WANIPConnection:1#"+function+"\"") + req.Header.Set("Connection", "Close") + req.Header.Set("Cache-Control", "no-cache") + req.Header.Set("Pragma", "no-cache") + + // log.Stderr("soapRequest ", req) + + r, err = http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + /*if r.Body != nil { + defer r.Body.Close() + }*/ + + if r.StatusCode >= 400 { + // log.Stderr(function, r.StatusCode) + err = errors.New("Error " + strconv.Itoa(r.StatusCode) + " for " + function) + r = nil + return + } + return +} + +type statusInfo struct { + externalIpAddress string +} + +func (n *UPnPNAT) getExternalIPAddress() (info statusInfo, err error) { + + message := "\r\n" + + "" + + var response *http.Response + response, err = soapRequest(n.serviceURL, "GetExternalIPAddress", message, n.urnDomain) + if response != nil { + defer response.Body.Close() + } + if err != nil { + return + } + var envelope Envelope + data, err := io.ReadAll(response.Body) + reader := bytes.NewReader(data) + xml.NewDecoder(reader).Decode(&envelope) + + info = statusInfo{envelope.Soap.ExternalIP.IPAddress} + + if err != nil { + return + } + + return +} + +func (n *UPnPNAT) GetExternalAddress() (addr net.IP, err error) { + info, err := n.getExternalIPAddress() + if err != nil { + return + } + addr = net.ParseIP(info.externalIpAddress) + return +} + +func (n *UPnPNAT) AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) { + // A single concatenation would break ARM compilation. + message := "\r\n" + + "" + strconv.Itoa(externalPort) + message += "" + protocol + "" + message += "" + strconv.Itoa(internalPort) + "" + + "" + n.ourIP + "" + + "1" + message += description + + "" + strconv.Itoa(timeout) + + "" + + var response *http.Response + response, err = soapRequest(n.serviceURL, "AddPortMapping", message, n.urnDomain) + if response != nil { + defer response.Body.Close() + } + if err != nil { + return + } + + // TODO: check response to see if the port was forwarded + // log.Println(message, response) + mappedExternalPort = externalPort + _ = response + return +} + +func (n *UPnPNAT) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) { + + message := "\r\n" + + "" + strconv.Itoa(externalPort) + + "" + protocol + "" + + "" + + var response *http.Response + response, err = soapRequest(n.serviceURL, "DeletePortMapping", message, n.urnDomain) + if response != nil { + defer response.Body.Close() + } + if err != nil { + return + } + + // TODO: check response to see if the port was deleted + // log.Println(message, response) + _ = response + return +}