|
|
|
@ -51,8 +51,7 @@ type Session struct {
|
|
|
|
|
// atomic
|
|
|
|
|
nextStreamID uint32
|
|
|
|
|
|
|
|
|
|
streamsM sync.Mutex
|
|
|
|
|
streams map[uint32]*Stream
|
|
|
|
|
streams sync.Map
|
|
|
|
|
|
|
|
|
|
// Switchboard manages all connections to remote
|
|
|
|
|
sb *switchboard
|
|
|
|
@ -73,7 +72,6 @@ func MakeSession(id uint32, config *SessionConfig) *Session {
|
|
|
|
|
id: id,
|
|
|
|
|
SessionConfig: config,
|
|
|
|
|
nextStreamID: 1,
|
|
|
|
|
streams: make(map[uint32]*Stream),
|
|
|
|
|
acceptCh: make(chan *Stream, acceptBacklog),
|
|
|
|
|
}
|
|
|
|
|
sesh.addrs.Store([]net.Addr{nil, nil})
|
|
|
|
@ -108,14 +106,12 @@ func (sesh *Session) OpenStream() (*Stream, error) {
|
|
|
|
|
}
|
|
|
|
|
id := atomic.AddUint32(&sesh.nextStreamID, 1) - 1
|
|
|
|
|
// Because atomic.AddUint32 returns the value after incrementation
|
|
|
|
|
connId, err := sesh.sb.assignRandomConn()
|
|
|
|
|
connId, _, err := sesh.sb.pickRandConn()
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
stream := makeStream(sesh, id, connId)
|
|
|
|
|
sesh.streamsM.Lock()
|
|
|
|
|
sesh.streams[id] = stream
|
|
|
|
|
sesh.streamsM.Unlock()
|
|
|
|
|
sesh.streams.Store(id, stream)
|
|
|
|
|
log.Tracef("stream %v of session %v opened", id, sesh.id)
|
|
|
|
|
return stream, nil
|
|
|
|
|
}
|
|
|
|
@ -128,9 +124,7 @@ func (sesh *Session) Accept() (net.Conn, error) {
|
|
|
|
|
if stream == nil {
|
|
|
|
|
return nil, ErrBrokenSession
|
|
|
|
|
}
|
|
|
|
|
sesh.streamsM.Lock()
|
|
|
|
|
sesh.streams[stream.id] = stream
|
|
|
|
|
sesh.streamsM.Unlock()
|
|
|
|
|
sesh.streams.Store(stream.id, stream)
|
|
|
|
|
log.Tracef("stream %v of session %v accepted", stream.id, sesh.id)
|
|
|
|
|
return stream, nil
|
|
|
|
|
}
|
|
|
|
@ -166,26 +160,29 @@ func (sesh *Session) closeStream(s *Stream, active bool) error {
|
|
|
|
|
log.Tracef("stream %v passively closed", s.id)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
sesh.streamsM.Lock()
|
|
|
|
|
delete(sesh.streams, s.id)
|
|
|
|
|
if len(sesh.streams) == 0 {
|
|
|
|
|
sesh.streams.Delete(s.id)
|
|
|
|
|
var count int
|
|
|
|
|
sesh.streams.Range(func(_, _ interface{}) bool {
|
|
|
|
|
count += 1
|
|
|
|
|
return true
|
|
|
|
|
})
|
|
|
|
|
if count == 0 {
|
|
|
|
|
log.Tracef("session %v has no active stream left", sesh.id)
|
|
|
|
|
go sesh.timeoutAfter(30 * time.Second)
|
|
|
|
|
}
|
|
|
|
|
sesh.streamsM.Unlock()
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// recvDataFromRemote deobfuscate the frame and send it to the appropriate stream buffer
|
|
|
|
|
func (sesh *Session) recvDataFromRemote(data []byte) error {
|
|
|
|
|
frame, err := sesh.Deobfs(data)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return fmt.Errorf("Failed to decrypt a frame for session %v: %v", sesh.id, err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
sesh.streamsM.Lock()
|
|
|
|
|
stream, existing := sesh.streams[frame.StreamID]
|
|
|
|
|
sesh.streamsM.Unlock()
|
|
|
|
|
streamI, existing := sesh.streams.Load(frame.StreamID)
|
|
|
|
|
if existing {
|
|
|
|
|
stream := streamI.(*Stream)
|
|
|
|
|
return stream.writeFrame(*frame)
|
|
|
|
|
} else {
|
|
|
|
|
if frame.Closing == 1 {
|
|
|
|
@ -200,9 +197,9 @@ func (sesh *Session) recvDataFromRemote(data []byte) error {
|
|
|
|
|
// any difference because we only care to send the data from the same stream through the same
|
|
|
|
|
// TCP connection. The remote may use a different connection to send the same stream than the one the client
|
|
|
|
|
// use to send.
|
|
|
|
|
connId, _ := sesh.sb.assignRandomConn()
|
|
|
|
|
connId, _, _ := sesh.sb.pickRandConn()
|
|
|
|
|
// we ignore the error here. If the switchboard is broken, it will be reflected upon stream.Write
|
|
|
|
|
stream = makeStream(sesh, frame.StreamID, connId)
|
|
|
|
|
stream := makeStream(sesh, frame.StreamID, connId)
|
|
|
|
|
sesh.acceptCh <- stream
|
|
|
|
|
return stream.writeFrame(*frame)
|
|
|
|
|
}
|
|
|
|
@ -230,13 +227,13 @@ func (sesh *Session) passiveClose() error {
|
|
|
|
|
}
|
|
|
|
|
sesh.acceptCh <- nil
|
|
|
|
|
|
|
|
|
|
sesh.streamsM.Lock()
|
|
|
|
|
for id, stream := range sesh.streams {
|
|
|
|
|
sesh.streams.Range(func(key, streamI interface{}) bool {
|
|
|
|
|
stream := streamI.(*Stream)
|
|
|
|
|
atomic.StoreUint32(&stream.closed, 1)
|
|
|
|
|
_ = stream.recvBuf.Close() // both datagramBuffer and streamBuffer won't return err on Close()
|
|
|
|
|
delete(sesh.streams, id)
|
|
|
|
|
}
|
|
|
|
|
sesh.streamsM.Unlock()
|
|
|
|
|
_ = stream.recvBuf.Close() // will not block
|
|
|
|
|
sesh.streams.Delete(key)
|
|
|
|
|
return true
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
sesh.sb.closeAll()
|
|
|
|
|
log.Debugf("session %v closed gracefully", sesh.id)
|
|
|
|
@ -259,13 +256,13 @@ func (sesh *Session) Close() error {
|
|
|
|
|
}
|
|
|
|
|
sesh.acceptCh <- nil
|
|
|
|
|
|
|
|
|
|
sesh.streamsM.Lock()
|
|
|
|
|
for id, stream := range sesh.streams {
|
|
|
|
|
sesh.streams.Range(func(key, streamI interface{}) bool {
|
|
|
|
|
stream := streamI.(*Stream)
|
|
|
|
|
atomic.StoreUint32(&stream.closed, 1)
|
|
|
|
|
_ = stream.recvBuf.Close() // both datagramBuffer and streamBuffer won't return err on Close()
|
|
|
|
|
delete(sesh.streams, id)
|
|
|
|
|
}
|
|
|
|
|
sesh.streamsM.Unlock()
|
|
|
|
|
_ = stream.recvBuf.Close() // will not block
|
|
|
|
|
sesh.streams.Delete(key)
|
|
|
|
|
return true
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
pad := genRandomPadding()
|
|
|
|
|
f := &Frame{
|
|
|
|
@ -295,13 +292,14 @@ func (sesh *Session) IsClosed() bool {
|
|
|
|
|
|
|
|
|
|
func (sesh *Session) timeoutAfter(to time.Duration) {
|
|
|
|
|
time.Sleep(to)
|
|
|
|
|
sesh.streamsM.Lock()
|
|
|
|
|
if len(sesh.streams) == 0 && !sesh.IsClosed() {
|
|
|
|
|
sesh.streamsM.Unlock()
|
|
|
|
|
var count int
|
|
|
|
|
sesh.streams.Range(func(_, _ interface{}) bool {
|
|
|
|
|
count += 1
|
|
|
|
|
return true
|
|
|
|
|
})
|
|
|
|
|
if count == 0 && !sesh.IsClosed() {
|
|
|
|
|
sesh.SetTerminalMsg("timeout")
|
|
|
|
|
sesh.Close()
|
|
|
|
|
} else {
|
|
|
|
|
sesh.streamsM.Unlock()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|