QOS and user managing, bug fixes

pull/2/head
Qian Wang 6 years ago
parent 6a6b293164
commit 3534d05055

@ -7,6 +7,7 @@ import (
"fmt"
"io"
"log"
"math/rand"
"net"
"os"
"time"
@ -60,7 +61,7 @@ func makeRemoteConn(sta *client.State) (net.Conn, error) {
// Three discarded messages: ServerHello, ChangeCipherSpec and Finished
discardBuf := make([]byte, 1024)
for c := 0; c < 3; c++ {
_, err = util.ReadTillDrain(remoteConn, discardBuf)
_, err = util.ReadTLS(remoteConn, discardBuf)
if err != nil {
log.Printf("Reading discarded message %v: %v\n", c, err)
return nil, err
@ -122,9 +123,13 @@ func main() {
log.Printf("Starting standalone mode. Listening for ss on %v:%v\n", localHost, localPort)
}
opaque := time.Now().UnixNano()
// sessionID is usergenerated. There shouldn't be a security concern because the scope of
// sessionID is limited to its UID.
rand.Seed(time.Now().UnixNano())
sessionID := rand.Uint32()
// opaque is used to generate the padding of session ticket
sta := client.InitState(localHost, localPort, remoteHost, remotePort, time.Now, opaque)
sta := client.InitState(localHost, localPort, remoteHost, remotePort, time.Now, sessionID)
err := sta.ParseConfig(pluginOpts)
if err != nil {
log.Fatal(err)
@ -140,19 +145,19 @@ func main() {
log.Fatal("TicketTimeHint cannot be empty or 0")
}
obfs := util.MakeObfs(sta.SID)
deobfs := util.MakeDeobfs(sta.SID)
sesh := mux.MakeSession(0, 1e9, 1e9, obfs, deobfs, util.ReadTillDrain)
valve := mux.MakeValve(1e9, 1e9, 1e9, 1e9)
obfs := util.MakeObfs(sta.UID)
deobfs := util.MakeDeobfs(sta.UID)
sesh := mux.MakeSession(0, valve, obfs, deobfs, util.ReadTLS)
// TODO: use sync group
for i := 0; i < sta.NumConn; i++ {
go func() {
conn, err := makeRemoteConn(sta)
if err != nil {
log.Printf("Failed to establish new connections to remote: %v\n", err)
return
}
sesh.AddConnection(conn)
}()
conn, err := makeRemoteConn(sta)
if err != nil {
log.Printf("Failed to establish new connections to remote: %v\n", err)
return
}
sesh.AddConnection(conn)
}
listener, err := net.Listen("tcp", sta.SS_LOCAL_HOST+":"+sta.SS_LOCAL_PORT)
@ -175,8 +180,12 @@ func main() {
stream, err := sesh.OpenStream()
if err != nil {
ssConn.Close()
return
}
_, err = stream.Write(data[:i])
if err != nil {
log.Println(err)
}
stream.Write(data[:i])
go pipe(ssConn, stream)
pipe(stream, ssConn)
}()

@ -1,15 +1,16 @@
package main
import (
"encoding/hex"
"flag"
"fmt"
"io"
"log"
"net"
//"net/http"
//_ "net/http/pprof"
"net/http"
_ "net/http/pprof"
"os"
//"runtime"
"runtime"
"strings"
"time"
@ -70,14 +71,21 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
return
}
isSS, SID := server.TouchStone(ch, sta)
isSS, UID, sessionID := server.TouchStone(ch, sta)
if !isSS {
log.Printf("+1 non SS TLS traffic from %v\n", conn.RemoteAddr())
goWeb(data)
return
}
// TODO: verify SID
var arrUID [32]byte
copy(arrUID[:], UID)
user, err := sta.Userpanel.GetAndActivateUser(arrUID)
log.Printf("UID: %x\n", UID)
if err != nil {
log.Printf("+1 unauthorised user from %v, uid: %x\n", conn.RemoteAddr(), UID)
goWeb(data)
}
reply := server.ComposeReply(ch)
_, err = conn.Write(reply)
@ -90,7 +98,7 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
// Two discarded messages: ChangeCipherSpec and Finished
discardBuf := make([]byte, 1024)
for c := 0; c < 2; c++ {
_, err = util.ReadTillDrain(conn, discardBuf)
_, err = util.ReadTLS(conn, discardBuf)
if err != nil {
log.Printf("Reading discarded message %v: %v\n", c, err)
go conn.Close()
@ -98,45 +106,36 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
}
}
go func() {
var arrSID [32]byte
copy(arrSID[:], SID)
var sesh *mux.Session
if sesh = sta.GetSession(arrSID); sesh == nil {
sesh = mux.MakeSession(0, 1e9, 1e9, util.MakeObfs(SID), util.MakeDeobfs(SID), util.ReadTillDrain)
sta.PutSession(arrSID, sesh)
}
sesh.AddConnection(conn)
go func() {
for {
newStream, err := sesh.AcceptStream()
if err != nil {
log.Printf("Failed to get new stream: %v", err)
if err == mux.ErrBrokenSession {
sta.DelSession(arrSID)
return
} else {
continue
}
}
ssConn, err := net.Dial("tcp", sta.SS_LOCAL_HOST+":"+sta.SS_LOCAL_PORT)
if err != nil {
log.Printf("Failed to connect to ssserver: %v", err)
continue
}
go pipe(ssConn, newStream)
go pipe(newStream, ssConn)
// FIXME: the following code should not be executed for every single remote connection
sesh := user.GetOrCreateSession(sessionID, util.MakeObfs(UID), util.MakeDeobfs(UID), util.ReadTLS)
sesh.AddConnection(conn)
for {
newStream, err := sesh.AcceptStream()
if err != nil {
log.Printf("Failed to get new stream: %v", err)
if err == mux.ErrBrokenSession {
user.DelSession(sessionID)
return
} else {
continue
}
}()
}()
}
ssConn, err := net.Dial("tcp", sta.SS_LOCAL_HOST+":"+sta.SS_LOCAL_PORT)
if err != nil {
log.Printf("Failed to connect to ssserver: %v", err)
continue
}
go pipe(ssConn, newStream)
go pipe(newStream, ssConn)
}
}
func main() {
//runtime.SetBlockProfileRate(5)
//go func() {
// log.Println(http.ListenAndServe("0.0.0.0:8001", nil))
//}()
runtime.SetBlockProfileRate(5)
go func() {
log.Println(http.ListenAndServe("0.0.0.0:8001", nil))
}()
// Should be 127.0.0.1 to listen to ss-server on this machine
var localHost string
// server_port in ss config, same as remotePort in plugin mode
@ -181,7 +180,13 @@ func main() {
localPort = strings.Split(*localAddr, ":")[1]
log.Printf("Starting standalone mode, listening on %v:%v to ss at %v:%v\n", remoteHost, remotePort, localHost, localPort)
}
sta := server.InitState(localHost, localPort, remoteHost, remotePort, time.Now)
sta, _ := server.InitState(localHost, localPort, remoteHost, remotePort, time.Now, "userinfo.db")
//debug
var arrUID [32]byte
UID, _ := hex.DecodeString("50d858e0985ecc7f60418aaf0cc5ab587f42c2570a884095a9e8ccacd0f6545c")
copy(arrUID[:], UID)
sta.Userpanel.AddNewUser(arrUID, 10, 1e12, 1e12, 1e12, 1e12)
err := sta.ParseConfig(pluginOpts)
if err != nil {
log.Fatalf("Configuration file error: %v", err)

@ -21,7 +21,7 @@ func MakeRandomField(sta *State) []byte {
rdm := make([]byte, 16)
io.ReadFull(rand.Reader, rdm)
preHash := make([]byte, 56)
copy(preHash[0:32], sta.SID)
copy(preHash[0:32], sta.UID)
copy(preHash[32:40], t)
copy(preHash[40:56], rdm)
h := sha256.New()
@ -33,9 +33,9 @@ func MakeRandomField(sta *State) []byte {
}
func MakeSessionTicket(sta *State) []byte {
// sessionTicket: [marshalled ephemeral pub key 32 bytes][encrypted SID 32 bytes][padding 128 bytes]
// sessionTicket: [marshalled ephemeral pub key 32 bytes][encrypted UID+sessionID 36 bytes][padding 124 bytes]
// The first 16 bytes of the marshalled ephemeral public key is used as the IV
// for encrypting the SID
// for encrypting the UID
tthInterval := sta.Now().Unix() / int64(sta.TicketTimeHint)
ec := ecdh.NewCurve25519ECDH()
ephKP := sta.getKeyPair(tthInterval)
@ -50,8 +50,21 @@ func MakeSessionTicket(sta *State) []byte {
ticket := make([]byte, 192)
copy(ticket[0:32], ec.Marshal(ephKP.PublicKey))
key, _ := ec.GenerateSharedSecret(ephKP.PrivateKey, sta.staticPub)
cipherSID := util.AESEncrypt(ticket[0:16], key, sta.SID)
copy(ticket[32:64], cipherSID)
copy(ticket[64:192], util.PsudoRandBytes(128, tthInterval+sta.opaque))
plainUIDsID := make([]byte, 36)
copy(plainUIDsID, sta.UID)
binary.BigEndian.PutUint32(plainUIDsID[32:36], sta.sessionID)
cipherUIDsID := util.AESEncrypt(ticket[0:16], key, plainUIDsID)
copy(ticket[32:68], cipherUIDsID)
// The purpose of adding sessionID is that, the generated padding of sessionTicket needs to be unpredictable.
// As shown in auth.go, the padding is generated by a psudo random generator. The seed
// needs to be the same for each TicketTimeHint interval. However the value of epoch/TicketTimeHint
// is public knowledge, so is the psudo random algorithm used by math/rand. Therefore not only
// can the firewall tell that the padding is generated in this specific way, this padding is identical
// for all ckclients in the same TicketTimeHint interval. This will expose us.
//
// With the sessionID value generated at startup of ckclient and used as a part of the seed, the
// sessionTicket is still identical for each TicketTimeHint interval, but others won't be able to know
// how it was generated. It will also be different for each client.
copy(ticket[68:192], util.PsudoRandBytes(124, tthInterval+int64(sta.sessionID)))
return ticket
}

@ -29,8 +29,8 @@ type State struct {
SS_REMOTE_PORT string
Now func() time.Time
opaque int64
SID []byte
sessionID uint32
UID []byte
staticPub crypto.PublicKey
keyPairsM sync.RWMutex
keyPairs map[int64]*keyPair
@ -41,14 +41,14 @@ type State struct {
NumConn int
}
func InitState(localHost, localPort, remoteHost, remotePort string, nowFunc func() time.Time, opaque int64) *State {
func InitState(localHost, localPort, remoteHost, remotePort string, nowFunc func() time.Time, sessionID uint32) *State {
ret := &State{
SS_LOCAL_HOST: localHost,
SS_LOCAL_PORT: localPort,
SS_REMOTE_HOST: remoteHost,
SS_REMOTE_PORT: remotePort,
Now: nowFunc,
opaque: opaque,
sessionID: sessionID,
}
ret.keyPairs = make(map[int64]*keyPair)
return ret
@ -56,6 +56,7 @@ func InitState(localHost, localPort, remoteHost, remotePort string, nowFunc func
// semi-colon separated value. This is for Android plugin options
func ssvToJson(ssv string) (ret []byte) {
// TODO: base64 encoded data has =. How to escape?
unescape := func(s string) string {
r := strings.Replace(s, "\\\\", "\\", -1)
r = strings.Replace(r, "\\=", "=", -1)
@ -104,16 +105,16 @@ func (sta *State) ParseConfig(conf string) (err error) {
sta.TicketTimeHint = preParse.TicketTimeHint
sta.MaskBrowser = preParse.MaskBrowser
sta.NumConn = preParse.NumConn
sid, pub, err := parseKey(preParse.Key)
uid, pub, err := parseKey(preParse.Key)
if err != nil {
return errors.New("Failed to parse Key: " + err.Error())
}
sta.SID = sid
sta.UID = uid
sta.staticPub = pub
return nil
}
// Structure: [SID 32 bytes][marshalled public key 32 bytes]
// Structure: [UID 32 bytes][marshalled public key 32 bytes]
func parseKey(b64 string) ([]byte, crypto.PublicKey, error) {
b, err := base64.StdEncoding.DecodeString(b64)
if err != nil {

@ -15,8 +15,7 @@ import (
// make sure packets arrive in order.
//
// Cloak packets will have a 32-bit sequence number on them, so we know in which order
// they should be sent to shadowsocks. In the case that the packets arrive out-of-order,
// the code in this file provides buffering and sorting.
// they should be sent to shadowsocks. The code in this file provides buffering and sorting.
//
// Similar to TCP, the next seq number after 2^32-1 is 0. This is called wrap around.
//
@ -54,6 +53,12 @@ func (sh *sorterHeap) Pop() interface{} {
return x
}
func (s *Stream) writeNewFrame(f *Frame) {
s.newFrameCh <- f
}
// recvNewFrame is a forever running loop which receives frames unordered,
// cache and order them and send them into sortedBufCh
func (s *Stream) recvNewFrame() {
for {
var f *Frame
@ -69,7 +74,7 @@ func (s *Stream) recvNewFrame() {
if len(s.sh) == 0 && f.Seq == s.nextRecvSeq {
if f.Closing == 1 {
s.passiveClose()
s.sortedBufCh <- []byte{}
return
}
@ -115,7 +120,7 @@ func (s *Stream) recvNewFrame() {
frame := heap.Pop(&s.sh).(*frameNode).frame
if frame.Closing == 1 {
s.passiveClose()
s.sortedBufCh <- []byte{}
return
}
payload := frame.Payload

@ -0,0 +1,58 @@
package multiplex
import (
"sync/atomic"
"github.com/juju/ratelimit"
)
// Valve needs to be universal, across all sessions that belong to a user
// gabe please don't sue
type Valve struct {
// traffic directions from the server's perspective are refered
// exclusively as rx and tx.
// rx is from client to server, tx is from server to client
// DO NOT use terms up or down as this is used in usermanager
// for bandwidth limiting
rxtb atomic.Value // *ratelimit.Bucket
txtb atomic.Value // *ratelimit.Bucket
rxCredit int64
txCredit int64
}
func MakeValve(rxRate, txRate, rxCredit, txCredit int64) *Valve {
v := &Valve{
rxCredit: rxCredit,
txCredit: txCredit,
}
v.SetRxRate(rxRate)
v.SetTxRate(txRate)
return v
}
func (v *Valve) SetRxRate(rate int64) {
v.rxtb.Store(ratelimit.NewBucketWithRate(float64(rate), rate))
}
func (v *Valve) SetTxRate(rate int64) {
v.txtb.Store(ratelimit.NewBucketWithRate(float64(rate), rate))
}
func (v *Valve) rxWait(n int) {
v.rxtb.Load().(*ratelimit.Bucket).Wait(int64(n))
}
func (v *Valve) txWait(n int) {
v.txtb.Load().(*ratelimit.Bucket).Wait(int64(n))
}
// n can be negative
func (v *Valve) AddRxCredit(n int64) int64 {
return atomic.AddInt64(&v.rxCredit, n)
}
// n can be negative
func (v *Valve) AddTxCredit(n int64) int64 {
return atomic.AddInt64(&v.txCredit, n)
}

@ -2,6 +2,7 @@ package multiplex
import (
"errors"
"log"
"net"
"sync"
"sync/atomic"
@ -16,14 +17,14 @@ var ErrBrokenSession = errors.New("broken session")
var errRepeatSessionClosing = errors.New("trying to close a closed session")
type Session struct {
id int
id uint32 // This field isn't acutally used
// Used in Stream.Write. Add multiplexing headers, encrypt and add TLS header
obfs func(*Frame) []byte
// Remove TLS header, decrypt and unmarshall multiplexing headers
deobfs func([]byte) *Frame
// This is supposed to read one TLS message, the same as GoQuiet's ReadTillDrain
obfsedReader func(net.Conn, []byte) (int, error)
obfsedRead func(net.Conn, []byte) (int, error)
// atomic
nextStreamID uint32
@ -37,24 +38,25 @@ type Session struct {
// For accepting new streams
acceptCh chan *Stream
// TODO: use sync.Once for this
closingM sync.Mutex
die chan struct{}
closing bool
}
// 1 conn is needed to make a session
func MakeSession(id int, uprate, downrate float64, obfs func(*Frame) []byte, deobfs func([]byte) *Frame, obfsedReader func(net.Conn, []byte) (int, error)) *Session {
func MakeSession(id uint32, valve *Valve, obfs func(*Frame) []byte, deobfs func([]byte) *Frame, obfsedRead func(net.Conn, []byte) (int, error)) *Session {
sesh := &Session{
id: id,
obfs: obfs,
deobfs: deobfs,
obfsedReader: obfsedReader,
obfsedRead: obfsedRead,
nextStreamID: 1,
streams: make(map[uint32]*Stream),
acceptCh: make(chan *Stream, acceptBacklog),
die: make(chan struct{}),
}
sesh.sb = makeSwitchboard(sesh, uprate, downrate)
sesh.sb = makeSwitchboard(sesh, valve)
return sesh
}
@ -63,12 +65,18 @@ func (sesh *Session) AddConnection(conn net.Conn) {
}
func (sesh *Session) OpenStream() (*Stream, error) {
id := atomic.AddUint32(&sesh.nextStreamID, 1)
id -= 1 // Because atomic.AddUint32 returns the value after incrementation
select {
case <-sesh.die:
return nil, ErrBrokenSession
default:
}
id := atomic.AddUint32(&sesh.nextStreamID, 1) - 1
// Because atomic.AddUint32 returns the value after incrementation
stream := makeStream(id, sesh)
sesh.streamsM.Lock()
sesh.streams[id] = stream
sesh.streamsM.Unlock()
log.Printf("Opening stream %v\n", id)
return stream, nil
}
@ -108,6 +116,7 @@ func (sesh *Session) addStream(id uint32) *Stream {
sesh.streams[id] = stream
sesh.streamsM.Unlock()
sesh.acceptCh <- stream
log.Printf("Adding stream %v\n", id)
return stream
}

@ -31,7 +31,7 @@ type Stream struct {
// atomic
nextSendSeq uint32
closingM sync.Mutex
closingM sync.RWMutex
// close(die) is used to notify different goroutines that this stream is closing
die chan struct{}
// to prevent closing a closed channel
@ -45,7 +45,7 @@ func makeStream(id uint32, sesh *Session) *Stream {
die: make(chan struct{}),
sh: []*frameNode{},
newFrameCh: make(chan *Frame, 1024),
sortedBufCh: make(chan []byte, 4096),
sortedBufCh: make(chan []byte, 1024),
}
go stream.recvNewFrame()
return stream
@ -64,6 +64,10 @@ func (stream *Stream) Read(buf []byte) (n int, err error) {
case <-stream.die:
return 0, errBrokenStream
case data := <-stream.sortedBufCh:
if len(data) == 0 {
stream.passiveClose()
return 0, errBrokenStream
}
if len(buf) < len(data) {
log.Println(len(data))
return 0, errors.New("buf too small")
@ -75,6 +79,13 @@ func (stream *Stream) Read(buf []byte) (n int, err error) {
}
func (stream *Stream) Write(in []byte) (n int, err error) {
// RWMutex used here isn't really for RW.
// we use it to exploit the fact that RLock doesn't create contention.
// The use of RWMutex is so that the stream will not actively close
// in the middle of the execution of Write. This may cause the closing frame
// to be sent before the data frame and cause loss of packet.
stream.closingM.RLock()
defer stream.closingM.RUnlock()
select {
case <-stream.die:
return 0, errBrokenStream
@ -83,13 +94,11 @@ func (stream *Stream) Write(in []byte) (n int, err error) {
f := &Frame{
StreamID: stream.id,
Seq: atomic.LoadUint32(&stream.nextSendSeq),
Seq: atomic.AddUint32(&stream.nextSendSeq, 1) - 1,
Closing: 0,
Payload: in,
}
atomic.AddUint32(&stream.nextSendSeq, 1)
tlsRecord := stream.session.obfs(f)
n, err = stream.session.sb.send(tlsRecord)
@ -97,9 +106,7 @@ func (stream *Stream) Write(in []byte) (n int, err error) {
}
// only close locally. Used when the stream close is notified by the remote
func (stream *Stream) passiveClose() error {
func (stream *Stream) shutdown() error {
// Lock here because closing a closed channel causes panic
stream.closingM.Lock()
defer stream.closingM.Unlock()
@ -108,29 +115,36 @@ func (stream *Stream) passiveClose() error {
}
stream.closing = true
close(stream.die)
return nil
}
// only close locally. Used when the stream close is notified by the remote
func (stream *Stream) passiveClose() error {
err := stream.shutdown()
if err != nil {
return err
}
stream.session.delStream(stream.id)
log.Printf("%v passive closing\n", stream.id)
return nil
}
// active close. Close locally and tell the remote that this stream is being closed
func (stream *Stream) Close() error {
// Lock here because closing a closed channel causes panic
stream.closingM.Lock()
defer stream.closingM.Unlock()
if stream.closing {
return errRepeatStreamClosing
err := stream.shutdown()
if err != nil {
return err
}
stream.closing = true
close(stream.die)
// Notify remote that this stream is closed
prand.Seed(int64(stream.id))
padLen := int(math.Floor(prand.Float64()*200 + 300))
pad := make([]byte, padLen)
prand.Read(pad)
f := &Frame{
StreamID: stream.id,
Seq: atomic.LoadUint32(&stream.nextSendSeq),
Seq: atomic.AddUint32(&stream.nextSendSeq, 1) - 1,
Closing: 1,
Payload: pad,
}
@ -138,20 +152,12 @@ func (stream *Stream) Close() error {
stream.session.sb.send(tlsRecord)
stream.session.delStream(stream.id)
log.Printf("%v actively closed\n", stream.id)
return nil
}
// Same as Close() but no call to session.delStream.
// This is called in session.Close() to avoid mutex deadlock
func (stream *Stream) closeNoDelMap() error {
// Lock here because closing a closed channel causes panic
stream.closingM.Lock()
defer stream.closingM.Unlock()
if stream.closing {
return errRepeatStreamClosing
}
stream.closing = true
close(stream.die)
return nil
return stream.shutdown()
}

@ -6,20 +6,34 @@ import (
"net"
"sync"
"sync/atomic"
"github.com/juju/ratelimit"
)
// switchboard is responsible for keeping the reference of TLS connections between client and server
type switchboard struct {
session *Session
wtb *ratelimit.Bucket
rtb *ratelimit.Bucket
*Valve
optimum atomic.Value
// optimum is the connEnclave with the smallest sendQueue
optimum atomic.Value // *connEnclave
cesM sync.RWMutex
ces []*connEnclave
//debug
hM sync.Mutex
used map[uint32]bool
}
func (sb *switchboard) getOptimum() *connEnclave {
if i := sb.optimum.Load(); i == nil {
return nil
} else {
return i.(*connEnclave)
}
}
func (sb *switchboard) setOptimum(ce *connEnclave) {
sb.optimum.Store(ce)
}
// Some data comes from a Stream to be sent through one of the many
@ -27,45 +41,51 @@ type switchboard struct {
//
// In this case, we pick the remoteConn that has about the smallest sendQueue.
type connEnclave struct {
sb *switchboard
remoteConn net.Conn
sendQueue uint32
}
// It takes at least 1 conn to start a switchboard
// TODO: does it really?
func makeSwitchboard(sesh *Session, uprate, downrate float64) *switchboard {
func makeSwitchboard(sesh *Session, valve *Valve) *switchboard {
// rates are uint64 because in the usermanager we want the bandwidth to be atomically
// operated (so that the bandwidth can change on the fly).
sb := &switchboard{
session: sesh,
wtb: ratelimit.NewBucketWithRate(uprate, int64(uprate)),
rtb: ratelimit.NewBucketWithRate(downrate, int64(downrate)),
Valve: valve,
ces: []*connEnclave{},
used: make(map[uint32]bool),
}
return sb
}
var errNilOptimum error = errors.New("The optimal connection is nil")
var ErrNoRxCredit error = errors.New("No Rx credit is left")
var ErrNoTxCredit error = errors.New("No Tx credit is left")
func (sb *switchboard) send(data []byte) (int, error) {
ce := sb.optimum.Load().(*connEnclave)
ce := sb.getOptimum()
if ce == nil {
return 0, errNilOptimum
}
sb.wtb.Wait(int64(len(data)))
atomic.AddUint32(&ce.sendQueue, uint32(len(data)))
go sb.updateOptimum()
n, err := ce.remoteConn.Write(data)
if err != nil {
return 0, err
return n, err
// TODO
}
if sb.AddTxCredit(-int64(n)) < 0 {
log.Println(ErrNoTxCredit)
defer sb.session.Close()
return n, ErrNoTxCredit
}
atomic.AddUint32(&ce.sendQueue, ^uint32(n-1))
go sb.updateOptimum()
return n, nil
}
func (sb *switchboard) updateOptimum() {
currentOpti := sb.optimum.Load().(*connEnclave)
currentOpti := sb.getOptimum()
currentOptiQ := atomic.LoadUint32(&currentOpti.sendQueue)
sb.cesM.RLock()
for _, ce := range sb.ces {
@ -76,20 +96,18 @@ func (sb *switchboard) updateOptimum() {
}
}
sb.cesM.RUnlock()
sb.optimum.Store(currentOpti)
sb.setOptimum(currentOpti)
}
func (sb *switchboard) addConn(conn net.Conn) {
newCe := &connEnclave{
sb: sb,
remoteConn: conn,
sendQueue: 0,
}
sb.cesM.Lock()
sb.ces = append(sb.ces, newCe)
sb.cesM.Unlock()
sb.optimum.Store(newCe)
sb.setOptimum(newCe)
go sb.deplex(newCe)
}
@ -101,10 +119,10 @@ func (sb *switchboard) removeConn(closing *connEnclave) {
break
}
}
sb.cesM.Unlock()
if len(sb.ces) == 0 {
sb.session.Close()
}
sb.cesM.Unlock()
}
func (sb *switchboard) shutdown() {
@ -118,19 +136,40 @@ func (sb *switchboard) shutdown() {
func (sb *switchboard) deplex(ce *connEnclave) {
buf := make([]byte, 20480)
for {
i, err := sb.session.obfsedReader(ce.remoteConn, buf)
sb.rtb.Wait(int64(i))
n, err := sb.session.obfsedRead(ce.remoteConn, buf)
sb.rxWait(n)
if err != nil {
log.Println(err)
go ce.remoteConn.Close()
sb.removeConn(ce)
return
}
frame := sb.session.deobfs(buf[:i])
if sb.AddRxCredit(-int64(n)) < 0 {
log.Println(ErrNoRxCredit)
sb.session.Close()
return
}
frame := sb.session.deobfs(buf[:n])
//debug
var stream *Stream
if stream = sb.session.getStream(frame.StreamID); stream == nil {
if frame.Closing == 1 {
// if the frame is telling us to close a closed stream
// (this happens when ss-server and ss-local closes the stream
// simutaneously), we don't do anything
continue
}
//debug
sb.hM.Lock()
if sb.used[frame.StreamID] {
log.Printf("%v lost!\n", frame.StreamID)
}
sb.used[frame.StreamID] = true
sb.hM.Unlock()
stream = sb.session.addStream(frame.StreamID)
}
stream.newFrameCh <- frame
stream.writeNewFrame(frame)
}
}

@ -11,54 +11,54 @@ import (
ecdh "github.com/cbeuw/go-ecdh"
)
// input ticket, return SID
func decryptSessionTicket(staticPv crypto.PrivateKey, ticket []byte) ([]byte, error) {
// input ticket, return UID
func decryptSessionTicket(staticPv crypto.PrivateKey, ticket []byte) ([]byte, uint32, error) {
ec := ecdh.NewCurve25519ECDH()
ephPub, _ := ec.Unmarshal(ticket[0:32])
key, err := ec.GenerateSharedSecret(staticPv, ephPub)
if err != nil {
return nil, err
return nil, 0, err
}
SID := util.AESDecrypt(ticket[0:16], key, ticket[32:64])
return SID, nil
UIDsID := util.AESDecrypt(ticket[0:16], key, ticket[32:68])
sessionID := binary.BigEndian.Uint32(UIDsID[32:36])
return UIDsID[0:32], sessionID, nil
}
func validateRandom(random []byte, SID []byte, time int64) bool {
func validateRandom(random []byte, UID []byte, time int64) bool {
t := make([]byte, 8)
binary.BigEndian.PutUint64(t, uint64(time/(12*60*60)))
rdm := random[0:16]
preHash := make([]byte, 56)
copy(preHash[0:32], SID)
copy(preHash[0:32], UID)
copy(preHash[32:40], t)
copy(preHash[40:56], rdm)
h := sha256.New()
h.Write(preHash)
return bytes.Equal(h.Sum(nil)[0:16], random[16:32])
}
func TouchStone(ch *ClientHello, sta *State) (bool, []byte) {
func TouchStone(ch *ClientHello, sta *State) (isSS bool, UID []byte, sessionID uint32) {
var random [32]byte
copy(random[:], ch.random)
used := sta.getUsedRandom(random)
if used != 0 {
log.Println("Replay! Duplicate random")
return false, nil
return false, nil, 0
}
sta.putUsedRandom(random)
ticket := ch.extensions[[2]byte{0x00, 0x23}]
if len(ticket) < 64 {
return false, nil
return false, nil, 0
}
SID, err := decryptSessionTicket(sta.staticPv, ticket)
UID, sessionID, err := decryptSessionTicket(sta.staticPv, ticket)
if err != nil {
log.Printf("ts: %v\n", err)
return false, nil
return false, nil, 0
}
log.Printf("SID: %x\n", SID)
isSS := validateRandom(ch.random, SID, sta.Now().Unix())
isSS = validateRandom(ch.random, UID, sta.Now().Unix())
if !isSS {
return false, nil
return false, nil, 0
}
return true, SID
return
}

@ -9,7 +9,7 @@ import (
"sync"
"time"
mux "github.com/cbeuw/Cloak/internal/multiplex"
"github.com/cbeuw/Cloak/internal/server/usermanager"
)
type rawConfig struct {
@ -31,25 +31,28 @@ type State struct {
Now func() time.Time
staticPv crypto.PrivateKey
Userpanel *usermanager.Userpanel
usedRandomM sync.RWMutex
usedRandom map[[32]byte]int
sessionsM sync.RWMutex
sessions map[[32]byte]*mux.Session
WebServerAddr string
}
func InitState(localHost, localPort, remoteHost, remotePort string, nowFunc func() time.Time) *State {
func InitState(localHost, localPort, remoteHost, remotePort string, nowFunc func() time.Time, dbPath string) (*State, error) {
up, err := usermanager.MakeUserpanel(dbPath)
if err != nil {
return nil, err
}
ret := &State{
SS_LOCAL_HOST: localHost,
SS_LOCAL_PORT: localPort,
SS_REMOTE_HOST: remoteHost,
SS_REMOTE_PORT: remotePort,
Now: nowFunc,
Userpanel: up,
}
ret.usedRandom = make(map[[32]byte]int)
ret.sessions = make(map[[32]byte]*mux.Session)
return ret
return ret, nil
}
// semi-colon separated value.
@ -115,28 +118,6 @@ func (sta *State) ParseConfig(conf string) (err error) {
return nil
}
func (sta *State) GetSession(SID [32]byte) *mux.Session {
sta.sessionsM.RLock()
defer sta.sessionsM.RUnlock()
if sesh, ok := sta.sessions[SID]; ok {
return sesh
} else {
return nil
}
}
func (sta *State) PutSession(SID [32]byte, sesh *mux.Session) {
sta.sessionsM.Lock()
sta.sessions[SID] = sesh
sta.sessionsM.Unlock()
}
func (sta *State) DelSession(SID [32]byte) {
sta.sessionsM.Lock()
delete(sta.sessions, SID)
sta.sessionsM.Unlock()
}
func (sta *State) getUsedRandom(random [32]byte) int {
sta.usedRandomM.Lock()
defer sta.usedRandomM.Unlock()

@ -0,0 +1,86 @@
package usermanager
import (
mux "github.com/cbeuw/Cloak/internal/multiplex"
"log"
"net"
"sync"
"sync/atomic"
)
/*
type userParams struct {
sessionsCap uint32
upRate int64
downRate int64
upCredit int64
downCredit int64
}
*/
type user struct {
up *Userpanel
uid [32]byte
sessionsCap uint32 //userParams
valve *mux.Valve
sessionsM sync.RWMutex
sessions map[uint32]*mux.Session
}
func MakeUser(up *Userpanel, uid [32]byte, sessionsCap uint32, upRate, downRate, upCredit, downCredit int64) *user {
valve := mux.MakeValve(upRate, downRate, upCredit, downCredit)
u := &user{
up: up,
uid: uid,
valve: valve,
sessionsCap: sessionsCap,
sessions: make(map[uint32]*mux.Session),
}
return u
}
func (u *user) setSessionsCap(cap uint32) {
atomic.StoreUint32(&u.sessionsCap, cap)
}
func (u *user) GetSession(sessionID uint32) *mux.Session {
u.sessionsM.RLock()
defer u.sessionsM.RUnlock()
if sesh, ok := u.sessions[sessionID]; ok {
return sesh
} else {
return nil
}
}
func (u *user) PutSession(sessionID uint32, sesh *mux.Session) {
u.sessionsM.Lock()
u.sessions[sessionID] = sesh
u.sessionsM.Unlock()
}
func (u *user) DelSession(sessionID uint32) {
u.sessionsM.Lock()
delete(u.sessions, sessionID)
if len(u.sessions) == 0 {
u.sessionsM.Unlock()
u.up.delActiveUser(u.uid)
return
}
u.sessionsM.Unlock()
}
func (u *user) GetOrCreateSession(sessionID uint32, obfs func(*mux.Frame) []byte, deobfs func([]byte) *mux.Frame, obfsedRead func(net.Conn, []byte) (int, error)) (sesh *mux.Session) {
log.Printf("getting sessionID %v\n", sessionID)
if sesh = u.GetSession(sessionID); sesh != nil {
return
} else {
sesh = mux.MakeSession(sessionID, u.valve, obfs, deobfs, obfsedRead)
u.PutSession(sessionID, sesh)
return
}
}

@ -0,0 +1,151 @@
package usermanager
import (
"encoding/binary"
"errors"
"github.com/boltdb/bolt"
"sync"
)
type Userpanel struct {
db *bolt.DB
activeUsersM sync.RWMutex
activeUsers map[[32]byte]*user
}
func MakeUserpanel(dbPath string) (*Userpanel, error) {
db, err := bolt.Open(dbPath, 0600, nil)
if err != nil {
return nil, err
}
up := &Userpanel{
db: db,
activeUsers: make(map[[32]byte]*user),
}
return up, nil
}
var ErrUserNotFound = errors.New("User does not exist in memory or db")
// GetUser is used to retrieve a user if s/he is active, or to retrieve the user's infor
// from the db and mark it as an active user
func (up *Userpanel) GetAndActivateUser(UID [32]byte) (*user, error) {
up.activeUsersM.RLock()
if user, ok := up.activeUsers[UID]; ok {
up.activeUsersM.RUnlock()
return user, nil
}
up.activeUsersM.RUnlock()
var sessionsCap uint32
var upRate, downRate, upCredit, downCredit int64
err := up.db.View(func(tx *bolt.Tx) error {
b := tx.Bucket(UID[:])
if b == nil {
return ErrUserNotFound
}
sessionsCap = binary.BigEndian.Uint32(b.Get([]byte("sessionsCap")))
upRate = int64(binary.BigEndian.Uint64(b.Get([]byte("upRate"))))
downRate = int64(binary.BigEndian.Uint64(b.Get([]byte("downRate"))))
upCredit = int64(binary.BigEndian.Uint64(b.Get([]byte("upCredit")))) // reee brackets
downCredit = int64(binary.BigEndian.Uint64(b.Get([]byte("downCredit"))))
return nil
})
if err != nil {
return nil, err
}
// TODO: put all of these parameters in a struct instead
u := MakeUser(up, UID, sessionsCap, upRate, downRate, upCredit, downCredit)
up.activeUsersM.Lock()
up.activeUsers[UID] = u
up.activeUsersM.Unlock()
return u, nil
}
func (up *Userpanel) AddNewUser(UID [32]byte, sessionsCap uint32, upRate, downRate, upCredit, downCredit int64) error {
err := up.db.Update(func(tx *bolt.Tx) error {
b, err := tx.CreateBucket(UID[:])
if err != nil {
return err
}
// FIXME: obnoxious code
quad := make([]byte, 4)
binary.BigEndian.PutUint32(quad, sessionsCap)
if err = b.Put([]byte("sessionsCap"), quad); err != nil {
return err
}
oct := make([]byte, 8)
binary.BigEndian.PutUint64(oct, uint64(upRate))
if err = b.Put([]byte("upRate"), oct); err != nil {
return err
}
binary.BigEndian.PutUint64(oct, uint64(downRate))
if err = b.Put([]byte("downRate"), oct); err != nil {
return err
}
binary.BigEndian.PutUint64(oct, uint64(upCredit))
if err = b.Put([]byte("upCredit"), oct); err != nil {
return err
}
binary.BigEndian.PutUint64(oct, uint64(downCredit))
if err = b.Put([]byte("downCredit"), oct); err != nil {
return err
}
return nil
})
return err
}
func (up *Userpanel) updateDBEntryUint32(UID [32]byte, key string, value uint32) error {
err := up.db.Update(func(tx *bolt.Tx) error {
b := tx.Bucket(UID[:])
if b == nil {
return ErrUserNotFound
}
quad := make([]byte, 4)
binary.BigEndian.PutUint32(quad, value)
if err := b.Put([]byte(key), quad); err != nil {
return err
}
return nil
})
return err
}
func (up *Userpanel) updateDBEntryInt64(UID [32]byte, key string, value int64) error {
err := up.db.Update(func(tx *bolt.Tx) error {
b := tx.Bucket(UID[:])
if b == nil {
return ErrUserNotFound
}
oct := make([]byte, 8)
binary.BigEndian.PutUint64(oct, uint64(value))
if err := b.Put([]byte(key), oct); err != nil {
return err
}
return nil
})
return err
}
// This is used when all sessions of a user close
func (up *Userpanel) delActiveUser(UID [32]byte) {
up.activeUsersM.Lock()
delete(up.activeUsers, UID)
up.activeUsersM.Unlock()
}
func (up *Userpanel) getActiveUser(UID [32]byte) *user {
up.activeUsersM.RLock()
defer up.activeUsersM.RUnlock()
return up.activeUsers[UID]
}
func (up *Userpanel) SetSessionsCap(UID [32]byte, newSessionsCap uint32) error {
if u := up.getActiveUser(UID); u != nil {
u.setSessionsCap(newSessionsCap)
}
err := up.updateDBEntryUint32(UID, "sessionsCap", newSessionsCap)
return err
}

@ -9,12 +9,13 @@ import (
// For each frame, the three parts of the header is xored with three keys.
// The keys are generated from the SID and the payload of the frame.
func genXorKeys(SID []byte, data []byte) (i uint32, ii uint32, iii uint32) {
// FIXME: this code will panic if len(data)<18.
func genXorKeys(secret []byte, data []byte) (i uint32, ii uint32, iii uint32) {
h := xxhash.New32()
ret := make([]uint32, 3)
preHash := make([]byte, 16)
for j := 0; j < 3; j++ {
copy(preHash[0:10], SID[j*10:j*10+10])
copy(preHash[0:10], secret[j*10:j*10+10])
copy(preHash[10:16], data[j*6:j*6+6])
h.Write(preHash)
ret[j] = h.Sum32()

@ -43,14 +43,14 @@ func BtoInt(b []byte) int {
// PsudoRandBytes returns a byte slice filled with psudorandom bytes generated by the seed
func PsudoRandBytes(length int, seed int64) []byte {
prand.Seed(seed)
r := prand.New(prand.NewSource(seed))
ret := make([]byte, length)
prand.Read(ret)
r.Read(ret)
return ret
}
// ReadTillDrain reads TLS data according to its record layer
func ReadTillDrain(conn net.Conn, buffer []byte) (n int, err error) {
// ReadTLS reads TLS data according to its record layer
func ReadTLS(conn net.Conn, buffer []byte) (n int, err error) {
// TCP is a stream. Multiple TLS messages can arrive at the same time,
// a single message can also be segmented due to MTU of the IP layer.
// This function guareentees a single TLS message to be read and everything

Loading…
Cancel
Save