diff --git a/server_play_test.go b/server_play_test.go index d8fa1f07..3a6b124e 100644 --- a/server_play_test.go +++ b/server_play_test.go @@ -640,7 +640,14 @@ func TestServerPlay(t *testing.T) { close(nconnClosed) }, - onSessionOpen: func(_ *ServerHandlerOnSessionOpenCtx) { + onSessionOpen: func(ctx *ServerHandlerOnSessionOpenCtx) { + require.NotEmpty(t, ctx.Session.Conns()) + + // test that properties can be accessed in parallel + go func() { + ctx.Session.Conns() + }() + close(sessionOpened) }, onSessionClose: func(ctx *ServerHandlerOnSessionCloseCtx) { diff --git a/server_session.go b/server_session.go index 0aa6f853..542057b8 100644 --- a/server_session.go +++ b/server_session.go @@ -389,6 +389,8 @@ func (ss *ServerSession) initialize() { ss.chAsyncStartWriter = make(chan struct{}) ss.chWriterError = make(chan error) + ss.conns[ss.author] = struct{}{} + ss.s.wg.Add(1) go ss.run() } @@ -460,6 +462,21 @@ func (ss *ServerSession) UserData() any { return ss.userData } +// Conns returns connections associated with the session. +func (ss *ServerSession) Conns() []*ServerConn { + ss.propsMutex.RLock() + defer ss.propsMutex.RUnlock() + + ret := make([]*ServerConn, len(ss.conns)) + i := 0 + for sc := range ss.conns { + ret[i] = sc + i++ + } + + return ret +} + // Transport returns transport details. // This is non-nil only if SETUP has been called at least once. func (ss *ServerSession) Transport() *SessionTransport { @@ -785,7 +802,9 @@ func (ss *ServerSession) runInner() error { ss.lastRequestTime = ss.s.timeNow() if _, ok := ss.conns[req.sc]; !ok { + ss.propsMutex.Lock() ss.conns[req.sc] = struct{}{} + ss.propsMutex.Unlock() } res, err := ss.handleRequestInner(req.sc, req.req) @@ -811,9 +830,11 @@ func (ss *ServerSession) runInner() error { }.Marshal() } - // after a TEARDOWN, session must be unpaired with the connection + // after a TEARDOWN, session must be unpaired from the connection if req.req.Method == base.Teardown { + ss.propsMutex.Lock() delete(ss.conns, req.sc) + ss.propsMutex.Unlock() returnedSession = nil } } @@ -831,7 +852,9 @@ func (ss *ServerSession) runInner() error { } case sc := <-ss.chRemoveConn: + ss.propsMutex.Lock() delete(ss.conns, sc) + ss.propsMutex.Unlock() // if session is not in state RECORD or PLAY, or transport is TCP, // and there are no associated connections,