Use default hashmap to store streams. Avoid allocating a stream object on receiving every single frame

pull/148/head
Andy Wang 3 years ago
parent fd5005db0a
commit 35f41424c9
No known key found for this signature in database
GPG Key ID: 181B49F9F38F3374

@ -63,7 +63,9 @@ type Session struct {
// atomic
activeStreamCount uint32
streams sync.Map
streamsM sync.Mutex
streams map[uint32]*Stream
// a pool of heap allocated frame objects so we don't have to allocate a new one each time we receive a frame
recvFramePool sync.Pool
@ -93,6 +95,7 @@ func MakeSession(id uint32, config SessionConfig) *Session {
nextStreamID: 1,
acceptCh: make(chan *Stream, acceptBacklog),
recvFramePool: sync.Pool{New: func() interface{} { return &Frame{} }},
streams: map[uint32]*Stream{},
}
sesh.addrs.Store([]net.Addr{nil, nil})
@ -149,7 +152,9 @@ func (sesh *Session) OpenStream() (*Stream, error) {
return nil, errNoMultiplex
}
stream := makeStream(sesh, id)
sesh.streams.Store(id, stream)
sesh.streamsM.Lock()
sesh.streams[id] = stream
sesh.streamsM.Unlock()
sesh.streamCountIncr()
log.Tracef("stream %v of session %v opened", id, sesh.id)
return stream, nil
@ -200,7 +205,9 @@ func (sesh *Session) closeStream(s *Stream, active bool) error {
// We set it as nil to signify that the stream id had existed before.
// If we Delete(s.id) straight away, later on in recvDataFromRemote, it will not be able to tell
// if the frame it received was from a new stream or a dying stream whose frame arrived late
sesh.streams.Store(s.id, nil)
sesh.streamsM.Lock()
sesh.streams[s.id] = nil
sesh.streamsM.Unlock()
if sesh.streamCountDecr() == 0 {
if sesh.Singleplex {
return sesh.Close()
@ -229,15 +236,19 @@ func (sesh *Session) recvDataFromRemote(data []byte) error {
return sesh.passiveClose()
}
newStream := makeStream(sesh, frame.StreamID)
existingStreamI, existing := sesh.streams.LoadOrStore(frame.StreamID, newStream)
sesh.streamsM.Lock()
existingStream, existing := sesh.streams[frame.StreamID]
if existing {
if existingStreamI == nil {
sesh.streamsM.Unlock()
if existingStream == nil {
// this is when the stream existed before but has since been closed. We do nothing
return nil
}
return existingStreamI.(*Stream).recvFrame(frame)
return existingStream.recvFrame(frame)
} else {
newStream := makeStream(sesh, frame.StreamID)
sesh.streams[frame.StreamID] = newStream
sesh.streamsM.Unlock()
// new stream
sesh.streamCountIncr()
sesh.acceptCh <- newStream
@ -265,17 +276,17 @@ func (sesh *Session) closeSession(closeSwitchboard bool) error {
}
sesh.acceptCh <- nil
sesh.streams.Range(func(key, streamI interface{}) bool {
if streamI == nil {
return true
sesh.streamsM.Lock()
for id, stream := range sesh.streams {
if stream == nil {
continue
}
stream := streamI.(*Stream)
atomic.StoreUint32(&stream.closed, 1)
_ = stream.getRecvBuf().Close() // will not block
sesh.streams.Delete(key)
delete(sesh.streams, id)
sesh.streamCountDecr()
return true
})
}
sesh.streamsM.Unlock()
if closeSwitchboard {
sesh.sb.closeAll()

@ -112,7 +112,9 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
if err != nil {
t.Fatalf("receiving normal frame for stream 1: %v", err)
}
_, ok := sesh.streams.Load(f1.StreamID)
sesh.streamsM.Lock()
_, ok := sesh.streams[f1.StreamID]
sesh.streamsM.Unlock()
if !ok {
t.Fatal("failed to fetch stream 1 after receiving it")
}
@ -132,8 +134,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
if err != nil {
t.Fatalf("receiving normal frame for stream 2: %v", err)
}
s2I, ok := sesh.streams.Load(f2.StreamID)
if s2I == nil || !ok {
sesh.streamsM.Lock()
s2M, ok := sesh.streams[f2.StreamID]
sesh.streamsM.Unlock()
if s2M == nil || !ok {
t.Fatal("failed to fetch stream 2 after receiving it")
}
if sesh.streamCount() != 2 {
@ -152,8 +156,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
if err != nil {
t.Fatalf("receiving stream closing frame for stream 1: %v", err)
}
s1I, _ := sesh.streams.Load(f1.StreamID)
if s1I != nil {
sesh.streamsM.Lock()
s1M, _ := sesh.streams[f1.StreamID]
sesh.streamsM.Unlock()
if s1M != nil {
t.Fatal("stream 1 still exist after receiving stream close")
}
s1, _ := sesh.Accept()
@ -179,8 +185,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
if err != nil {
t.Fatalf("receiving stream closing frame for stream 1 %v", err)
}
s1I, _ = sesh.streams.Load(f1.StreamID)
if s1I != nil {
sesh.streamsM.Lock()
s1M, _ = sesh.streams[f1.StreamID]
sesh.streamsM.Unlock()
if s1M != nil {
t.Error("stream 1 exists after receiving stream close for the second time")
}
streamCount := sesh.streamCount()
@ -243,7 +251,9 @@ func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) {
if err != nil {
t.Fatalf("receiving out of order stream closing frame for stream 1: %v", err)
}
_, ok := sesh.streams.Load(f1CloseStream.StreamID)
sesh.streamsM.Lock()
_, ok := sesh.streams[f1CloseStream.StreamID]
sesh.streamsM.Unlock()
if !ok {
t.Fatal("stream 1 doesn't exist")
}
@ -334,12 +344,13 @@ func TestParallelStreams(t *testing.T) {
wg.Wait()
sc := int(sesh.streamCount())
var count int
sesh.streams.Range(func(_, s interface{}) bool {
sesh.streamsM.Lock()
for _, s := range sesh.streams {
if s != nil {
count++
}
return true
})
}
sesh.streamsM.Unlock()
if sc != count {
t.Errorf("broken referential integrety: actual %v, reference count: %v", count, sc)
}

@ -167,10 +167,13 @@ func TestStream_Close(t *testing.T) {
return
}
if sI, _ := sesh.streams.Load(stream.(*Stream).id); sI != nil {
sesh.streamsM.Lock()
if s, _ := sesh.streams[stream.(*Stream).id]; s != nil {
sesh.streamsM.Unlock()
t.Error("stream still exists")
return
}
sesh.streamsM.Unlock()
_, err = io.ReadFull(stream, readBuf[1:])
if err != nil {
@ -242,8 +245,10 @@ func TestStream_Close(t *testing.T) {
}
assert.Eventually(t, func() bool {
sI, _ := sesh.streams.Load(stream.(*Stream).id)
return sI == nil
sesh.streamsM.Lock()
s, _ := sesh.streams[stream.(*Stream).id]
sesh.streamsM.Unlock()
return s == nil
}, time.Second, 10*time.Millisecond, "streams still exists")
})

Loading…
Cancel
Save