Noise crypto (#4)

* integrate noise handshake

it works! communication occurs over a secure channel now. ephemeral keys
for confidentiality and static keys  for authenticity. this means
that the server has a public key in order to guarantee authenticity.

* remove check for udp packet

all packets recieved are UDP, no need to check
the socket type is "ip4:udp"

* add timestamp to handshakes to prevent replays

without a timestamp, the handshake initiation packet can be replayed,
discarding the current session keys and triggereing a response. this has
terrible dos attack potential. by using an AEAD secured timestamp,
a timestamp that is not new will cause the handshake to be discarded

* add client-initiated key rotation

after a fixed period, the client now initiates a new handshake.
communication can only continue if the server responds to the handshake,
so keys are rotated on both sides. if the handshake initiation or
response packet is dropped, the connection will be broken until the next
handshake

* add sliding window package

i borrowed wireguard-go's anti-replay algorithm implementation
(see RFC 6479). i didn't just copy paste, i read through it line by line
an littered it with comments, and i did reposition and rename some
parts. i'm using this implementation because i couldn't find anything to
improve on or optimize.

* move go.mod to project root

i was having difficulty sharing antireplay between client and server
but using one go.mod in the project root did the trick.

* integrate sliding window for nonces

this migitates replay attacks and prevents attackers from holding
packets indefinitely. nice.

* update README.md to reflect command usage

* slightly better error handling
This commit is contained in:
Malcolm Seyd
2020-08-30 01:52:17 -07:00
committed by GitHub
parent 2cdfdf002f
commit 9966b4c613
14 changed files with 792 additions and 94 deletions
+3 -3
View File
@@ -8,17 +8,17 @@ This tools allows you to connect to other Wireguard peers from behind a NAT usin
The client cycles through each peer on the interface until they are all resolved. Requires root to run due to raw socket usage.
```
Usage: ./client [OPTION]... SERVER_HOSTNAME:PORT WIREGUARD_INTERFACE
Usage: ./client [OPTION]... WIREGUARD_INTERFACE SERVER_HOSTNAME:PORT SERVER_PUBKEY
Flags:
-c, --continuous=false: continuously resolve peers after they've already been resolved
-d, --delay=2: time to wait between retries (in seconds)
Example:
./client demo.wireguard.com:12345 wg0
./client wg0 demo.wireguard.com:12345 1rwvlEQkF6vL4jA1gRzlTM7I3tuZHtdq8qkLMwBs8Uw=
```
The server associates each pubkey to an ip and a port. Doesn't require root to run.
```
Usage: ./server PORT
Usage: ./server PORT [PRIVATE_KEY]
```
## Why
+102
View File
@@ -0,0 +1,102 @@
package antireplay
// thank you again to Wireguard-Go for helping me understand this
// most credit to https://git.zx2c4.com/wireguard-go/tree/replay/replay.go
// We use uintptr as blocks because pointers' size are optimized for the
// local CPU architecture.
const (
// a word filled with 1's
blockMask = ^uintptr(0)
// each word is 2**blockSizeLog bytes long
// 1 if > 8 bit 1 if > 16 bit 1 if > 32 bit
blockSizeLog = blockMask>>8&1 + blockMask>>16&1 + blockMask>>32&1
// size of word in bytes
blockSize = 1 << blockSizeLog
)
const (
// total number of bits in the array
// must be power of 2
blocksTotalBits = 1024
// bits in a block
blockBits = blockSize * 8
// log of bits in a block
blockBitsLog = blockSizeLog + 3
// WindowSize is the size of the range in which indicies are stored
// W = M-1*blockSize
// uint64 to avoid casting in comparisons
WindowSize = uint64(blocksTotalBits - blockBits)
numBlocks = blocksTotalBits / blockSize
)
// Window is a sliding window that records which sequence numbers have been seen.
// It implements the anti-replay algorithm described in RFC 6479
type Window struct {
highest uint64
blocks [numBlocks]uintptr
}
// Reset resets the window to its initial state
func (w *Window) Reset() {
w.highest = 0
// this is fine because higher blocks are cleared during Check()
w.blocks[0] = 0
}
// Check records seeing index and returns true if the index is within the
// window and has not been seen before. If it returns false, the index is
// considered invalid.
func (w *Window) Check(index uint64) bool {
// check if too old
if index+WindowSize < w.highest {
return false
}
// bits outside the block size represent which block the index is in
indexBlock := index >> blockBitsLog
// move window if new index is higher
if index > w.highest {
currTopBlock := w.highest >> blockBitsLog
// how many blocks ahead is indexBlock?
// cap it at a full circle around the array, at that point we clear the
// whole thing
newBlocks := min(indexBlock-currTopBlock, numBlocks)
// clear each new block
for i := uint64(1); i <= newBlocks; i++ {
// mod index so it wraps around
w.blocks[(currTopBlock+i)%numBlocks] = 0
}
w.highest = index
}
// we didn't mod until now because we needed to know the difference between
// a lower index and wrapped higher index
// we need to keep the index inside the array now
indexBlock %= numBlocks
// bits inside the block represent where in the block the bit is
// mask it with the block size
indexBit := index & uint64(blockBits-1)
// finally check the index
// save existing block to see if it changes
oldBlock := w.blocks[indexBlock]
// create updated block
newBlock := oldBlock | (1 << indexBit)
// set block to new value
w.blocks[indexBlock] = newBlock
// if the bit wasn't already 1, the values should be different and this should return true
return oldBlock != newBlock
}
func min(a, b uint64) uint64 {
if a < b {
return a
}
return b
}
+51
View File
@@ -0,0 +1,51 @@
package antireplay
import (
"testing"
)
func TestWindow(t *testing.T) {
w := Window{}
w.testCheck(t, 0, true)
w.testCheck(t, 0, false)
w.testCheck(t, 1, true)
w.testCheck(t, 1, false)
w.testCheck(t, 0, false)
w.testCheck(t, 3, true)
w.testCheck(t, 2, true)
w.testCheck(t, 2, false)
w.testCheck(t, 3, false)
w.testCheck(t, 30, true)
w.testCheck(t, 29, true)
w.testCheck(t, 28, true)
w.testCheck(t, 30, false)
w.testCheck(t, 28, false)
w.testCheck(t, WindowSize, true)
w.testCheck(t, WindowSize, false)
w.testCheck(t, WindowSize+1, true)
w.Reset()
w.testCheck(t, 0, true)
w.testCheck(t, 1, true)
w.testCheck(t, WindowSize, true)
w.Reset()
w.testCheck(t, WindowSize+1, true)
w.testCheck(t, 0, false)
w.testCheck(t, 1, true)
w.testCheck(t, WindowSize+3, true)
w.testCheck(t, 1, false)
w.testCheck(t, 2, false)
w.testCheck(t, WindowSize*3, true)
w.testCheck(t, WindowSize*2-1, false)
w.testCheck(t, WindowSize*2, true)
w.testCheck(t, WindowSize*3, false)
}
func (w *Window) testCheck(t *testing.T, index uint64, expected bool) {
result := w.Check(index)
t.Log(index, "->", result)
if result != expected {
t.FailNow()
}
}
+74
View File
@@ -0,0 +1,74 @@
package auth
import (
"crypto/rand"
"github.com/flynn/noise"
"github.com/malcolmseyd/natpunch-go/antireplay"
"golang.org/x/crypto/curve25519"
)
var noiseConfig = noise.Config{
CipherSuite: noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashBLAKE2s),
Random: rand.Reader,
Pattern: noise.HandshakeIK,
Initiator: true,
Prologue: []byte("natpunch-go is the best :)"),
}
// CipherState is an alternate implementation of noise.CipherState
// that allows manual control over the nonce
type CipherState struct {
c noise.Cipher
n uint64
w antireplay.Window
}
// NewCipherState initializes a new CipherState
func NewCipherState(c noise.Cipher) *CipherState {
return &CipherState{c: c}
}
// Encrypt is the same as noise.HandshakeState
func (s *CipherState) Encrypt(out, ad, plaintext []byte) []byte {
out = s.c.Encrypt(out, s.n, ad, plaintext)
s.n++
return out
}
// Decrypt is the same as noise.HandshakeState
func (s *CipherState) Decrypt(out, ad, ciphertext []byte) ([]byte, error) {
out, err := s.c.Decrypt(out, s.n, ad, ciphertext)
s.n++
return out, err
}
// Nonce returns the nonce value inside CipherState
func (s *CipherState) Nonce() uint64 {
return s.n
}
// SetNonce sets the nonce value inside CipherState
func (s *CipherState) SetNonce(n uint64) {
s.n = n
}
// CheckNonce returns true if the nonce is valid, and false if the nonce is
// reused or outside of the sliding window
func (s *CipherState) CheckNonce(n uint64) bool {
return s.w.Check(n)
}
// NewConfig initializes a new noise.Config with the provided data
func NewConfig(privkey, theirPubkey [32]byte) (config noise.Config, err error) {
config = noiseConfig
config.StaticKeypair = noise.DHKey{
Private: privkey[:],
}
config.StaticKeypair.Public, err = curve25519.X25519(config.StaticKeypair.Private, curve25519.Basepoint)
if err != nil {
return config, err
}
config.PeerStatic = theirPubkey[:]
return
}
+47 -16
View File
@@ -13,12 +13,12 @@ import (
"github.com/ogier/pflag"
"github.com/malcolmseyd/natpunch-go/client/auth"
"github.com/malcolmseyd/natpunch-go/client/cmd"
"github.com/malcolmseyd/natpunch-go/client/network"
"github.com/malcolmseyd/natpunch-go/client/util"
)
const timeout = time.Second * 10
const persistentKeepalive = 25
func main() {
@@ -30,7 +30,7 @@ func main() {
pflag.Parse()
args := pflag.Args()
if len(args) < 2 {
if len(args) < 3 {
printUsage()
os.Exit(1)
}
@@ -40,7 +40,9 @@ func main() {
os.Exit(1)
}
serverSplit := strings.Split(args[0], ":")
ifaceName := args[0]
serverSplit := strings.Split(args[1], ":")
serverHostname := serverSplit[0]
if len(serverSplit) < 2 {
fmt.Fprintln(os.Stderr, "Please include a port like this:", serverHostname+":PORT")
@@ -53,12 +55,20 @@ func main() {
if err != nil {
log.Fatalln("Error parsing server port:", err)
}
serverKey, err := base64.StdEncoding.DecodeString(args[2])
if err != nil || len(serverKey) != 32 {
log.Fatalln("Server key has improper formatting")
}
var serverKeyArr network.Key
copy(serverKeyArr[:], serverKey)
server := network.Server{
Hostname: serverHostname,
Addr: serverAddr,
Port: uint16(serverPort),
Pubkey: serverKeyArr,
}
ifaceName := args[1]
run(ifaceName, server, *continuous, *delay)
}
@@ -72,6 +82,7 @@ func run(ifaceName string, server network.Server, continuous bool, delay float32
// get info about the Wireguard config
clientPort := cmd.GetClientPort(ifaceName)
clientPubkey := cmd.GetClientPubkey(ifaceName)
clientPrivkey := cmd.GetClientPrivkey(ifaceName)
client := network.Peer{
IP: clientIP,
@@ -96,6 +107,10 @@ func run(ifaceName string, server network.Server, continuous bool, delay float32
fmt.Println("Resolving", totalPeers, "peers")
var sendCipher, recvCipher *auth.CipherState
var index uint32
var err error
// we keep requesting if the server doesn't have one of our peers.
// this keeps running until all connections are established.
tryAgain := true
@@ -105,16 +120,32 @@ func run(ifaceName string, server network.Server, continuous bool, delay float32
if peer.Resolved && !continuous {
continue
}
// Noise handshake w/ key rotation
if time.Since(server.LastHandshake) > network.RekeyDuration {
sendCipher, recvCipher, index, err = network.Handshake(rawConn, clientPrivkey, &server, &client)
if err != nil {
if err, ok := err.(net.Error); ok && err.Timeout() {
fmt.Println("Connection to", server.Hostname, "timed out.")
tryAgain = true
break
}
fmt.Fprintln(os.Stderr, "Key rotation failed:", err)
tryAgain = true
break
}
}
fmt.Printf("(%d/%d) %s: ", resolvedPeers, totalPeers, base64.RawStdEncoding.EncodeToString(peer.Pubkey[:])[:16])
copy(payload[32:64], peer.Pubkey[:])
err := network.SendPacket(payload, rawConn, &server, &client)
err := network.SendDataPacket(sendCipher, index, payload, rawConn, &server, &client)
if err != nil {
log.Println("\nError sending packet:", err)
continue
}
response, n, err := network.RecvPacket(rawConn, timeout, &server, &client)
// throw away udp header, we have no use for it right now
body, _, packetType, n, err := network.RecvDataPacket(recvCipher, rawConn, &server, &client)
if err != nil {
if err, ok := err.(net.Error); ok && err.Timeout() {
fmt.Println("\nConnection to", server.Hostname, "timed out.")
@@ -124,24 +155,24 @@ func run(ifaceName string, server network.Server, continuous bool, delay float32
fmt.Println("\nError receiving packet:", err)
continue
}
if packetType != network.PacketData {
fmt.Println("\nExpected data packet, got", packetType)
}
if n == network.EmptyUDPSize {
if len(body) == 0 {
fmt.Println("not found")
tryAgain = true
continue
} else if n < network.EmptyUDPSize {
log.Println("\nError: response is not a valid udp packet")
continue
} else if n != network.EmptyUDPSize+4+2 {
} else if len(body) != 4+2 {
// expected packet size, 4 bytes for ip, 2 for port
log.Println("\nError: invalid response of length", n)
log.Println("\nError: invalid response of length", len(body))
// For debugging
fmt.Println(hex.Dump(response[:n]))
fmt.Println(hex.Dump(body[:n]))
tryAgain = true
continue
}
peer.IP, peer.Port = network.ParseResponse(response)
peer.IP, peer.Port = network.ParseResponse(body)
if peer.IP == nil {
log.Println("Error parsing packet: not a valid UDP packet")
}
@@ -173,13 +204,13 @@ func run(ifaceName string, server network.Server, continuous bool, delay float32
func printUsage() {
fmt.Fprintf(os.Stderr,
"Usage: %s [OPTION]... SERVER_HOSTNAME:PORT WIREGUARD_INTERFACE\n"+
"Usage: %s [OPTION]... WIREGUARD_INTERFACE SERVER_HOSTNAME:PORT SERVER_PUBKEY\n"+
"Flags:\n", os.Args[0],
)
pflag.PrintDefaults()
fmt.Fprintf(os.Stderr,
"Example:\n"+
" %s demo.wireguard.com:12345 wg0\n",
" %s wg0 demo.wireguard.com:12345 1rwvlEQkF6vL4jA1gRzlTM7I3tuZHtdq8qkLMwBs8Uw=\n",
os.Args[0],
)
}
+18 -3
View File
@@ -44,8 +44,8 @@ func GetPeers(iface string) []string {
return strings.Split(strings.TrimSpace(output), "\n")
}
// GetClientPubkey returns the pubkey on the Wireguard interface
func GetClientPubkey(iface string) network.Pubkey {
// GetClientPubkey returns the publib key on the Wireguard interface
func GetClientPubkey(iface string) network.Key {
var keyArr [32]byte
output, err := RunCmd("wg", "show", iface, "public-key")
if err != nil {
@@ -56,7 +56,22 @@ func GetClientPubkey(iface string) network.Pubkey {
log.Fatalln("Error parsing client pubkey:", err)
}
copy(keyArr[:], keyBytes)
return network.Pubkey(keyArr)
return network.Key(keyArr)
}
// GetClientPrivkey returns the private key on the Wireguard interface
func GetClientPrivkey(iface string) network.Key {
var keyArr [32]byte
output, err := RunCmd("wg", "show", iface, "private-key")
if err != nil {
log.Fatalln("Error getting client privkey:", err)
}
keyBytes, err := base64.StdEncoding.DecodeString(strings.TrimSpace(output))
if err != nil {
log.Fatalln("Error parsing client privkey:", err)
}
copy(keyArr[:], keyBytes)
return network.Key(keyArr)
}
// SetPeer updates a peer's endpoint and keepalive with `wg`. keepalive is in seconds
-10
View File
@@ -1,10 +0,0 @@
module github.com/malcolmseyd/natpunch-go/client
go 1.14
require (
github.com/google/gopacket v1.1.18
github.com/ogier/pflag v0.0.1
github.com/vishvananda/netlink v1.1.0
golang.org/x/net v0.0.0-20200707034311-ab3426394381
)
+159 -22
View File
@@ -2,32 +2,61 @@ package network
import (
"bytes"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"log"
"net"
"time"
"github.com/flynn/noise"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/malcolmseyd/natpunch-go/client/auth"
"github.com/vishvananda/netlink"
"golang.org/x/net/bpf"
"golang.org/x/net/ipv4"
)
const udpProtocol = 17
const (
udpProtocol = 17
// EmptyUDPSize is the size of an empty UDP packet
EmptyUDPSize = 28
timeout = time.Second * 10
// PacketHandshakeInit identifies handhshake initiation packets
PacketHandshakeInit byte = 1
// PacketHandshakeResp identifies handhshake response packets
PacketHandshakeResp byte = 2
// PacketData identifies regular data packets
PacketData byte = 3
)
var (
// ErrPacketType is returned when an unexepcted packet type is enountered
ErrPacketType = errors.New("client/network: incorrect packet type")
// ErrNonce is returned when the nonce on a packet isn't valid
ErrNonce = errors.New("client/network: invalid nonce")
// RekeyDuration is the time after which keys are invalid and a new handshake is required.
RekeyDuration = 5 * time.Minute
)
// EmptyUDPSize is the size of the IPv4 and UDP headers combined.
const EmptyUDPSize = 28
// Pubkey stores a 32 byte representation of a Wireguard public key
type Pubkey [32]byte
// Key stores a 32 byte representation of a Wireguard key
type Key [32]byte
// Server stores data relating to the server and its location
// Server stores data relating to the server
type Server struct {
Hostname string
Addr *net.IPAddr
Port uint16
Pubkey Key
LastHandshake time.Time
}
// Peer stores data about a peer's key and endpoint, whether it's another peer or the client
@@ -38,7 +67,7 @@ type Peer struct {
Resolved bool
IP net.IP
Port uint16
Pubkey Pubkey
Pubkey Key
}
// GetClientIP gets source ip address that will be used when sending data to dstIP
@@ -150,6 +179,62 @@ func MakePacket(payload []byte, server *Server, client *Peer) []byte {
return buf.Bytes()
}
// Handshake performs a Noise-IK handshake with the Server
func Handshake(conn *ipv4.RawConn, privkey Key, server *Server, client *Peer) (sendCipher, recvCipher *auth.CipherState, index uint32, err error) {
// we generate index on the client side
indexBytes := make([]byte, 4)
rand.Read(indexBytes)
index = binary.BigEndian.Uint32(indexBytes)
config, err := auth.NewConfig(privkey, server.Pubkey)
if err != nil {
return
}
handshake, err := noise.NewHandshakeState(config)
if err != nil {
return
}
header := append([]byte{PacketHandshakeInit}, indexBytes...)
timestamp := make([]byte, 8)
binary.BigEndian.PutUint64(timestamp, uint64(time.Now().UnixNano()))
packet, _, _, err := handshake.WriteMessage(header, timestamp)
if err != nil {
return
}
err = SendPacket(packet, conn, server, client)
if err != nil {
return
}
response, n, err := RecvPacket(conn, server, client)
if err != nil {
return
}
response = response[EmptyUDPSize:n]
packetType := response[0]
response = response[1:]
if packetType != PacketHandshakeResp {
err = ErrPacketType
return
}
index = binary.BigEndian.Uint32(response[:4])
response = response[4:]
_, send, recv, err := handshake.ReadMessage(nil, response)
// we use our own implementation for manual nonce control
sendCipher = auth.NewCipherState(send.Cipher())
recvCipher = auth.NewCipherState(recv.Cipher())
server.LastHandshake = time.Now()
return
}
// SendPacket sends packet to the Server
func SendPacket(packet []byte, conn *ipv4.RawConn, server *Server, client *Peer) error {
fullPacket := MakePacket(packet, server, client)
@@ -157,8 +242,25 @@ func SendPacket(packet []byte, conn *ipv4.RawConn, server *Server, client *Peer)
return err
}
// SendDataPacket encrypts and sends packet to the Server
func SendDataPacket(cipher *auth.CipherState, index uint32, data []byte, conn *ipv4.RawConn, server *Server, client *Peer) error {
indexBytes := make([]byte, 4)
binary.BigEndian.PutUint32(indexBytes, index)
nonceBytes := make([]byte, 8)
binary.BigEndian.PutUint64(nonceBytes, cipher.Nonce())
// println("sending nonce:", cipher.Nonce())
header := append([]byte{PacketData}, indexBytes...)
header = append(header, nonceBytes...)
packet := cipher.Encrypt(header, nil, data)
return SendPacket(packet, conn, server, client)
}
// RecvPacket recieves a UDP packet from server
func RecvPacket(conn *ipv4.RawConn, timeout time.Duration, server *Server, client *Peer) ([]byte, int, error) {
func RecvPacket(conn *ipv4.RawConn, server *Server, client *Peer) ([]byte, int, error) {
err := conn.SetReadDeadline(time.Now().Add(timeout))
if err != nil {
return nil, 0, err
@@ -173,28 +275,63 @@ func RecvPacket(conn *ipv4.RawConn, timeout time.Duration, server *Server, clien
return response, n, nil
}
// RecvDataPacket recieves a UDP packet from server
func RecvDataPacket(cipher *auth.CipherState, conn *ipv4.RawConn, server *Server, client *Peer) (body, header []byte, packetType byte, n int, err error) {
response, n, err := RecvPacket(conn, server, client)
if err != nil {
return
}
header = response[:EmptyUDPSize]
response = response[EmptyUDPSize:n]
// println(hex.Dump(response))
packetType = response[0]
response = response[1:]
nonce := binary.BigEndian.Uint64(response[:8])
response = response[8:]
cipher.SetNonce(nonce)
// println("recving nonce:", nonce)
body, err = cipher.Decrypt(nil, nil, response)
if err != nil {
return
}
// now that we're authenticated, see if the nonce is valid
// the sliding window contains a generous 1000 packets, that should hold up
// with plenty of peers.
if !cipher.CheckNonce(nonce) {
err = ErrNonce
body = nil
}
return
}
// ParseResponse takes a response packet and parses it into an IP and port.
// There's no error checking, we assume that data passed in is valid
func ParseResponse(response []byte) (net.IP, uint16) {
var ip net.IP
var ipv4Slice []byte = make([]byte, 4)
var port uint16
packet := gopacket.NewPacket(response, layers.LayerTypeIPv4, gopacket.DecodeOptions{
Lazy: true,
NoCopy: true,
})
if packet.TransportLayer().LayerType() != layers.LayerTypeUDP {
return nil, 0
}
payload := packet.ApplicationLayer().LayerContents()
// packet := gopacket.NewPacket(response, layers.LayerTypeIPv4, gopacket.DecodeOptions{
// Lazy: true,
// NoCopy: true,
// })
// if packet.TransportLayer().LayerType() != layers.LayerTypeUDP {
// return nil, 0
// }
// payload := packet.ApplicationLayer().LayerContents()
data := bytes.NewBuffer(payload)
// fmt.Println("Layer payload:\n", hex.Dump(data.Bytes()))
// data := bytes.NewBuffer(payload)
// // fmt.Println("Layer payload:\n", hex.Dump(data.Bytes()))
binary.Read(data, binary.BigEndian, &ipv4Slice)
ip = net.IP(ipv4Slice)
binary.Read(data, binary.BigEndian, &port)
// fmt.Println("ip:", ip.String(), "port:", port)
// binary.Read(data, binary.BigEndian, &ipv4Slice)
// ip = net.IP(ipv4Slice)
// binary.Read(data, binary.BigEndian, &port)
// // fmt.Println("ip:", ip.String(), "port:", port)
ip = net.IP(response[:4])
port = binary.BigEndian.Uint16(response[4:6])
return ip, port
}
+1 -1
View File
@@ -20,7 +20,7 @@ func MakePeerSlice(peerKeys []string) []network.Peer {
copy(keyArr[:], keyBytes)
peer := network.Peer{
Pubkey: network.Pubkey(keyArr),
Pubkey: network.Key(keyArr),
Resolved: false,
}
keys[i] = peer
+12
View File
@@ -0,0 +1,12 @@
module github.com/malcolmseyd/natpunch-go
go 1.15
require (
github.com/flynn/noise v0.0.0-20180327030543-2492fe189ae6
github.com/google/gopacket v1.1.18
github.com/ogier/pflag v0.0.1
github.com/vishvananda/netlink v1.1.0
golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a
golang.org/x/net v0.0.0-20200822124328-c89045814202
)
+7 -2
View File
@@ -1,3 +1,5 @@
github.com/flynn/noise v0.0.0-20180327030543-2492fe189ae6 h1:u/UEqS66A5ckRmS4yNpjmVH56sVtS/RfclBAYocb4as=
github.com/flynn/noise v0.0.0-20180327030543-2492fe189ae6/go.mod h1:1i71OnUq3iUe1ma7Lr6yG6/rjvM3emb6yoL7xLFzcVQ=
github.com/google/gopacket v1.1.18 h1:lum7VRA9kdlvBi7/v2p7/zcbkduHaCH/SVVyurs7OpY=
github.com/google/gopacket v1.1.18/go.mod h1:UdDNZ1OO62aGYVnPhxT1U6aI7ukYtA/kB8vaU0diBUM=
github.com/ogier/pflag v0.0.1 h1:RW6JSWSu/RkSatfcLtogGfFgpim5p7ARQ10ECk5O750=
@@ -7,10 +9,13 @@ github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYp
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df h1:OviZH7qLw/7ZovXvuNyL3XQl8UFofeikI1NW1Gypu7k=
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a h1:vclmkQCjlDX5OydZ9wv8rBCcS0QyQY66Mpf/7BZbInM=
golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20200707034311-ab3426394381 h1:VXak5I6aEWmAXeQjA+QSZzlgNrpq9mjcfDemuexIKsU=
golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA=
golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190405154228-4b34438f7a67/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+74
View File
@@ -0,0 +1,74 @@
package auth
import (
"crypto/rand"
"github.com/flynn/noise"
"github.com/malcolmseyd/natpunch-go/antireplay"
"golang.org/x/crypto/curve25519"
)
var noiseConfig = noise.Config{
CipherSuite: noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashBLAKE2s),
Random: rand.Reader,
Pattern: noise.HandshakeIK,
Initiator: false,
Prologue: []byte("natpunch-go is the best :)"),
}
// CipherState is an alternate implementation of noise.CipherState
// that allows manual control over the nonce
type CipherState struct {
c noise.Cipher
n uint64
w antireplay.Window
}
// NewCipherState initializes a new CipherState
func NewCipherState(c noise.Cipher) *CipherState {
return &CipherState{c: c}
}
// Encrypt is the same as noise.HandshakeState
func (s *CipherState) Encrypt(out, ad, plaintext []byte) []byte {
out = s.c.Encrypt(out, s.n, ad, plaintext)
s.n++
return out
}
// Decrypt is the same as noise.HandshakeState
func (s *CipherState) Decrypt(out, ad, ciphertext []byte) ([]byte, error) {
out, err := s.c.Decrypt(out, s.n, ad, ciphertext)
s.n++
return out, err
}
// Nonce returns the nonce value inside CipherState
func (s *CipherState) Nonce() uint64 {
return s.n
}
// SetNonce sets the nonce value inside CipherState
func (s *CipherState) SetNonce(n uint64) {
s.n = n
}
// CheckNonce returns true if the nonce is valid, and false if the nonce is
// reused or outside of the sliding window
func (s *CipherState) CheckNonce(n uint64) bool {
return s.w.Check(n)
}
// NewConfig initializes a new noise.Config with the provided data
func NewConfig(privkey, theirPubkey [32]byte) (config noise.Config, err error) {
config = noiseConfig
config.StaticKeypair = noise.DHKey{
Private: privkey[:],
}
config.StaticKeypair.Public, err = curve25519.X25519(config.StaticKeypair.Private, curve25519.Basepoint)
if err != nil {
return config, err
}
config.PeerStatic = theirPubkey[:]
return
}
-3
View File
@@ -1,3 +0,0 @@
module github.com/malcolmseyd/natpunch-go/server
go 1.14
+244 -34
View File
@@ -2,30 +2,113 @@ package main
import (
"bytes"
"crypto/rand"
"encoding/base64"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"log"
"net"
"os"
"time"
"github.com/flynn/noise"
"github.com/malcolmseyd/natpunch-go/server/auth"
"golang.org/x/crypto/curve25519"
)
// Endpoint is the location of a peer.
type Endpoint struct {
ip net.IP
port uint16
const (
// PacketHandshakeInit identifies handhshake initiation packets
PacketHandshakeInit byte = 1
// PacketHandshakeResp identifies handhshake response packets
PacketHandshakeResp byte = 2
// PacketData identifies regular data packets.
PacketData byte = 3
)
var (
// ErrPacketType is returned when an unexepcted packet type is enountered
ErrPacketType = errors.New("server: incorrect packet type")
// ErrPeerNotFound is returned when the requested peer is not found
ErrPeerNotFound = errors.New("server: peer not found")
// ErrPubkey is returned when the public key recieved does not match the one we expect
ErrPubkey = errors.New("server: public key did not match expected one")
// ErrOldTimestamp is returned when a handshake timestamp isn't newer than the previous one
ErrOldTimestamp = errors.New("server: handshake timestamp isn't new")
// ErrNoTimestamp is returned when the handshake packet doesn't contain a timestamp
ErrNoTimestamp = errors.New("server: handshake had no timestamp")
// ErrNonce is returned when the nonce on a packet isn't valid
ErrNonce = errors.New("client/network: invalid nonce")
timeout = 5 * time.Second
noiseConfig = noise.Config{
CipherSuite: noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashBLAKE2s),
Random: rand.Reader,
Pattern: noise.HandshakeIK,
Initiator: false,
Prologue: []byte("natpunch-go is the best :)"),
}
)
// Key stores a Wireguard key
type Key [32]byte
// we use pointers on these maps so that two maps can link to one object
// PeerMap stores the peers by key
type PeerMap map[Key]*Peer
// IndexMap stores the Peers by index
type IndexMap map[uint32]*Peer
// Peer represents a Wireguard peer.
type Peer struct {
ip net.IP
port uint16
pubkey Key
index uint32
send, recv *auth.CipherState
// UnixNano cast to uint64
lastHandshake uint64
}
type state struct {
conn *net.UDPConn
keyMap PeerMap
indexMap IndexMap
privKey Key
}
func main() {
if len(os.Args) < 2 {
fmt.Fprintln(os.Stderr, "Usage:", os.Args[0], "PORT")
fmt.Fprintln(os.Stderr, "Usage:", os.Args[0], "PORT [PRIVATE_KEY]")
os.Exit(1)
}
port := os.Args[1]
s := state{}
var err error
port := os.Args[1]
if len(os.Args) > 2 {
priv, err := base64.StdEncoding.DecodeString(os.Args[2])
if err != nil || len(priv) != 32 {
fmt.Fprintln(os.Stderr, "Error parsing public key")
}
copy(s.privKey[:], priv)
} else {
rand.Read(s.privKey[:])
s.privKey.clamp()
}
pubkey, _ := curve25519.X25519(s.privKey[:], curve25519.Basepoint)
fmt.Println("Starting nat-punching server on port", port)
peers := make(map[[32]byte]Endpoint)
fmt.Println("Public key:", base64.StdEncoding.EncodeToString(pubkey))
s.keyMap = make(PeerMap)
s.indexMap = make(IndexMap)
// the client can only handle IPv4 addresses right now.
listenAddr, err := net.ResolveUDPAddr("udp4", ":"+port)
@@ -33,67 +116,194 @@ func main() {
log.Panicln("Error getting UDP address", err)
}
conn, err := net.ListenUDP("udp4", listenAddr)
s.conn, err = net.ListenUDP("udp4", listenAddr)
if err != nil {
log.Panicln("Error getting UDP listen connection")
}
for {
err := handleConnection(conn, peers)
err := s.handleConnection()
if err != nil {
log.Panicln("Error handling the connection", err)
fmt.Println("Error handling the connection", err)
}
}
}
func handleConnection(conn *net.UDPConn, peers map[[32]byte]Endpoint) error {
var packet [64]byte
func (s *state) handleConnection() error {
packet := make([]byte, 4096)
_, clientAddr, err := conn.ReadFromUDP(packet[:])
n, clientAddr, err := s.conn.ReadFromUDP(packet)
if err != nil {
return err
}
packet = packet[:n]
var clientPubKey [32]byte
copy(clientPubKey[:], packet[0:32])
packetType := packet[0]
packet = packet[1:]
var targetPubKey [32]byte
copy(targetPubKey[:], packet[32:64])
clientLocation := Endpoint{
ip: clientAddr.IP,
port: uint16(clientAddr.Port),
if packetType == PacketHandshakeInit {
return s.handshake(packet, clientAddr, timeout)
} else if packetType == PacketData {
return s.dataPacket(packet, clientAddr, timeout)
} else {
fmt.Println("Unknown packet type:", packetType)
fmt.Println(hex.Dump(packet))
}
peers[clientPubKey] = clientLocation
return nil
}
targetLocation, exists := peers[targetPubKey]
func (s *state) dataPacket(packet []byte, clientAddr *net.UDPAddr, timeout time.Duration) (err error) {
index := binary.BigEndian.Uint32(packet[:4])
packet = packet[4:]
response := bytes.NewBuffer([]byte{})
if exists {
binary.Write(response, binary.BigEndian, targetLocation.ip)
binary.Write(response, binary.BigEndian, targetLocation.port)
client, ok := s.indexMap[index]
if !ok {
return
}
// otherwise send an empty response
_, err = conn.WriteToUDP(response.Bytes(), clientAddr)
nonce := binary.BigEndian.Uint64(packet[:8])
packet = packet[8:]
// println("recving nonce", nonce)
client.recv.SetNonce(nonce)
plaintext, err := client.recv.Decrypt(nil, nil, packet)
if err != nil {
return nil
return
}
if !client.recv.CheckNonce(nonce) {
// no need to throw an error, just return
return
}
clientPubKey := plaintext[:32]
plaintext = plaintext[32:]
if !bytes.Equal(clientPubKey, client.pubkey[:]) {
err = ErrPubkey
return
}
var targetPubKey Key
copy(targetPubKey[:], plaintext[:32])
// for later use
plaintext = plaintext[:6]
client.ip = clientAddr.IP
client.port = uint16(clientAddr.Port)
targetPeer, peerExists := s.keyMap[targetPubKey]
if peerExists {
// client must be ipv4 so this will never return nil
copy(plaintext[:4], targetPeer.ip.To4())
binary.BigEndian.PutUint16(plaintext[4:6], targetPeer.port)
} else {
// return nothing if peer not found
plaintext = plaintext[:0]
}
nonceBytes := make([]byte, 8)
binary.BigEndian.PutUint64(nonceBytes, client.send.Nonce())
header := append([]byte{PacketData}, nonceBytes...)
// println("sent nonce:", client.send.Nonce())
// println("sending", len(plaintext), "bytes")
response := client.send.Encrypt(header, nil, plaintext)
_, err = s.conn.WriteToUDP(response, clientAddr)
if err != nil {
return
}
fmt.Print(
base64.StdEncoding.EncodeToString(clientPubKey[:])[:16],
base64.StdEncoding.EncodeToString(client.pubkey[:])[:16],
" ==> ",
base64.StdEncoding.EncodeToString(targetPubKey[:])[:16],
": ",
)
if exists {
if peerExists {
fmt.Println("CONNECTED")
} else {
fmt.Println("NOT FOUND")
}
return nil
return
}
func (s *state) handshake(packet []byte, clientAddr *net.UDPAddr, timeout time.Duration) (err error) {
config := noiseConfig
config.StaticKeypair = noise.DHKey{
Private: s.privKey[:],
}
config.StaticKeypair.Public, err = curve25519.X25519(config.StaticKeypair.Private, curve25519.Basepoint)
if err != nil {
return
}
handshake, err := noise.NewHandshakeState(config)
if err != nil {
return
}
indexBytes := packet[:4]
index := binary.BigEndian.Uint32(indexBytes)
packet = packet[4:]
timestampBytes, _, _, err := handshake.ReadMessage(nil, packet)
if err != nil {
return
}
if len(timestampBytes) == 0 {
err = ErrNoTimestamp
}
timestamp := binary.BigEndian.Uint64(timestampBytes)
var pubkey Key
copy(pubkey[:], handshake.PeerStatic())
client, ok := s.keyMap[pubkey]
if !ok {
client = &Peer{
pubkey: pubkey,
}
s.keyMap[pubkey] = client
}
if timestamp <= client.lastHandshake {
err = ErrOldTimestamp
return
}
client.lastHandshake = timestamp
// clear old entry
s.indexMap[index] = nil
client.ip = clientAddr.IP
client.port = uint16(clientAddr.Port)
// if index is aleady taken, set a new one
for {
_, ok = s.indexMap[index]
if !ok {
break
}
index++
}
client.index = index
binary.BigEndian.PutUint32(indexBytes, index)
s.indexMap[index] = client
header := append([]byte{PacketHandshakeResp}, indexBytes...)
// recv and send are opposite order from client code
packet, recv, send, err := handshake.WriteMessage(header, nil)
if err != nil {
return
}
client.send = auth.NewCipherState(send.Cipher())
client.recv = auth.NewCipherState(recv.Cipher())
_, err = s.conn.WriteTo(packet, clientAddr)
return
}
func (k *Key) clamp() {
k[0] &= 248
k[31] = (k[31] & 127) | 64
}