Prevent unnecessary allocation in stream closing

pull/158/head
Andy Wang 3 years ago
parent 3e737717bd
commit 5c5e9f8c14
No known key found for this signature in database
GPG Key ID: 181B49F9F38F3374

@ -174,26 +174,27 @@ func (sesh *Session) Accept() (net.Conn, error) {
}
func (sesh *Session) closeStream(s *Stream, active bool) error {
// must be holding s.wirtingM on entry
if atomic.SwapUint32(&s.closed, 1) == 1 {
return fmt.Errorf("closing stream %v: %w", s.id, errRepeatStreamClosing)
}
_ = s.getRecvBuf().Close() // recvBuf.Close should not return error
if active {
// must be holding s.wirtingM on entry
if len(s.obfsBuf) < 256+frameHeaderLength+sesh.Obfuscator.maxOverhead {
s.obfsBuf = make([]byte, 256+frameHeaderLength+sesh.Obfuscator.maxOverhead)
}
// Notify remote that this stream is closed
padding := genRandomPadding()
s.writingFrame.Closing = closingStream
s.writingFrame.Payload = padding
common.CryptoRandRead(s.obfsBuf[:1])
padLen := int(s.obfsBuf[0]) + 1
payload := s.obfsBuf[frameHeaderLength : padLen+frameHeaderLength]
common.CryptoRandRead(payload)
obfsBuf := make([]byte, len(padding)+frameHeaderLength+sesh.Obfuscator.maxOverhead)
s.writingFrame.Closing = closingStream
s.writingFrame.Payload = payload
i, err := sesh.Obfs(&s.writingFrame, obfsBuf, 0)
s.writingFrame.Seq++
if err != nil {
return err
}
_, err = sesh.sb.send(obfsBuf[:i], &s.assignedConnId)
err := s.obfuscateAndSend(frameHeaderLength)
if err != nil {
return err
}
@ -304,14 +305,6 @@ func (sesh *Session) passiveClose() error {
return nil
}
func genRandomPadding() []byte {
lenB := make([]byte, 1)
common.CryptoRandRead(lenB)
pad := make([]byte, int(lenB[0])+1)
common.CryptoRandRead(pad)
return pad
}
func (sesh *Session) Close() error {
log.Debugf("attempting to actively close session %v", sesh.id)
err := sesh.closeSession(false)
@ -319,19 +312,24 @@ func (sesh *Session) Close() error {
return err
}
// we send a notice frame telling remote to close the session
pad := genRandomPadding()
padBuf := make([]byte, 256+frameHeaderLength+sesh.Obfuscator.maxOverhead)
common.CryptoRandRead(padBuf[:1])
padLen := int(padBuf[0]) + 1
payload := padBuf[frameHeaderLength : padLen+frameHeaderLength]
common.CryptoRandRead(payload)
f := &Frame{
StreamID: 0xffffffff,
Seq: 0,
Closing: closingSession,
Payload: pad,
Payload: payload,
}
obfsBuf := make([]byte, len(pad)+frameHeaderLength+sesh.Obfuscator.maxOverhead)
i, err := sesh.Obfs(f, obfsBuf, 0)
i, err := sesh.Obfs(f, padBuf, frameHeaderLength)
if err != nil {
return err
}
_, err = sesh.sb.send(obfsBuf[:i], new(uint32))
_, err = sesh.sb.send(padBuf[:i], new(uint32))
if err != nil {
return err
}

Loading…
Cancel
Save