mirror of
https://github.com/libp2p/go-libp2p.git
synced 2026-04-22 16:17:19 +08:00
830 lines
26 KiB
Go
830 lines
26 KiB
Go
package libp2pwebtransport_test
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"runtime"
|
|
"sync/atomic"
|
|
"testing"
|
|
"testing/quick"
|
|
"time"
|
|
|
|
ic "github.com/libp2p/go-libp2p/core/crypto"
|
|
"github.com/libp2p/go-libp2p/core/network"
|
|
mocknetwork "github.com/libp2p/go-libp2p/core/network/mocks"
|
|
"github.com/libp2p/go-libp2p/core/peer"
|
|
"github.com/libp2p/go-libp2p/core/test"
|
|
tpt "github.com/libp2p/go-libp2p/core/transport"
|
|
"github.com/libp2p/go-libp2p/p2p/transport/quicreuse"
|
|
libp2pwebtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport"
|
|
|
|
"github.com/benbjohnson/clock"
|
|
ma "github.com/multiformats/go-multiaddr"
|
|
manet "github.com/multiformats/go-multiaddr/net"
|
|
"github.com/multiformats/go-multibase"
|
|
"github.com/multiformats/go-multihash"
|
|
"github.com/quic-go/quic-go"
|
|
"github.com/quic-go/quic-go/http3"
|
|
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/mock/gomock"
|
|
)
|
|
|
|
const clockSkewAllowance = time.Hour
|
|
const certValidity = 14 * 24 * time.Hour
|
|
|
|
func newIdentity(t *testing.T) (peer.ID, ic.PrivKey) {
|
|
key, _, err := ic.GenerateEd25519Key(rand.Reader)
|
|
require.NoError(t, err)
|
|
id, err := peer.IDFromPrivateKey(key)
|
|
require.NoError(t, err)
|
|
return id, key
|
|
}
|
|
|
|
func randomMultihash(t *testing.T) string {
|
|
t.Helper()
|
|
b := make([]byte, 16)
|
|
rand.Read(b)
|
|
h, err := multihash.Encode(b, multihash.KECCAK_224)
|
|
require.NoError(t, err)
|
|
s, err := multibase.Encode(multibase.Base32hex, h)
|
|
require.NoError(t, err)
|
|
return s
|
|
}
|
|
|
|
func extractCertHashes(addr ma.Multiaddr) []string {
|
|
var certHashesStr []string
|
|
ma.ForEach(addr, func(c ma.Component) bool {
|
|
if c.Protocol().Code == ma.P_CERTHASH {
|
|
certHashesStr = append(certHashesStr, c.Value())
|
|
}
|
|
return true
|
|
})
|
|
return certHashesStr
|
|
}
|
|
|
|
func stripCertHashes(addr ma.Multiaddr) ma.Multiaddr {
|
|
for {
|
|
_, err := addr.ValueForProtocol(ma.P_CERTHASH)
|
|
if err != nil {
|
|
return addr
|
|
}
|
|
addr, _ = ma.SplitLast(addr)
|
|
}
|
|
}
|
|
|
|
// create a /certhash multiaddr component using the SHA256 of foobar
|
|
func getCerthashComponent(t *testing.T, b []byte) ma.Multiaddr {
|
|
t.Helper()
|
|
h := sha256.Sum256(b)
|
|
mh, err := multihash.Encode(h[:], multihash.SHA2_256)
|
|
require.NoError(t, err)
|
|
certStr, err := multibase.Encode(multibase.Base58BTC, mh)
|
|
require.NoError(t, err)
|
|
ha, err := ma.NewComponent(ma.ProtocolWithCode(ma.P_CERTHASH).Name, certStr)
|
|
require.NoError(t, err)
|
|
return ha
|
|
}
|
|
|
|
func newConnManager(t *testing.T, opts ...quicreuse.Option) *quicreuse.ConnManager {
|
|
t.Helper()
|
|
cm, err := quicreuse.NewConnManager([32]byte{}, opts...)
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() { cm.Close() })
|
|
return cm
|
|
}
|
|
|
|
func TestTransport(t *testing.T) {
|
|
serverID, serverKey := newIdentity(t)
|
|
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, nil)
|
|
require.NoError(t, err)
|
|
defer tr.(io.Closer).Close()
|
|
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
|
|
require.NoError(t, err)
|
|
defer ln.Close()
|
|
|
|
addrChan := make(chan ma.Multiaddr)
|
|
go func() {
|
|
_, clientKey := newIdentity(t)
|
|
tr2, err := libp2pwebtransport.New(clientKey, nil, newConnManager(t), nil, nil)
|
|
require.NoError(t, err)
|
|
defer tr2.(io.Closer).Close()
|
|
|
|
conn, err := tr2.Dial(context.Background(), ln.Multiaddr(), serverID)
|
|
require.NoError(t, err)
|
|
str, err := conn.OpenStream(context.Background())
|
|
require.NoError(t, err)
|
|
_, err = str.Write([]byte("foobar"))
|
|
require.NoError(t, err)
|
|
require.NoError(t, str.Close())
|
|
|
|
// check RemoteMultiaddr
|
|
_, addr, err := manet.DialArgs(ln.Multiaddr())
|
|
require.NoError(t, err)
|
|
_, port, err := net.SplitHostPort(addr)
|
|
require.NoError(t, err)
|
|
require.Equal(t, ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/udp/%s/quic-v1/webtransport", port)), conn.RemoteMultiaddr())
|
|
addrChan <- conn.RemoteMultiaddr()
|
|
}()
|
|
|
|
conn, err := ln.Accept()
|
|
require.NoError(t, err)
|
|
require.False(t, conn.IsClosed())
|
|
str, err := conn.AcceptStream()
|
|
require.NoError(t, err)
|
|
data, err := io.ReadAll(str)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "foobar", string(data))
|
|
require.Equal(t, <-addrChan, conn.LocalMultiaddr())
|
|
require.NoError(t, conn.Close())
|
|
require.True(t, conn.IsClosed())
|
|
}
|
|
|
|
func TestHashVerification(t *testing.T) {
|
|
serverID, serverKey := newIdentity(t)
|
|
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, &network.NullResourceManager{})
|
|
require.NoError(t, err)
|
|
defer tr.(io.Closer).Close()
|
|
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
|
|
require.NoError(t, err)
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
_, err := ln.Accept()
|
|
require.Error(t, err)
|
|
}()
|
|
|
|
_, clientKey := newIdentity(t)
|
|
tr2, err := libp2pwebtransport.New(clientKey, nil, newConnManager(t), nil, &network.NullResourceManager{})
|
|
require.NoError(t, err)
|
|
defer tr2.(io.Closer).Close()
|
|
|
|
foobarHash := getCerthashComponent(t, []byte("foobar"))
|
|
|
|
t.Run("fails using only a wrong hash", func(t *testing.T) {
|
|
// replace the certificate hash in the multiaddr with a fake hash
|
|
addr := stripCertHashes(ln.Multiaddr()).Encapsulate(foobarHash)
|
|
_, err := tr2.Dial(context.Background(), addr, serverID)
|
|
require.Error(t, err)
|
|
var trErr *quic.TransportError
|
|
require.ErrorAs(t, err, &trErr)
|
|
require.Equal(t, quic.TransportErrorCode(0x12a), trErr.ErrorCode)
|
|
require.Contains(t, errors.Unwrap(trErr).Error(), "cert hash not found")
|
|
})
|
|
|
|
t.Run("fails when adding a wrong hash", func(t *testing.T) {
|
|
_, err := tr2.Dial(context.Background(), ln.Multiaddr().Encapsulate(foobarHash), serverID)
|
|
require.Error(t, err)
|
|
})
|
|
|
|
require.NoError(t, ln.Close())
|
|
<-done
|
|
}
|
|
|
|
func TestCanDial(t *testing.T) {
|
|
valid := []ma.Multiaddr{
|
|
ma.StringCast("/ip4/127.0.0.1/udp/1234/quic-v1/webtransport/certhash/" + randomMultihash(t)),
|
|
ma.StringCast("/ip6/b16b:8255:efc6:9cd5:1a54:ee86:2d7a:c2e6/udp/1234/quic-v1/webtransport/certhash/" + randomMultihash(t)),
|
|
ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/udp/1234/quic-v1/webtransport/certhash/%s/certhash/%s/certhash/%s", randomMultihash(t), randomMultihash(t), randomMultihash(t))),
|
|
ma.StringCast("/ip4/127.0.0.1/udp/1234/quic-v1/webtransport"), // no certificate hash
|
|
}
|
|
|
|
invalid := []ma.Multiaddr{
|
|
ma.StringCast("/ip4/127.0.0.1/udp/1234"), // missing webtransport
|
|
ma.StringCast("/ip4/127.0.0.1/udp/1234/webtransport"), // missing quic
|
|
ma.StringCast("/ip4/127.0.0.1/tcp/1234/webtransport"), // WebTransport over TCP? Is this a joke?
|
|
}
|
|
|
|
_, key := newIdentity(t)
|
|
tr, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{})
|
|
require.NoError(t, err)
|
|
defer tr.(io.Closer).Close()
|
|
|
|
for _, addr := range valid {
|
|
require.Truef(t, tr.CanDial(addr), "expected to be able to dial %s", addr)
|
|
}
|
|
for _, addr := range invalid {
|
|
require.Falsef(t, tr.CanDial(addr), "expected to not be able to dial %s", addr)
|
|
}
|
|
}
|
|
|
|
func TestListenAddrValidity(t *testing.T) {
|
|
valid := []ma.Multiaddr{
|
|
ma.StringCast("/ip6/::/udp/0/quic-v1/webtransport/"),
|
|
}
|
|
|
|
invalid := []ma.Multiaddr{
|
|
ma.StringCast("/ip4/127.0.0.1/udp/0"), // missing webtransport
|
|
ma.StringCast("/ip4/127.0.0.1/udp/0/webtransport"), // missing quic
|
|
ma.StringCast("/ip4/127.0.0.1/tcp/0/webtransport"), // WebTransport over TCP? Is this a joke?
|
|
ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport/certhash/" + randomMultihash(t)), // We can't listen on a specific certhash
|
|
}
|
|
|
|
_, key := newIdentity(t)
|
|
tr, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{})
|
|
require.NoError(t, err)
|
|
defer tr.(io.Closer).Close()
|
|
|
|
for _, addr := range valid {
|
|
ln, err := tr.Listen(addr)
|
|
require.NoErrorf(t, err, "expected to be able to listen on %s", addr)
|
|
ln.Close()
|
|
}
|
|
for _, addr := range invalid {
|
|
_, err := tr.Listen(addr)
|
|
require.Errorf(t, err, "expected to not be able to listen on %s", addr)
|
|
}
|
|
}
|
|
|
|
func TestListenerAddrs(t *testing.T) {
|
|
_, key := newIdentity(t)
|
|
tr, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{})
|
|
require.NoError(t, err)
|
|
defer tr.(io.Closer).Close()
|
|
|
|
ln1, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
|
|
require.NoError(t, err)
|
|
ln2, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
|
|
require.NoError(t, err)
|
|
hashes1 := extractCertHashes(ln1.Multiaddr())
|
|
require.Len(t, hashes1, 2)
|
|
hashes2 := extractCertHashes(ln2.Multiaddr())
|
|
require.Equal(t, hashes1, hashes2)
|
|
}
|
|
|
|
func TestResourceManagerDialing(t *testing.T) {
|
|
ctrl := gomock.NewController(t)
|
|
defer ctrl.Finish()
|
|
rcmgr := mocknetwork.NewMockResourceManager(ctrl)
|
|
|
|
addr := ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")
|
|
p := peer.ID("foobar")
|
|
|
|
_, key := newIdentity(t)
|
|
tr, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, rcmgr)
|
|
require.NoError(t, err)
|
|
defer tr.(io.Closer).Close()
|
|
l, err := tr.Listen(addr)
|
|
require.NoError(t, err)
|
|
|
|
addr = l.Multiaddr()
|
|
|
|
scope := mocknetwork.NewMockConnManagementScope(ctrl)
|
|
rcmgr.EXPECT().OpenConnection(network.DirOutbound, false, addr).Return(scope, nil)
|
|
scope.EXPECT().SetPeer(p).Return(errors.New("denied"))
|
|
scope.EXPECT().Done()
|
|
|
|
_, err = tr.Dial(context.Background(), addr, p)
|
|
require.EqualError(t, err, "denied")
|
|
}
|
|
|
|
func TestResourceManagerListening(t *testing.T) {
|
|
clientID, key := newIdentity(t)
|
|
cl, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{})
|
|
require.NoError(t, err)
|
|
defer cl.(io.Closer).Close()
|
|
|
|
t.Run("blocking the connection", func(t *testing.T) {
|
|
serverID, key := newIdentity(t)
|
|
ctrl := gomock.NewController(t)
|
|
defer ctrl.Finish()
|
|
rcmgr := mocknetwork.NewMockResourceManager(ctrl)
|
|
tr, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, rcmgr)
|
|
require.NoError(t, err)
|
|
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
|
|
require.NoError(t, err)
|
|
defer ln.Close()
|
|
|
|
rcmgr.EXPECT().OpenConnection(network.DirInbound, false, gomock.Any()).DoAndReturn(func(_ network.Direction, _ bool, addr ma.Multiaddr) (network.ConnManagementScope, error) {
|
|
_, err := addr.ValueForProtocol(ma.P_WEBTRANSPORT)
|
|
require.NoError(t, err, "expected a WebTransport multiaddr")
|
|
_, addrStr, err := manet.DialArgs(addr)
|
|
require.NoError(t, err)
|
|
host, _, err := net.SplitHostPort(addrStr)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "127.0.0.1", host)
|
|
return nil, errors.New("denied")
|
|
})
|
|
|
|
_, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID)
|
|
require.EqualError(t, err, "received status 503")
|
|
})
|
|
|
|
t.Run("blocking the peer", func(t *testing.T) {
|
|
serverID, key := newIdentity(t)
|
|
ctrl := gomock.NewController(t)
|
|
defer ctrl.Finish()
|
|
rcmgr := mocknetwork.NewMockResourceManager(ctrl)
|
|
tr, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, rcmgr)
|
|
require.NoError(t, err)
|
|
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
|
|
require.NoError(t, err)
|
|
defer ln.Close()
|
|
|
|
serverDone := make(chan struct{})
|
|
scope := mocknetwork.NewMockConnManagementScope(ctrl)
|
|
rcmgr.EXPECT().OpenConnection(network.DirInbound, false, gomock.Any()).Return(scope, nil)
|
|
scope.EXPECT().SetPeer(clientID).Return(errors.New("denied"))
|
|
scope.EXPECT().Done().Do(func() { close(serverDone) })
|
|
|
|
// The handshake will complete, but the server will immediately close the connection.
|
|
conn, err := cl.Dial(context.Background(), ln.Multiaddr(), serverID)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
clientDone := make(chan struct{})
|
|
go func() {
|
|
defer close(clientDone)
|
|
_, err = conn.AcceptStream()
|
|
require.Error(t, err)
|
|
}()
|
|
select {
|
|
case <-clientDone:
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
select {
|
|
case <-serverDone:
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
})
|
|
}
|
|
|
|
// TODO: unify somehow. We do the same in libp2pquic.
|
|
//go:generate sh -c "go run go.uber.org/mock/mockgen -package libp2pwebtransport_test -destination mock_connection_gater_test.go github.com/libp2p/go-libp2p/core/connmgr ConnectionGater && go run golang.org/x/tools/cmd/goimports -w mock_connection_gater_test.go"
|
|
|
|
func TestConnectionGaterDialing(t *testing.T) {
|
|
ctrl := gomock.NewController(t)
|
|
defer ctrl.Finish()
|
|
connGater := NewMockConnectionGater(ctrl)
|
|
|
|
serverID, serverKey := newIdentity(t)
|
|
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, &network.NullResourceManager{})
|
|
require.NoError(t, err)
|
|
defer tr.(io.Closer).Close()
|
|
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
|
|
require.NoError(t, err)
|
|
defer ln.Close()
|
|
|
|
connGater.EXPECT().InterceptSecured(network.DirOutbound, serverID, gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) {
|
|
require.Equal(t, stripCertHashes(ln.Multiaddr()), addrs.RemoteMultiaddr())
|
|
})
|
|
_, key := newIdentity(t)
|
|
cl, err := libp2pwebtransport.New(key, nil, newConnManager(t), connGater, &network.NullResourceManager{})
|
|
require.NoError(t, err)
|
|
defer cl.(io.Closer).Close()
|
|
_, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID)
|
|
require.EqualError(t, err, "secured connection gated")
|
|
}
|
|
|
|
func TestConnectionGaterInterceptAccept(t *testing.T) {
|
|
ctrl := gomock.NewController(t)
|
|
defer ctrl.Finish()
|
|
connGater := NewMockConnectionGater(ctrl)
|
|
|
|
serverID, serverKey := newIdentity(t)
|
|
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), connGater, &network.NullResourceManager{})
|
|
require.NoError(t, err)
|
|
defer tr.(io.Closer).Close()
|
|
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
|
|
require.NoError(t, err)
|
|
defer ln.Close()
|
|
|
|
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
|
|
require.Equal(t, stripCertHashes(ln.Multiaddr()), addrs.LocalMultiaddr())
|
|
require.NotEqual(t, stripCertHashes(ln.Multiaddr()), addrs.RemoteMultiaddr())
|
|
})
|
|
|
|
_, key := newIdentity(t)
|
|
cl, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{})
|
|
require.NoError(t, err)
|
|
defer cl.(io.Closer).Close()
|
|
_, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID)
|
|
require.EqualError(t, err, "received status 403")
|
|
}
|
|
|
|
func TestConnectionGaterInterceptSecured(t *testing.T) {
|
|
ctrl := gomock.NewController(t)
|
|
defer ctrl.Finish()
|
|
connGater := NewMockConnectionGater(ctrl)
|
|
|
|
serverID, serverKey := newIdentity(t)
|
|
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), connGater, &network.NullResourceManager{})
|
|
require.NoError(t, err)
|
|
defer tr.(io.Closer).Close()
|
|
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
|
|
require.NoError(t, err)
|
|
defer ln.Close()
|
|
|
|
clientID, key := newIdentity(t)
|
|
cl, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{})
|
|
require.NoError(t, err)
|
|
defer cl.(io.Closer).Close()
|
|
|
|
connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true)
|
|
connGater.EXPECT().InterceptSecured(network.DirInbound, clientID, gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) {
|
|
require.Equal(t, stripCertHashes(ln.Multiaddr()), addrs.LocalMultiaddr())
|
|
require.NotEqual(t, stripCertHashes(ln.Multiaddr()), addrs.RemoteMultiaddr())
|
|
})
|
|
// The handshake will complete, but the server will immediately close the connection.
|
|
conn, err := cl.Dial(context.Background(), ln.Multiaddr(), serverID)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
_, err = conn.AcceptStream()
|
|
require.Error(t, err)
|
|
}()
|
|
select {
|
|
case <-done:
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
func TestAcceptQueueFilledUp(t *testing.T) {
|
|
serverID, serverKey := newIdentity(t)
|
|
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, &network.NullResourceManager{})
|
|
require.NoError(t, err)
|
|
defer tr.(io.Closer).Close()
|
|
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
|
|
require.NoError(t, err)
|
|
defer ln.Close()
|
|
|
|
newConn := func() (tpt.CapableConn, error) {
|
|
t.Helper()
|
|
_, key := newIdentity(t)
|
|
cl, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{})
|
|
require.NoError(t, err)
|
|
defer cl.(io.Closer).Close()
|
|
return cl.Dial(context.Background(), ln.Multiaddr(), serverID)
|
|
}
|
|
|
|
const num = 16 + 1 // one more than the accept queue capacity
|
|
// Dial one more connection than the accept queue can hold.
|
|
errChan := make(chan error, num)
|
|
for i := 0; i < num; i++ {
|
|
go func() {
|
|
conn, err := newConn()
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
_, err = conn.AcceptStream()
|
|
errChan <- err
|
|
}()
|
|
}
|
|
|
|
// Since the handshakes complete asynchronously, we won't know _which_ one is rejected,
|
|
// so the only thing we can test for is that exactly one connection attempt is rejected.
|
|
select {
|
|
case <-errChan:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("expected one connection to be rejected")
|
|
}
|
|
select {
|
|
case <-errChan:
|
|
t.Fatal("only expected one connection to be rejected")
|
|
case <-time.After(100 * time.Millisecond):
|
|
}
|
|
|
|
// test shutdown
|
|
require.NoError(t, ln.Close())
|
|
var count int
|
|
timer := time.NewTimer(time.Second)
|
|
defer timer.Stop()
|
|
for i := 0; i < 16; i++ {
|
|
select {
|
|
case <-errChan:
|
|
count++
|
|
if count == 16 {
|
|
return
|
|
}
|
|
case <-timer.C:
|
|
t.Fatal("shutdown failed")
|
|
}
|
|
}
|
|
}
|
|
|
|
type reportingRcmgr struct {
|
|
network.NullResourceManager
|
|
report chan<- int
|
|
}
|
|
|
|
func (m *reportingRcmgr) OpenConnection(dir network.Direction, usefd bool, endpoint ma.Multiaddr) (network.ConnManagementScope, error) {
|
|
return &reportingScope{report: m.report}, nil
|
|
}
|
|
|
|
type reportingScope struct {
|
|
network.NullScope
|
|
report chan<- int
|
|
}
|
|
|
|
func (s *reportingScope) ReserveMemory(size int, _ uint8) error {
|
|
s.report <- size
|
|
return nil
|
|
}
|
|
|
|
func TestFlowControlWindowIncrease(t *testing.T) {
|
|
if runtime.GOOS == "windows" {
|
|
t.Skip("this test is flaky on Windows")
|
|
}
|
|
|
|
rtt := 10 * time.Millisecond
|
|
timeout := 5 * time.Second
|
|
|
|
if os.Getenv("CI") != "" {
|
|
rtt = 40 * time.Millisecond
|
|
timeout = 15 * time.Second
|
|
}
|
|
|
|
serverID, serverKey := newIdentity(t)
|
|
serverWindowIncreases := make(chan int, 100)
|
|
serverRcmgr := &reportingRcmgr{report: serverWindowIncreases}
|
|
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, serverRcmgr)
|
|
require.NoError(t, err)
|
|
defer tr.(io.Closer).Close()
|
|
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
|
|
require.NoError(t, err)
|
|
defer ln.Close()
|
|
|
|
go func() {
|
|
conn, err := ln.Accept()
|
|
require.NoError(t, err)
|
|
str, err := conn.AcceptStream()
|
|
require.NoError(t, err)
|
|
_, err = io.CopyBuffer(str, str, make([]byte, 2<<10))
|
|
require.NoError(t, err)
|
|
str.CloseWrite()
|
|
}()
|
|
|
|
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
|
RemoteAddr: ln.Addr().String(),
|
|
DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 },
|
|
})
|
|
require.NoError(t, err)
|
|
defer proxy.Close()
|
|
|
|
_, clientKey := newIdentity(t)
|
|
clientWindowIncreases := make(chan int, 100)
|
|
clientRcmgr := &reportingRcmgr{report: clientWindowIncreases}
|
|
tr2, err := libp2pwebtransport.New(clientKey, nil, newConnManager(t), nil, clientRcmgr)
|
|
require.NoError(t, err)
|
|
defer tr2.(io.Closer).Close()
|
|
|
|
var addr ma.Multiaddr
|
|
for _, comp := range ma.Split(ln.Multiaddr()) {
|
|
if _, err := comp.ValueForProtocol(ma.P_UDP); err == nil {
|
|
addr = addr.Encapsulate(ma.StringCast(fmt.Sprintf("/udp/%d", proxy.LocalPort())))
|
|
continue
|
|
}
|
|
if addr == nil {
|
|
addr = comp
|
|
continue
|
|
}
|
|
addr = addr.Encapsulate(comp)
|
|
}
|
|
|
|
conn, err := tr2.Dial(context.Background(), addr, serverID)
|
|
require.NoError(t, err)
|
|
str, err := conn.OpenStream(context.Background())
|
|
require.NoError(t, err)
|
|
var increasesDone atomic.Bool
|
|
go func() {
|
|
for {
|
|
_, err := str.Write(bytes.Repeat([]byte{0x42}, 1<<10))
|
|
require.NoError(t, err)
|
|
if increasesDone.Load() {
|
|
str.CloseWrite()
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
_, err := io.ReadAll(str)
|
|
require.NoError(t, err)
|
|
}()
|
|
|
|
var numServerIncreases, numClientIncreases int
|
|
timer := time.NewTimer(timeout)
|
|
defer timer.Stop()
|
|
for {
|
|
select {
|
|
case <-serverWindowIncreases:
|
|
numServerIncreases++
|
|
case <-clientWindowIncreases:
|
|
numClientIncreases++
|
|
case <-timer.C:
|
|
t.Fatalf("didn't receive enough window increases (client: %d, server: %d)", numClientIncreases, numServerIncreases)
|
|
}
|
|
if numClientIncreases >= 1 && numServerIncreases >= 1 {
|
|
increasesDone.Store(true)
|
|
break
|
|
}
|
|
}
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(timeout):
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
var errTimeout = errors.New("timeout")
|
|
|
|
func serverSendsBackValidCert(t *testing.T, timeSinceUnixEpoch time.Duration, keySeed int64, randomClientSkew time.Duration) error {
|
|
if timeSinceUnixEpoch < 0 {
|
|
timeSinceUnixEpoch = -timeSinceUnixEpoch
|
|
}
|
|
|
|
// Bound this to 100 years
|
|
timeSinceUnixEpoch = time.Duration(timeSinceUnixEpoch % (time.Hour * 24 * 365 * 100))
|
|
// Start a bit further in the future to avoid edge cases around epoch
|
|
timeSinceUnixEpoch += time.Hour * 24 * 365
|
|
start := time.UnixMilli(timeSinceUnixEpoch.Milliseconds())
|
|
|
|
randomClientSkew = randomClientSkew % clockSkewAllowance
|
|
|
|
cl := clock.NewMock()
|
|
cl.Set(start)
|
|
|
|
priv, _, err := test.SeededTestKeyPair(ic.Ed25519, 256, keySeed)
|
|
require.NoError(t, err)
|
|
tr, err := libp2pwebtransport.New(priv, nil, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl))
|
|
require.NoError(t, err)
|
|
l, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
|
|
require.NoError(t, err)
|
|
defer l.Close()
|
|
|
|
conn, err := quic.DialAddr(context.Background(), l.Addr().String(), &tls.Config{
|
|
NextProtos: []string{http3.NextProtoH3},
|
|
InsecureSkipVerify: true,
|
|
VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
|
for _, c := range rawCerts {
|
|
cert, err := x509.ParseCertificate(c)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, clientSkew := range []time.Duration{randomClientSkew, -clockSkewAllowance, clockSkewAllowance} {
|
|
clientTime := cl.Now().Add(clientSkew)
|
|
if clientTime.After(cert.NotAfter) || clientTime.Before(cert.NotBefore) {
|
|
return fmt.Errorf("Times are not valid: server_now=%v client_now=%v certstart=%v certend=%v", cl.Now().UTC(), clientTime.UTC(), cert.NotBefore.UTC(), cert.NotAfter.UTC())
|
|
}
|
|
}
|
|
|
|
}
|
|
return nil
|
|
},
|
|
}, &quic.Config{MaxIdleTimeout: time.Second})
|
|
|
|
if err != nil {
|
|
if _, ok := err.(*quic.IdleTimeoutError); ok {
|
|
return errTimeout
|
|
}
|
|
return err
|
|
}
|
|
defer conn.CloseWithError(0, "")
|
|
|
|
return nil
|
|
}
|
|
|
|
func TestServerSendsBackValidCert(t *testing.T) {
|
|
var maxTimeoutErrors = 10
|
|
require.NoError(t, quick.Check(func(timeSinceUnixEpoch time.Duration, keySeed int64, randomClientSkew time.Duration) bool {
|
|
err := serverSendsBackValidCert(t, timeSinceUnixEpoch, keySeed, randomClientSkew)
|
|
if err == errTimeout {
|
|
maxTimeoutErrors -= 1
|
|
if maxTimeoutErrors <= 0 {
|
|
fmt.Println("Too many timeout errors")
|
|
return false
|
|
}
|
|
// Sporadic timeout errors on macOS
|
|
return true
|
|
} else if err != nil {
|
|
fmt.Println("Err:", err)
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}, nil))
|
|
}
|
|
|
|
func TestServerRotatesCertCorrectly(t *testing.T) {
|
|
require.NoError(t, quick.Check(func(timeSinceUnixEpoch time.Duration, keySeed int64) bool {
|
|
if timeSinceUnixEpoch < 0 {
|
|
timeSinceUnixEpoch = -timeSinceUnixEpoch
|
|
}
|
|
|
|
// Bound this to 100 years
|
|
timeSinceUnixEpoch = time.Duration(timeSinceUnixEpoch % (time.Hour * 24 * 365 * 100))
|
|
// Start a bit further in the future to avoid edge cases around epoch
|
|
timeSinceUnixEpoch += time.Hour * 24 * 365
|
|
start := time.UnixMilli(timeSinceUnixEpoch.Milliseconds())
|
|
|
|
cl := clock.NewMock()
|
|
cl.Set(start)
|
|
|
|
priv, _, err := test.SeededTestKeyPair(ic.Ed25519, 256, keySeed)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
tr, err := libp2pwebtransport.New(priv, nil, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl))
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
l, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
|
|
if err != nil {
|
|
return false
|
|
}
|
|
certhashes := extractCertHashes(l.Multiaddr())
|
|
l.Close()
|
|
|
|
// These two certificates together are valid for at most certValidity - (4*clockSkewAllowance)
|
|
cl.Add(certValidity - (4 * clockSkewAllowance) - time.Second)
|
|
tr, err = libp2pwebtransport.New(priv, nil, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl))
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
l, err = tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
|
|
if err != nil {
|
|
return false
|
|
}
|
|
defer l.Close()
|
|
|
|
var found bool
|
|
ma.ForEach(l.Multiaddr(), func(c ma.Component) bool {
|
|
if c.Protocol().Code == ma.P_CERTHASH {
|
|
for _, prevCerthash := range certhashes {
|
|
if c.Value() == prevCerthash {
|
|
found = true
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
|
|
return found
|
|
|
|
}, nil))
|
|
}
|
|
|
|
func TestServerRotatesCertCorrectlyAfterSteps(t *testing.T) {
|
|
cl := clock.NewMock()
|
|
// Move one year ahead to avoid edge cases around epoch
|
|
cl.Add(time.Hour * 24 * 365)
|
|
|
|
priv, _, err := test.RandTestKeyPair(ic.Ed25519, 256)
|
|
require.NoError(t, err)
|
|
tr, err := libp2pwebtransport.New(priv, nil, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl))
|
|
require.NoError(t, err)
|
|
|
|
l, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
|
|
require.NoError(t, err)
|
|
|
|
certhashes := extractCertHashes(l.Multiaddr())
|
|
l.Close()
|
|
|
|
// Traverse various time boundaries and make sure we always keep a common certhash.
|
|
// e.g. certhash/A/certhash/B ... -> ... certhash/B/certhash/C ... -> ... certhash/C/certhash/D
|
|
for i := 0; i < 200; i++ {
|
|
cl.Add(24 * time.Hour)
|
|
tr, err := libp2pwebtransport.New(priv, nil, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl))
|
|
require.NoError(t, err)
|
|
l, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
|
|
require.NoError(t, err)
|
|
|
|
var found bool
|
|
ma.ForEach(l.Multiaddr(), func(c ma.Component) bool {
|
|
if c.Protocol().Code == ma.P_CERTHASH {
|
|
for _, prevCerthash := range certhashes {
|
|
if prevCerthash == c.Value() {
|
|
found = true
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
certhashes = extractCertHashes(l.Multiaddr())
|
|
l.Close()
|
|
|
|
require.True(t, found, "Failed after hour: %v", i)
|
|
}
|
|
}
|