diff --git a/listener/inbound/anytls_test.go b/listener/inbound/anytls_test.go index b28ec233..5d9ac8ba 100644 --- a/listener/inbound/anytls_test.go +++ b/listener/inbound/anytls_test.go @@ -41,7 +41,7 @@ func testInboundAnyTLS(t *testing.T, inboundOptions inbound.AnyTLSOption, outbou outboundOptions.Server = addrPort.Addr().String() outboundOptions.Port = int(addrPort.Port()) outboundOptions.Password = userUUID - outboundOptions.DialerForAPI = NewTestDialer() + outboundOptions.DialerForAPI = tunnel.NewDialer() out, err := outbound.NewAnyTLS(outboundOptions) if !assert.NoError(t, err) { diff --git a/listener/inbound/common_test.go b/listener/inbound/common_test.go index ace5f637..74df75a2 100644 --- a/listener/inbound/common_test.go +++ b/listener/inbound/common_test.go @@ -58,12 +58,15 @@ func init() { realityPublickey = base64.RawURLEncoding.EncodeToString(privateKey.PublicKey().Bytes()) } -type TestDialer struct{ dialer C.Dialer } +type TestDialer struct { + dialer C.Dialer + ctx context.Context +} func (t *TestDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { start: conn, err := t.dialer.DialContext(ctx, network, address) - if err != nil && ctx.Err() == nil { + if err != nil && ctx.Err() == nil && t.ctx.Err() == nil { // We are conducting tests locally, and they shouldn't fail. // However, a large number of requests in a short period during concurrent testing can exhaust system ports. // This can lead to various errors such as WSAECONNREFUSED and WSAENOBUFS. @@ -77,10 +80,6 @@ func (t *TestDialer) ListenPacket(ctx context.Context, network, address string, return t.dialer.ListenPacket(ctx, network, address, rAddrPort) } -func NewTestDialer() *TestDialer { - return &TestDialer{dialer: dialer.NewDialer()} -} - var _ C.Dialer = (*TestDialer)(nil) type TestTunnel struct { @@ -90,6 +89,7 @@ type TestTunnel struct { CloseFn func() error DoSequentialTestFn func(t *testing.T, proxy C.ProxyAdapter) DoConcurrentTestFn func(t *testing.T, proxy C.ProxyAdapter) + NewDialerFn func() C.Dialer } func (tt *TestTunnel) HandleTCPConn(conn net.Conn, metadata *C.Metadata) { @@ -121,6 +121,10 @@ func (tt *TestTunnel) DoConcurrentTest(t *testing.T, proxy C.ProxyAdapter) { tt.DoConcurrentTestFn(t, proxy) } +func (tt *TestTunnel) NewDialer() C.Dialer { + return tt.NewDialerFn() +} + type TestTunnelListener struct { ch chan net.Conn ctx context.Context @@ -353,6 +357,7 @@ func NewHttpTestTunnel() *TestTunnel { CloseFn: ln.Close, DoSequentialTestFn: sequentialTestFn, DoConcurrentTestFn: concurrentTestFn, + NewDialerFn: func() C.Dialer { return &TestDialer{dialer: dialer.NewDialer(), ctx: ctx} }, } return tunnel } diff --git a/listener/inbound/hysteria2_test.go b/listener/inbound/hysteria2_test.go index d866792e..b0132fa9 100644 --- a/listener/inbound/hysteria2_test.go +++ b/listener/inbound/hysteria2_test.go @@ -41,7 +41,7 @@ func testInboundHysteria2(t *testing.T, inboundOptions inbound.Hysteria2Option, outboundOptions.Server = addrPort.Addr().String() outboundOptions.Port = int(addrPort.Port()) outboundOptions.Password = userUUID - outboundOptions.DialerForAPI = NewTestDialer() + outboundOptions.DialerForAPI = tunnel.NewDialer() out, err := outbound.NewHysteria2(outboundOptions) if !assert.NoError(t, err) { diff --git a/listener/inbound/mieru_test.go b/listener/inbound/mieru_test.go index ceee2df5..97742186 100644 --- a/listener/inbound/mieru_test.go +++ b/listener/inbound/mieru_test.go @@ -237,7 +237,7 @@ func testInboundMieruTCP(t *testing.T, handshakeMode string) { Password: "password", HandshakeMode: handshakeMode, } - outboundOptions.DialerForAPI = NewTestDialer() + outboundOptions.DialerForAPI = tunnel.NewDialer() out, err := outbound.NewMieru(outboundOptions) if !assert.NoError(t, err) { return @@ -293,7 +293,7 @@ func testInboundMieruUDP(t *testing.T, handshakeMode string) { Password: "password", HandshakeMode: handshakeMode, } - outboundOptions.DialerForAPI = NewTestDialer() + outboundOptions.DialerForAPI = tunnel.NewDialer() out, err := outbound.NewMieru(outboundOptions) if !assert.NoError(t, err) { return diff --git a/listener/inbound/shadowsocks_test.go b/listener/inbound/shadowsocks_test.go index 5f2c5359..755db2cf 100644 --- a/listener/inbound/shadowsocks_test.go +++ b/listener/inbound/shadowsocks_test.go @@ -85,7 +85,7 @@ func testInboundShadowSocks0(t *testing.T, inboundOptions inbound.ShadowSocksOpt outboundOptions.Server = addrPort.Addr().String() outboundOptions.Port = int(addrPort.Port()) outboundOptions.Password = password - outboundOptions.DialerForAPI = NewTestDialer() + outboundOptions.DialerForAPI = tunnel.NewDialer() out, err := outbound.NewShadowSocks(outboundOptions) if !assert.NoError(t, err) { diff --git a/listener/inbound/sudoku_test.go b/listener/inbound/sudoku_test.go index e8101a56..41348008 100644 --- a/listener/inbound/sudoku_test.go +++ b/listener/inbound/sudoku_test.go @@ -43,7 +43,7 @@ func testInboundSudoku(t *testing.T, inboundOptions inbound.SudokuOption, outbou outboundOptions.Name = "sudoku_outbound" outboundOptions.Server = addrPort.Addr().String() outboundOptions.Port = int(addrPort.Port()) - outboundOptions.DialerForAPI = NewTestDialer() + outboundOptions.DialerForAPI = tunnel.NewDialer() out, err := outbound.NewSudoku(outboundOptions) if !assert.NoError(t, err) { diff --git a/listener/inbound/trojan_test.go b/listener/inbound/trojan_test.go index d29a2977..14c36e82 100644 --- a/listener/inbound/trojan_test.go +++ b/listener/inbound/trojan_test.go @@ -43,7 +43,7 @@ func testInboundTrojan(t *testing.T, inboundOptions inbound.TrojanOption, outbou outboundOptions.Server = addrPort.Addr().String() outboundOptions.Port = int(addrPort.Port()) outboundOptions.Password = userUUID - outboundOptions.DialerForAPI = NewTestDialer() + outboundOptions.DialerForAPI = tunnel.NewDialer() out, err := outbound.NewTrojan(outboundOptions) if !assert.NoError(t, err) { diff --git a/listener/inbound/trusttunnel_test.go b/listener/inbound/trusttunnel_test.go index fb939c4b..4c6c3cdd 100644 --- a/listener/inbound/trusttunnel_test.go +++ b/listener/inbound/trusttunnel_test.go @@ -42,7 +42,7 @@ func testInboundTrustTunnel(t *testing.T, inboundOptions inbound.TrustTunnelOpti outboundOptions.Port = int(addrPort.Port()) outboundOptions.UserName = "test" outboundOptions.Password = userUUID - outboundOptions.DialerForAPI = NewTestDialer() + outboundOptions.DialerForAPI = tunnel.NewDialer() out, err := outbound.NewTrustTunnel(outboundOptions) if !assert.NoError(t, err) { diff --git a/listener/inbound/tuic_test.go b/listener/inbound/tuic_test.go index 094f706e..34bf8e4d 100644 --- a/listener/inbound/tuic_test.go +++ b/listener/inbound/tuic_test.go @@ -69,7 +69,7 @@ func testInboundTuic0(t *testing.T, inboundOptions inbound.TuicOption, outboundO outboundOptions.Name = "tuic_outbound" outboundOptions.Server = addrPort.Addr().String() outboundOptions.Port = int(addrPort.Port()) - outboundOptions.DialerForAPI = NewTestDialer() + outboundOptions.DialerForAPI = tunnel.NewDialer() out, err := outbound.NewTuic(outboundOptions) if !assert.NoError(t, err) { diff --git a/listener/inbound/vless_test.go b/listener/inbound/vless_test.go index f506beb0..a483c341 100644 --- a/listener/inbound/vless_test.go +++ b/listener/inbound/vless_test.go @@ -44,7 +44,7 @@ func testInboundVless(t *testing.T, inboundOptions inbound.VlessOption, outbound outboundOptions.Server = addrPort.Addr().String() outboundOptions.Port = int(addrPort.Port()) outboundOptions.UUID = userUUID - outboundOptions.DialerForAPI = NewTestDialer() + outboundOptions.DialerForAPI = tunnel.NewDialer() out, err := outbound.NewVless(outboundOptions) if !assert.NoError(t, err) { diff --git a/listener/inbound/vmess_test.go b/listener/inbound/vmess_test.go index 848add47..175d9803 100644 --- a/listener/inbound/vmess_test.go +++ b/listener/inbound/vmess_test.go @@ -45,7 +45,7 @@ func testInboundVMess(t *testing.T, inboundOptions inbound.VmessOption, outbound outboundOptions.UUID = userUUID outboundOptions.AlterID = 0 outboundOptions.Cipher = "auto" - outboundOptions.DialerForAPI = NewTestDialer() + outboundOptions.DialerForAPI = tunnel.NewDialer() out, err := outbound.NewVmess(outboundOptions) if !assert.NoError(t, err) {