From eca5f13936eb771d6adbbd5eece3b110763cf5d4 Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Sun, 12 Nov 2023 20:45:17 +0000 Subject: [PATCH] Remove WriteTo from recvBuffer to prevent blocking on external Writer. Fixes #229 --- internal/multiplex/datagramBufferedPipe.go | 48 ---------------------- internal/multiplex/recvBuffer.go | 4 -- internal/multiplex/session_test.go | 2 +- internal/multiplex/stream.go | 12 ------ internal/multiplex/streamBuffer.go | 8 +--- internal/multiplex/streamBufferedPipe.go | 45 -------------------- internal/multiplex/stream_test.go | 26 ------------ 7 files changed, 2 insertions(+), 143 deletions(-) diff --git a/internal/multiplex/datagramBufferedPipe.go b/internal/multiplex/datagramBufferedPipe.go index 7082264..f1df206 100644 --- a/internal/multiplex/datagramBufferedPipe.go +++ b/internal/multiplex/datagramBufferedPipe.go @@ -66,46 +66,6 @@ func (d *datagramBufferedPipe) Read(target []byte) (int, error) { return dataLen, nil } -func (d *datagramBufferedPipe) WriteTo(w io.Writer) (n int64, err error) { - d.rwCond.L.Lock() - defer d.rwCond.L.Unlock() - for { - if d.closed && len(d.pLens) == 0 { - return 0, io.EOF - } - - hasRDeadline := !d.rDeadline.IsZero() - if hasRDeadline { - if time.Until(d.rDeadline) <= 0 { - return 0, ErrTimeout - } - } - - if len(d.pLens) > 0 { - var dataLen int - dataLen, d.pLens = d.pLens[0], d.pLens[1:] - written, er := w.Write(d.buf.Next(dataLen)) - n += int64(written) - if er != nil { - d.rwCond.Broadcast() - return n, er - } - d.rwCond.Broadcast() - } else { - if d.wtTimeout == 0 { - if hasRDeadline { - d.broadcastAfter(time.Until(d.rDeadline)) - } - } else { - d.rDeadline = time.Now().Add(d.wtTimeout) - d.broadcastAfter(d.wtTimeout) - } - - d.rwCond.Wait() - } - } -} - func (d *datagramBufferedPipe) Write(f *Frame) (toBeClosed bool, err error) { d.rwCond.L.Lock() defer d.rwCond.L.Unlock() @@ -151,14 +111,6 @@ func (d *datagramBufferedPipe) SetReadDeadline(t time.Time) { d.rwCond.Broadcast() } -func (d *datagramBufferedPipe) SetWriteToTimeout(t time.Duration) { - d.rwCond.L.Lock() - defer d.rwCond.L.Unlock() - - d.wtTimeout = t - d.rwCond.Broadcast() -} - func (d *datagramBufferedPipe) broadcastAfter(t time.Duration) { if d.timeoutTimer != nil { d.timeoutTimer.Stop() diff --git a/internal/multiplex/recvBuffer.go b/internal/multiplex/recvBuffer.go index 91af149..d8924d2 100644 --- a/internal/multiplex/recvBuffer.go +++ b/internal/multiplex/recvBuffer.go @@ -14,12 +14,8 @@ type recvBuffer interface { // Instead, it should behave as if it hasn't been closed. Closure is only relevant // when the buffer is empty. io.ReadCloser - io.WriterTo Write(*Frame) (toBeClosed bool, err error) SetReadDeadline(time time.Time) - // SetWriteToTimeout sets the duration a recvBuffer waits in a WriteTo call when nothing - // has been written for a while. After that duration it should return ErrTimeout - SetWriteToTimeout(d time.Duration) } // size we want the amount of unread data in buffer to grow before recvBuffer.Write blocks. diff --git a/internal/multiplex/session_test.go b/internal/multiplex/session_test.go index eac7482..0c8c7be 100644 --- a/internal/multiplex/session_test.go +++ b/internal/multiplex/session_test.go @@ -557,7 +557,7 @@ func BenchmarkRecvDataFromRemote(b *testing.B) { go func() { stream, _ := sesh.Accept() - stream.(*Stream).WriteTo(ioutil.Discard) + io.Copy(ioutil.Discard, stream) }() binaryFrames := [maxIter][]byte{} diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index f7e0376..4dd1b71 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -96,17 +96,6 @@ func (s *Stream) Read(buf []byte) (n int, err error) { return } -// WriteTo continuously write data Stream has received into the writer w. -func (s *Stream) WriteTo(w io.Writer) (int64, error) { - // will keep writing until the underlying buffer is closed - n, err := s.recvBuf.WriteTo(w) - log.Tracef("%v read from stream %v with err %v", n, s.id, err) - if err == io.EOF { - return n, ErrBrokenStream - } - return n, nil -} - func (s *Stream) obfuscateAndSend(buf []byte, payloadOffsetInBuf int) error { cipherTextLen, err := s.session.obfuscate(&s.writingFrame, buf, payloadOffsetInBuf) s.writingFrame.Seq++ @@ -210,7 +199,6 @@ func (s *Stream) Close() error { func (s *Stream) LocalAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[0] } func (s *Stream) RemoteAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[1] } -func (s *Stream) SetWriteToTimeout(d time.Duration) { s.recvBuf.SetWriteToTimeout(d) } func (s *Stream) SetReadDeadline(t time.Time) error { s.recvBuf.SetReadDeadline(t); return nil } func (s *Stream) SetReadFromTimeout(d time.Duration) { s.readFromTimeout = d } diff --git a/internal/multiplex/streamBuffer.go b/internal/multiplex/streamBuffer.go index a9fd0fd..293df6e 100644 --- a/internal/multiplex/streamBuffer.go +++ b/internal/multiplex/streamBuffer.go @@ -13,7 +13,6 @@ package multiplex import ( "container/heap" "fmt" - "io" "sync" "time" ) @@ -102,10 +101,6 @@ func (sb *streamBuffer) Read(buf []byte) (int, error) { return sb.buf.Read(buf) } -func (sb *streamBuffer) WriteTo(w io.Writer) (int64, error) { - return sb.buf.WriteTo(w) -} - func (sb *streamBuffer) Close() error { sb.recvM.Lock() defer sb.recvM.Unlock() @@ -113,5 +108,4 @@ func (sb *streamBuffer) Close() error { return sb.buf.Close() } -func (sb *streamBuffer) SetReadDeadline(t time.Time) { sb.buf.SetReadDeadline(t) } -func (sb *streamBuffer) SetWriteToTimeout(d time.Duration) { sb.buf.SetWriteToTimeout(d) } +func (sb *streamBuffer) SetReadDeadline(t time.Time) { sb.buf.SetReadDeadline(t) } diff --git a/internal/multiplex/streamBufferedPipe.go b/internal/multiplex/streamBufferedPipe.go index 0dd3e46..6b6ff5b 100644 --- a/internal/multiplex/streamBufferedPipe.go +++ b/internal/multiplex/streamBufferedPipe.go @@ -58,43 +58,6 @@ func (p *streamBufferedPipe) Read(target []byte) (int, error) { return n, err } -func (p *streamBufferedPipe) WriteTo(w io.Writer) (n int64, err error) { - p.rwCond.L.Lock() - defer p.rwCond.L.Unlock() - for { - if p.closed && p.buf.Len() == 0 { - return 0, io.EOF - } - - hasRDeadline := !p.rDeadline.IsZero() - if hasRDeadline { - if time.Until(p.rDeadline) <= 0 { - return 0, ErrTimeout - } - } - if p.buf.Len() > 0 { - written, er := p.buf.WriteTo(w) - n += written - if er != nil { - p.rwCond.Broadcast() - return n, er - } - p.rwCond.Broadcast() - } else { - if p.wtTimeout == 0 { - if hasRDeadline { - p.broadcastAfter(time.Until(p.rDeadline)) - } - } else { - p.rDeadline = time.Now().Add(p.wtTimeout) - p.broadcastAfter(p.wtTimeout) - } - - p.rwCond.Wait() - } - } -} - func (p *streamBufferedPipe) Write(input []byte) (int, error) { p.rwCond.L.Lock() defer p.rwCond.L.Unlock() @@ -131,14 +94,6 @@ func (p *streamBufferedPipe) SetReadDeadline(t time.Time) { p.rwCond.Broadcast() } -func (p *streamBufferedPipe) SetWriteToTimeout(d time.Duration) { - p.rwCond.L.Lock() - defer p.rwCond.L.Unlock() - - p.wtTimeout = d - p.rwCond.Broadcast() -} - func (p *streamBufferedPipe) broadcastAfter(d time.Duration) { if p.timeoutTimer != nil { p.timeoutTimer.Stop() diff --git a/internal/multiplex/stream_test.go b/internal/multiplex/stream_test.go index ceb0835..35a6684 100644 --- a/internal/multiplex/stream_test.go +++ b/internal/multiplex/stream_test.go @@ -3,7 +3,6 @@ package multiplex import ( "bytes" "io" - "io/ioutil" "math/rand" "testing" "time" @@ -364,31 +363,6 @@ func TestStream_Read(t *testing.T) { } } -func TestStream_SetWriteToTimeout(t *testing.T) { - seshes := map[string]*Session{ - "ordered": setupSesh(false, emptyKey, EncryptionMethodPlain), - "unordered": setupSesh(true, emptyKey, EncryptionMethodPlain), - } - for name, sesh := range seshes { - t.Run(name, func(t *testing.T) { - stream, _ := sesh.OpenStream() - stream.SetWriteToTimeout(100 * time.Millisecond) - done := make(chan struct{}) - go func() { - stream.WriteTo(ioutil.Discard) - done <- struct{}{} - }() - - select { - case <-done: - return - case <-time.After(500 * time.Millisecond): - t.Error("didn't timeout") - } - }) - } -} - func TestStream_SetReadFromTimeout(t *testing.T) { seshes := map[string]*Session{ "ordered": setupSesh(false, emptyKey, EncryptionMethodPlain),