Fix goroutine leak

pull/2/head
Qian Wang 6 years ago
parent 0db52a8a26
commit 077eb16dba

@ -9,7 +9,7 @@ import (
"net/http"
_ "net/http/pprof"
"os"
//"runtime"
"runtime"
"strings"
"time"
@ -115,7 +115,12 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
newStream, err := sesh.AcceptStream()
if err != nil {
log.Printf("Failed to get new stream: %v", err)
continue
if err == mux.ErrBrokenSession {
sta.DelSession(arrSID)
return
} else {
continue
}
}
ssConn, err := net.Dial("tcp", sta.SS_LOCAL_HOST+":"+sta.SS_LOCAL_PORT)
if err != nil {
@ -131,6 +136,7 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
}
func main() {
runtime.SetBlockProfileRate(5)
go func() {
log.Println(http.ListenAndServe("0.0.0.0:8001", nil))
}()

@ -56,7 +56,12 @@ func (sh *sorterHeap) Pop() interface{} {
func (s *Stream) recvNewFrame() {
for {
f := <-s.newFrameCh
var f *Frame
select {
case <-s.die:
return
case f = <-s.newFrameCh:
}
if f == nil {
log.Println("nil frame")
continue

@ -9,14 +9,15 @@ import (
)
const (
errBrokenSession = "broken session"
errRepeatSessionClosing = "trying to close a closed session"
// Copied from smux
acceptBacklog = 1024
closeBacklog = 512
)
var ErrBrokenSession = errors.New("broken session")
var errRepeatSessionClosing = errors.New("trying to close a closed session")
type Session struct {
id int
@ -58,6 +59,7 @@ func MakeSession(id int, conn net.Conn, obfs func(*Frame) []byte, deobfs func([]
streams: make(map[uint32]*Stream),
acceptCh: make(chan *Stream, acceptBacklog),
closeQCh: make(chan uint32, closeBacklog),
die: make(chan struct{}),
}
sesh.sb = makeSwitchboard(conn, sesh)
return sesh
@ -80,7 +82,7 @@ func (sesh *Session) OpenStream() (*Stream, error) {
func (sesh *Session) AcceptStream() (*Stream, error) {
select {
case <-sesh.die:
return nil, errors.New(errBrokenSession)
return nil, ErrBrokenSession
case stream := <-sesh.acceptCh:
return stream, nil
}
@ -122,7 +124,7 @@ func (sesh *Session) Close() error {
sesh.closingM.Lock()
defer sesh.closingM.Unlock()
if sesh.closing {
return errors.New(errRepeatSessionClosing)
return errRepeatSessionClosing
}
sesh.closing = true
close(sesh.die)
@ -138,6 +140,7 @@ func (sesh *Session) Close() error {
}
sesh.streamsM.Unlock()
close(sesh.sb.die)
return nil
}

@ -7,10 +7,8 @@ import (
"sync/atomic"
)
const (
errBrokenStream = "broken stream"
errRepeatStreamClosing = "trying to close a closed stream"
)
var errBrokenStream = errors.New("broken stream")
var errRepeatStreamClosing = errors.New("trying to close a closed stream")
type Stream struct {
id uint32
@ -23,14 +21,18 @@ type Stream struct {
sh sorterHeap
wrapMode bool
newFrameCh chan *Frame
// New frames are received through newFrameCh by frameSorter
newFrameCh chan *Frame
// sortedBufCh are order-sorted data ready to be read raw
sortedBufCh chan []byte
nextSendSeq uint32
closingM sync.Mutex
die chan struct{}
closing bool
// close(die) is used to notify different goroutines that this stream is closing
die chan struct{}
// to prevent closing a closed channel
closing bool
}
func makeStream(id uint32, sesh *Session) *Stream {
@ -51,7 +53,7 @@ func (stream *Stream) Read(buf []byte) (n int, err error) {
select {
case <-stream.die:
log.Printf("Stream %v dying\n", stream.id)
return 0, errors.New(errBrokenStream)
return 0, errBrokenStream
default:
return 0, nil
}
@ -59,7 +61,7 @@ func (stream *Stream) Read(buf []byte) (n int, err error) {
select {
case <-stream.die:
log.Printf("Stream %v dying\n", stream.id)
return 0, errors.New(errBrokenStream)
return 0, errBrokenStream
case data := <-stream.sortedBufCh:
if len(buf) < len(data) {
log.Println(len(data))
@ -75,7 +77,7 @@ func (stream *Stream) Write(in []byte) (n int, err error) {
select {
case <-stream.die:
log.Printf("Stream %v dying\n", stream.id)
return 0, errors.New(errBrokenStream)
return 0, errBrokenStream
default:
}
@ -109,7 +111,7 @@ func (stream *Stream) Close() error {
stream.closingM.Lock()
defer stream.closingM.Unlock()
if stream.closing {
return errors.New(errRepeatStreamClosing)
return errRepeatStreamClosing
}
stream.closing = true
close(stream.die)
@ -127,7 +129,7 @@ func (stream *Stream) closeNoDelMap() error {
stream.closingM.Lock()
defer stream.closingM.Unlock()
if stream.closing {
return errors.New(errRepeatStreamClosing)
return errRepeatStreamClosing
}
stream.closing = true
close(stream.die)

@ -20,9 +20,12 @@ type switchboard struct {
// For telling dispatcher how many bytes have been sent after Connection.send.
sentNotifyCh chan *sentNotifier
dispatCh chan []byte
newConnCh chan net.Conn
closingCECh chan *connEnclave
// dispatCh is used by streams to send new data to remote
dispatCh chan []byte
newConnCh chan net.Conn
closingCECh chan *connEnclave
die chan struct{}
closing bool
}
// Some data comes from a Stream to be sent through one of the many
@ -57,6 +60,7 @@ func makeSwitchboard(conn net.Conn, sesh *Session) *switchboard {
dispatCh: make(chan []byte, dispatchBacklog),
newConnCh: make(chan net.Conn, newConnBacklog),
closingCECh: make(chan *connEnclave, 5),
die: make(chan struct{}),
}
ce := &connEnclave{
sb: sb,
@ -97,6 +101,7 @@ func (ce *connEnclave) send(data []byte) {
// Dispatcher sends data coming from a stream to a remote connection
// I used channels here because I didn't want to use mutex
func (sb *switchboard) dispatch() {
var dying bool
for {
select {
// dispatCh receives data from stream.Write
@ -123,6 +128,15 @@ func (sb *switchboard) dispatch() {
break
}
}
if len(sb.ces) == 0 && !dying {
sb.session.Close()
}
case <-sb.die:
dying = true
for _, ce := range sb.ces {
ce.remoteConn.Close()
}
return
}
}
}

@ -116,8 +116,8 @@ func (sta *State) ParseConfig(conf string) (err error) {
}
func (sta *State) GetSession(SID [32]byte) *mux.Session {
sta.sessionsM.Lock()
defer sta.sessionsM.Unlock()
sta.sessionsM.RLock()
defer sta.sessionsM.RUnlock()
if sesh, ok := sta.sessions[SID]; ok {
return sesh
} else {
@ -131,6 +131,12 @@ func (sta *State) PutSession(SID [32]byte, sesh *mux.Session) {
sta.sessionsM.Unlock()
}
func (sta *State) DelSession(SID [32]byte) {
sta.sessionsM.Lock()
delete(sta.sessions, SID)
sta.sessionsM.Unlock()
}
func (sta *State) getUsedRandom(random [32]byte) int {
sta.usedRandomM.Lock()
defer sta.usedRandomM.Unlock()

Loading…
Cancel
Save