General improvements

pull/2/head
Qian Wang 6 years ago
parent 3f7eef98e3
commit 0db52a8a26

@ -124,7 +124,9 @@ func main() {
log.Printf("Starting standalone mode. Listening for ss on %v:%v\n", localHost, localPort)
}
sta := client.InitState(localHost, localPort, remoteHost, remotePort, time.Now)
opaque := time.Now().UnixNano()
// opaque is used to generate the padding of session ticket
sta := client.InitState(localHost, localPort, remoteHost, remotePort, time.Now, opaque)
err := sta.ParseConfig(pluginOpts)
if err != nil {
log.Fatal(err)

@ -52,6 +52,6 @@ func MakeSessionTicket(sta *State) []byte {
key, _ := ec.GenerateSharedSecret(ephKP.PrivateKey, sta.staticPub)
cipherSID := util.AESEncrypt(ticket[0:16], key, sta.SID)
copy(ticket[32:64], cipherSID)
io.ReadFull(rand.Reader, ticket[64:192])
copy(ticket[64:192], util.PsudoRandBytes(128, tthInterval+sta.opaque))
return ticket
}

@ -29,6 +29,7 @@ type State struct {
SS_REMOTE_PORT string
Now func() time.Time
opaque int64
SID []byte
staticPub crypto.PublicKey
keyPairsM sync.RWMutex
@ -40,13 +41,14 @@ type State struct {
NumConn int
}
func InitState(localHost, localPort, remoteHost, remotePort string, nowFunc func() time.Time) *State {
func InitState(localHost, localPort, remoteHost, remotePort string, nowFunc func() time.Time, opaque int64) *State {
ret := &State{
SS_LOCAL_HOST: localHost,
SS_LOCAL_PORT: localPort,
SS_REMOTE_HOST: remoteHost,
SS_REMOTE_PORT: remotePort,
Now: nowFunc,
opaque: opaque,
}
ret.keyPairs = make(map[int64]*keyPair)
return ret

@ -91,11 +91,11 @@ func (s *Stream) recvNewFrame() {
// wrapMode is true when the latest seq is wrapped but nextN is not
s.wrapMode = true
}
fs.trueSeq = uint64(2<<16*(s.rev+1)) + uint64(fs.seq) + 1
fs.trueSeq = uint64(1<<16*(s.rev+1)) + uint64(fs.seq) + 1
// +1 because wrapped 0 should have trueSeq of 256 instead of 255
// when this bit was run on 1, the trueSeq of 1 would become 256
} else {
fs.trueSeq = uint64(2<<16*s.rev) + uint64(fs.seq)
fs.trueSeq = uint64(1<<16*s.rev) + uint64(fs.seq)
// when this bit was run on 255, the trueSeq of 255 would be 255
}
heap.Push(&s.sh, fs)

@ -1,16 +1,18 @@
package multiplex
import (
"errors"
"log"
"net"
"sync"
"sync/atomic"
)
const (
errBrokenSession = "broken session"
errRepeatSessionClosing = "trying to close a closed session"
// Copied from smux
errBrokenPipe = "broken stream"
errRepeatStreamClosing = "trying to close a closed stream"
acceptBacklog = 1024
acceptBacklog = 1024
closeBacklog = 512
)
@ -25,8 +27,7 @@ type Session struct {
// This is supposed to read one TLS message, the same as GoQuiet's ReadTillDrain
obfsedReader func(net.Conn, []byte) (int, error)
nextStreamIDM sync.Mutex
nextStreamID uint32
nextStreamID uint32
streamsM sync.RWMutex
streams map[uint32]*Stream
@ -40,6 +41,10 @@ type Session struct {
// to be read by another stream to send the streamID to notify the remote
// that this stream is closed
closeQCh chan uint32
closingM sync.Mutex
die chan struct{}
closing bool
}
// 1 conn is needed to make a session
@ -63,13 +68,9 @@ func (sesh *Session) AddConnection(conn net.Conn) {
}
func (sesh *Session) OpenStream() (*Stream, error) {
sesh.nextStreamIDM.Lock()
id := sesh.nextStreamID
sesh.nextStreamID += 1
sesh.nextStreamIDM.Unlock()
id := atomic.AddUint32(&sesh.nextStreamID, 1)
id -= 1 // Because atomic.AddUint32 returns the value after incrementation
stream := makeStream(id, sesh)
sesh.streamsM.Lock()
sesh.streams[id] = stream
sesh.streamsM.Unlock()
@ -77,8 +78,12 @@ func (sesh *Session) OpenStream() (*Stream, error) {
}
func (sesh *Session) AcceptStream() (*Stream, error) {
stream := <-sesh.acceptCh
return stream, nil
select {
case <-sesh.die:
return nil, errors.New(errBrokenSession)
case stream := <-sesh.acceptCh:
return stream, nil
}
}
@ -89,15 +94,15 @@ func (sesh *Session) delStream(id uint32) {
}
func (sesh *Session) isStream(id uint32) bool {
sesh.streamsM.Lock()
sesh.streamsM.RLock()
_, ok := sesh.streams[id]
sesh.streamsM.Unlock()
sesh.streamsM.RUnlock()
return ok
}
func (sesh *Session) getStream(id uint32) *Stream {
sesh.streamsM.Lock()
defer sesh.streamsM.Unlock()
sesh.streamsM.RLock()
defer sesh.streamsM.RUnlock()
return sesh.streams[id]
}
@ -111,3 +116,28 @@ func (sesh *Session) addStream(id uint32) *Stream {
sesh.acceptCh <- stream
return stream
}
func (sesh *Session) Close() error {
// Because closing a closed channel causes panic
sesh.closingM.Lock()
defer sesh.closingM.Unlock()
if sesh.closing {
return errors.New(errRepeatSessionClosing)
}
sesh.closing = true
close(sesh.die)
sesh.streamsM.Lock()
for id, stream := range sesh.streams {
// If we call stream.Close() here, streamsM will result in a deadlock
// because stream.Close calls sesh.delStream, which locks the mutex.
// so we need to implement a method of stream that closes the stream without calling
// sesh.delStream
// This can also be seen in smux
go stream.closeNoDelMap()
delete(sesh.streams, id)
}
sesh.streamsM.Unlock()
return nil
}

@ -4,10 +4,12 @@ import (
"errors"
"log"
"sync"
"sync/atomic"
)
const (
readBuffer = 20480
errBrokenStream = "broken stream"
errRepeatStreamClosing = "trying to close a closed stream"
)
type Stream struct {
@ -15,10 +17,6 @@ type Stream struct {
session *Session
// Copied from smux
dieM sync.Mutex
die chan struct{}
// Explanations of the following 4 fields can be found in frameSorter.go
nextRecvSeq uint32
rev int
@ -28,10 +26,10 @@ type Stream struct {
newFrameCh chan *Frame
sortedBufCh chan []byte
nextSendSeqM sync.Mutex
nextSendSeq uint32
nextSendSeq uint32
closingM sync.Mutex
die chan struct{}
closing bool
}
@ -53,7 +51,7 @@ func (stream *Stream) Read(buf []byte) (n int, err error) {
select {
case <-stream.die:
log.Printf("Stream %v dying\n", stream.id)
return 0, errors.New(errBrokenPipe)
return 0, errors.New(errBrokenStream)
default:
return 0, nil
}
@ -61,7 +59,7 @@ func (stream *Stream) Read(buf []byte) (n int, err error) {
select {
case <-stream.die:
log.Printf("Stream %v dying\n", stream.id)
return 0, errors.New(errBrokenPipe)
return 0, errors.New(errBrokenStream)
case data := <-stream.sortedBufCh:
if len(buf) < len(data) {
log.Println(len(data))
@ -77,7 +75,7 @@ func (stream *Stream) Write(in []byte) (n int, err error) {
select {
case <-stream.die:
log.Printf("Stream %v dying\n", stream.id)
return 0, errors.New(errBrokenPipe)
return 0, errors.New(errBrokenStream)
default:
}
@ -95,9 +93,7 @@ func (stream *Stream) Write(in []byte) (n int, err error) {
Payload: in,
}
stream.nextSendSeqM.Lock()
stream.nextSendSeq += 1
stream.nextSendSeqM.Unlock()
atomic.AddUint32(&stream.nextSendSeq, 1)
tlsRecord := stream.session.obfs(f)
stream.session.sb.dispatCh <- tlsRecord
@ -109,7 +105,7 @@ func (stream *Stream) Write(in []byte) (n int, err error) {
func (stream *Stream) Close() error {
log.Printf("ID: %v closing\n", stream.id)
// Because closing a closed channel causes panic
// Lock here because closing a closed channel causes panic
stream.closingM.Lock()
defer stream.closingM.Unlock()
if stream.closing {
@ -121,3 +117,20 @@ func (stream *Stream) Close() error {
stream.session.closeQCh <- stream.id
return nil
}
// Same as Close() but no call to session.delStream.
// This is called in session.Close() to avoid mutex deadlock
func (stream *Stream) closeNoDelMap() error {
log.Printf("ID: %v closing\n", stream.id)
// Lock here because closing a closed channel causes panic
stream.closingM.Lock()
defer stream.closingM.Unlock()
if stream.closing {
return errors.New(errRepeatStreamClosing)
}
stream.closing = true
close(stream.die)
stream.session.closeQCh <- stream.id
return nil
}

@ -97,14 +97,12 @@ func (ce *connEnclave) send(data []byte) {
// Dispatcher sends data coming from a stream to a remote connection
// I used channels here because I didn't want to use mutex
func (sb *switchboard) dispatch() {
var nextCE int
for {
select {
// dispatCh receives data from stream.Write
case data := <-sb.dispatCh:
go sb.ces[nextCE%len(sb.ces)].send(data)
go sb.ces[0].send(data)
sb.ces[0].sendQueue += len(data)
nextCE += 1
case notified := <-sb.sentNotifyCh:
notified.ce.sendQueue -= notified.sent
sort.Sort(byQ(sb.ces))
@ -117,7 +115,6 @@ func (sb *switchboard) dispatch() {
}
sb.ces = append(sb.ces, newCe)
go sb.deplex(newCe)
//sort.Sort(byQ(sb.ces))
case closing := <-sb.closingCECh:
log.Println("Closing conn")
for i, ce := range sb.ces {
@ -126,15 +123,15 @@ func (sb *switchboard) dispatch() {
break
}
}
// TODO: when all connections closed
}
}
}
// deplex function costantly reads from a TLS connection
// deplex function costantly reads from a TCP connection
// it is responsible to act in response to the deobfsed header
// i.e. should a new stream be added? which existing stream should be closed?
func (sb *switchboard) deplex(ce *connEnclave) {
var highestStream uint32
buf := make([]byte, 20480)
for {
i, err := sb.session.obfsedReader(ce.remoteConn, buf)
@ -149,12 +146,15 @@ func (sb *switchboard) deplex(ce *connEnclave) {
log.Printf("HeaderClosing: %v\n", frame.ClosingStreamID)
closing.Close()
}
sb.session.nextStreamIDM.Lock()
nextID := sb.session.nextStreamID
sb.session.nextStreamIDM.Unlock()
var stream *Stream
if stream = sb.session.getStream(frame.StreamID); nextID <= frame.StreamID && stream == nil {
// If we want to open a new stream, we need to make sure that the newStreamID is indeed new
// i.e. it is not a stream that existed before but has been closed
// we don't allow streamID reuse.
// So here we do a check that the new stream has a higher ID than the highest ID we have got
if stream = sb.session.getStream(frame.StreamID); highestStream < frame.StreamID && stream == nil {
stream = sb.session.addStream(frame.StreamID)
highestStream = frame.StreamID
}
if stream != nil {
stream.newFrameCh <- frame

@ -7,6 +7,8 @@ import (
mux "github.com/cbeuw/Cloak/internal/multiplex"
)
// For each frame, the three parts of the header is xored with three keys.
// The keys are generated from the SID and the payload of the frame.
func genXorKeys(SID []byte, data []byte) (i uint32, ii uint32, iii uint32) {
h := xxhash.New32()
ret := make([]uint32, 3)

@ -46,10 +46,7 @@ func BtoInt(b []byte) int {
func PsudoRandBytes(length int, seed int64) []byte {
prand.Seed(seed)
ret := make([]byte, length)
for i := 0; i < length; i++ {
randByte := byte(prand.Intn(256))
ret[i] = randByte
}
prand.Read(ret)
return ret
}

Loading…
Cancel
Save