Implement WriteTo and ReadFrom timeouts

pull/115/head
Andy Wang 4 years ago
parent 4a81683e44
commit e202d8d03b

@ -128,13 +128,14 @@ func RouteTCP(listener net.Listener, streamTimeout time.Duration, newSeshFunc fu
stream.Close()
return
}
stream.SetReadFromTimeout(streamTimeout) // if localConn hasn't sent anything to stream to a period of time, stream closes
go func() {
if _, err := common.Copy(localConn, stream, 0); err != nil {
if _, err := common.Copy(localConn, stream); err != nil {
log.Tracef("copying stream to proxy client: %v", err)
}
}()
//util.Pipe(stream, localConn, localConfig.Timeout)
if _, err = common.Copy(stream, localConn, streamTimeout); err != nil {
if _, err = common.Copy(stream, localConn); err != nil {
log.Tracef("copying proxy client to stream: %v", err)
}
}()

@ -35,12 +35,9 @@ package common
import (
"io"
"net"
"time"
)
// copyBuffer is the actual implementation of Copy and CopyBuffer.
// if buf is nil, one is allocated.
func Copy(dst net.Conn, src net.Conn, srcReadTimeout time.Duration) (written int64, err error) {
func Copy(dst net.Conn, src net.Conn) (written int64, err error) {
defer func() { src.Close(); dst.Close() }()
// If the reader has a WriteTo method, use it to do the copy.
@ -56,13 +53,6 @@ func Copy(dst net.Conn, src net.Conn, srcReadTimeout time.Duration) (written int
size := 32 * 1024
buf := make([]byte, size)
for {
if srcReadTimeout != 0 {
// TODO: don't rely on setreaddeadline
err = src.SetReadDeadline(time.Now().Add(srcReadTimeout))
if err != nil {
break
}
}
nr, er := src.Read(buf)
if nr > 0 {
nw, ew := dst.Write(buf[0:nr])

@ -22,6 +22,7 @@ type bufferedPipe struct {
closed bool
rwCond *sync.Cond
rDeadline time.Time
wtTimeout time.Duration
}
func NewBufferedPipe() *bufferedPipe {
@ -74,7 +75,14 @@ func (p *bufferedPipe) WriteTo(w io.Writer) (n int64, err error) {
if d <= 0 {
return 0, ErrTimeout
}
time.AfterFunc(d, p.rwCond.Broadcast)
if p.wtTimeout == 0 {
// if there hasn't been a scheduled broadcast
time.AfterFunc(d, p.rwCond.Broadcast)
}
}
if p.wtTimeout != 0 {
p.rDeadline = time.Now().Add(p.wtTimeout)
time.AfterFunc(p.wtTimeout, p.rwCond.Broadcast)
}
if p.buf.Len() > 0 {
written, er := p.buf.WriteTo(w)
@ -127,3 +135,11 @@ func (p *bufferedPipe) SetReadDeadline(t time.Time) {
p.rDeadline = t
p.rwCond.Broadcast()
}
func (p *bufferedPipe) SetWriteToTimeout(d time.Duration) {
p.rwCond.L.Lock()
defer p.rwCond.L.Unlock()
p.wtTimeout = d
p.rwCond.Broadcast()
}

@ -17,6 +17,7 @@ type datagramBuffer struct {
buf *bytes.Buffer
closed bool
rwCond *sync.Cond
wtTimeout time.Duration
rDeadline time.Time
}
@ -72,13 +73,19 @@ func (d *datagramBuffer) WriteTo(w io.Writer) (n int64, err error) {
if d.closed && len(d.pLens) == 0 {
return 0, io.EOF
}
if !d.rDeadline.IsZero() {
delta := time.Until(d.rDeadline)
if delta <= 0 {
return 0, ErrTimeout
}
time.AfterFunc(delta, d.rwCond.Broadcast)
if d.wtTimeout == 0 {
// if there hasn't been a scheduled broadcast
time.AfterFunc(delta, d.rwCond.Broadcast)
}
}
if d.wtTimeout != 0 {
d.rDeadline = time.Now().Add(d.wtTimeout)
time.AfterFunc(d.wtTimeout, d.rwCond.Broadcast)
}
if len(d.pLens) > 0 {
@ -143,3 +150,11 @@ func (d *datagramBuffer) SetReadDeadline(t time.Time) {
d.rDeadline = t
d.rwCond.Broadcast()
}
func (d *datagramBuffer) SetWriteToTimeout(t time.Duration) {
d.rwCond.L.Lock()
defer d.rwCond.L.Unlock()
d.wtTimeout = t
d.rwCond.Broadcast()
}

@ -11,4 +11,5 @@ type recvBuffer interface {
io.WriterTo
Write(Frame) (toBeClosed bool, err error)
SetReadDeadline(time time.Time)
SetWriteToTimeout(d time.Duration)
}

@ -36,6 +36,8 @@ type Stream struct {
// overall the streams in a session should be uniformly distributed across all connections
// This is not used in unordered connection mode
assignedConnId uint32
rfTimeout time.Duration
}
func makeStream(sesh *Session, id uint32) *Stream {
@ -152,6 +154,13 @@ func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) {
s.obfsBuf = make([]byte, s.session.SendBufferSize)
}
for {
if s.rfTimeout != 0 {
if rder, ok := r.(net.Conn); !ok {
log.Warn("ReadFrom timeout is set but reader doesn't implement SetReadDeadline")
} else {
rder.SetReadDeadline(time.Now().Add(s.rfTimeout))
}
}
read, er := r.Read(s.obfsBuf[HEADER_LEN : HEADER_LEN+s.session.maxStreamUnitWrite])
if er != nil {
return n, er
@ -199,5 +208,7 @@ func (s *Stream) RemoteAddr() net.Addr { return s.session.addrs.Load().([]net.Ad
// TODO: implement the following
func (s *Stream) SetDeadline(t time.Time) error { return errNotImplemented }
func (s *Stream) SetWriteToTimeout(d time.Duration) { s.recvBuf.SetWriteToTimeout(d) }
func (s *Stream) SetReadDeadline(t time.Time) error { s.recvBuf.SetReadDeadline(t); return nil }
func (s *Stream) SetReadFromTimeout(d time.Duration) { s.rfTimeout = d }
func (s *Stream) SetWriteDeadline(t time.Time) error { return errNotImplemented }

@ -106,4 +106,5 @@ func (sb *streamBuffer) Close() error {
return sb.buf.Close()
}
func (sb *streamBuffer) SetReadDeadline(t time.Time) { sb.buf.SetReadDeadline(t) }
func (sb *streamBuffer) SetReadDeadline(t time.Time) { sb.buf.SetReadDeadline(t) }
func (sb *streamBuffer) SetWriteToTimeout(d time.Duration) { sb.buf.SetWriteToTimeout(d) }

@ -4,6 +4,7 @@ import (
"bytes"
"github.com/cbeuw/Cloak/internal/common"
"io"
"io/ioutil"
"math/rand"
"testing"
"time"
@ -13,6 +14,8 @@ import (
const payloadLen = 1000
var emptyKey [32]byte
func setupSesh(unordered bool, key [32]byte, encryptionMethod byte) *Session {
obfuscator, _ := MakeObfuscator(encryptionMethod, key)
@ -433,5 +436,53 @@ func TestStream_UnorderedRead(t *testing.T) {
"got nil error")
}
})
}
func TestStream_SetWriteToTimeout(t *testing.T) {
seshes := map[string]*Session{
"ordered": setupSesh(false, emptyKey, E_METHOD_PLAIN),
"unordered": setupSesh(true, emptyKey, E_METHOD_PLAIN),
}
for name, sesh := range seshes {
t.Run(name, func(t *testing.T) {
stream, _ := sesh.OpenStream()
stream.SetWriteToTimeout(100 * time.Millisecond)
done := make(chan struct{})
go func() {
stream.WriteTo(ioutil.Discard)
done <- struct{}{}
}()
select {
case <-done:
return
case <-time.After(500 * time.Millisecond):
t.Error("didn't timeout")
}
})
}
}
func TestStream_SetReadFromTimeout(t *testing.T) {
seshes := map[string]*Session{
"ordered": setupSesh(false, emptyKey, E_METHOD_PLAIN),
"unordered": setupSesh(true, emptyKey, E_METHOD_PLAIN),
}
for name, sesh := range seshes {
t.Run(name, func(t *testing.T) {
stream, _ := sesh.OpenStream()
stream.SetReadFromTimeout(100 * time.Millisecond)
done := make(chan struct{})
go func() {
stream.ReadFrom(connutil.Discard())
done <- struct{}{}
}()
select {
case <-done:
return
case <-time.After(500 * time.Millisecond):
t.Error("didn't timeout")
}
})
}
}

@ -184,13 +184,16 @@ func dispatchConnection(conn net.Conn, sta *State) {
}
log.Tracef("%v endpoint has been successfully connected", ci.ProxyMethod)
// if stream has nothing to send to proxy server for sta.Timeout period of time, stream will return error
newStream.(*mux.Stream).SetWriteToTimeout(sta.Timeout)
go func() {
if _, err := common.Copy(localConn, newStream, sta.Timeout); err != nil {
if _, err := common.Copy(localConn, newStream); err != nil {
log.Tracef("copying stream to proxy server: %v", err)
}
}()
go func() {
if _, err := common.Copy(newStream, localConn, 0); err != nil {
if _, err := common.Copy(newStream, localConn); err != nil {
log.Tracef("copying proxy server to stream: %v", err)
}
}()

Loading…
Cancel
Save