diff --git a/net/net.go b/net/net.go index 3bba86e..20f1e50 100644 --- a/net/net.go +++ b/net/net.go @@ -50,3 +50,11 @@ func (d *Deadline) Deadline() <-chan struct{} { }) return d.deadline } + +func (d *Deadline) Close() error { + d.init.Do(func() { + d.deadline = make(chan struct{}) + }) + close(d.deadline) + return nil +} diff --git a/p2p/conn.go b/p2p/conn.go index 6b6d4f7..bfd2f84 100644 --- a/p2p/conn.go +++ b/p2p/conn.go @@ -8,13 +8,13 @@ import ( "io" "log/slog" "net" - "os" "sync" "time" "github.com/rkonfj/peerguard/disco" "github.com/rkonfj/peerguard/disco/tp" "github.com/rkonfj/peerguard/lru" + N "github.com/rkonfj/peerguard/net" "github.com/rkonfj/peerguard/netlink" "storj.io/common/base58" ) @@ -31,11 +31,12 @@ var ( type PeerPacketConn struct { cfg Config closedSig chan struct{} - readTimeout chan struct{} udpConn *tp.UDPConn wsConn *tp.WSConn discoCooling *lru.Cache[disco.PeerID, time.Time] discoCoolingMutex sync.Mutex + + deadlineRead N.Deadline } // ReadFrom reads a packet from the connection, @@ -52,8 +53,8 @@ func (c *PeerPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { case <-c.closedSig: err = net.ErrClosed return - case <-c.readTimeout: - err = os.ErrDeadlineExceeded + case <-c.deadlineRead.Deadline(): + err = N.ErrDeadline return case datagram := <-c.wsConn.Datagrams(): addr = datagram.PeerID @@ -91,7 +92,7 @@ func (c *PeerPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { // Any blocked ReadFrom or WriteTo operations will be unblocked and return errors. func (c *PeerPacketConn) Close() error { close(c.closedSig) - close(c.readTimeout) + c.deadlineRead.Close() var errs []error if err := c.wsConn.Close(); err != nil { errs = append(errs, err) @@ -136,10 +137,7 @@ func (c *PeerPacketConn) SetDeadline(t time.Time) error { // and any currently-blocked ReadFrom call. // A zero value for t means ReadFrom will not time out. func (c *PeerPacketConn) SetReadDeadline(t time.Time) error { - timeout := time.Until(t) - if timeout > 0 { - time.AfterFunc(timeout, func() { c.readTimeout <- struct{}{} }) - } + c.deadlineRead.SetDeadline(t) return nil } @@ -333,7 +331,6 @@ func ListenPacketContext(ctx context.Context, peermap *disco.Peermap, opts ...Op packetConn := PeerPacketConn{ cfg: cfg, closedSig: make(chan struct{}), - readTimeout: make(chan struct{}), udpConn: udpConn, wsConn: wsConn, discoCooling: lru.New[disco.PeerID, time.Time](1024), diff --git a/rdt/rdt.go b/rdt/rdt.go index a060003..1de310a 100644 --- a/rdt/rdt.go +++ b/rdt/rdt.go @@ -159,6 +159,7 @@ func (c *rdtConn) Close() error { close(c.nck) close(c.finack) close(c.sendEvent) + c.deadlineRead.Close() c.inboundBuf = nil c.sentNO = 0 c.sendMutex.Lock()