Merge pull request #1185 from libp2p/circuit-shutdown

don't use a context to shut down the circuitv2
This commit is contained in:
Marten Seemann
2021-09-17 18:23:41 +02:00
committed by GitHub
6 changed files with 46 additions and 35 deletions
+21 -7
View File
@@ -2,12 +2,14 @@ package client
import (
"context"
"io"
"sync"
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/proto"
"github.com/libp2p/go-libp2p-core/host"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/transport"
logging "github.com/ipfs/go-log/v2"
tptu "github.com/libp2p/go-libp2p-transport-upgrader"
@@ -24,9 +26,10 @@ var log = logging.Logger("p2p-circuit")
// This allows us to use the v2 code as drop in replacement for v1 in a host without breaking
// existing code and interoperability with older nodes.
type Client struct {
ctx context.Context
host host.Host
upgrader *tptu.Upgrader
ctx context.Context
ctxCancel context.CancelFunc
host host.Host
upgrader *tptu.Upgrader
incoming chan accept
@@ -35,6 +38,9 @@ type Client struct {
hopCount map[peer.ID]int
}
var _ io.Closer = &Client{}
var _ transport.Transport = &Client{}
type accept struct {
conn *Conn
writeResponse func() error
@@ -48,15 +54,16 @@ type completion struct {
// New constructs a new p2p-circuit/v2 client, attached to the given host and using the given
// upgrader to perform connection upgrades.
func New(ctx context.Context, h host.Host, upgrader *tptu.Upgrader) (*Client, error) {
return &Client{
ctx: ctx,
func New(h host.Host, upgrader *tptu.Upgrader) (*Client, error) {
cl := &Client{
host: h,
upgrader: upgrader,
incoming: make(chan accept),
activeDials: make(map[peer.ID]*completion),
hopCount: make(map[peer.ID]int),
}, nil
}
cl.ctx, cl.ctxCancel = context.WithCancel(context.Background())
return cl, nil
}
// Start registers the circuit (client) protocol stream handlers
@@ -64,3 +71,10 @@ func (c *Client) Start() {
c.host.SetStreamHandler(proto.ProtoIDv1, c.handleStreamV1)
c.host.SetStreamHandler(proto.ProtoIDv2Stop, c.handleStreamV2)
}
func (c *Client) Close() error {
c.ctxCancel()
c.host.RemoveStreamHandler(proto.ProtoIDv1)
c.host.RemoveStreamHandler(proto.ProtoIDv2Stop)
return nil
}
+3 -3
View File
@@ -1,6 +1,7 @@
package client
import (
"errors"
"net"
ma "github.com/multiformats/go-multiaddr"
@@ -32,7 +33,7 @@ func (l *Listener) Accept() (manet.Conn, error) {
return evt.conn, nil
case <-l.ctx.Done():
return nil, l.ctx.Err()
return nil, errors.New("circuit v2 client closed")
}
}
}
@@ -49,6 +50,5 @@ func (l *Listener) Multiaddr() ma.Multiaddr {
}
func (l *Listener) Close() error {
// noop for now
return nil
return (*Client)(l).Close()
}
+4 -2
View File
@@ -3,6 +3,7 @@ package client
import (
"context"
"fmt"
"io"
"github.com/libp2p/go-libp2p-core/host"
"github.com/libp2p/go-libp2p-core/peer"
@@ -17,13 +18,13 @@ var circuitAddr = ma.Cast(circuitProtocol.VCode)
// AddTransport constructs a new p2p-circuit/v2 client and adds it as a transport to the
// host network
func AddTransport(ctx context.Context, h host.Host, upgrader *tptu.Upgrader) error {
func AddTransport(h host.Host, upgrader *tptu.Upgrader) error {
n, ok := h.Network().(transport.TransportNetwork)
if !ok {
return fmt.Errorf("%v is not a transport network", h.Network())
}
c, err := New(ctx, h, upgrader)
c, err := New(h, upgrader)
if err != nil {
return fmt.Errorf("error constructing circuit client: %w", err)
}
@@ -45,6 +46,7 @@ func AddTransport(ctx context.Context, h host.Host, upgrader *tptu.Upgrader) err
// Transport interface
var _ transport.Transport = (*Client)(nil)
var _ io.Closer = (*Client)(nil)
func (c *Client) Dial(ctx context.Context, a ma.Multiaddr, p peer.ID) (transport.CapableConn, error) {
conn, err := c.dial(ctx, a, p)
+2 -2
View File
@@ -30,7 +30,7 @@ func TestRelayCompatV2DialV1(t *testing.T) {
hosts, upgraders := getNetHosts(t, ctx, 3)
addTransportV1(t, ctx, hosts[0], upgraders[0])
addTransport(t, ctx, hosts[2], upgraders[2])
addTransport(t, hosts[2], upgraders[2])
rch := make(chan []byte, 1)
hosts[0].SetStreamHandler("test", func(s network.Stream) {
@@ -105,7 +105,7 @@ func TestRelayCompatV1DialV2(t *testing.T) {
defer cancel()
hosts, upgraders := getNetHosts(t, ctx, 3)
addTransport(t, ctx, hosts[0], upgraders[0])
addTransport(t, hosts[0], upgraders[0])
addTransportV1(t, ctx, hosts[2], upgraders[2])
rch := make(chan []byte, 1)
+11 -12
View File
@@ -20,12 +20,12 @@ import (
logging "github.com/ipfs/go-log/v2"
bhost "github.com/libp2p/go-libp2p-blankhost"
metrics "github.com/libp2p/go-libp2p-core/metrics"
pstoremem "github.com/libp2p/go-libp2p-peerstore/pstoremem"
"github.com/libp2p/go-libp2p-core/metrics"
"github.com/libp2p/go-libp2p-peerstore/pstoremem"
swarm "github.com/libp2p/go-libp2p-swarm"
swarmt "github.com/libp2p/go-libp2p-swarm/testing"
tptu "github.com/libp2p/go-libp2p-transport-upgrader"
tcp "github.com/libp2p/go-tcp-transport"
"github.com/libp2p/go-tcp-transport"
ma "github.com/multiformats/go-multiaddr"
)
@@ -85,9 +85,8 @@ func connect(t *testing.T, a, b host.Host) {
}
}
func addTransport(t *testing.T, ctx context.Context, h host.Host, upgrader *tptu.Upgrader) {
err := client.AddTransport(ctx, h, upgrader)
if err != nil {
func addTransport(t *testing.T, h host.Host, upgrader *tptu.Upgrader) {
if err := client.AddTransport(h, upgrader); err != nil {
t.Fatal(err)
}
}
@@ -97,8 +96,8 @@ func TestBasicRelay(t *testing.T) {
defer cancel()
hosts, upgraders := getNetHosts(t, ctx, 3)
addTransport(t, ctx, hosts[0], upgraders[0])
addTransport(t, ctx, hosts[2], upgraders[2])
addTransport(t, hosts[0], upgraders[0])
addTransport(t, hosts[2], upgraders[2])
rch := make(chan []byte, 1)
hosts[0].SetStreamHandler("test", func(s network.Stream) {
@@ -184,8 +183,8 @@ func TestRelayLimitTime(t *testing.T) {
defer cancel()
hosts, upgraders := getNetHosts(t, ctx, 3)
addTransport(t, ctx, hosts[0], upgraders[0])
addTransport(t, ctx, hosts[2], upgraders[2])
addTransport(t, hosts[0], upgraders[0])
addTransport(t, hosts[2], upgraders[2])
rch := make(chan error, 1)
hosts[0].SetStreamHandler("test", func(s network.Stream) {
@@ -258,8 +257,8 @@ func TestRelayLimitData(t *testing.T) {
defer cancel()
hosts, upgraders := getNetHosts(t, ctx, 3)
addTransport(t, ctx, hosts[0], upgraders[0])
addTransport(t, ctx, hosts[2], upgraders[2])
addTransport(t, hosts[0], upgraders[0])
addTransport(t, hosts[2], upgraders[2])
rch := make(chan int, 1)
hosts[0].SetStreamHandler("test", func(s network.Stream) {