Compare commits

...

2 Commits

Author SHA1 Message Date
Andy Wang dc2e83f75f
Move to common.RandInt 1 month ago
Andy Wang 5988b4337d
Stop using fixedConnMapping 1 month ago

@ -265,6 +265,7 @@ func (sesh *Session) recvDataFromRemote(data []byte) error {
} }
func (sesh *Session) SetTerminalMsg(msg string) { func (sesh *Session) SetTerminalMsg(msg string) {
log.Debug("terminal message set to " + msg)
sesh.terminalMsgSetter.Do(func() { sesh.terminalMsgSetter.Do(func() {
sesh.terminalMsg = msg sesh.terminalMsg = msg
}) })

@ -2,13 +2,12 @@ package multiplex
import ( import (
"errors" "errors"
"math/rand" "github.com/cbeuw/Cloak/internal/common"
log "github.com/sirupsen/logrus"
"math/rand/v2"
"net" "net"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time"
log "github.com/sirupsen/logrus"
) )
type switchboardStrategy int type switchboardStrategy int
@ -39,19 +38,14 @@ type switchboard struct {
} }
func makeSwitchboard(sesh *Session) *switchboard { func makeSwitchboard(sesh *Session) *switchboard {
var strategy switchboardStrategy
if sesh.Unordered {
log.Debug("Connection is unordered")
strategy = uniformSpread
} else {
strategy = fixedConnMapping
}
sb := &switchboard{ sb := &switchboard{
session: sesh, session: sesh,
strategy: strategy, strategy: uniformSpread,
valve: sesh.Valve, valve: sesh.Valve,
randPool: sync.Pool{New: func() interface{} { randPool: sync.Pool{New: func() interface{} {
return rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) var state [32]byte
common.CryptoRandRead(state[:])
return rand.New(rand.NewChaCha8(state))
}}, }},
} }
return sb return sb
@ -60,8 +54,8 @@ func makeSwitchboard(sesh *Session) *switchboard {
var errBrokenSwitchboard = errors.New("the switchboard is broken") var errBrokenSwitchboard = errors.New("the switchboard is broken")
func (sb *switchboard) addConn(conn net.Conn) { func (sb *switchboard) addConn(conn net.Conn) {
atomic.AddUint32(&sb.connsCount, 1) connId := atomic.AddUint32(&sb.connsCount, 1) - 1
sb.conns.Store(conn, conn) sb.conns.Store(connId, conn)
go sb.deplex(conn) go sb.deplex(conn)
} }
@ -86,6 +80,9 @@ func (sb *switchboard) send(data []byte, assignedConn *net.Conn) (n int, err err
return n, err return n, err
} }
case fixedConnMapping: case fixedConnMapping:
// FIXME: this strategy has a tendency to cause a TLS conn socket buffer to fill up,
// which is a problem when multiple streams are mapped to the same conn, resulting
// in all such streams being blocked.
conn = *assignedConn conn = *assignedConn
if conn == nil { if conn == nil {
conn, err = sb.pickRandConn() conn, err = sb.pickRandConn()
@ -110,7 +107,7 @@ func (sb *switchboard) send(data []byte, assignedConn *net.Conn) (n int, err err
return n, nil return n, nil
} }
// returns a random connId // returns a random conn. This function can be called concurrently.
func (sb *switchboard) pickRandConn() (net.Conn, error) { func (sb *switchboard) pickRandConn() (net.Conn, error) {
if atomic.LoadUint32(&sb.broken) == 1 { if atomic.LoadUint32(&sb.broken) == 1 {
return nil, errBrokenSwitchboard return nil, errBrokenSwitchboard
@ -122,22 +119,15 @@ func (sb *switchboard) pickRandConn() (net.Conn, error) {
} }
randReader := sb.randPool.Get().(*rand.Rand) randReader := sb.randPool.Get().(*rand.Rand)
connId := randReader.Uint32N(connsCount)
r := randReader.Intn(int(connsCount))
sb.randPool.Put(randReader) sb.randPool.Put(randReader)
var c int ret, ok := sb.conns.Load(connId)
var ret net.Conn if !ok {
sb.conns.Range(func(_, conn interface{}) bool { log.Errorf("failed to get conn %d", connId)
if r == c { return nil, errBrokenSwitchboard
ret = conn.(net.Conn) }
return false return ret.(net.Conn), nil
}
c++
return true
})
return ret, nil
} }
// actively triggered by session.Close() // actively triggered by session.Close()
@ -145,10 +135,10 @@ func (sb *switchboard) closeAll() {
if !atomic.CompareAndSwapUint32(&sb.broken, 0, 1) { if !atomic.CompareAndSwapUint32(&sb.broken, 0, 1) {
return return
} }
atomic.StoreUint32(&sb.connsCount, 0)
sb.conns.Range(func(_, conn interface{}) bool { sb.conns.Range(func(_, conn interface{}) bool {
conn.(net.Conn).Close() conn.(net.Conn).Close()
sb.conns.Delete(conn) sb.conns.Delete(conn)
atomic.AddUint32(&sb.connsCount, ^uint32(0))
return true return true
}) })
} }

@ -5,7 +5,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"math/rand"
"net" "net"
"github.com/cbeuw/Cloak/internal/common" "github.com/cbeuw/Cloak/internal/common"
@ -46,8 +45,7 @@ func (TLS) makeResponder(clientHelloSessionId []byte, sharedSecret [32]byte) Res
// the cert length needs to be the same for all handshakes belonging to the same session // the cert length needs to be the same for all handshakes belonging to the same session
// we can use sessionKey as a seed here to ensure consistency // we can use sessionKey as a seed here to ensure consistency
possibleCertLengths := []int{42, 27, 68, 59, 36, 44, 46} possibleCertLengths := []int{42, 27, 68, 59, 36, 44, 46}
rand.Seed(int64(sessionKey[0])) cert := make([]byte, possibleCertLengths[common.RandInt(len(possibleCertLengths))])
cert := make([]byte, possibleCertLengths[rand.Intn(len(possibleCertLengths))])
common.RandRead(randSource, cert) common.RandRead(randSource, cert)
var nonce [12]byte var nonce [12]byte

Loading…
Cancel
Save