Files
go-libp2p/p2p/transport/webtransport/transport_test.go
T
2025-06-11 10:00:47 +01:00

872 lines
28 KiB
Go

package libp2pwebtransport_test
import (
"bytes"
"context"
"crypto/rand"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"net"
"net/http"
"os"
"runtime"
"sync/atomic"
"testing"
"testing/quick"
"time"
ic "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/network"
mocknetwork "github.com/libp2p/go-libp2p/core/network/mocks"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/test"
tpt "github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/transport/quicreuse"
libp2pwebtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport"
"github.com/benbjohnson/clock"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/multiformats/go-multibase"
"github.com/multiformats/go-multihash"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
)
const clockSkewAllowance = time.Hour
const certValidity = 14 * 24 * time.Hour
func newIdentity(t *testing.T) (peer.ID, ic.PrivKey) {
key, _, err := ic.GenerateEd25519Key(rand.Reader)
require.NoError(t, err)
id, err := peer.IDFromPrivateKey(key)
require.NoError(t, err)
return id, key
}
func randomMultihash(t *testing.T) string {
t.Helper()
b := make([]byte, 16)
rand.Read(b)
h, err := multihash.Encode(b, multihash.KECCAK_224)
require.NoError(t, err)
s, err := multibase.Encode(multibase.Base32hex, h)
require.NoError(t, err)
return s
}
func extractCertHashes(addr ma.Multiaddr) []string {
var certHashesStr []string
ma.ForEach(addr, func(c ma.Component) bool {
if c.Protocol().Code == ma.P_CERTHASH {
certHashesStr = append(certHashesStr, c.Value())
}
return true
})
return certHashesStr
}
func stripCertHashes(addr ma.Multiaddr) ma.Multiaddr {
for {
_, err := addr.ValueForProtocol(ma.P_CERTHASH)
if err != nil {
return addr
}
addr, _ = ma.SplitLast(addr)
}
}
// create a /certhash multiaddr component using the SHA256 of foobar
func getCerthashComponent(t *testing.T, b []byte) *ma.Component {
t.Helper()
h := sha256.Sum256(b)
mh, err := multihash.Encode(h[:], multihash.SHA2_256)
require.NoError(t, err)
certStr, err := multibase.Encode(multibase.Base58BTC, mh)
require.NoError(t, err)
ha, err := ma.NewComponent(ma.ProtocolWithCode(ma.P_CERTHASH).Name, certStr)
require.NoError(t, err)
return ha
}
func newConnManager(t *testing.T, opts ...quicreuse.Option) *quicreuse.ConnManager {
t.Helper()
cm, err := quicreuse.NewConnManager(quic.StatelessResetKey{}, quic.TokenGeneratorKey{}, opts...)
require.NoError(t, err)
t.Cleanup(func() { cm.Close() })
return cm
}
func TestTransport(t *testing.T) {
serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, nil)
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
require.NoError(t, err)
defer ln.Close()
addrChan := make(chan ma.Multiaddr)
go func() {
_, clientKey := newIdentity(t)
tr2, err := libp2pwebtransport.New(clientKey, nil, newConnManager(t), nil, nil)
require.NoError(t, err)
defer tr2.(io.Closer).Close()
conn, err := tr2.Dial(context.Background(), ln.Multiaddr(), serverID)
require.NoError(t, err)
str, err := conn.OpenStream(context.Background())
require.NoError(t, err)
_, err = str.Write([]byte("foobar"))
require.NoError(t, err)
require.NoError(t, str.Close())
// check RemoteMultiaddr
_, addr, err := manet.DialArgs(ln.Multiaddr())
require.NoError(t, err)
_, port, err := net.SplitHostPort(addr)
require.NoError(t, err)
require.Equal(t, fmt.Sprintf("/ip4/127.0.0.1/udp/%s/quic-v1/webtransport", port), conn.RemoteMultiaddr().String())
addrChan <- conn.RemoteMultiaddr()
}()
conn, err := ln.Accept()
require.NoError(t, err)
require.False(t, conn.IsClosed())
str, err := conn.AcceptStream()
require.NoError(t, err)
data, err := io.ReadAll(str)
require.NoError(t, err)
require.Equal(t, "foobar", string(data))
require.Equal(t, (<-addrChan).String(), conn.LocalMultiaddr().String())
require.NoError(t, conn.Close())
require.True(t, conn.IsClosed())
}
func TestHashVerification(t *testing.T) {
serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
require.NoError(t, err)
done := make(chan struct{})
go func() {
defer close(done)
_, err := ln.Accept()
require.Error(t, err)
}()
_, clientKey := newIdentity(t)
tr2, err := libp2pwebtransport.New(clientKey, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer tr2.(io.Closer).Close()
foobarHash := getCerthashComponent(t, []byte("foobar"))
t.Run("fails using only a wrong hash", func(t *testing.T) {
// replace the certificate hash in the multiaddr with a fake hash
addr := stripCertHashes(ln.Multiaddr()).Encapsulate(foobarHash)
_, err := tr2.Dial(context.Background(), addr, serverID)
require.Error(t, err)
var trErr *quic.TransportError
require.ErrorAs(t, err, &trErr)
require.Equal(t, quic.TransportErrorCode(0x12a), trErr.ErrorCode)
var errMismatchHash libp2pwebtransport.ErrCertHashMismatch
require.ErrorAs(t, err, &errMismatchHash)
e := sha256.Sum256([]byte("foobar"))
require.EqualValues(t, e[:], errMismatchHash.Actual[0])
})
t.Run("fails when adding a wrong hash", func(t *testing.T) {
_, err := tr2.Dial(context.Background(), ln.Multiaddr().Encapsulate(foobarHash), serverID)
require.Error(t, err)
})
require.NoError(t, ln.Close())
<-done
}
func TestCanDial(t *testing.T) {
valid := []ma.Multiaddr{
ma.StringCast("/ip4/127.0.0.1/udp/1234/quic-v1/webtransport/certhash/" + randomMultihash(t)),
ma.StringCast("/ip6/b16b:8255:efc6:9cd5:1a54:ee86:2d7a:c2e6/udp/1234/quic-v1/webtransport/certhash/" + randomMultihash(t)),
ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/udp/1234/quic-v1/webtransport/certhash/%s/certhash/%s/certhash/%s", randomMultihash(t), randomMultihash(t), randomMultihash(t))),
ma.StringCast("/ip4/127.0.0.1/udp/1234/quic-v1/webtransport"), // no certificate hash
}
invalid := []ma.Multiaddr{
ma.StringCast("/ip4/127.0.0.1/udp/1234"), // missing webtransport
ma.StringCast("/ip4/127.0.0.1/udp/1234/webtransport"), // missing quic
ma.StringCast("/ip4/127.0.0.1/tcp/1234/webtransport"), // WebTransport over TCP? Is this a joke?
}
_, key := newIdentity(t)
tr, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer tr.(io.Closer).Close()
for _, addr := range valid {
require.Truef(t, tr.CanDial(addr), "expected to be able to dial %s", addr)
}
for _, addr := range invalid {
require.Falsef(t, tr.CanDial(addr), "expected to not be able to dial %s", addr)
}
}
func TestListenAddrValidity(t *testing.T) {
valid := []ma.Multiaddr{
ma.StringCast("/ip6/::/udp/0/quic-v1/webtransport/"),
}
invalid := []ma.Multiaddr{
ma.StringCast("/ip4/127.0.0.1/udp/0"), // missing webtransport
ma.StringCast("/ip4/127.0.0.1/udp/0/webtransport"), // missing quic
ma.StringCast("/ip4/127.0.0.1/tcp/0/webtransport"), // WebTransport over TCP? Is this a joke?
ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport/certhash/" + randomMultihash(t)), // We can't listen on a specific certhash
}
_, key := newIdentity(t)
tr, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer tr.(io.Closer).Close()
for _, addr := range valid {
ln, err := tr.Listen(addr)
require.NoErrorf(t, err, "expected to be able to listen on %s", addr)
ln.Close()
}
for _, addr := range invalid {
_, err := tr.Listen(addr)
require.Errorf(t, err, "expected to not be able to listen on %s", addr)
}
}
func TestListenerAddrs(t *testing.T) {
_, key := newIdentity(t)
tr, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln1, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
require.NoError(t, err)
ln2, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
require.NoError(t, err)
hashes1 := extractCertHashes(ln1.Multiaddr())
require.Len(t, hashes1, 2)
hashes2 := extractCertHashes(ln2.Multiaddr())
require.Equal(t, hashes1, hashes2)
}
func TestResourceManagerDialing(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
rcmgr := mocknetwork.NewMockResourceManager(ctrl)
addr := ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")
p := peer.ID("foobar")
_, key := newIdentity(t)
tr, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, rcmgr)
require.NoError(t, err)
defer tr.(io.Closer).Close()
l, err := tr.Listen(addr)
require.NoError(t, err)
addr = l.Multiaddr()
scope := mocknetwork.NewMockConnManagementScope(ctrl)
rcmgr.EXPECT().OpenConnection(network.DirOutbound, false, addr).Return(scope, nil)
scope.EXPECT().SetPeer(p).Return(errors.New("denied"))
scope.EXPECT().Done()
_, err = tr.Dial(context.Background(), addr, p)
require.EqualError(t, err, "denied")
}
func TestResourceManagerListening(t *testing.T) {
clientID, key := newIdentity(t)
cl, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer cl.(io.Closer).Close()
t.Run("blocking the connection", func(t *testing.T) {
serverID, key := newIdentity(t)
ctrl := gomock.NewController(t)
defer ctrl.Finish()
rcmgr := mocknetwork.NewMockResourceManager(ctrl)
tr, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, rcmgr)
require.NoError(t, err)
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
require.NoError(t, err)
defer ln.Close()
rcmgr.EXPECT().OpenConnection(network.DirInbound, false, gomock.Any()).DoAndReturn(func(_ network.Direction, _ bool, addr ma.Multiaddr) (network.ConnManagementScope, error) {
_, err := addr.ValueForProtocol(ma.P_WEBTRANSPORT)
require.NoError(t, err, "expected a WebTransport multiaddr")
_, addrStr, err := manet.DialArgs(addr)
require.NoError(t, err)
host, _, err := net.SplitHostPort(addrStr)
require.NoError(t, err)
require.Equal(t, "127.0.0.1", host)
return nil, errors.New("denied")
})
_, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID)
require.EqualError(t, err, "received status 503")
})
t.Run("blocking the peer", func(t *testing.T) {
serverID, key := newIdentity(t)
ctrl := gomock.NewController(t)
defer ctrl.Finish()
rcmgr := mocknetwork.NewMockResourceManager(ctrl)
tr, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, rcmgr)
require.NoError(t, err)
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
require.NoError(t, err)
defer ln.Close()
serverDone := make(chan struct{})
scope := mocknetwork.NewMockConnManagementScope(ctrl)
rcmgr.EXPECT().OpenConnection(network.DirInbound, false, gomock.Any()).Return(scope, nil)
scope.EXPECT().SetPeer(clientID).Return(errors.New("denied"))
scope.EXPECT().Done().Do(func() { close(serverDone) })
// The handshake will complete, but the server will immediately close the connection.
conn, err := cl.Dial(context.Background(), ln.Multiaddr(), serverID)
require.NoError(t, err)
defer conn.Close()
clientDone := make(chan struct{})
go func() {
defer close(clientDone)
_, err = conn.AcceptStream()
require.Error(t, err)
}()
select {
case <-clientDone:
case <-time.After(5 * time.Second):
t.Fatal("timeout")
}
select {
case <-serverDone:
case <-time.After(5 * time.Second):
t.Fatal("timeout")
}
})
}
// TODO: unify somehow. We do the same in libp2pquic.
//go:generate sh -c "go run go.uber.org/mock/mockgen -package libp2pwebtransport_test -destination mock_connection_gater_test.go github.com/libp2p/go-libp2p/core/connmgr ConnectionGater && go run golang.org/x/tools/cmd/goimports -w mock_connection_gater_test.go"
func TestConnectionGaterDialing(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
connGater := NewMockConnectionGater(ctrl)
serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
require.NoError(t, err)
defer ln.Close()
connGater.EXPECT().InterceptSecured(network.DirOutbound, serverID, gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) {
require.Equal(t, stripCertHashes(ln.Multiaddr()).String(), addrs.RemoteMultiaddr().String())
})
_, key := newIdentity(t)
cl, err := libp2pwebtransport.New(key, nil, newConnManager(t), connGater, &network.NullResourceManager{})
require.NoError(t, err)
defer cl.(io.Closer).Close()
_, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID)
require.EqualError(t, err, "secured connection gated")
}
func TestConnectionGaterInterceptAccept(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
connGater := NewMockConnectionGater(ctrl)
serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), connGater, &network.NullResourceManager{})
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
require.NoError(t, err)
defer ln.Close()
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
require.Equal(t, stripCertHashes(ln.Multiaddr()).String(), addrs.LocalMultiaddr().String())
require.NotEqual(t, stripCertHashes(ln.Multiaddr()).String(), addrs.RemoteMultiaddr().String())
})
_, key := newIdentity(t)
cl, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer cl.(io.Closer).Close()
_, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID)
require.EqualError(t, err, "received status 403")
}
func TestConnectionGaterInterceptSecured(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
connGater := NewMockConnectionGater(ctrl)
serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), connGater, &network.NullResourceManager{})
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
require.NoError(t, err)
defer ln.Close()
clientID, key := newIdentity(t)
cl, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer cl.(io.Closer).Close()
connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true)
connGater.EXPECT().InterceptSecured(network.DirInbound, clientID, gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) {
require.Equal(t, stripCertHashes(ln.Multiaddr()).String(), addrs.LocalMultiaddr().String())
require.NotEqual(t, stripCertHashes(ln.Multiaddr()).String(), addrs.RemoteMultiaddr().String())
})
// The handshake will complete, but the server will immediately close the connection.
conn, err := cl.Dial(context.Background(), ln.Multiaddr(), serverID)
require.NoError(t, err)
defer conn.Close()
done := make(chan struct{})
go func() {
defer close(done)
_, err = conn.AcceptStream()
require.Error(t, err)
}()
select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatal("timeout")
}
}
func TestAcceptQueueFilledUp(t *testing.T) {
serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
require.NoError(t, err)
defer ln.Close()
newConn := func() (tpt.CapableConn, error) {
t.Helper()
_, key := newIdentity(t)
cl, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer cl.(io.Closer).Close()
return cl.Dial(context.Background(), ln.Multiaddr(), serverID)
}
const num = 16 + 1 // one more than the accept queue capacity
// Dial one more connection than the accept queue can hold.
errChan := make(chan error, num)
for i := 0; i < num; i++ {
go func() {
conn, err := newConn()
if err != nil {
errChan <- err
return
}
_, err = conn.AcceptStream()
errChan <- err
}()
}
// Since the handshakes complete asynchronously, we won't know _which_ one is rejected,
// so the only thing we can test for is that exactly one connection attempt is rejected.
select {
case <-errChan:
case <-time.After(time.Second):
t.Fatal("expected one connection to be rejected")
}
select {
case <-errChan:
t.Fatal("only expected one connection to be rejected")
case <-time.After(100 * time.Millisecond):
}
// test shutdown
require.NoError(t, ln.Close())
var count int
timer := time.NewTimer(time.Second)
defer timer.Stop()
for i := 0; i < 16; i++ {
select {
case <-errChan:
count++
if count == 16 {
return
}
case <-timer.C:
t.Fatal("shutdown failed")
}
}
}
type reportingRcmgr struct {
network.NullResourceManager
report chan<- int
}
func (m *reportingRcmgr) OpenConnection(_ network.Direction, _ bool, _ ma.Multiaddr) (network.ConnManagementScope, error) {
return &reportingScope{report: m.report}, nil
}
type reportingScope struct {
network.NullScope
report chan<- int
}
func (s *reportingScope) ReserveMemory(size int, _ uint8) error {
s.report <- size
return nil
}
func newUDPConnLocalhost(t testing.TB) *net.UDPConn {
t.Helper()
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
t.Cleanup(func() { conn.Close() })
return conn
}
func TestFlowControlWindowIncrease(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("this test is flaky on Windows")
}
rtt := 10 * time.Millisecond
timeout := 5 * time.Second
if os.Getenv("CI") != "" {
rtt = 40 * time.Millisecond
timeout = 15 * time.Second
}
serverID, serverKey := newIdentity(t)
serverWindowIncreases := make(chan int, 100)
serverRcmgr := &reportingRcmgr{report: serverWindowIncreases}
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, serverRcmgr)
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
require.NoError(t, err)
defer ln.Close()
go func() {
conn, err := ln.Accept()
require.NoError(t, err)
str, err := conn.AcceptStream()
require.NoError(t, err)
_, err = io.CopyBuffer(str, str, make([]byte, 2<<10))
require.NoError(t, err)
str.CloseWrite()
}()
proxy := quicproxy.Proxy{
Conn: newUDPConnLocalhost(t),
ServerAddr: ln.Addr().(*net.UDPAddr),
DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 },
}
require.NoError(t, proxy.Start())
defer proxy.Close()
_, clientKey := newIdentity(t)
clientWindowIncreases := make(chan int, 100)
clientRcmgr := &reportingRcmgr{report: clientWindowIncreases}
tr2, err := libp2pwebtransport.New(clientKey, nil, newConnManager(t), nil, clientRcmgr)
require.NoError(t, err)
defer tr2.(io.Closer).Close()
var addr ma.Multiaddr
for _, comp := range ln.Multiaddr() {
if _, err := comp.ValueForProtocol(ma.P_UDP); err == nil {
addr = addr.Encapsulate(ma.StringCast(fmt.Sprintf("/udp/%d", proxy.LocalAddr().(*net.UDPAddr).Port)))
continue
}
addr = append(addr, comp)
}
conn, err := tr2.Dial(context.Background(), addr, serverID)
require.NoError(t, err)
str, err := conn.OpenStream(context.Background())
require.NoError(t, err)
var increasesDone atomic.Bool
go func() {
for {
_, err := str.Write(bytes.Repeat([]byte{0x42}, 1<<10))
require.NoError(t, err)
if increasesDone.Load() {
str.CloseWrite()
return
}
}
}()
done := make(chan struct{})
go func() {
defer close(done)
_, err := io.ReadAll(str)
require.NoError(t, err)
}()
var numServerIncreases, numClientIncreases int
timer := time.NewTimer(timeout)
defer timer.Stop()
for {
select {
case <-serverWindowIncreases:
numServerIncreases++
case <-clientWindowIncreases:
numClientIncreases++
case <-timer.C:
t.Fatalf("didn't receive enough window increases (client: %d, server: %d)", numClientIncreases, numServerIncreases)
}
if numClientIncreases >= 1 && numServerIncreases >= 1 {
increasesDone.Store(true)
break
}
}
select {
case <-done:
case <-time.After(timeout):
t.Fatal("timeout")
}
}
var errTimeout = errors.New("timeout")
func serverSendsBackValidCert(t *testing.T, timeSinceUnixEpoch time.Duration, keySeed int64, randomClientSkew time.Duration) error {
if timeSinceUnixEpoch < 0 {
timeSinceUnixEpoch = -timeSinceUnixEpoch
}
// Bound this to 100 years
timeSinceUnixEpoch = timeSinceUnixEpoch % (time.Hour * 24 * 365 * 100)
// Start a bit further in the future to avoid edge cases around epoch
timeSinceUnixEpoch += time.Hour * 24 * 365
start := time.UnixMilli(timeSinceUnixEpoch.Milliseconds())
randomClientSkew = randomClientSkew % clockSkewAllowance
cl := clock.NewMock()
cl.Set(start)
priv, _, err := test.SeededTestKeyPair(ic.Ed25519, 256, keySeed)
require.NoError(t, err)
tr, err := libp2pwebtransport.New(priv, nil, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl))
require.NoError(t, err)
l, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
require.NoError(t, err)
defer l.Close()
conn, err := quic.DialAddr(context.Background(), l.Addr().String(), &tls.Config{
NextProtos: []string{http3.NextProtoH3},
InsecureSkipVerify: true,
VerifyPeerCertificate: func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
for _, c := range rawCerts {
cert, err := x509.ParseCertificate(c)
if err != nil {
return err
}
for _, clientSkew := range []time.Duration{randomClientSkew, -clockSkewAllowance, clockSkewAllowance} {
clientTime := cl.Now().Add(clientSkew)
if clientTime.After(cert.NotAfter) || clientTime.Before(cert.NotBefore) {
return fmt.Errorf("Times are not valid: server_now=%v client_now=%v certstart=%v certend=%v", cl.Now().UTC(), clientTime.UTC(), cert.NotBefore.UTC(), cert.NotAfter.UTC())
}
}
}
return nil
},
}, &quic.Config{MaxIdleTimeout: time.Second})
if err != nil {
if _, ok := err.(*quic.IdleTimeoutError); ok {
return errTimeout
}
return err
}
defer conn.CloseWithError(0, "")
return nil
}
func TestServerSendsBackValidCert(t *testing.T) {
var maxTimeoutErrors = 10
require.NoError(t, quick.Check(func(timeSinceUnixEpoch time.Duration, keySeed int64, randomClientSkew time.Duration) bool {
err := serverSendsBackValidCert(t, timeSinceUnixEpoch, keySeed, randomClientSkew)
if err == errTimeout {
maxTimeoutErrors -= 1
if maxTimeoutErrors <= 0 {
fmt.Println("Too many timeout errors")
return false
}
// Sporadic timeout errors on macOS
return true
} else if err != nil {
fmt.Println("Err:", err)
return false
}
return true
}, nil))
}
func TestServerRotatesCertCorrectly(t *testing.T) {
require.NoError(t, quick.Check(func(timeSinceUnixEpoch time.Duration, keySeed int64) bool {
if timeSinceUnixEpoch < 0 {
timeSinceUnixEpoch = -timeSinceUnixEpoch
}
// Bound this to 100 years
timeSinceUnixEpoch = timeSinceUnixEpoch % (time.Hour * 24 * 365 * 100)
// Start a bit further in the future to avoid edge cases around epoch
timeSinceUnixEpoch += time.Hour * 24 * 365
start := time.UnixMilli(timeSinceUnixEpoch.Milliseconds())
cl := clock.NewMock()
cl.Set(start)
priv, _, err := test.SeededTestKeyPair(ic.Ed25519, 256, keySeed)
if err != nil {
return false
}
tr, err := libp2pwebtransport.New(priv, nil, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl))
if err != nil {
return false
}
l, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
if err != nil {
return false
}
certhashes := extractCertHashes(l.Multiaddr())
l.Close()
// These two certificates together are valid for at most certValidity - (4*clockSkewAllowance)
cl.Add(certValidity - (4 * clockSkewAllowance) - time.Second)
tr, err = libp2pwebtransport.New(priv, nil, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl))
if err != nil {
return false
}
l, err = tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
if err != nil {
return false
}
defer l.Close()
var found bool
ma.ForEach(l.Multiaddr(), func(c ma.Component) bool {
if c.Protocol().Code == ma.P_CERTHASH {
for _, prevCerthash := range certhashes {
if c.Value() == prevCerthash {
found = true
return false
}
}
}
return true
})
return found
}, nil))
}
func TestServerRotatesCertCorrectlyAfterSteps(t *testing.T) {
cl := clock.NewMock()
// Move one year ahead to avoid edge cases around epoch
cl.Add(time.Hour * 24 * 365)
priv, _, err := test.RandTestKeyPair(ic.Ed25519, 256)
require.NoError(t, err)
tr, err := libp2pwebtransport.New(priv, nil, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl))
require.NoError(t, err)
l, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
require.NoError(t, err)
certhashes := extractCertHashes(l.Multiaddr())
l.Close()
// Traverse various time boundaries and make sure we always keep a common certhash.
// e.g. certhash/A/certhash/B ... -> ... certhash/B/certhash/C ... -> ... certhash/C/certhash/D
for i := 0; i < 200; i++ {
cl.Add(24 * time.Hour)
tr, err := libp2pwebtransport.New(priv, nil, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl))
require.NoError(t, err)
l, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
require.NoError(t, err)
var found bool
ma.ForEach(l.Multiaddr(), func(c ma.Component) bool {
if c.Protocol().Code == ma.P_CERTHASH {
for _, prevCerthash := range certhashes {
if prevCerthash == c.Value() {
found = true
return false
}
}
}
return true
})
certhashes = extractCertHashes(l.Multiaddr())
l.Close()
require.True(t, found, "Failed after hour: %v", i)
}
}
func TestH3ConnClosed(t *testing.T) {
_, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, nil, libp2pwebtransport.WithHandshakeTimeout(1*time.Second))
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
require.NoError(t, err)
defer ln.Close()
p, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
conn, err := quic.Dial(context.Background(), p, ln.Addr(), &tls.Config{
InsecureSkipVerify: true,
NextProtos: []string{http3.NextProtoH3},
}, nil)
require.NoError(t, err)
rt := &http3.Transport{}
rt.NewClientConn(conn)
require.Eventually(t, func() bool {
c := http.Client{
Transport: rt,
Timeout: 1 * time.Second,
}
resp, err := c.Get(fmt.Sprintf("https://%s", ln.Addr().String()))
if err != nil {
return true
}
resp.Body.Close()
return false
}, 10*time.Second, 1*time.Second)
}