Use sync.Pool for obfuscation buffer

pull/158/head
Andy Wang 3 years ago
parent 5933ad8781
commit 881f6e6f9d
No known key found for this signature in database
GPG Key ID: 181B49F9F38F3374

@ -70,6 +70,8 @@ type Session struct {
// a pool of heap allocated frame objects so we don't have to allocate a new one each time we receive a frame
recvFramePool sync.Pool
streamObfsBufPool sync.Pool
// Switchboard manages all connections to remote
sb *switchboard
@ -117,6 +119,11 @@ func MakeSession(id uint32, config SessionConfig) *Session {
// todo: validation. this must be smaller than StreamSendBufferSize
sesh.maxStreamUnitWrite = sesh.MsgOnWireSizeLimit - frameHeaderLength - sesh.Obfuscator.maxOverhead
sesh.streamObfsBufPool = sync.Pool{New: func() interface{} {
b := make([]byte, sesh.StreamSendBufferSize)
return &b
}}
sesh.sb = makeSwitchboard(sesh)
time.AfterFunc(sesh.InactivityTimeout, sesh.checkTimeout)
return sesh
@ -180,25 +187,20 @@ func (sesh *Session) closeStream(s *Stream, active bool) error {
_ = s.getRecvBuf().Close() // recvBuf.Close should not return error
if active {
tmpBuf := make([]byte, 256+frameHeaderLength+sesh.Obfuscator.maxOverhead)
tmpBuf := sesh.streamObfsBufPool.Get().(*[]byte)
// Notify remote that this stream is closed
common.CryptoRandRead(tmpBuf[:1])
padLen := int(tmpBuf[0]) + 1
payload := tmpBuf[frameHeaderLength : padLen+frameHeaderLength]
common.CryptoRandRead((*tmpBuf)[:1])
padLen := int((*tmpBuf)[0]) + 1
payload := (*tmpBuf)[frameHeaderLength : padLen+frameHeaderLength]
common.CryptoRandRead(payload)
// must be holding s.wirtingM on entry
s.writingFrame.Closing = closingStream
s.writingFrame.Payload = payload
cipherTextLen, err := sesh.Obfs(&s.writingFrame, tmpBuf, frameHeaderLength)
s.writingFrame.Seq++
if err != nil {
return err
}
_, err = sesh.sb.send(tmpBuf[:cipherTextLen], &s.assignedConnId)
err := s.obfuscateAndSend(*tmpBuf, frameHeaderLength)
sesh.streamObfsBufPool.Put(tmpBuf)
if err != nil {
return err
}

@ -34,11 +34,6 @@ type Stream struct {
// atomic
closed uint32
// obfuscation happens in obfsBuf. This buffer is lazily allocated as obfsBuf is only used when data is sent from
// the stream (through Write or ReadFrom). Some streams never send data so eager allocation will waste
// memory
obfsBuf []byte
// When we want order guarantee (i.e. session.Unordered is false),
// we assign each stream a fixed underlying connection.
// If the underlying connections the session uses provide ordering guarantee (most likely TCP),
@ -117,13 +112,14 @@ func (s *Stream) WriteTo(w io.Writer) (int64, error) {
return n, nil
}
func (s *Stream) obfuscateAndSend(payloadOffsetInObfsBuf int) error {
cipherTextLen, err := s.session.Obfs(&s.writingFrame, s.obfsBuf, payloadOffsetInObfsBuf)
func (s *Stream) obfuscateAndSend(buf []byte, payloadOffsetInBuf int) error {
cipherTextLen, err := s.session.Obfs(&s.writingFrame, buf, payloadOffsetInBuf)
s.writingFrame.Seq++
if err != nil {
return err
}
_, err = s.session.sb.send(s.obfsBuf[:cipherTextLen], &s.assignedConnId)
_, err = s.session.sb.send(buf[:cipherTextLen], &s.assignedConnId)
if err != nil {
if err == errBrokenSwitchboard {
s.session.SetTerminalMsg(err.Error())
@ -142,9 +138,6 @@ func (s *Stream) Write(in []byte) (n int, err error) {
return 0, ErrBrokenStream
}
if s.obfsBuf == nil {
s.obfsBuf = make([]byte, s.session.StreamSendBufferSize)
}
for n < len(in) {
var framePayload []byte
if len(in)-n <= s.session.maxStreamUnitWrite {
@ -160,8 +153,9 @@ func (s *Stream) Write(in []byte) (n int, err error) {
framePayload = in[n : s.session.maxStreamUnitWrite+n]
}
s.writingFrame.Payload = framePayload
err = s.obfuscateAndSend(0)
s.writingFrame.Seq++
buf := s.session.streamObfsBufPool.Get().(*[]byte)
err = s.obfuscateAndSend(*buf, 0)
s.session.streamObfsBufPool.Put(buf)
if err != nil {
return
}
@ -173,9 +167,6 @@ func (s *Stream) Write(in []byte) (n int, err error) {
// ReadFrom continuously read data from r and send it off, until either r returns error or nothing has been read
// for readFromTimeout amount of time
func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) {
if s.obfsBuf == nil {
s.obfsBuf = make([]byte, s.session.StreamSendBufferSize)
}
for {
if s.readFromTimeout != 0 {
if rder, ok := r.(net.Conn); !ok {
@ -184,7 +175,8 @@ func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) {
rder.SetReadDeadline(time.Now().Add(s.readFromTimeout))
}
}
read, er := r.Read(s.obfsBuf[frameHeaderLength : frameHeaderLength+s.session.maxStreamUnitWrite])
buf := s.session.streamObfsBufPool.Get().(*[]byte)
read, er := r.Read((*buf)[frameHeaderLength : frameHeaderLength+s.session.maxStreamUnitWrite])
if er != nil {
return n, er
}
@ -196,10 +188,10 @@ func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) {
}
s.writingM.Lock()
s.writingFrame.Payload = s.obfsBuf[frameHeaderLength : frameHeaderLength+read]
err = s.obfuscateAndSend(frameHeaderLength)
s.writingFrame.Seq++
s.writingFrame.Payload = (*buf)[frameHeaderLength : frameHeaderLength+read]
err = s.obfuscateAndSend(*buf, frameHeaderLength)
s.writingM.Unlock()
s.session.streamObfsBufPool.Put(buf)
if err != nil {
return

Loading…
Cancel
Save