Refactor Transport and add tests

pull/97/head
Andy Wang 4 years ago
parent e7e4cd5726
commit 74a70a3113

@ -1,219 +1,56 @@
package server
import (
"bytes"
"crypto"
"crypto/rand"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"github.com/cbeuw/Cloak/internal/ecdh"
"github.com/cbeuw/Cloak/internal/util"
)
// ClientHello contains every field in a ClientHello message
type ClientHello struct {
handshakeType byte
length int
clientVersion []byte
random []byte
sessionIdLen int
sessionId []byte
cipherSuitesLen int
cipherSuites []byte
compressionMethodsLen int
compressionMethods []byte
extensionsLen int
extensions map[[2]byte][]byte
}
"net"
var u16 = binary.BigEndian.Uint16
var u32 = binary.BigEndian.Uint32
func parseExtensions(input []byte) (ret map[[2]byte][]byte, err error) {
defer func() {
if r := recover(); r != nil {
err = errors.New("Malformed Extensions")
}
}()
pointer := 0
totalLen := len(input)
ret = make(map[[2]byte][]byte)
for pointer < totalLen {
var typ [2]byte
copy(typ[:], input[pointer:pointer+2])
pointer += 2
length := int(u16(input[pointer : pointer+2]))
pointer += 2
data := input[pointer : pointer+length]
pointer += length
ret[typ] = data
}
return ret, err
}
func parseKeyShare(input []byte) (ret []byte, err error) {
defer func() {
if r := recover(); r != nil {
err = errors.New("malformed key_share")
}
}()
totalLen := int(u16(input[0:2]))
// 2 bytes "client key share length"
pointer := 2
for pointer < totalLen {
if bytes.Equal([]byte{0x00, 0x1d}, input[pointer:pointer+2]) {
// skip "key exchange length"
pointer += 2
length := int(u16(input[pointer : pointer+2]))
pointer += 2
if length != 32 {
return nil, fmt.Errorf("key share length should be 32, instead of %v", length)
}
return input[pointer : pointer+length], nil
}
pointer += 2
length := int(u16(input[pointer : pointer+2]))
pointer += 2
_ = input[pointer : pointer+length]
pointer += length
}
return nil, errors.New("x25519 does not exist")
}
log "github.com/sirupsen/logrus"
)
// addRecordLayer adds record layer to data
func addRecordLayer(input []byte, typ []byte, ver []byte) []byte {
length := make([]byte, 2)
binary.BigEndian.PutUint16(length, uint16(len(input)))
ret := make([]byte, 5+len(input))
copy(ret[0:1], typ)
copy(ret[1:3], ver)
copy(ret[3:5], length)
copy(ret[5:], input)
return ret
}
type TLS struct{}
// parseClientHello parses everything on top of the TLS layer
// (including the record layer) into ClientHello type
func parseClientHello(data []byte) (ret *ClientHello, err error) {
defer func() {
if r := recover(); r != nil {
err = errors.New("Malformed ClientHello")
}
}()
var ErrBadClientHello = errors.New("non (or malformed) ClientHello")
if !bytes.Equal(data[0:3], []byte{0x16, 0x03, 0x01}) {
return ret, errors.New("wrong TLS1.3 handshake magic bytes")
}
func (TLS) String() string { return "TLS" }
func (TLS) HasRecordLayer() bool { return true }
func (TLS) UnitReadFunc() func(net.Conn, []byte) (int, error) { return util.ReadTLS }
peeled := make([]byte, len(data)-5)
copy(peeled, data[5:])
pointer := 0
// Handshake Type
handshakeType := peeled[pointer]
if handshakeType != 0x01 {
return ret, errors.New("Not a ClientHello")
}
pointer += 1
// Length
length := int(u32(append([]byte{0x00}, peeled[pointer:pointer+3]...)))
pointer += 3
if length != len(peeled[pointer:]) {
return ret, errors.New("Hello length doesn't match")
}
// Client Version
clientVersion := peeled[pointer : pointer+2]
pointer += 2
// Random
random := peeled[pointer : pointer+32]
pointer += 32
// Session ID
sessionIdLen := int(peeled[pointer])
pointer += 1
sessionId := peeled[pointer : pointer+sessionIdLen]
pointer += sessionIdLen
// Cipher Suites
cipherSuitesLen := int(u16(peeled[pointer : pointer+2]))
pointer += 2
cipherSuites := peeled[pointer : pointer+cipherSuitesLen]
pointer += cipherSuitesLen
// Compression Methods
compressionMethodsLen := int(peeled[pointer])
pointer += 1
compressionMethods := peeled[pointer : pointer+compressionMethodsLen]
pointer += compressionMethodsLen
// Extensions
extensionsLen := int(u16(peeled[pointer : pointer+2]))
pointer += 2
extensions, err := parseExtensions(peeled[pointer:])
ret = &ClientHello{
handshakeType,
length,
clientVersion,
random,
sessionIdLen,
sessionId,
cipherSuitesLen,
cipherSuites,
compressionMethodsLen,
compressionMethods,
extensionsLen,
extensions,
func (TLS) handshake(clientHello []byte, privateKey crypto.PrivateKey, originalConn net.Conn) (ai authenticationInfo, finisher func([]byte) (net.Conn, error), err error) {
var ch *ClientHello
ch, err = parseClientHello(clientHello)
if err != nil {
log.Debug(err)
err = ErrBadClientHello
return
}
return
}
func composeServerHello(sessionId []byte, sharedSecret []byte, sessionKey []byte) ([]byte, error) {
nonce := make([]byte, 12)
rand.Read(nonce)
encryptedKey, err := util.AESGCMEncrypt(nonce, sharedSecret, sessionKey) // 32 + 16 = 48 bytes
ai, err = unmarshalClientHello(ch, privateKey)
if err != nil {
return nil, err
err = fmt.Errorf("failed to unmarshal ClientHello into authenticationInfo: %v", err)
return
}
var serverHello [11][]byte
serverHello[0] = []byte{0x02} // handshake type
serverHello[1] = []byte{0x00, 0x00, 0x76} // length 77
serverHello[2] = []byte{0x03, 0x03} // server version
serverHello[3] = append(nonce[0:12], encryptedKey[0:20]...) // random 32 bytes
serverHello[4] = []byte{0x20} // session id length 32
serverHello[5] = sessionId // session id
serverHello[6] = []byte{0xc0, 0x30} // cipher suite TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384
serverHello[7] = []byte{0x00} // compression method null
serverHello[8] = []byte{0x00, 0x2e} // extensions length 46
keyShare, _ := hex.DecodeString("00330024001d0020")
keyExchange := make([]byte, 32)
copy(keyExchange, encryptedKey[20:48])
rand.Read(keyExchange[28:32])
serverHello[9] = append(keyShare, keyExchange...)
serverHello[10], _ = hex.DecodeString("002b00020304")
var ret []byte
for _, s := range serverHello {
ret = append(ret, s...)
finisher = func(sessionKey []byte) (preparedConn net.Conn, err error) {
preparedConn = originalConn
reply, err := composeReply(ch, ai.sharedSecret, sessionKey)
if err != nil {
err = fmt.Errorf("failed to compose TLS reply: %v", err)
return
}
_, err = preparedConn.Write(reply)
if err != nil {
err = fmt.Errorf("failed to write TLS reply: %v", err)
go preparedConn.Close()
return
}
return
}
return ret, nil
}
// composeReply composes the ServerHello, ChangeCipherSpec and an ApplicationData messages
// together with their respective record layers into one byte slice.
func composeReply(ch *ClientHello, sharedSecret []byte, sessionKey []byte) ([]byte, error) {
TLS12 := []byte{0x03, 0x03}
sh, err := composeServerHello(ch.sessionId, sharedSecret, sessionKey)
if err != nil {
return nil, err
}
shBytes := addRecordLayer(sh, []byte{0x16}, TLS12)
ccsBytes := addRecordLayer([]byte{0x01}, []byte{0x14}, TLS12)
cert := make([]byte, 68) // TODO: add some different lengths maybe?
rand.Read(cert)
encryptedCertBytes := addRecordLayer(cert, []byte{0x17}, TLS12)
ret := append(shBytes, ccsBytes...)
ret = append(ret, encryptedCertBytes...)
return ret, nil
return
}
func unmarshalClientHello(ch *ClientHello, staticPv crypto.PrivateKey) (ai authenticationInfo, err error) {

@ -0,0 +1,215 @@
package server
import (
"bytes"
"crypto/rand"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"github.com/cbeuw/Cloak/internal/util"
)
// ClientHello contains every field in a ClientHello message
type ClientHello struct {
handshakeType byte
length int
clientVersion []byte
random []byte
sessionIdLen int
sessionId []byte
cipherSuitesLen int
cipherSuites []byte
compressionMethodsLen int
compressionMethods []byte
extensionsLen int
extensions map[[2]byte][]byte
}
var u16 = binary.BigEndian.Uint16
var u32 = binary.BigEndian.Uint32
func parseExtensions(input []byte) (ret map[[2]byte][]byte, err error) {
defer func() {
if r := recover(); r != nil {
err = errors.New("Malformed Extensions")
}
}()
pointer := 0
totalLen := len(input)
ret = make(map[[2]byte][]byte)
for pointer < totalLen {
var typ [2]byte
copy(typ[:], input[pointer:pointer+2])
pointer += 2
length := int(u16(input[pointer : pointer+2]))
pointer += 2
data := input[pointer : pointer+length]
pointer += length
ret[typ] = data
}
return ret, err
}
func parseKeyShare(input []byte) (ret []byte, err error) {
defer func() {
if r := recover(); r != nil {
err = errors.New("malformed key_share")
}
}()
totalLen := int(u16(input[0:2]))
// 2 bytes "client key share length"
pointer := 2
for pointer < totalLen {
if bytes.Equal([]byte{0x00, 0x1d}, input[pointer:pointer+2]) {
// skip "key exchange length"
pointer += 2
length := int(u16(input[pointer : pointer+2]))
pointer += 2
if length != 32 {
return nil, fmt.Errorf("key share length should be 32, instead of %v", length)
}
return input[pointer : pointer+length], nil
}
pointer += 2
length := int(u16(input[pointer : pointer+2]))
pointer += 2
_ = input[pointer : pointer+length]
pointer += length
}
return nil, errors.New("x25519 does not exist")
}
// addRecordLayer adds record layer to data
func addRecordLayer(input []byte, typ []byte, ver []byte) []byte {
length := make([]byte, 2)
binary.BigEndian.PutUint16(length, uint16(len(input)))
ret := make([]byte, 5+len(input))
copy(ret[0:1], typ)
copy(ret[1:3], ver)
copy(ret[3:5], length)
copy(ret[5:], input)
return ret
}
// parseClientHello parses everything on top of the TLS layer
// (including the record layer) into ClientHello type
func parseClientHello(data []byte) (ret *ClientHello, err error) {
defer func() {
if r := recover(); r != nil {
err = errors.New("Malformed ClientHello")
}
}()
if !bytes.Equal(data[0:3], []byte{0x16, 0x03, 0x01}) {
return ret, errors.New("wrong TLS1.3 handshake magic bytes")
}
peeled := make([]byte, len(data)-5)
copy(peeled, data[5:])
pointer := 0
// Handshake Type
handshakeType := peeled[pointer]
if handshakeType != 0x01 {
return ret, errors.New("Not a ClientHello")
}
pointer += 1
// Length
length := int(u32(append([]byte{0x00}, peeled[pointer:pointer+3]...)))
pointer += 3
if length != len(peeled[pointer:]) {
return ret, errors.New("Hello length doesn't match")
}
// Client Version
clientVersion := peeled[pointer : pointer+2]
pointer += 2
// Random
random := peeled[pointer : pointer+32]
pointer += 32
// Session ID
sessionIdLen := int(peeled[pointer])
pointer += 1
sessionId := peeled[pointer : pointer+sessionIdLen]
pointer += sessionIdLen
// Cipher Suites
cipherSuitesLen := int(u16(peeled[pointer : pointer+2]))
pointer += 2
cipherSuites := peeled[pointer : pointer+cipherSuitesLen]
pointer += cipherSuitesLen
// Compression Methods
compressionMethodsLen := int(peeled[pointer])
pointer += 1
compressionMethods := peeled[pointer : pointer+compressionMethodsLen]
pointer += compressionMethodsLen
// Extensions
extensionsLen := int(u16(peeled[pointer : pointer+2]))
pointer += 2
extensions, err := parseExtensions(peeled[pointer:])
ret = &ClientHello{
handshakeType,
length,
clientVersion,
random,
sessionIdLen,
sessionId,
cipherSuitesLen,
cipherSuites,
compressionMethodsLen,
compressionMethods,
extensionsLen,
extensions,
}
return
}
func composeServerHello(sessionId []byte, sharedSecret []byte, sessionKey []byte) ([]byte, error) {
nonce := make([]byte, 12)
rand.Read(nonce)
encryptedKey, err := util.AESGCMEncrypt(nonce, sharedSecret, sessionKey) // 32 + 16 = 48 bytes
if err != nil {
return nil, err
}
var serverHello [11][]byte
serverHello[0] = []byte{0x02} // handshake type
serverHello[1] = []byte{0x00, 0x00, 0x76} // length 77
serverHello[2] = []byte{0x03, 0x03} // server version
serverHello[3] = append(nonce[0:12], encryptedKey[0:20]...) // random 32 bytes
serverHello[4] = []byte{0x20} // session id length 32
serverHello[5] = sessionId // session id
serverHello[6] = []byte{0xc0, 0x30} // cipher suite TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384
serverHello[7] = []byte{0x00} // compression method null
serverHello[8] = []byte{0x00, 0x2e} // extensions length 46
keyShare, _ := hex.DecodeString("00330024001d0020")
keyExchange := make([]byte, 32)
copy(keyExchange, encryptedKey[20:48])
rand.Read(keyExchange[28:32])
serverHello[9] = append(keyShare, keyExchange...)
serverHello[10], _ = hex.DecodeString("002b00020304")
var ret []byte
for _, s := range serverHello {
ret = append(ret, s...)
}
return ret, nil
}
// composeReply composes the ServerHello, ChangeCipherSpec and an ApplicationData messages
// together with their respective record layers into one byte slice.
func composeReply(ch *ClientHello, sharedSecret []byte, sessionKey []byte) ([]byte, error) {
TLS12 := []byte{0x03, 0x03}
sh, err := composeServerHello(ch.sessionId, sharedSecret, sessionKey)
if err != nil {
return nil, err
}
shBytes := addRecordLayer(sh, []byte{0x16}, TLS12)
ccsBytes := addRecordLayer([]byte{0x01}, []byte{0x14}, TLS12)
cert := make([]byte, 68) // TODO: add some different lengths maybe?
rand.Read(cert)
encryptedCertBytes := addRecordLayer(cert, []byte{0x17}, TLS12)
ret := append(shBytes, ccsBytes...)
ret = append(ret, encryptedCertBytes...)
return ret, nil
}

@ -1,16 +1,12 @@
package server
import (
"bufio"
"bytes"
"crypto/rand"
"encoding/base64"
"encoding/binary"
"errors"
"fmt"
"github.com/cbeuw/Cloak/internal/util"
"net"
"net/http"
"time"
log "github.com/sirupsen/logrus"
@ -35,8 +31,6 @@ const (
UNORDERED_FLAG = 0x01 // 0000 0001
)
var ErrInvalidPubKey = errors.New("public key has invalid format")
var ErrCiphertextLength = errors.New("ciphertext has the wrong length")
var ErrTimestampOutOfWindow = errors.New("timestamp is outside of the accepting window")
var ErrUnreconisedProtocol = errors.New("unreconised protocol")
@ -67,7 +61,6 @@ func touchStone(ai authenticationInfo, now func() time.Time) (info ClientInfo, e
return
}
var ErrBadClientHello = errors.New("non (or malformed) ClientHello")
var ErrReplay = errors.New("duplicate random")
var ErrBadProxyMethod = errors.New("invalid proxy method")
@ -76,100 +69,34 @@ var ErrBadProxyMethod = errors.New("invalid proxy method")
// is authorised. It also returns a finisher callback function to be called when the caller wishes to proceed with
// the handshake
func PrepareConnection(firstPacket []byte, sta *State, conn net.Conn) (info ClientInfo, finisher func([]byte) (net.Conn, error), err error) {
var transport Transport
var ai authenticationInfo
switch firstPacket[0] {
case 0x47:
transport = WebSocket{}
var req *http.Request
req, err = http.ReadRequest(bufio.NewReader(bytes.NewBuffer(firstPacket)))
if err != nil {
err = fmt.Errorf("failed to parse first HTTP GET: %v", err)
return
}
var hiddenData []byte
hiddenData, err = base64.StdEncoding.DecodeString(req.Header.Get("hidden"))
ai, err = unmarshalHidden(hiddenData, sta.staticPv)
if err != nil {
err = fmt.Errorf("failed to unmarshal hidden data from WS into authenticationInfo: %v", err)
return
}
finisher = func(sessionKey []byte) (preparedConn net.Conn, err error) {
handler := newWsHandshakeHandler()
// For an explanation of the following 3 lines, see the comments in websocket.go
http.Serve(newWsAcceptor(conn, firstPacket), handler)
<-handler.finished
preparedConn = handler.conn
nonce := make([]byte, 12)
rand.Read(nonce)
// reply: [12 bytes nonce][32 bytes encrypted session key][16 bytes authentication tag]
encryptedKey, err := util.AESGCMEncrypt(nonce, ai.sharedSecret, sessionKey) // 32 + 16 = 48 bytes
if err != nil {
err = fmt.Errorf("failed to encrypt reply: %v", err)
return
}
reply := append(nonce, encryptedKey...)
_, err = preparedConn.Write(reply)
if err != nil {
err = fmt.Errorf("failed to write reply: %v", err)
go preparedConn.Close()
return
}
return
}
info.Transport = WebSocket{}
case 0x16:
transport = TLS{}
var ch *ClientHello
ch, err = parseClientHello(firstPacket)
if err != nil {
log.Debug(err)
err = ErrBadClientHello
return
}
if sta.registerRandom(ch.random) {
err = ErrReplay
return
}
ai, err = unmarshalClientHello(ch, sta.staticPv)
if err != nil {
err = fmt.Errorf("failed to unmarshal ClientHello into authenticationInfo: %v", err)
return
}
finisher = func(sessionKey []byte) (preparedConn net.Conn, err error) {
preparedConn = conn
reply, err := composeReply(ch, ai.sharedSecret, sessionKey)
if err != nil {
err = fmt.Errorf("failed to compose TLS reply: %v", err)
return
}
_, err = preparedConn.Write(reply)
if err != nil {
err = fmt.Errorf("failed to write TLS reply: %v", err)
go preparedConn.Close()
return
}
return
}
info.Transport = TLS{}
default:
err = ErrUnreconisedProtocol
return
}
var ai authenticationInfo
ai, finisher, err = info.Transport.handshake(firstPacket, sta.staticPv, conn)
if err != nil {
return
}
if sta.registerRandom(ai.nonce) {
err = ErrReplay
return
}
info, err = touchStone(ai, sta.Now)
if err != nil {
log.Debug(err)
err = fmt.Errorf("transport %v in correct format but not Cloak: %v", transport, err)
err = fmt.Errorf("transport %v in correct format but not Cloak: %v", info.Transport, err)
return
}
info.Transport = transport
if _, ok := sta.ProxyBook[info.ProxyMethod]; !ok {
err = ErrBadProxyMethod
return

@ -98,3 +98,38 @@ func TestTouchStone(t *testing.T) {
})
}
func TestPrepareConnection(t *testing.T) {
nineSixSix := func() time.Time { return time.Unix(1565998966, 0) }
sta, _ := InitState(nineSixSix)
pvBytes, _ := hex.DecodeString("10de5a3c4a4d04efafc3e06d1506363a72bd6d053baef123e6a9a79a0c04b547")
p, _ := ecdh.Unmarshal(pvBytes)
sta.staticPv = p.(crypto.PrivateKey)
sta.ProxyBook["shadowsocks"] = nil
t.Run("TLS correct", func(t *testing.T) {
chBytes, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
info, _, err := PrepareConnection(chBytes, sta, nil)
if err != nil {
t.Errorf("failed to get client info: %v", err)
return
}
if info.SessionId != 3710878841 {
t.Error("failed to get correct session id")
return
}
})
t.Run("TLS correct but replay", func(t *testing.T) {
chBytes, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
_, _, err := PrepareConnection(chBytes, sta, nil)
if err != nil {
t.Error("failed to prepare for the first time")
return
}
_, _, err = PrepareConnection(chBytes, sta, nil)
if err != ErrReplay {
t.Errorf("failed to return ErrReplay, got %v instead", err)
return
}
})
}

@ -1,23 +1,16 @@
package server
import (
"github.com/cbeuw/Cloak/internal/util"
"crypto"
"errors"
"net"
)
type Transport interface {
HasRecordLayer() bool
UnitReadFunc() func(net.Conn, []byte) (int, error)
handshake(reqPacket []byte, privateKey crypto.PrivateKey, originalConn net.Conn) (authenticationInfo, func([]byte) (net.Conn, error), error)
}
type TLS struct{}
func (TLS) String() string { return "TLS" }
func (TLS) HasRecordLayer() bool { return true }
func (TLS) UnitReadFunc() func(net.Conn, []byte) (int, error) { return util.ReadTLS }
type WebSocket struct{}
func (WebSocket) String() string { return "WebSocket" }
func (WebSocket) HasRecordLayer() bool { return false }
func (WebSocket) UnitReadFunc() func(net.Conn, []byte) (int, error) { return util.ReadWebSocket }
var ErrInvalidPubKey = errors.New("public key has invalid format")
var ErrCiphertextLength = errors.New("ciphertext has the wrong length")

@ -1,142 +1,69 @@
package server
import (
"bufio"
"bytes"
"crypto"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"github.com/cbeuw/Cloak/internal/ecdh"
"github.com/cbeuw/Cloak/internal/util"
"github.com/gorilla/websocket"
"net"
"net/http"
log "github.com/sirupsen/logrus"
)
// The code in this file is mostly to obtain a binary-oriented, net.Conn analogous
// util.WebSocketConn from the awkward APIs of gorilla/websocket and net/http
//
// The flow of our process is: accept a Conn from remote, read the first packet remote sent us. If it's in the format
// of a TLS handshake, we hand it over to the TLS part; if it's in the format of a HTTP request, we process it as a
// websocket and eventually wrap the remote Conn as util.WebSocketConn,
//
// To get a util.WebSocketConn, we need a gorilla/websocket.Conn. This is obtained by using upgrader.Upgrade method
// inside a HTTP request handler function (which is defined by us). The HTTP request handler function is invoked by
// net/http package upon receiving a request from a Conn.
//
// Ideally we want to give net/http the connection we got from remote, then it can read the first packet (which should
// be an HTTP request) from that Conn and call the handler function, which can then be upgraded to obtain a
// gorilla/websocket.Conn. But this won't work for two reasons: one is that we have ALREADY READ the request packet
// from the remote Conn to determine if it's TLS or HTTP. When net/http reads from the Conn, it will not receive that
// request packet. The second reason is that there is no API in net/http that accepts a Conn at all. Instead, the
// closest we can get is http.Serve which takes in a net.Listener and a http.Handler which implements the ServeHTTP
// function.
//
// Recall that net.Listener has a method Accept which blocks until the Listener receives a connection, then
// it returns a net.Conn. net/http calls Listener.Accept repeatedly and creates a new goroutine handling each Conn
// accepted.
//
// So here is what we need to do: we need to create a type WsAcceptor that implements net.Listener interface.
// the first time WsAcceptor.Accept is called, it will return something that implements net.Conn, subsequent calls to
// Accept will return error (so that the caller won't call again)
//
// The "something that implements net.Conn" needs to do the following: the first time Read is called, it returns the
// request packet we got from the remote Conn which we have already read, so that the packet, which is an HTTP request
// will be processed by the handling function. Subsequent calls to Read will read directly from the remote Conn. To do
// this we create a type firstBuffedConn that implements net.Conn. When we instantiate a firstBuffedConn object, we
// give it the request packet we have already read from the remote Conn, as well as the reference to the remote Conn.
//
// So now we call http.Serve(WsAcceptor, [some handler]), net/http will call WsAcceptor.Accept, which returns a
// firstBuffedConn. net/http will call WsAcceptor.Accept again but this time it returns error so net/http will stop.
// firstBuffedConn.Read will then be called, which returns the request packet from remote Conn. Then
// [some handler].ServeHTTP will be called, in which websocket.upgrader.Upgrade will be called to obtain a
// websocket.Conn
//
// One problem remains: websocket.upgrader.Upgrade is called inside the handling function. The websocket.Conn it
// returned needs to be somehow preserved so we can keep using it. To do this, we define a type WsHandshakeHandler
// which implements http.Handler. WsHandshakeHandler has a struct field of type net.Conn that can be set. Inside
// WsHandshakeHandler.ServeHTTP, the returned websocket.Conn from upgrader.Upgrade will be converted into a
// util.WebSocketConn, whose reference will be kept in the struct field. Whoever has the reference to the instance of
// WsHandshakeHandler can get the reference to the established util.WebSocketConn.
//
// There is another problem: the call of http.Serve(WsAcceptor, WsHandshakeHandler) is async. We don't know when
// the instance of WsHandshakeHandler will have the util.WebSocketConn ready. We synchronise this using a channel.
// A channel called finished will be provided to an instance of WsHandshakeHandler upon its creation. Once
// WsHandshakeHandler.ServeHTTP has the reference to util.WebSocketConn ready, it will write to finished.
// Outside, immediately after the call to http.Serve(WsAcceptor, WsHandshakeHandler), we read from finished so that the
// execution will block until the reference to util.WebSocketConn is ready.
// since we need to read the first packet from the client to identify its protocol, the first packet will no longer
// be in Conn's buffer. However, websocket.Upgrade relies on reading the first packet for handshake, so we must
// fake a conn that returns the first packet on first read
type firstBuffedConn struct {
net.Conn
firstRead bool
firstPacket []byte
}
func (c *firstBuffedConn) Read(buf []byte) (int, error) {
if !c.firstRead {
c.firstRead = true
copy(buf, c.firstPacket)
n := len(c.firstPacket)
c.firstPacket = []byte{}
return n, nil
}
return c.Conn.Read(buf)
}
type WebSocket struct{}
type wsAcceptor struct {
done bool
c *firstBuffedConn
}
func (WebSocket) String() string { return "WebSocket" }
func (WebSocket) HasRecordLayer() bool { return false }
func (WebSocket) UnitReadFunc() func(net.Conn, []byte) (int, error) { return util.ReadWebSocket }
// net/http provides no method to serve an existing connection, we must feed in a net.Accept interface to get an
// http.Server. This is an acceptor that accepts only one Conn
func newWsAcceptor(conn net.Conn, first []byte) *wsAcceptor {
f := make([]byte, len(first))
copy(f, first)
return &wsAcceptor{
c: &firstBuffedConn{Conn: conn, firstPacket: f},
func (WebSocket) handshake(reqPacket []byte, privateKey crypto.PrivateKey, originalConn net.Conn) (ai authenticationInfo, finisher func([]byte) (net.Conn, error), err error) {
var req *http.Request
req, err = http.ReadRequest(bufio.NewReader(bytes.NewBuffer(reqPacket)))
if err != nil {
err = fmt.Errorf("failed to parse first HTTP GET: %v", err)
return
}
}
var hiddenData []byte
hiddenData, err = base64.StdEncoding.DecodeString(req.Header.Get("hidden"))
func (w *wsAcceptor) Accept() (net.Conn, error) {
if w.done {
return nil, errors.New("already accepted")
ai, err = unmarshalHidden(hiddenData, privateKey)
if err != nil {
err = fmt.Errorf("failed to unmarshal hidden data from WS into authenticationInfo: %v", err)
return
}
w.done = true
return w.c, nil
}
func (w *wsAcceptor) Close() error {
w.done = true
return nil
}
func (w *wsAcceptor) Addr() net.Addr {
return w.c.LocalAddr()
}
type wsHandshakeHandler struct {
conn net.Conn
finished chan struct{}
}
// the handler to turn a net.Conn into a websocket.Conn
func newWsHandshakeHandler() *wsHandshakeHandler {
return &wsHandshakeHandler{finished: make(chan struct{})}
}
func (ws *wsHandshakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
upgrader := websocket.Upgrader{}
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Errorf("failed to upgrade connection to ws: %v", err)
finisher = func(sessionKey []byte) (preparedConn net.Conn, err error) {
handler := newWsHandshakeHandler()
// For an explanation of the following 3 lines, see the comments in websocketAux.go
http.Serve(newWsAcceptor(originalConn, reqPacket), handler)
<-handler.finished
preparedConn = handler.conn
nonce := make([]byte, 12)
rand.Read(nonce)
// reply: [12 bytes nonce][32 bytes encrypted session key][16 bytes authentication tag]
encryptedKey, err := util.AESGCMEncrypt(nonce, ai.sharedSecret, sessionKey) // 32 + 16 = 48 bytes
if err != nil {
err = fmt.Errorf("failed to encrypt reply: %v", err)
return
}
reply := append(nonce, encryptedKey...)
_, err = preparedConn.Write(reply)
if err != nil {
err = fmt.Errorf("failed to write reply: %v", err)
go preparedConn.Close()
return
}
return
}
ws.conn = &util.WebSocketConn{Conn: c}
ws.finished <- struct{}{}
return
}
var ErrBadGET = errors.New("non (or malformed) HTTP GET")

@ -0,0 +1,137 @@
package server
import (
"errors"
"github.com/cbeuw/Cloak/internal/util"
"github.com/gorilla/websocket"
"net"
"net/http"
log "github.com/sirupsen/logrus"
)
// The code in this file is mostly to obtain a binary-oriented, net.Conn analogous
// util.WebSocketConn from the awkward APIs of gorilla/websocket and net/http
//
// The flow of our process is: accept a Conn from remote, read the first packet remote sent us. If it's in the format
// of a TLS handshake, we hand it over to the TLS part; if it's in the format of a HTTP request, we process it as a
// websocket and eventually wrap the remote Conn as util.WebSocketConn,
//
// To get a util.WebSocketConn, we need a gorilla/websocket.Conn. This is obtained by using upgrader.Upgrade method
// inside a HTTP request handler function (which is defined by us). The HTTP request handler function is invoked by
// net/http package upon receiving a request from a Conn.
//
// Ideally we want to give net/http the connection we got from remote, then it can read the first packet (which should
// be an HTTP request) from that Conn and call the handler function, which can then be upgraded to obtain a
// gorilla/websocket.Conn. But this won't work for two reasons: one is that we have ALREADY READ the request packet
// from the remote Conn to determine if it's TLS or HTTP. When net/http reads from the Conn, it will not receive that
// request packet. The second reason is that there is no API in net/http that accepts a Conn at all. Instead, the
// closest we can get is http.Serve which takes in a net.Listener and a http.Handler which implements the ServeHTTP
// function.
//
// Recall that net.Listener has a method Accept which blocks until the Listener receives a connection, then
// it returns a net.Conn. net/http calls Listener.Accept repeatedly and creates a new goroutine handling each Conn
// accepted.
//
// So here is what we need to do: we need to create a type WsAcceptor that implements net.Listener interface.
// the first time WsAcceptor.Accept is called, it will return something that implements net.Conn, subsequent calls to
// Accept will return error (so that the caller won't call again)
//
// The "something that implements net.Conn" needs to do the following: the first time Read is called, it returns the
// request packet we got from the remote Conn which we have already read, so that the packet, which is an HTTP request
// will be processed by the handling function. Subsequent calls to Read will read directly from the remote Conn. To do
// this we create a type firstBuffedConn that implements net.Conn. When we instantiate a firstBuffedConn object, we
// give it the request packet we have already read from the remote Conn, as well as the reference to the remote Conn.
//
// So now we call http.Serve(WsAcceptor, [some handler]), net/http will call WsAcceptor.Accept, which returns a
// firstBuffedConn. net/http will call WsAcceptor.Accept again but this time it returns error so net/http will stop.
// firstBuffedConn.Read will then be called, which returns the request packet from remote Conn. Then
// [some handler].ServeHTTP will be called, in which websocket.upgrader.Upgrade will be called to obtain a
// websocket.Conn
//
// One problem remains: websocket.upgrader.Upgrade is called inside the handling function. The websocket.Conn it
// returned needs to be somehow preserved so we can keep using it. To do this, we define a type WsHandshakeHandler
// which implements http.Handler. WsHandshakeHandler has a struct field of type net.Conn that can be set. Inside
// WsHandshakeHandler.ServeHTTP, the returned websocket.Conn from upgrader.Upgrade will be converted into a
// util.WebSocketConn, whose reference will be kept in the struct field. Whoever has the reference to the instance of
// WsHandshakeHandler can get the reference to the established util.WebSocketConn.
//
// There is another problem: the call of http.Serve(WsAcceptor, WsHandshakeHandler) is async. We don't know when
// the instance of WsHandshakeHandler will have the util.WebSocketConn ready. We synchronise this using a channel.
// A channel called finished will be provided to an instance of WsHandshakeHandler upon its creation. Once
// WsHandshakeHandler.ServeHTTP has the reference to util.WebSocketConn ready, it will write to finished.
// Outside, immediately after the call to http.Serve(WsAcceptor, WsHandshakeHandler), we read from finished so that the
// execution will block until the reference to util.WebSocketConn is ready.
// since we need to read the first packet from the client to identify its protocol, the first packet will no longer
// be in Conn's buffer. However, websocket.Upgrade relies on reading the first packet for handshake, so we must
// fake a conn that returns the first packet on first read
type firstBuffedConn struct {
net.Conn
firstRead bool
firstPacket []byte
}
func (c *firstBuffedConn) Read(buf []byte) (int, error) {
if !c.firstRead {
c.firstRead = true
copy(buf, c.firstPacket)
n := len(c.firstPacket)
c.firstPacket = []byte{}
return n, nil
}
return c.Conn.Read(buf)
}
type wsAcceptor struct {
done bool
c *firstBuffedConn
}
// net/http provides no method to serve an existing connection, we must feed in a net.Accept interface to get an
// http.Server. This is an acceptor that accepts only one Conn
func newWsAcceptor(conn net.Conn, first []byte) *wsAcceptor {
f := make([]byte, len(first))
copy(f, first)
return &wsAcceptor{
c: &firstBuffedConn{Conn: conn, firstPacket: f},
}
}
func (w *wsAcceptor) Accept() (net.Conn, error) {
if w.done {
return nil, errors.New("already accepted")
}
w.done = true
return w.c, nil
}
func (w *wsAcceptor) Close() error {
w.done = true
return nil
}
func (w *wsAcceptor) Addr() net.Addr {
return w.c.LocalAddr()
}
type wsHandshakeHandler struct {
conn net.Conn
finished chan struct{}
}
// the handler to turn a net.Conn into a websocket.Conn
func newWsHandshakeHandler() *wsHandshakeHandler {
return &wsHandshakeHandler{finished: make(chan struct{})}
}
func (ws *wsHandshakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
upgrader := websocket.Upgrader{}
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Errorf("failed to upgrade connection to ws: %v", err)
return
}
ws.conn = &util.WebSocketConn{Conn: c}
ws.finished <- struct{}{}
}
Loading…
Cancel
Save