diff --git a/disco/udp.go b/disco/udp.go
index d6a3160..5a6dfc7 100644
--- a/disco/udp.go
+++ b/disco/udp.go
@@ -6,12 +6,14 @@ import (
"fmt"
"log/slog"
"net"
+ "net/netip"
"strings"
"sync"
"time"
cmap "github.com/orcaman/concurrent-map/v2"
"github.com/rkonfj/peerguard/peer"
+ "github.com/rkonfj/peerguard/upnp"
"tailscale.com/net/stun"
)
@@ -41,6 +43,33 @@ func (c *UDPConn) UDPAddrSends() <-chan *PeerUDPAddrEvent {
}
func (c *UDPConn) GenerateLocalAddrsSends(peerID peer.PeerID, stunServers []string) {
+ // UPnP
+ go func() {
+ nat, err := upnp.Discover()
+ if err != nil {
+ slog.Debug("UPnP is disabled", "err", err)
+ return
+ }
+ externalIP, err := nat.GetExternalAddress()
+ if err != nil {
+ slog.Debug("UPnP is disabled", "err", err)
+ return
+ }
+ udpPort := int(netip.MustParseAddrPort(c.UDPConn.LocalAddr().String()).Port())
+
+ for i := 0; i < 5; i++ {
+ mappedPort, err := nat.AddPortMapping("udp", udpPort+i, udpPort, "peerguard", 24*3600)
+ if err != nil {
+ continue
+ }
+ c.udpAddrSends <- &PeerUDPAddrEvent{
+ PeerID: peerID,
+ Addr: &net.UDPAddr{IP: externalIP, Port: mappedPort},
+ }
+ return
+ }
+ }()
+ // LAN
for _, addr := range c.localAddrs {
uaddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
@@ -52,6 +81,7 @@ func (c *UDPConn) GenerateLocalAddrsSends(peerID peer.PeerID, stunServers []stri
Addr: uaddr,
}
}
+ // WAN
time.AfterFunc(time.Second, func() {
if ctx, ok := c.findPeer(peerID); !ok || !ctx.IPv4Ready() {
c.requestSTUN(peerID, stunServers)
diff --git a/upnp/upnp.go b/upnp/upnp.go
new file mode 100644
index 0000000..0efe74e
--- /dev/null
+++ b/upnp/upnp.go
@@ -0,0 +1,368 @@
+package upnp
+
+// Copyright (c) 2024 rkonfj@gmail.com
+// Origin source: https://github.com/jackpal/Taipei-Torrent/blob/6808fdfe24b4db505476ef48b9e288a5e7398f77/torrent/upnp.go
+// Copyright (c) 2024 Jack Palevich (https://github.com/jackpal/Taipei-Torrent/blob/6808fdfe24b4db505476ef48b9e288a5e7398f77/LICENSE)
+//
+
+import (
+ "bytes"
+ "encoding/xml"
+ "errors"
+ "io"
+ "net"
+ "net/http"
+ "strconv"
+ "strings"
+ "time"
+)
+
+// UPnPNAT Just enough UPnP to be able to forward ports
+type UPnPNAT struct {
+ serviceURL string
+ ourIP string
+ urnDomain string
+}
+
+func Discover() (nat *UPnPNAT, err error) {
+ ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900")
+ if err != nil {
+ return
+ }
+ conn, err := net.ListenPacket("udp4", ":0")
+ if err != nil {
+ return
+ }
+ socket := conn.(*net.UDPConn)
+ defer socket.Close()
+
+ err = socket.SetDeadline(time.Now().Add(3 * time.Second))
+ if err != nil {
+ return
+ }
+
+ st := "InternetGatewayDevice:1"
+
+ buf := bytes.NewBufferString(
+ "M-SEARCH * HTTP/1.1\r\n" +
+ "HOST: 239.255.255.250:1900\r\n" +
+ "ST: ssdp:all\r\n" +
+ "MAN: \"ssdp:discover\"\r\n" +
+ "MX: 2\r\n\r\n")
+ message := buf.Bytes()
+ answerBytes := make([]byte, 1024)
+ for i := 0; i < 3; i++ {
+ _, err = socket.WriteToUDP(message, ssdp)
+ if err != nil {
+ return
+ }
+ var n int
+ for {
+ n, _, err = socket.ReadFromUDP(answerBytes)
+ if err != nil {
+ break
+ }
+ answer := string(answerBytes[0:n])
+ if !strings.Contains(answer, st) {
+ continue
+ }
+ // HTTP header field names are case-insensitive.
+ // http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2
+ locString := "\r\nlocation:"
+ answer = strings.ToLower(answer)
+ locIndex := strings.Index(answer, locString)
+ if locIndex < 0 {
+ continue
+ }
+ loc := answer[locIndex+len(locString):]
+ endIndex := strings.Index(loc, "\r\n")
+ if endIndex < 0 {
+ continue
+ }
+ locURL := strings.TrimSpace(loc[0:endIndex])
+ var serviceURL, urnDomain string
+ serviceURL, urnDomain, err = getServiceURL(locURL)
+ if err != nil {
+ return
+ }
+ var ourIP net.IP
+ ourIP, err = localIPv4()
+ if err != nil {
+ return
+ }
+ nat = &UPnPNAT{serviceURL: serviceURL, ourIP: ourIP.String(), urnDomain: urnDomain}
+ return
+ }
+ }
+ err = errors.New("upnp port discovery failed")
+ return
+}
+
+type Envelope struct {
+ XMLName xml.Name `xml:"http://schemas.xmlsoap.org/soap/envelope/ Envelope"`
+ Soap *SoapBody
+}
+type SoapBody struct {
+ XMLName xml.Name `xml:"http://schemas.xmlsoap.org/soap/envelope/ Body"`
+ ExternalIP *ExternalIPAddressResponse
+}
+
+type ExternalIPAddressResponse struct {
+ XMLName xml.Name `xml:"GetExternalIPAddressResponse"`
+ IPAddress string `xml:"NewExternalIPAddress"`
+}
+
+type ExternalIPAddress struct {
+ XMLName xml.Name `xml:"NewExternalIPAddress"`
+ IP string
+}
+
+type Service struct {
+ ServiceType string `xml:"serviceType"`
+ ControlURL string `xml:"controlURL"`
+}
+
+type DeviceList struct {
+ Device []Device `xml:"device"`
+}
+
+type ServiceList struct {
+ Service []Service `xml:"service"`
+}
+
+type Device struct {
+ XMLName xml.Name `xml:"device"`
+ DeviceType string `xml:"deviceType"`
+ DeviceList DeviceList `xml:"deviceList"`
+ ServiceList ServiceList `xml:"serviceList"`
+}
+
+type Root struct {
+ Device Device
+}
+
+func getChildDevice(d *Device, deviceType string) *Device {
+ dl := d.DeviceList.Device
+ for i := 0; i < len(dl); i++ {
+ if strings.Contains(dl[i].DeviceType, deviceType) {
+ return &dl[i]
+ }
+ }
+ return nil
+}
+
+func getChildService(d *Device, serviceType string) *Service {
+ sl := d.ServiceList.Service
+ for i := 0; i < len(sl); i++ {
+ if strings.Contains(sl[i].ServiceType, serviceType) {
+ return &sl[i]
+ }
+ }
+ return nil
+}
+
+func localIPv4() (net.IP, error) {
+ tt, err := net.Interfaces()
+ if err != nil {
+ return nil, err
+ }
+ for _, t := range tt {
+ aa, err := t.Addrs()
+ if err != nil {
+ return nil, err
+ }
+ for _, a := range aa {
+ ipnet, ok := a.(*net.IPNet)
+ if !ok {
+ continue
+ }
+ v4 := ipnet.IP.To4()
+ if v4 == nil || v4[0] == 127 { // loopback address
+ continue
+ }
+ return v4, nil
+ }
+ }
+ return nil, errors.New("cannot find local IP address")
+}
+
+func getServiceURL(rootURL string) (url, urnDomain string, err error) {
+ r, err := http.Get(rootURL)
+ if err != nil {
+ return
+ }
+ defer r.Body.Close()
+ if r.StatusCode >= 400 {
+ err = errors.New(r.Status)
+ return
+ }
+ var root Root
+ err = xml.NewDecoder(r.Body).Decode(&root)
+ if err != nil {
+ return
+ }
+ a := &root.Device
+ if !strings.Contains(a.DeviceType, "InternetGatewayDevice:1") {
+ err = errors.New("no InternetGatewayDevice")
+ return
+ }
+ b := getChildDevice(a, "WANDevice:1")
+ if b == nil {
+ err = errors.New("no WANDevice")
+ return
+ }
+ c := getChildDevice(b, "WANConnectionDevice:1")
+ if c == nil {
+ err = errors.New("no WANConnectionDevice")
+ return
+ }
+ d := getChildService(c, "WANIPConnection:1")
+ if d == nil {
+ // Some routers don't follow the UPnP spec, and put WanIPConnection under WanDevice,
+ // instead of under WanConnectionDevice
+ d = getChildService(b, "WANIPConnection:1")
+
+ if d == nil {
+ err = errors.New("no WANIPConnection")
+ return
+ }
+ }
+ // Extract the domain name, which isn't always 'schemas-upnp-org'
+ urnDomain = strings.Split(d.ServiceType, ":")[1]
+ url = combineURL(rootURL, d.ControlURL)
+ return
+}
+
+func combineURL(rootURL, subURL string) string {
+ protocolEnd := "://"
+ protoEndIndex := strings.Index(rootURL, protocolEnd)
+ a := rootURL[protoEndIndex+len(protocolEnd):]
+ rootIndex := strings.Index(a, "/")
+ return rootURL[0:protoEndIndex+len(protocolEnd)+rootIndex] + subURL
+}
+
+func soapRequest(url, function, message, domain string) (r *http.Response, err error) {
+ fullMessage := "" +
+ "\r\n" +
+ "" + message + ""
+
+ req, err := http.NewRequest("POST", url, strings.NewReader(fullMessage))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "text/xml ; charset=\"utf-8\"")
+ req.Header.Set("User-Agent", "Darwin/10.0.0, UPnP/1.0, MiniUPnPc/1.3")
+ //req.Header.Set("Transfer-Encoding", "chunked")
+ req.Header.Set("SOAPAction", "\"urn:"+domain+":service:WANIPConnection:1#"+function+"\"")
+ req.Header.Set("Connection", "Close")
+ req.Header.Set("Cache-Control", "no-cache")
+ req.Header.Set("Pragma", "no-cache")
+
+ // log.Stderr("soapRequest ", req)
+
+ r, err = http.DefaultClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ /*if r.Body != nil {
+ defer r.Body.Close()
+ }*/
+
+ if r.StatusCode >= 400 {
+ // log.Stderr(function, r.StatusCode)
+ err = errors.New("Error " + strconv.Itoa(r.StatusCode) + " for " + function)
+ r = nil
+ return
+ }
+ return
+}
+
+type statusInfo struct {
+ externalIpAddress string
+}
+
+func (n *UPnPNAT) getExternalIPAddress() (info statusInfo, err error) {
+
+ message := "\r\n" +
+ ""
+
+ var response *http.Response
+ response, err = soapRequest(n.serviceURL, "GetExternalIPAddress", message, n.urnDomain)
+ if response != nil {
+ defer response.Body.Close()
+ }
+ if err != nil {
+ return
+ }
+ var envelope Envelope
+ data, err := io.ReadAll(response.Body)
+ reader := bytes.NewReader(data)
+ xml.NewDecoder(reader).Decode(&envelope)
+
+ info = statusInfo{envelope.Soap.ExternalIP.IPAddress}
+
+ if err != nil {
+ return
+ }
+
+ return
+}
+
+func (n *UPnPNAT) GetExternalAddress() (addr net.IP, err error) {
+ info, err := n.getExternalIPAddress()
+ if err != nil {
+ return
+ }
+ addr = net.ParseIP(info.externalIpAddress)
+ return
+}
+
+func (n *UPnPNAT) AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) {
+ // A single concatenation would break ARM compilation.
+ message := "\r\n" +
+ "" + strconv.Itoa(externalPort)
+ message += "" + protocol + ""
+ message += "" + strconv.Itoa(internalPort) + "" +
+ "" + n.ourIP + "" +
+ "1"
+ message += description +
+ "" + strconv.Itoa(timeout) +
+ ""
+
+ var response *http.Response
+ response, err = soapRequest(n.serviceURL, "AddPortMapping", message, n.urnDomain)
+ if response != nil {
+ defer response.Body.Close()
+ }
+ if err != nil {
+ return
+ }
+
+ // TODO: check response to see if the port was forwarded
+ // log.Println(message, response)
+ mappedExternalPort = externalPort
+ _ = response
+ return
+}
+
+func (n *UPnPNAT) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) {
+
+ message := "\r\n" +
+ "" + strconv.Itoa(externalPort) +
+ "" + protocol + "" +
+ ""
+
+ var response *http.Response
+ response, err = soapRequest(n.serviceURL, "DeletePortMapping", message, n.urnDomain)
+ if response != nil {
+ defer response.Body.Close()
+ }
+ if err != nil {
+ return
+ }
+
+ // TODO: check response to see if the port was deleted
+ // log.Println(message, response)
+ _ = response
+ return
+}